├── .github ├── update_dependencies.sh └── workflows │ ├── debug.yml │ └── lint.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── Makefile ├── README.md ├── congestion_meta1 ├── README.md ├── bandwidth.go ├── bandwidth_sampler.go ├── bbr_sender.go ├── clock.go ├── cubic.go ├── cubic_sender.go ├── hybrid_slow_start.go ├── minmax.go ├── pacer.go └── windowed_filter.go ├── congestion_meta2 ├── README.md ├── bandwidth.go ├── bandwidth_sampler.go ├── bbr_sender.go ├── clock.go ├── minmax_go120.go ├── minmax_go121.go ├── pacer.go ├── packet_number_indexed_queue.go ├── ringbuffer.go └── windowed_filter.go ├── go.mod ├── go.sum ├── hysteria ├── client.go ├── client_packet.go ├── congestion │ ├── brutal.go │ └── pacer.go ├── hop.go ├── packet.go ├── packet_wait.go ├── protocol.go ├── service.go ├── service_packet.go └── xplus.go ├── hysteria2 ├── client.go ├── client_packet.go ├── internal │ └── protocol │ │ ├── http.go │ │ ├── padding.go │ │ └── proxy.go ├── packet.go ├── packet_wait.go ├── salamander.go ├── service.go └── service_packet.go ├── quic.go ├── tuic ├── address.go ├── client.go ├── client_packet.go ├── congestion.go ├── packet.go ├── packet_wait.go ├── protocol.go ├── service.go └── service_packet.go └── workflows └── debug.yml /.github/update_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PROJECTS=$(dirname "$0")/../.. 4 | go get -x github.com/sagernet/$1@$(git -C $PROJECTS/$1 rev-parse HEAD) 5 | go mod tidy 6 | -------------------------------------------------------------------------------- /.github/workflows/debug.yml: -------------------------------------------------------------------------------- 1 | name: Debug build 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | paths-ignore: 9 | - '**.md' 10 | - '.github/**' 11 | - '!.github/workflows/debug.yml' 12 | pull_request: 13 | branches: 14 | - main 15 | - dev 16 | 17 | jobs: 18 | build: 19 | name: Linux Debug build 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v4 24 | with: 25 | fetch-depth: 0 26 | - name: Setup Go 27 | uses: actions/setup-go@v4 28 | with: 29 | go-version: ^1.22 30 | - name: Build 31 | run: | 32 | make test 33 | build_go120: 34 | name: Linux Debug build (Go 1.20) 35 | runs-on: ubuntu-latest 36 | steps: 37 | - name: Checkout 38 | uses: actions/checkout@v4 39 | with: 40 | fetch-depth: 0 41 | - name: Setup Go 42 | uses: actions/setup-go@v4 43 | with: 44 | go-version: ~1.20 45 | continue-on-error: true 46 | - name: Build 47 | run: | 48 | make test 49 | build_go121: 50 | name: Linux Debug build (Go 1.21) 51 | runs-on: ubuntu-latest 52 | steps: 53 | - name: Checkout 54 | uses: actions/checkout@v4 55 | with: 56 | fetch-depth: 0 57 | - name: Setup Go 58 | uses: actions/setup-go@v4 59 | with: 60 | go-version: ~1.21 61 | continue-on-error: true 62 | - name: Build 63 | run: | 64 | make test 65 | build__windows: 66 | name: Windows Debug build 67 | runs-on: windows-latest 68 | steps: 69 | - name: Checkout 70 | uses: actions/checkout@v4 71 | with: 72 | fetch-depth: 0 73 | - name: Setup Go 74 | uses: actions/setup-go@v4 75 | with: 76 | go-version: ^1.22 77 | continue-on-error: true 78 | - name: Build 79 | run: | 80 | make test 81 | build_darwin: 82 | name: macOS Debug build 83 | runs-on: macos-latest 84 | steps: 85 | - name: Checkout 86 | uses: actions/checkout@v4 87 | with: 88 | fetch-depth: 0 89 | - name: Setup Go 90 | uses: actions/setup-go@v4 91 | with: 92 | go-version: ^1.22 93 | continue-on-error: true 94 | - name: Build 95 | run: | 96 | make test -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | paths-ignore: 9 | - '**.md' 10 | - '.github/**' 11 | - '!.github/workflows/lint.yml' 12 | pull_request: 13 | branches: 14 | - main 15 | - dev 16 | 17 | jobs: 18 | build: 19 | name: Build 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v4 24 | with: 25 | fetch-depth: 0 26 | - name: Setup Go 27 | uses: actions/setup-go@v4 28 | with: 29 | go-version: ^1.22 30 | - name: Cache go module 31 | uses: actions/cache@v3 32 | with: 33 | path: | 34 | ~/go/pkg/mod 35 | key: go-${{ hashFiles('**/go.sum') }} 36 | - name: golangci-lint 37 | uses: golangci/golangci-lint-action@v3 38 | with: 39 | version: latest -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | /vendor/ 3 | .DS_Store 4 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | disable-all: true 3 | enable: 4 | - gofumpt 5 | - govet 6 | - gci 7 | - staticcheck 8 | - paralleltest 9 | - ineffassign 10 | 11 | linters-settings: 12 | gci: 13 | custom-order: true 14 | sections: 15 | - standard 16 | - prefix(github.com/sagernet/) 17 | - default 18 | staticcheck: 19 | checks: 20 | - all 21 | - -SA1003 22 | 23 | run: 24 | go: "1.23" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2022 by nekohasekai 2 | 3 | This program is free software: you can redistribute it and/or modify 4 | it under the terms of the GNU General Public License as published by 5 | the Free Software Foundation, either version 3 of the License, or 6 | (at your option) any later version. 7 | 8 | This program is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | 13 | You should have received a copy of the GNU General Public License 14 | along with this program. If not, see . -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | fmt: 2 | @gofumpt -l -w . 3 | @gofmt -s -w . 4 | @gci write --custom-order -s standard -s "prefix(github.com/sagernet/)" -s "default" . 5 | 6 | fmt_install: 7 | go install -v mvdan.cc/gofumpt@latest 8 | go install -v github.com/daixiang0/gci@latest 9 | 10 | lint: 11 | GOOS=linux golangci-lint run ./... 12 | GOOS=android golangci-lint run ./... 13 | GOOS=windows golangci-lint run ./... 14 | GOOS=darwin golangci-lint run ./... 15 | GOOS=freebsd golangci-lint run ./... 16 | 17 | lint_install: 18 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest 19 | 20 | test: 21 | go test -v ./... -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sing-quic 2 | 3 | quic-go API wrapper and QUIC based protocol implementations. 4 | -------------------------------------------------------------------------------- /congestion_meta1/README.md: -------------------------------------------------------------------------------- 1 | # congestion 2 | 3 | mod from https://github.com/MetaCubeX/Clash.Meta/tree/53f9e1ee7104473da2b4ff5da29965563084482d/transport/tuic/congestion -------------------------------------------------------------------------------- /congestion_meta1/bandwidth.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "math" 5 | "time" 6 | 7 | "github.com/sagernet/quic-go/congestion" 8 | ) 9 | 10 | // Bandwidth of a connection 11 | type Bandwidth uint64 12 | 13 | const infBandwidth Bandwidth = math.MaxUint64 14 | 15 | const ( 16 | // BitsPerSecond is 1 bit per second 17 | BitsPerSecond Bandwidth = 1 18 | // BytesPerSecond is 1 byte per second 19 | BytesPerSecond = 8 * BitsPerSecond 20 | ) 21 | 22 | // BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta 23 | func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth { 24 | return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond 25 | } 26 | -------------------------------------------------------------------------------- /congestion_meta1/clock.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import "time" 4 | 5 | // A Clock returns the current time 6 | type Clock interface { 7 | Now() time.Time 8 | } 9 | 10 | // DefaultClock implements the Clock interface using the Go stdlib clock. 11 | type DefaultClock struct { 12 | TimeFunc func() time.Time 13 | } 14 | 15 | var _ Clock = DefaultClock{} 16 | 17 | // Now gets the current time 18 | func (c DefaultClock) Now() time.Time { 19 | return c.TimeFunc() 20 | } 21 | -------------------------------------------------------------------------------- /congestion_meta1/cubic.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "math" 5 | "time" 6 | 7 | "github.com/sagernet/quic-go/congestion" 8 | ) 9 | 10 | // This cubic implementation is based on the one found in Chromiums's QUIC 11 | // implementation, in the files net/quic/congestion_control/cubic.{hh,cc}. 12 | 13 | // Constants based on TCP defaults. 14 | // The following constants are in 2^10 fractions of a second instead of ms to 15 | // allow a 10 shift right to divide. 16 | 17 | // 1024*1024^3 (first 1024 is from 0.100^3) 18 | // where 0.100 is 100 ms which is the scaling round trip time. 19 | const ( 20 | cubeScale = 40 21 | cubeCongestionWindowScale = 410 22 | cubeFactor congestion.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize 23 | // TODO: when re-enabling cubic, make sure to use the actual packet size here 24 | maxDatagramSize = congestion.ByteCount(congestion.InitialPacketSizeIPv4) 25 | ) 26 | 27 | const defaultNumConnections = 1 28 | 29 | // Default Cubic backoff factor 30 | const beta float32 = 0.7 31 | 32 | // Additional backoff factor when loss occurs in the concave part of the Cubic 33 | // curve. This additional backoff factor is expected to give up bandwidth to 34 | // new concurrent flows and speed up convergence. 35 | const betaLastMax float32 = 0.85 36 | 37 | // Cubic implements the cubic algorithm from TCP 38 | type Cubic struct { 39 | clock Clock 40 | 41 | // Number of connections to simulate. 42 | numConnections int 43 | 44 | // Time when this cycle started, after last loss event. 45 | epoch time.Time 46 | 47 | // Max congestion window used just before last loss event. 48 | // Note: to improve fairness to other streams an additional back off is 49 | // applied to this value if the new value is below our latest value. 50 | lastMaxCongestionWindow congestion.ByteCount 51 | 52 | // Number of acked bytes since the cycle started (epoch). 53 | ackedBytesCount congestion.ByteCount 54 | 55 | // TCP Reno equivalent congestion window in packets. 56 | estimatedTCPcongestionWindow congestion.ByteCount 57 | 58 | // Origin point of cubic function. 59 | originPointCongestionWindow congestion.ByteCount 60 | 61 | // Time to origin point of cubic function in 2^10 fractions of a second. 62 | timeToOriginPoint uint32 63 | 64 | // Last congestion window in packets computed by cubic function. 65 | lastTargetCongestionWindow congestion.ByteCount 66 | } 67 | 68 | // NewCubic returns a new Cubic instance 69 | func NewCubic(clock Clock) *Cubic { 70 | c := &Cubic{ 71 | clock: clock, 72 | numConnections: defaultNumConnections, 73 | } 74 | c.Reset() 75 | return c 76 | } 77 | 78 | // Reset is called after a timeout to reset the cubic state 79 | func (c *Cubic) Reset() { 80 | c.epoch = time.Time{} 81 | c.lastMaxCongestionWindow = 0 82 | c.ackedBytesCount = 0 83 | c.estimatedTCPcongestionWindow = 0 84 | c.originPointCongestionWindow = 0 85 | c.timeToOriginPoint = 0 86 | c.lastTargetCongestionWindow = 0 87 | } 88 | 89 | func (c *Cubic) alpha() float32 { 90 | // TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that 91 | // beta here is a cwnd multiplier, and is equal to 1-beta from the paper. 92 | // We derive the equivalent alpha for an N-connection emulation as: 93 | b := c.beta() 94 | return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b) 95 | } 96 | 97 | func (c *Cubic) beta() float32 { 98 | // kNConnectionBeta is the backoff factor after loss for our N-connection 99 | // emulation, which emulates the effective backoff of an ensemble of N 100 | // TCP-Reno connections on a single loss event. The effective multiplier is 101 | // computed as: 102 | return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections) 103 | } 104 | 105 | func (c *Cubic) betaLastMax() float32 { 106 | // betaLastMax is the additional backoff factor after loss for our 107 | // N-connection emulation, which emulates the additional backoff of 108 | // an ensemble of N TCP-Reno connections on a single loss event. The 109 | // effective multiplier is computed as: 110 | return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections) 111 | } 112 | 113 | // OnApplicationLimited is called on ack arrival when sender is unable to use 114 | // the available congestion window. Resets Cubic state during quiescence. 115 | func (c *Cubic) OnApplicationLimited() { 116 | // When sender is not using the available congestion window, the window does 117 | // not grow. But to be RTT-independent, Cubic assumes that the sender has been 118 | // using the entire window during the time since the beginning of the current 119 | // "epoch" (the end of the last loss recovery period). Since 120 | // application-limited periods break this assumption, we reset the epoch when 121 | // in such a period. This reset effectively freezes congestion window growth 122 | // through application-limited periods and allows Cubic growth to continue 123 | // when the entire window is being used. 124 | c.epoch = time.Time{} 125 | } 126 | 127 | // CongestionWindowAfterPacketLoss computes a new congestion window to use after 128 | // a loss event. Returns the new congestion window in packets. The new 129 | // congestion window is a multiplicative decrease of our current window. 130 | func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow congestion.ByteCount) congestion.ByteCount { 131 | if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow { 132 | // We never reached the old max, so assume we are competing with another 133 | // flow. Use our extra back off factor to allow the other flow to go up. 134 | c.lastMaxCongestionWindow = congestion.ByteCount(c.betaLastMax() * float32(currentCongestionWindow)) 135 | } else { 136 | c.lastMaxCongestionWindow = currentCongestionWindow 137 | } 138 | c.epoch = time.Time{} // Reset time. 139 | return congestion.ByteCount(float32(currentCongestionWindow) * c.beta()) 140 | } 141 | 142 | // CongestionWindowAfterAck computes a new congestion window to use after a received ACK. 143 | // Returns the new congestion window in packets. The new congestion window 144 | // follows a cubic function that depends on the time passed since last 145 | // packet loss. 146 | func (c *Cubic) CongestionWindowAfterAck( 147 | ackedBytes congestion.ByteCount, 148 | currentCongestionWindow congestion.ByteCount, 149 | delayMin time.Duration, 150 | eventTime time.Time, 151 | ) congestion.ByteCount { 152 | c.ackedBytesCount += ackedBytes 153 | 154 | if c.epoch.IsZero() { 155 | // First ACK after a loss event. 156 | c.epoch = eventTime // Start of epoch. 157 | c.ackedBytesCount = ackedBytes // Reset count. 158 | // Reset estimated_tcp_congestion_window_ to be in sync with cubic. 159 | c.estimatedTCPcongestionWindow = currentCongestionWindow 160 | if c.lastMaxCongestionWindow <= currentCongestionWindow { 161 | c.timeToOriginPoint = 0 162 | c.originPointCongestionWindow = currentCongestionWindow 163 | } else { 164 | c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow)))) 165 | c.originPointCongestionWindow = c.lastMaxCongestionWindow 166 | } 167 | } 168 | 169 | // Change the time unit from microseconds to 2^10 fractions per second. Take 170 | // the round trip time in account. This is done to allow us to use shift as a 171 | // divide operator. 172 | elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000) 173 | 174 | // Right-shifts of negative, signed numbers have implementation-dependent 175 | // behavior, so force the offset to be positive, as is done in the kernel. 176 | offset := int64(c.timeToOriginPoint) - elapsedTime 177 | if offset < 0 { 178 | offset = -offset 179 | } 180 | 181 | deltaCongestionWindow := congestion.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale 182 | var targetCongestionWindow congestion.ByteCount 183 | if elapsedTime > int64(c.timeToOriginPoint) { 184 | targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow 185 | } else { 186 | targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow 187 | } 188 | // Limit the CWND increase to half the acked bytes. 189 | targetCongestionWindow = Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2) 190 | 191 | // Increase the window by approximately Alpha * 1 MSS of bytes every 192 | // time we ack an estimated tcp window of bytes. For small 193 | // congestion windows (less than 25), the formula below will 194 | // increase slightly slower than linearly per estimated tcp window 195 | // of bytes. 196 | c.estimatedTCPcongestionWindow += congestion.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow)) 197 | c.ackedBytesCount = 0 198 | 199 | // We have a new cubic congestion window. 200 | c.lastTargetCongestionWindow = targetCongestionWindow 201 | 202 | // Compute target congestion_window based on cubic target and estimated TCP 203 | // congestion_window, use highest (fastest). 204 | if targetCongestionWindow < c.estimatedTCPcongestionWindow { 205 | targetCongestionWindow = c.estimatedTCPcongestionWindow 206 | } 207 | return targetCongestionWindow 208 | } 209 | 210 | // SetNumConnections sets the number of emulated connections 211 | func (c *Cubic) SetNumConnections(n int) { 212 | c.numConnections = n 213 | } 214 | -------------------------------------------------------------------------------- /congestion_meta1/cubic_sender.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/sagernet/quic-go/congestion" 8 | "github.com/sagernet/quic-go/logging" 9 | ) 10 | 11 | const ( 12 | maxBurstPackets = 3 13 | renoBeta = 0.7 // Reno backoff factor. 14 | minCongestionWindowPackets = 2 15 | initialCongestionWindow = 32 16 | ) 17 | 18 | const ( 19 | InvalidPacketNumber congestion.PacketNumber = -1 20 | MaxCongestionWindowPackets = 20000 21 | MaxByteCount = congestion.ByteCount(1<<62 - 1) 22 | ) 23 | 24 | type cubicSender struct { 25 | hybridSlowStart HybridSlowStart 26 | rttStats congestion.RTTStatsProvider 27 | cubic *Cubic 28 | pacer *pacer 29 | clock Clock 30 | 31 | reno bool 32 | 33 | // Track the largest packet that has been sent. 34 | largestSentPacketNumber congestion.PacketNumber 35 | 36 | // Track the largest packet that has been acked. 37 | largestAckedPacketNumber congestion.PacketNumber 38 | 39 | // Track the largest packet number outstanding when a CWND cutback occurs. 40 | largestSentAtLastCutback congestion.PacketNumber 41 | 42 | // Whether the last loss event caused us to exit slowstart. 43 | // Used for stats collection of slowstartPacketsLost 44 | lastCutbackExitedSlowstart bool 45 | 46 | // Congestion window in bytes. 47 | congestionWindow congestion.ByteCount 48 | 49 | // Slow start congestion window in bytes, aka ssthresh. 50 | slowStartThreshold congestion.ByteCount 51 | 52 | // ACK counter for the Reno implementation. 53 | numAckedPackets uint64 54 | 55 | initialCongestionWindow congestion.ByteCount 56 | initialMaxCongestionWindow congestion.ByteCount 57 | 58 | maxDatagramSize congestion.ByteCount 59 | 60 | lastState logging.CongestionState 61 | tracer *logging.ConnectionTracer 62 | } 63 | 64 | var _ congestion.CongestionControl = &cubicSender{} 65 | 66 | // NewCubicSender makes a new cubic sender 67 | func NewCubicSender( 68 | clock Clock, 69 | initialMaxDatagramSize congestion.ByteCount, 70 | reno bool, 71 | tracer *logging.ConnectionTracer, 72 | ) *cubicSender { 73 | return newCubicSender( 74 | clock, 75 | reno, 76 | initialMaxDatagramSize, 77 | initialCongestionWindow*initialMaxDatagramSize, 78 | MaxCongestionWindowPackets*initialMaxDatagramSize, 79 | tracer, 80 | ) 81 | } 82 | 83 | func newCubicSender( 84 | clock Clock, 85 | reno bool, 86 | initialMaxDatagramSize, 87 | initialCongestionWindow, 88 | initialMaxCongestionWindow congestion.ByteCount, 89 | tracer *logging.ConnectionTracer, 90 | ) *cubicSender { 91 | c := &cubicSender{ 92 | largestSentPacketNumber: InvalidPacketNumber, 93 | largestAckedPacketNumber: InvalidPacketNumber, 94 | largestSentAtLastCutback: InvalidPacketNumber, 95 | initialCongestionWindow: initialCongestionWindow, 96 | initialMaxCongestionWindow: initialMaxCongestionWindow, 97 | congestionWindow: initialCongestionWindow, 98 | slowStartThreshold: MaxByteCount, 99 | cubic: NewCubic(clock), 100 | clock: clock, 101 | reno: reno, 102 | tracer: tracer, 103 | maxDatagramSize: initialMaxDatagramSize, 104 | } 105 | c.pacer = newPacer(c.BandwidthEstimate) 106 | if c.tracer != nil { 107 | c.lastState = logging.CongestionStateSlowStart 108 | c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) 109 | } 110 | return c 111 | } 112 | 113 | func (c *cubicSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { 114 | c.rttStats = provider 115 | } 116 | 117 | // TimeUntilSend returns when the next packet should be sent. 118 | func (c *cubicSender) TimeUntilSend(_ congestion.ByteCount) time.Time { 119 | return c.pacer.TimeUntilSend() 120 | } 121 | 122 | func (c *cubicSender) HasPacingBudget(now time.Time) bool { 123 | return c.pacer.Budget(now) >= c.maxDatagramSize 124 | } 125 | 126 | func (c *cubicSender) maxCongestionWindow() congestion.ByteCount { 127 | return c.maxDatagramSize * MaxCongestionWindowPackets 128 | } 129 | 130 | func (c *cubicSender) minCongestionWindow() congestion.ByteCount { 131 | return c.maxDatagramSize * minCongestionWindowPackets 132 | } 133 | 134 | func (c *cubicSender) OnPacketSent( 135 | sentTime time.Time, 136 | _ congestion.ByteCount, 137 | packetNumber congestion.PacketNumber, 138 | bytes congestion.ByteCount, 139 | isRetransmittable bool, 140 | ) { 141 | c.pacer.SentPacket(sentTime, bytes) 142 | if !isRetransmittable { 143 | return 144 | } 145 | c.largestSentPacketNumber = packetNumber 146 | c.hybridSlowStart.OnPacketSent(packetNumber) 147 | } 148 | 149 | func (c *cubicSender) CanSend(bytesInFlight congestion.ByteCount) bool { 150 | return bytesInFlight < c.GetCongestionWindow() 151 | } 152 | 153 | func (c *cubicSender) InRecovery() bool { 154 | return c.largestAckedPacketNumber != InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback 155 | } 156 | 157 | func (c *cubicSender) InSlowStart() bool { 158 | return c.GetCongestionWindow() < c.slowStartThreshold 159 | } 160 | 161 | func (c *cubicSender) GetCongestionWindow() congestion.ByteCount { 162 | return c.congestionWindow 163 | } 164 | 165 | func (c *cubicSender) MaybeExitSlowStart() { 166 | if c.InSlowStart() && 167 | c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) { 168 | // exit slow start 169 | c.slowStartThreshold = c.congestionWindow 170 | c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) 171 | } 172 | } 173 | 174 | func (c *cubicSender) OnPacketAcked( 175 | ackedPacketNumber congestion.PacketNumber, 176 | ackedBytes congestion.ByteCount, 177 | priorInFlight congestion.ByteCount, 178 | eventTime time.Time, 179 | ) { 180 | c.largestAckedPacketNumber = Max(ackedPacketNumber, c.largestAckedPacketNumber) 181 | if c.InRecovery() { 182 | return 183 | } 184 | c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime) 185 | if c.InSlowStart() { 186 | c.hybridSlowStart.OnPacketAcked(ackedPacketNumber) 187 | } 188 | } 189 | 190 | func (c *cubicSender) OnCongestionEvent(packetNumber congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { 191 | // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets 192 | // already sent should be treated as a single loss event, since it's expected. 193 | if packetNumber <= c.largestSentAtLastCutback { 194 | return 195 | } 196 | c.lastCutbackExitedSlowstart = c.InSlowStart() 197 | c.maybeTraceStateChange(logging.CongestionStateRecovery) 198 | 199 | if c.reno { 200 | c.congestionWindow = congestion.ByteCount(float64(c.congestionWindow) * renoBeta) 201 | } else { 202 | c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow) 203 | } 204 | if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd { 205 | c.congestionWindow = minCwnd 206 | } 207 | c.slowStartThreshold = c.congestionWindow 208 | c.largestSentAtLastCutback = c.largestSentPacketNumber 209 | // reset packet count from congestion avoidance mode. We start 210 | // counting again when we're out of recovery. 211 | c.numAckedPackets = 0 212 | } 213 | 214 | // Called when we receive an ack. Normal TCP tracks how many packets one ack 215 | // represents, but quic has a separate ack for each packet. 216 | func (c *cubicSender) maybeIncreaseCwnd( 217 | _ congestion.PacketNumber, 218 | ackedBytes congestion.ByteCount, 219 | priorInFlight congestion.ByteCount, 220 | eventTime time.Time, 221 | ) { 222 | // Do not increase the congestion window unless the sender is close to using 223 | // the current window. 224 | if !c.isCwndLimited(priorInFlight) { 225 | c.cubic.OnApplicationLimited() 226 | c.maybeTraceStateChange(logging.CongestionStateApplicationLimited) 227 | return 228 | } 229 | if c.congestionWindow >= c.maxCongestionWindow() { 230 | return 231 | } 232 | if c.InSlowStart() { 233 | // TCP slow start, exponential growth, increase by one for each ACK. 234 | c.congestionWindow += c.maxDatagramSize 235 | c.maybeTraceStateChange(logging.CongestionStateSlowStart) 236 | return 237 | } 238 | // Congestion avoidance 239 | c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance) 240 | if c.reno { 241 | // Classic Reno congestion avoidance. 242 | c.numAckedPackets++ 243 | if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) { 244 | c.congestionWindow += c.maxDatagramSize 245 | c.numAckedPackets = 0 246 | } 247 | } else { 248 | c.congestionWindow = Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime)) 249 | } 250 | } 251 | 252 | func (c *cubicSender) isCwndLimited(bytesInFlight congestion.ByteCount) bool { 253 | congestionWindow := c.GetCongestionWindow() 254 | if bytesInFlight >= congestionWindow { 255 | return true 256 | } 257 | availableBytes := congestionWindow - bytesInFlight 258 | slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2 259 | return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize 260 | } 261 | 262 | // BandwidthEstimate returns the current bandwidth estimate 263 | func (c *cubicSender) BandwidthEstimate() Bandwidth { 264 | if c.rttStats == nil { 265 | return infBandwidth 266 | } 267 | srtt := c.rttStats.SmoothedRTT() 268 | if srtt == 0 { 269 | // If we haven't measured an rtt, the bandwidth estimate is unknown. 270 | return infBandwidth 271 | } 272 | return BandwidthFromDelta(c.GetCongestionWindow(), srtt) 273 | } 274 | 275 | // OnRetransmissionTimeout is called on an retransmission timeout 276 | func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) { 277 | c.largestSentAtLastCutback = InvalidPacketNumber 278 | if !packetsRetransmitted { 279 | return 280 | } 281 | c.hybridSlowStart.Restart() 282 | c.cubic.Reset() 283 | c.slowStartThreshold = c.congestionWindow / 2 284 | c.congestionWindow = c.minCongestionWindow() 285 | } 286 | 287 | // OnConnectionMigration is called when the connection is migrated (?) 288 | func (c *cubicSender) OnConnectionMigration() { 289 | c.hybridSlowStart.Restart() 290 | c.largestSentPacketNumber = InvalidPacketNumber 291 | c.largestAckedPacketNumber = InvalidPacketNumber 292 | c.largestSentAtLastCutback = InvalidPacketNumber 293 | c.lastCutbackExitedSlowstart = false 294 | c.cubic.Reset() 295 | c.numAckedPackets = 0 296 | c.congestionWindow = c.initialCongestionWindow 297 | c.slowStartThreshold = c.initialMaxCongestionWindow 298 | } 299 | 300 | func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { 301 | if c.tracer == nil || new == c.lastState { 302 | return 303 | } 304 | c.tracer.UpdatedCongestionState(new) 305 | c.lastState = new 306 | } 307 | 308 | func (c *cubicSender) SetMaxDatagramSize(s congestion.ByteCount) { 309 | if s < c.maxDatagramSize { 310 | panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s)) 311 | } 312 | cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow() 313 | c.maxDatagramSize = s 314 | if cwndIsMinCwnd { 315 | c.congestionWindow = c.minCongestionWindow() 316 | } 317 | c.pacer.SetMaxDatagramSize(s) 318 | } 319 | -------------------------------------------------------------------------------- /congestion_meta1/hybrid_slow_start.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/sagernet/quic-go/congestion" 7 | ) 8 | 9 | // Note(pwestin): the magic clamping numbers come from the original code in 10 | // tcp_cubic.c. 11 | const hybridStartLowWindow = congestion.ByteCount(16) 12 | 13 | // Number of delay samples for detecting the increase of delay. 14 | const hybridStartMinSamples = uint32(8) 15 | 16 | // Exit slow start if the min rtt has increased by more than 1/8th. 17 | const hybridStartDelayFactorExp = 3 // 2^3 = 8 18 | // The original paper specifies 2 and 8ms, but those have changed over time. 19 | const ( 20 | hybridStartDelayMinThresholdUs = int64(4000) 21 | hybridStartDelayMaxThresholdUs = int64(16000) 22 | ) 23 | 24 | // HybridSlowStart implements the TCP hybrid slow start algorithm 25 | type HybridSlowStart struct { 26 | endPacketNumber congestion.PacketNumber 27 | lastSentPacketNumber congestion.PacketNumber 28 | started bool 29 | currentMinRTT time.Duration 30 | rttSampleCount uint32 31 | hystartFound bool 32 | } 33 | 34 | // StartReceiveRound is called for the start of each receive round (burst) in the slow start phase. 35 | func (s *HybridSlowStart) StartReceiveRound(lastSent congestion.PacketNumber) { 36 | s.endPacketNumber = lastSent 37 | s.currentMinRTT = 0 38 | s.rttSampleCount = 0 39 | s.started = true 40 | } 41 | 42 | // IsEndOfRound returns true if this ack is the last packet number of our current slow start round. 43 | func (s *HybridSlowStart) IsEndOfRound(ack congestion.PacketNumber) bool { 44 | return s.endPacketNumber < ack 45 | } 46 | 47 | // ShouldExitSlowStart should be called on every new ack frame, since a new 48 | // RTT measurement can be made then. 49 | // rtt: the RTT for this ack packet. 50 | // minRTT: is the lowest delay (RTT) we have seen during the session. 51 | // congestionWindow: the congestion window in packets. 52 | func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow congestion.ByteCount) bool { 53 | if !s.started { 54 | // Time to start the hybrid slow start. 55 | s.StartReceiveRound(s.lastSentPacketNumber) 56 | } 57 | if s.hystartFound { 58 | return true 59 | } 60 | // Second detection parameter - delay increase detection. 61 | // Compare the minimum delay (s.currentMinRTT) of the current 62 | // burst of packets relative to the minimum delay during the session. 63 | // Note: we only look at the first few(8) packets in each burst, since we 64 | // only want to compare the lowest RTT of the burst relative to previous 65 | // bursts. 66 | s.rttSampleCount++ 67 | if s.rttSampleCount <= hybridStartMinSamples { 68 | if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT { 69 | s.currentMinRTT = latestRTT 70 | } 71 | } 72 | // We only need to check this once per round. 73 | if s.rttSampleCount == hybridStartMinSamples { 74 | // Divide minRTT by 8 to get a rtt increase threshold for exiting. 75 | minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp) 76 | // Ensure the rtt threshold is never less than 2ms or more than 16ms. 77 | minRTTincreaseThresholdUs = Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs) 78 | minRTTincreaseThreshold := time.Duration(Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond 79 | 80 | if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) { 81 | s.hystartFound = true 82 | } 83 | } 84 | // Exit from slow start if the cwnd is greater than 16 and 85 | // increasing delay is found. 86 | return congestionWindow >= hybridStartLowWindow && s.hystartFound 87 | } 88 | 89 | // OnPacketSent is called when a packet was sent 90 | func (s *HybridSlowStart) OnPacketSent(packetNumber congestion.PacketNumber) { 91 | s.lastSentPacketNumber = packetNumber 92 | } 93 | 94 | // OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end 95 | // the round when the final packet of the burst is received and start it on 96 | // the next incoming ack. 97 | func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber congestion.PacketNumber) { 98 | if s.IsEndOfRound(ackedPacketNumber) { 99 | s.started = false 100 | } 101 | } 102 | 103 | // Started returns true if started 104 | func (s *HybridSlowStart) Started() bool { 105 | return s.started 106 | } 107 | 108 | // Restart the slow start phase 109 | func (s *HybridSlowStart) Restart() { 110 | s.started = false 111 | s.hystartFound = false 112 | } 113 | -------------------------------------------------------------------------------- /congestion_meta1/minmax.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "math" 5 | "time" 6 | 7 | "golang.org/x/exp/constraints" 8 | ) 9 | 10 | // InfDuration is a duration of infinite length 11 | const InfDuration = time.Duration(math.MaxInt64) 12 | 13 | func Max[T constraints.Ordered](a, b T) T { 14 | if a < b { 15 | return b 16 | } 17 | return a 18 | } 19 | 20 | func Min[T constraints.Ordered](a, b T) T { 21 | if a < b { 22 | return a 23 | } 24 | return b 25 | } 26 | 27 | // MinNonZeroDuration return the minimum duration that's not zero. 28 | func MinNonZeroDuration(a, b time.Duration) time.Duration { 29 | if a == 0 { 30 | return b 31 | } 32 | if b == 0 { 33 | return a 34 | } 35 | return Min(a, b) 36 | } 37 | 38 | // AbsDuration returns the absolute value of a time duration 39 | func AbsDuration(d time.Duration) time.Duration { 40 | if d >= 0 { 41 | return d 42 | } 43 | return -d 44 | } 45 | 46 | // MinTime returns the earlier time 47 | func MinTime(a, b time.Time) time.Time { 48 | if a.After(b) { 49 | return b 50 | } 51 | return a 52 | } 53 | 54 | // MinNonZeroTime returns the earlist time that is not time.Time{} 55 | // If both a and b are time.Time{}, it returns time.Time{} 56 | func MinNonZeroTime(a, b time.Time) time.Time { 57 | if a.IsZero() { 58 | return b 59 | } 60 | if b.IsZero() { 61 | return a 62 | } 63 | return MinTime(a, b) 64 | } 65 | 66 | // MaxTime returns the later time 67 | func MaxTime(a, b time.Time) time.Time { 68 | if a.After(b) { 69 | return a 70 | } 71 | return b 72 | } 73 | -------------------------------------------------------------------------------- /congestion_meta1/pacer.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "math" 5 | "time" 6 | 7 | "github.com/sagernet/quic-go/congestion" 8 | ) 9 | 10 | const ( 11 | initialMaxDatagramSize = congestion.ByteCount(1252) 12 | MinPacingDelay = time.Millisecond 13 | TimerGranularity = time.Millisecond 14 | maxBurstSizePackets = 10 15 | ) 16 | 17 | // The pacer implements a token bucket pacing algorithm. 18 | type pacer struct { 19 | budgetAtLastSent congestion.ByteCount 20 | maxDatagramSize congestion.ByteCount 21 | lastSentTime time.Time 22 | getAdjustedBandwidth func() uint64 // in bytes/s 23 | } 24 | 25 | func newPacer(getBandwidth func() Bandwidth) *pacer { 26 | p := &pacer{ 27 | maxDatagramSize: initialMaxDatagramSize, 28 | getAdjustedBandwidth: func() uint64 { 29 | // Bandwidth is in bits/s. We need the value in bytes/s. 30 | bw := uint64(getBandwidth() / BytesPerSecond) 31 | // Use a slightly higher value than the actual measured bandwidth. 32 | // RTT variations then won't result in under-utilization of the congestion window. 33 | // Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire, 34 | // provided the congestion window is fully utilized and acknowledgments arrive at regular intervals. 35 | return bw * 5 / 4 36 | }, 37 | } 38 | p.budgetAtLastSent = p.maxBurstSize() 39 | return p 40 | } 41 | 42 | func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { 43 | budget := p.Budget(sendTime) 44 | if size > budget { 45 | p.budgetAtLastSent = 0 46 | } else { 47 | p.budgetAtLastSent = budget - size 48 | } 49 | p.lastSentTime = sendTime 50 | } 51 | 52 | func (p *pacer) Budget(now time.Time) congestion.ByteCount { 53 | if p.lastSentTime.IsZero() { 54 | return p.maxBurstSize() 55 | } 56 | budget := p.budgetAtLastSent + (congestion.ByteCount(p.getAdjustedBandwidth())*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 57 | return Min(p.maxBurstSize(), budget) 58 | } 59 | 60 | func (p *pacer) maxBurstSize() congestion.ByteCount { 61 | return Max( 62 | congestion.ByteCount(uint64((MinPacingDelay+TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9, 63 | maxBurstSizePackets*p.maxDatagramSize, 64 | ) 65 | } 66 | 67 | // TimeUntilSend returns when the next packet should be sent. 68 | // It returns the zero value of time.Time if a packet can be sent immediately. 69 | func (p *pacer) TimeUntilSend() time.Time { 70 | if p.budgetAtLastSent >= p.maxDatagramSize { 71 | return time.Time{} 72 | } 73 | return p.lastSentTime.Add(Max( 74 | MinPacingDelay, 75 | time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond, 76 | )) 77 | } 78 | 79 | func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) { 80 | p.maxDatagramSize = s 81 | } 82 | -------------------------------------------------------------------------------- /congestion_meta1/windowed_filter.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | // WindowedFilter Use the following to construct a windowed filter object of type T. 4 | // For example, a min filter using QuicTime as the time type: 5 | // 6 | // WindowedFilter, QuicTime, QuicTime::Delta> ObjectName; 7 | // 8 | // A max filter using 64-bit integers as the time type: 9 | // 10 | // WindowedFilter, uint64_t, int64_t> ObjectName; 11 | // 12 | // Specifically, this template takes four arguments: 13 | // 1. T -- type of the measurement that is being filtered. 14 | // 2. Compare -- MinFilter or MaxFilter, depending on the type of filter 15 | // desired. 16 | // 3. TimeT -- the type used to represent timestamps. 17 | // 4. TimeDeltaT -- the type used to represent continuous time intervals between 18 | // two timestamps. Has to be the type of (a - b) if both |a| and |b| are 19 | // of type TimeT. 20 | type WindowedFilter struct { 21 | // Time length of window. 22 | windowLength int64 23 | estimates []Sample 24 | comparator func(int64, int64) bool 25 | } 26 | 27 | type Sample struct { 28 | sample int64 29 | time int64 30 | } 31 | 32 | // Compares two values and returns true if the first is greater than or equal 33 | // to the second. 34 | func MaxFilter(a, b int64) bool { 35 | return a >= b 36 | } 37 | 38 | // Compares two values and returns true if the first is less than or equal 39 | // to the second. 40 | func MinFilter(a, b int64) bool { 41 | return a <= b 42 | } 43 | 44 | func NewWindowedFilter(windowLength int64, comparator func(int64, int64) bool) *WindowedFilter { 45 | return &WindowedFilter{ 46 | windowLength: windowLength, 47 | estimates: make([]Sample, 3), 48 | comparator: comparator, 49 | } 50 | } 51 | 52 | // Changes the window length. Does not update any current samples. 53 | func (f *WindowedFilter) SetWindowLength(windowLength int64) { 54 | f.windowLength = windowLength 55 | } 56 | 57 | func (f *WindowedFilter) GetBest() int64 { 58 | return f.estimates[0].sample 59 | } 60 | 61 | func (f *WindowedFilter) GetSecondBest() int64 { 62 | return f.estimates[1].sample 63 | } 64 | 65 | func (f *WindowedFilter) GetThirdBest() int64 { 66 | return f.estimates[2].sample 67 | } 68 | 69 | func (f *WindowedFilter) Update(sample int64, time int64) { 70 | if f.estimates[0].time == 0 || f.comparator(sample, f.estimates[0].sample) || (time-f.estimates[2].time) > f.windowLength { 71 | f.Reset(sample, time) 72 | return 73 | } 74 | 75 | if f.comparator(sample, f.estimates[1].sample) { 76 | f.estimates[1].sample = sample 77 | f.estimates[1].time = time 78 | f.estimates[2].sample = sample 79 | f.estimates[2].time = time 80 | } else if f.comparator(sample, f.estimates[2].sample) { 81 | f.estimates[2].sample = sample 82 | f.estimates[2].time = time 83 | } 84 | 85 | // Expire and update estimates as necessary. 86 | if time-f.estimates[0].time > f.windowLength { 87 | // The best estimate hasn't been updated for an entire window, so promote 88 | // second and third best estimates. 89 | f.estimates[0].sample = f.estimates[1].sample 90 | f.estimates[0].time = f.estimates[1].time 91 | f.estimates[1].sample = f.estimates[2].sample 92 | f.estimates[1].time = f.estimates[2].time 93 | f.estimates[2].sample = sample 94 | f.estimates[2].time = time 95 | // Need to iterate one more time. Check if the new best estimate is 96 | // outside the window as well, since it may also have been recorded a 97 | // long time ago. Don't need to iterate once more since we cover that 98 | // case at the beginning of the method. 99 | if time-f.estimates[0].time > f.windowLength { 100 | f.estimates[0].sample = f.estimates[1].sample 101 | f.estimates[0].time = f.estimates[1].time 102 | f.estimates[1].sample = f.estimates[2].sample 103 | f.estimates[1].time = f.estimates[2].time 104 | } 105 | return 106 | } 107 | if f.estimates[1].sample == f.estimates[0].sample && time-f.estimates[1].time > f.windowLength>>2 { 108 | // A quarter of the window has passed without a better sample, so the 109 | // second-best estimate is taken from the second quarter of the window. 110 | f.estimates[1].sample = sample 111 | f.estimates[1].time = time 112 | f.estimates[2].sample = sample 113 | f.estimates[2].time = time 114 | return 115 | } 116 | 117 | if f.estimates[2].sample == f.estimates[1].sample && time-f.estimates[2].time > f.windowLength>>1 { 118 | // We've passed a half of the window without a better estimate, so take 119 | // a third-best estimate from the second half of the window. 120 | f.estimates[2].sample = sample 121 | f.estimates[2].time = time 122 | } 123 | } 124 | 125 | func (f *WindowedFilter) Reset(newSample int64, newTime int64) { 126 | f.estimates[0].sample = newSample 127 | f.estimates[0].time = newTime 128 | f.estimates[1].sample = newSample 129 | f.estimates[1].time = newTime 130 | f.estimates[2].sample = newSample 131 | f.estimates[2].time = newTime 132 | } 133 | -------------------------------------------------------------------------------- /congestion_meta2/README.md: -------------------------------------------------------------------------------- 1 | # congestion 2 | 3 | mod from https://github.com/MetaCubeX/Clash.Meta/tree/dbaee284e4310aea71344f5154b174bd0333b657/transport/tuic/congestion_v2 -------------------------------------------------------------------------------- /congestion_meta2/bandwidth.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "math" 5 | "time" 6 | 7 | "github.com/sagernet/quic-go/congestion" 8 | ) 9 | 10 | const ( 11 | infBandwidth = Bandwidth(math.MaxUint64) 12 | ) 13 | 14 | // Bandwidth of a connection 15 | type Bandwidth uint64 16 | 17 | const ( 18 | // BitsPerSecond is 1 bit per second 19 | BitsPerSecond Bandwidth = 1 20 | // BytesPerSecond is 1 byte per second 21 | BytesPerSecond = 8 * BitsPerSecond 22 | ) 23 | 24 | // BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta 25 | func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth { 26 | return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond 27 | } 28 | -------------------------------------------------------------------------------- /congestion_meta2/clock.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import "time" 4 | 5 | // A Clock returns the current time 6 | type Clock interface { 7 | Now() time.Time 8 | } 9 | 10 | // DefaultClock implements the Clock interface using the Go stdlib clock. 11 | type DefaultClock struct { 12 | TimeFunc func() time.Time 13 | } 14 | 15 | var _ Clock = DefaultClock{} 16 | 17 | // Now gets the current time 18 | func (c DefaultClock) Now() time.Time { 19 | return c.TimeFunc() 20 | } 21 | -------------------------------------------------------------------------------- /congestion_meta2/minmax_go120.go: -------------------------------------------------------------------------------- 1 | //go:build !go1.21 2 | 3 | package congestion 4 | 5 | import "golang.org/x/exp/constraints" 6 | 7 | func Max[T constraints.Ordered](a, b T) T { 8 | if a < b { 9 | return b 10 | } 11 | return a 12 | } 13 | 14 | func Min[T constraints.Ordered](a, b T) T { 15 | if a < b { 16 | return a 17 | } 18 | return b 19 | } 20 | -------------------------------------------------------------------------------- /congestion_meta2/minmax_go121.go: -------------------------------------------------------------------------------- 1 | //go:build go1.21 2 | 3 | package congestion 4 | 5 | import "cmp" 6 | 7 | func Max[T cmp.Ordered](a, b T) T { 8 | return max(a, b) 9 | } 10 | 11 | func Min[T cmp.Ordered](a, b T) T { 12 | return min(a, b) 13 | } 14 | -------------------------------------------------------------------------------- /congestion_meta2/pacer.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "math" 5 | "time" 6 | 7 | "github.com/sagernet/quic-go/congestion" 8 | ) 9 | 10 | const ( 11 | maxBurstPackets = 10 12 | ) 13 | 14 | // Pacer implements a token bucket pacing algorithm. 15 | type Pacer struct { 16 | budgetAtLastSent congestion.ByteCount 17 | maxDatagramSize congestion.ByteCount 18 | lastSentTime time.Time 19 | getBandwidth func() congestion.ByteCount // in bytes/s 20 | } 21 | 22 | func NewPacer(getBandwidth func() congestion.ByteCount) *Pacer { 23 | p := &Pacer{ 24 | budgetAtLastSent: maxBurstPackets * congestion.InitialPacketSizeIPv4, 25 | maxDatagramSize: congestion.InitialPacketSizeIPv4, 26 | getBandwidth: getBandwidth, 27 | } 28 | return p 29 | } 30 | 31 | func (p *Pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { 32 | budget := p.Budget(sendTime) 33 | if size > budget { 34 | p.budgetAtLastSent = 0 35 | } else { 36 | p.budgetAtLastSent = budget - size 37 | } 38 | p.lastSentTime = sendTime 39 | } 40 | 41 | func (p *Pacer) Budget(now time.Time) congestion.ByteCount { 42 | if p.lastSentTime.IsZero() { 43 | return p.maxBurstSize() 44 | } 45 | budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 46 | if budget < 0 { // protect against overflows 47 | budget = congestion.ByteCount(1<<62 - 1) 48 | } 49 | return Min(p.maxBurstSize(), budget) 50 | } 51 | 52 | func (p *Pacer) maxBurstSize() congestion.ByteCount { 53 | return Max( 54 | congestion.ByteCount((congestion.MinPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9, 55 | maxBurstPackets*p.maxDatagramSize, 56 | ) 57 | } 58 | 59 | // TimeUntilSend returns when the next packet should be sent. 60 | // It returns the zero value of time.Time if a packet can be sent immediately. 61 | func (p *Pacer) TimeUntilSend() time.Time { 62 | if p.budgetAtLastSent >= p.maxDatagramSize { 63 | return time.Time{} 64 | } 65 | return p.lastSentTime.Add(Max( 66 | congestion.MinPacingDelay, 67 | time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/ 68 | float64(p.getBandwidth())))*time.Nanosecond, 69 | )) 70 | } 71 | 72 | func (p *Pacer) SetMaxDatagramSize(s congestion.ByteCount) { 73 | p.maxDatagramSize = s 74 | } 75 | -------------------------------------------------------------------------------- /congestion_meta2/packet_number_indexed_queue.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "github.com/sagernet/quic-go/congestion" 5 | ) 6 | 7 | // packetNumberIndexedQueue is a queue of mostly continuous numbered entries 8 | // which supports the following operations: 9 | // - adding elements to the end of the queue, or at some point past the end 10 | // - removing elements in any order 11 | // - retrieving elements 12 | // If all elements are inserted in order, all of the operations above are 13 | // amortized O(1) time. 14 | // 15 | // Internally, the data structure is a deque where each element is marked as 16 | // present or not. The deque starts at the lowest present index. Whenever an 17 | // element is removed, it's marked as not present, and the front of the deque is 18 | // cleared of elements that are not present. 19 | // 20 | // The tail of the queue is not cleared due to the assumption of entries being 21 | // inserted in order, though removing all elements of the queue will return it 22 | // to its initial state. 23 | // 24 | // Note that this data structure is inherently hazardous, since an addition of 25 | // just two entries will cause it to consume all of the memory available. 26 | // Because of that, it is not a general-purpose container and should not be used 27 | // as one. 28 | 29 | type entryWrapper[T any] struct { 30 | present bool 31 | entry T 32 | } 33 | 34 | type packetNumberIndexedQueue[T any] struct { 35 | entries RingBuffer[entryWrapper[T]] 36 | numberOfPresentEntries int 37 | firstPacket congestion.PacketNumber 38 | } 39 | 40 | func newPacketNumberIndexedQueue[T any](size int) *packetNumberIndexedQueue[T] { 41 | q := &packetNumberIndexedQueue[T]{ 42 | firstPacket: invalidPacketNumber, 43 | } 44 | 45 | q.entries.Init(size) 46 | 47 | return q 48 | } 49 | 50 | // Emplace inserts data associated |packet_number| into (or past) the end of the 51 | // queue, filling up the missing intermediate entries as necessary. Returns 52 | // true if the element has been inserted successfully, false if it was already 53 | // in the queue or inserted out of order. 54 | func (p *packetNumberIndexedQueue[T]) Emplace(packetNumber congestion.PacketNumber, entry *T) bool { 55 | if packetNumber == invalidPacketNumber || entry == nil { 56 | return false 57 | } 58 | 59 | if p.IsEmpty() { 60 | p.entries.PushBack(entryWrapper[T]{ 61 | present: true, 62 | entry: *entry, 63 | }) 64 | p.numberOfPresentEntries = 1 65 | p.firstPacket = packetNumber 66 | return true 67 | } 68 | 69 | // Do not allow insertion out-of-order. 70 | if packetNumber <= p.LastPacket() { 71 | return false 72 | } 73 | 74 | // Handle potentially missing elements. 75 | offset := int(packetNumber - p.FirstPacket()) 76 | if gap := offset - p.entries.Len(); gap > 0 { 77 | for i := 0; i < gap; i++ { 78 | p.entries.PushBack(entryWrapper[T]{}) 79 | } 80 | } 81 | 82 | p.entries.PushBack(entryWrapper[T]{ 83 | present: true, 84 | entry: *entry, 85 | }) 86 | p.numberOfPresentEntries++ 87 | return true 88 | } 89 | 90 | // GetEntry Retrieve the entry associated with the packet number. Returns the pointer 91 | // to the entry in case of success, or nullptr if the entry does not exist. 92 | func (p *packetNumberIndexedQueue[T]) GetEntry(packetNumber congestion.PacketNumber) *T { 93 | ew := p.getEntryWraper(packetNumber) 94 | if ew == nil { 95 | return nil 96 | } 97 | 98 | return &ew.entry 99 | } 100 | 101 | // Remove, Same as above, but if an entry is present in the queue, also call f(entry) 102 | // before removing it. 103 | func (p *packetNumberIndexedQueue[T]) Remove(packetNumber congestion.PacketNumber, f func(T)) bool { 104 | ew := p.getEntryWraper(packetNumber) 105 | if ew == nil { 106 | return false 107 | } 108 | if f != nil { 109 | f(ew.entry) 110 | } 111 | ew.present = false 112 | p.numberOfPresentEntries-- 113 | 114 | if packetNumber == p.FirstPacket() { 115 | p.clearup() 116 | } 117 | 118 | return true 119 | } 120 | 121 | // RemoveUpTo, but not including |packet_number|. 122 | // Unused slots in the front are also removed, which means when the function 123 | // returns, |first_packet()| can be larger than |packet_number|. 124 | func (p *packetNumberIndexedQueue[T]) RemoveUpTo(packetNumber congestion.PacketNumber) { 125 | for !p.entries.Empty() && 126 | p.firstPacket != invalidPacketNumber && 127 | p.firstPacket < packetNumber { 128 | if p.entries.Front().present { 129 | p.numberOfPresentEntries-- 130 | } 131 | p.entries.PopFront() 132 | p.firstPacket++ 133 | } 134 | p.clearup() 135 | 136 | return 137 | } 138 | 139 | // IsEmpty return if queue is empty. 140 | func (p *packetNumberIndexedQueue[T]) IsEmpty() bool { 141 | return p.numberOfPresentEntries == 0 142 | } 143 | 144 | // NumberOfPresentEntries returns the number of entries in the queue. 145 | func (p *packetNumberIndexedQueue[T]) NumberOfPresentEntries() int { 146 | return p.numberOfPresentEntries 147 | } 148 | 149 | // EntrySlotsUsed returns the number of entries allocated in the underlying deque. This is 150 | // proportional to the memory usage of the queue. 151 | func (p *packetNumberIndexedQueue[T]) EntrySlotsUsed() int { 152 | return p.entries.Len() 153 | } 154 | 155 | // LastPacket returns packet number of the first entry in the queue. 156 | func (p *packetNumberIndexedQueue[T]) FirstPacket() (packetNumber congestion.PacketNumber) { 157 | return p.firstPacket 158 | } 159 | 160 | // LastPacket returns packet number of the last entry ever inserted in the queue. Note that the 161 | // entry in question may have already been removed. Zero if the queue is 162 | // empty. 163 | func (p *packetNumberIndexedQueue[T]) LastPacket() (packetNumber congestion.PacketNumber) { 164 | if p.IsEmpty() { 165 | return invalidPacketNumber 166 | } 167 | 168 | return p.firstPacket + congestion.PacketNumber(p.entries.Len()-1) 169 | } 170 | 171 | func (p *packetNumberIndexedQueue[T]) clearup() { 172 | for !p.entries.Empty() && !p.entries.Front().present { 173 | p.entries.PopFront() 174 | p.firstPacket++ 175 | } 176 | if p.entries.Empty() { 177 | p.firstPacket = invalidPacketNumber 178 | } 179 | } 180 | 181 | func (p *packetNumberIndexedQueue[T]) getEntryWraper(packetNumber congestion.PacketNumber) *entryWrapper[T] { 182 | if packetNumber == invalidPacketNumber || 183 | p.IsEmpty() || 184 | packetNumber < p.firstPacket { 185 | return nil 186 | } 187 | 188 | offset := int(packetNumber - p.firstPacket) 189 | if offset >= p.entries.Len() { 190 | return nil 191 | } 192 | 193 | ew := p.entries.Offset(offset) 194 | if ew == nil || !ew.present { 195 | return nil 196 | } 197 | 198 | return ew 199 | } 200 | -------------------------------------------------------------------------------- /congestion_meta2/ringbuffer.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | // A RingBuffer is a ring buffer. 4 | // It acts as a heap that doesn't cause any allocations. 5 | type RingBuffer[T any] struct { 6 | ring []T 7 | headPos, tailPos int 8 | full bool 9 | } 10 | 11 | // Init preallocs a buffer with a certain size. 12 | func (r *RingBuffer[T]) Init(size int) { 13 | r.ring = make([]T, size) 14 | } 15 | 16 | // Len returns the number of elements in the ring buffer. 17 | func (r *RingBuffer[T]) Len() int { 18 | if r.full { 19 | return len(r.ring) 20 | } 21 | if r.tailPos >= r.headPos { 22 | return r.tailPos - r.headPos 23 | } 24 | return r.tailPos - r.headPos + len(r.ring) 25 | } 26 | 27 | // Empty says if the ring buffer is empty. 28 | func (r *RingBuffer[T]) Empty() bool { 29 | return !r.full && r.headPos == r.tailPos 30 | } 31 | 32 | // PushBack adds a new element. 33 | // If the ring buffer is full, its capacity is increased first. 34 | func (r *RingBuffer[T]) PushBack(t T) { 35 | if r.full || len(r.ring) == 0 { 36 | r.grow() 37 | } 38 | r.ring[r.tailPos] = t 39 | r.tailPos++ 40 | if r.tailPos == len(r.ring) { 41 | r.tailPos = 0 42 | } 43 | if r.tailPos == r.headPos { 44 | r.full = true 45 | } 46 | } 47 | 48 | // PopFront returns the next element. 49 | // It must not be called when the buffer is empty, that means that 50 | // callers might need to check if there are elements in the buffer first. 51 | func (r *RingBuffer[T]) PopFront() T { 52 | if r.Empty() { 53 | panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue") 54 | } 55 | r.full = false 56 | t := r.ring[r.headPos] 57 | r.ring[r.headPos] = *new(T) 58 | r.headPos++ 59 | if r.headPos == len(r.ring) { 60 | r.headPos = 0 61 | } 62 | return t 63 | } 64 | 65 | // Offset returns the offset element. 66 | // It must not be called when the buffer is empty, that means that 67 | // callers might need to check if there are elements in the buffer first 68 | // and check if the index larger than buffer length. 69 | func (r *RingBuffer[T]) Offset(index int) *T { 70 | if r.Empty() || index >= r.Len() { 71 | panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: offset from invalid index") 72 | } 73 | offset := (r.headPos + index) % len(r.ring) 74 | return &r.ring[offset] 75 | } 76 | 77 | // Front returns the front element. 78 | // It must not be called when the buffer is empty, that means that 79 | // callers might need to check if there are elements in the buffer first. 80 | func (r *RingBuffer[T]) Front() *T { 81 | if r.Empty() { 82 | panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: front from an empty queue") 83 | } 84 | return &r.ring[r.headPos] 85 | } 86 | 87 | // Back returns the back element. 88 | // It must not be called when the buffer is empty, that means that 89 | // callers might need to check if there are elements in the buffer first. 90 | func (r *RingBuffer[T]) Back() *T { 91 | if r.Empty() { 92 | panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: back from an empty queue") 93 | } 94 | return r.Offset(r.Len() - 1) 95 | } 96 | 97 | // Grow the maximum size of the queue. 98 | // This method assume the queue is full. 99 | func (r *RingBuffer[T]) grow() { 100 | oldRing := r.ring 101 | newSize := len(oldRing) * 2 102 | if newSize == 0 { 103 | newSize = 1 104 | } 105 | r.ring = make([]T, newSize) 106 | headLen := copy(r.ring, oldRing[r.headPos:]) 107 | copy(r.ring[headLen:], oldRing[:r.headPos]) 108 | r.headPos, r.tailPos, r.full = 0, len(oldRing), false 109 | } 110 | 111 | // Clear removes all elements. 112 | func (r *RingBuffer[T]) Clear() { 113 | var zeroValue T 114 | for i := range r.ring { 115 | r.ring[i] = zeroValue 116 | } 117 | r.headPos, r.tailPos, r.full = 0, 0, false 118 | } 119 | -------------------------------------------------------------------------------- /congestion_meta2/windowed_filter.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "golang.org/x/exp/constraints" 5 | ) 6 | 7 | // Implements Kathleen Nichols' algorithm for tracking the minimum (or maximum) 8 | // estimate of a stream of samples over some fixed time interval. (E.g., 9 | // the minimum RTT over the past five minutes.) The algorithm keeps track of 10 | // the best, second best, and third best min (or max) estimates, maintaining an 11 | // invariant that the measurement time of the n'th best >= n-1'th best. 12 | 13 | // The algorithm works as follows. On a reset, all three estimates are set to 14 | // the same sample. The second best estimate is then recorded in the second 15 | // quarter of the window, and a third best estimate is recorded in the second 16 | // half of the window, bounding the worst case error when the true min is 17 | // monotonically increasing (or true max is monotonically decreasing) over the 18 | // window. 19 | // 20 | // A new best sample replaces all three estimates, since the new best is lower 21 | // (or higher) than everything else in the window and it is the most recent. 22 | // The window thus effectively gets reset on every new min. The same property 23 | // holds true for second best and third best estimates. Specifically, when a 24 | // sample arrives that is better than the second best but not better than the 25 | // best, it replaces the second and third best estimates but not the best 26 | // estimate. Similarly, a sample that is better than the third best estimate 27 | // but not the other estimates replaces only the third best estimate. 28 | // 29 | // Finally, when the best expires, it is replaced by the second best, which in 30 | // turn is replaced by the third best. The newest sample replaces the third 31 | // best. 32 | 33 | type WindowedFilterValue interface { 34 | any 35 | } 36 | 37 | type WindowedFilterTime interface { 38 | constraints.Integer | constraints.Float 39 | } 40 | 41 | type WindowedFilter[V WindowedFilterValue, T WindowedFilterTime] struct { 42 | // Time length of window. 43 | windowLength T 44 | estimates []entry[V, T] 45 | comparator func(V, V) int 46 | } 47 | 48 | type entry[V WindowedFilterValue, T WindowedFilterTime] struct { 49 | sample V 50 | time T 51 | } 52 | 53 | // Compares two values and returns true if the first is greater than or equal 54 | // to the second. 55 | func MaxFilter[O constraints.Ordered](a, b O) int { 56 | if a > b { 57 | return 1 58 | } else if a < b { 59 | return -1 60 | } 61 | return 0 62 | } 63 | 64 | // Compares two values and returns true if the first is less than or equal 65 | // to the second. 66 | func MinFilter[O constraints.Ordered](a, b O) int { 67 | if a < b { 68 | return 1 69 | } else if a > b { 70 | return -1 71 | } 72 | return 0 73 | } 74 | 75 | func NewWindowedFilter[V WindowedFilterValue, T WindowedFilterTime](windowLength T, comparator func(V, V) int) *WindowedFilter[V, T] { 76 | return &WindowedFilter[V, T]{ 77 | windowLength: windowLength, 78 | estimates: make([]entry[V, T], 3, 3), 79 | comparator: comparator, 80 | } 81 | } 82 | 83 | // Changes the window length. Does not update any current samples. 84 | func (f *WindowedFilter[V, T]) SetWindowLength(windowLength T) { 85 | f.windowLength = windowLength 86 | } 87 | 88 | func (f *WindowedFilter[V, T]) GetBest() V { 89 | return f.estimates[0].sample 90 | } 91 | 92 | func (f *WindowedFilter[V, T]) GetSecondBest() V { 93 | return f.estimates[1].sample 94 | } 95 | 96 | func (f *WindowedFilter[V, T]) GetThirdBest() V { 97 | return f.estimates[2].sample 98 | } 99 | 100 | // Updates best estimates with |sample|, and expires and updates best 101 | // estimates as necessary. 102 | func (f *WindowedFilter[V, T]) Update(newSample V, newTime T) { 103 | // Reset all estimates if they have not yet been initialized, if new sample 104 | // is a new best, or if the newest recorded estimate is too old. 105 | if f.comparator(f.estimates[0].sample, *new(V)) == 0 || 106 | f.comparator(newSample, f.estimates[0].sample) >= 0 || 107 | newTime-f.estimates[2].time > f.windowLength { 108 | f.Reset(newSample, newTime) 109 | return 110 | } 111 | 112 | if f.comparator(newSample, f.estimates[1].sample) >= 0 { 113 | f.estimates[1] = entry[V, T]{newSample, newTime} 114 | f.estimates[2] = f.estimates[1] 115 | } else if f.comparator(newSample, f.estimates[2].sample) >= 0 { 116 | f.estimates[2] = entry[V, T]{newSample, newTime} 117 | } 118 | 119 | // Expire and update estimates as necessary. 120 | if newTime-f.estimates[0].time > f.windowLength { 121 | // The best estimate hasn't been updated for an entire window, so promote 122 | // second and third best estimates. 123 | f.estimates[0] = f.estimates[1] 124 | f.estimates[1] = f.estimates[2] 125 | f.estimates[2] = entry[V, T]{newSample, newTime} 126 | // Need to iterate one more time. Check if the new best estimate is 127 | // outside the window as well, since it may also have been recorded a 128 | // long time ago. Don't need to iterate once more since we cover that 129 | // case at the beginning of the method. 130 | if newTime-f.estimates[0].time > f.windowLength { 131 | f.estimates[0] = f.estimates[1] 132 | f.estimates[1] = f.estimates[2] 133 | } 134 | return 135 | } 136 | if f.comparator(f.estimates[1].sample, f.estimates[0].sample) == 0 && 137 | newTime-f.estimates[1].time > f.windowLength/4 { 138 | // A quarter of the window has passed without a better sample, so the 139 | // second-best estimate is taken from the second quarter of the window. 140 | f.estimates[1] = entry[V, T]{newSample, newTime} 141 | f.estimates[2] = f.estimates[1] 142 | return 143 | } 144 | 145 | if f.comparator(f.estimates[2].sample, f.estimates[1].sample) == 0 && 146 | newTime-f.estimates[2].time > f.windowLength/2 { 147 | // We've passed a half of the window without a better estimate, so take 148 | // a third-best estimate from the second half of the window. 149 | f.estimates[2] = entry[V, T]{newSample, newTime} 150 | } 151 | } 152 | 153 | // Resets all estimates to new sample. 154 | func (f *WindowedFilter[V, T]) Reset(newSample V, newTime T) { 155 | f.estimates[2] = entry[V, T]{newSample, newTime} 156 | f.estimates[1] = f.estimates[2] 157 | f.estimates[0] = f.estimates[1] 158 | } 159 | 160 | func (f *WindowedFilter[V, T]) Clear() { 161 | f.estimates = make([]entry[V, T], 3, 3) 162 | } 163 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/sagernet/sing-quic 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/gofrs/uuid/v5 v5.3.2 7 | github.com/sagernet/quic-go v0.49.0-beta.1 8 | github.com/sagernet/sing v0.6.5 9 | golang.org/x/crypto v0.32.0 10 | golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 11 | ) 12 | 13 | require ( 14 | github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect 15 | github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a // indirect 16 | github.com/onsi/ginkgo/v2 v2.9.7 // indirect 17 | github.com/quic-go/qpack v0.4.0 // indirect 18 | github.com/quic-go/qtls-go1-20 v0.4.1 // indirect 19 | golang.org/x/mod v0.20.0 // indirect 20 | golang.org/x/net v0.34.0 // indirect 21 | golang.org/x/sync v0.10.0 // indirect 22 | golang.org/x/sys v0.29.0 // indirect 23 | golang.org/x/text v0.21.0 // indirect 24 | golang.org/x/tools v0.24.0 // indirect 25 | ) 26 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= 5 | github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= 6 | github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= 7 | github.com/gofrs/uuid/v5 v5.3.2 h1:2jfO8j3XgSwlz/wHqemAEugfnTlikAYHhnqQ8Xh4fE0= 8 | github.com/gofrs/uuid/v5 v5.3.2/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= 9 | github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= 10 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 11 | github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a h1:fEBsGL/sjAuJrgah5XqmmYsTLzJp/TO9Lhy39gkverk= 12 | github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= 13 | github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss= 14 | github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0= 15 | github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= 16 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 17 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 18 | github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= 19 | github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= 20 | github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs= 21 | github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= 22 | github.com/sagernet/quic-go v0.49.0-beta.1 h1:3LdoCzVVfYRibZns1tYWSIoB65fpTmrwy+yfK8DQ8Jk= 23 | github.com/sagernet/quic-go v0.49.0-beta.1/go.mod h1:uesWD1Ihrldq1M3XtjuEvIUqi8WHNsRs71b3Lt1+p/U= 24 | github.com/sagernet/sing v0.6.5 h1:TBKTK6Ms0/MNTZm+cTC2hhKunE42XrNIdsxcYtWqeUU= 25 | github.com/sagernet/sing v0.6.5/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= 26 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 27 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 28 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 29 | golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= 30 | golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= 31 | golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= 32 | golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= 33 | golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= 34 | golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 35 | golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= 36 | golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= 37 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= 38 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 39 | golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= 40 | golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 41 | golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= 42 | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= 43 | golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= 44 | golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= 45 | google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= 46 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 47 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 48 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 49 | -------------------------------------------------------------------------------- /hysteria/client.go: -------------------------------------------------------------------------------- 1 | package hysteria 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "math" 7 | "net" 8 | "os" 9 | "runtime" 10 | "strconv" 11 | "strings" 12 | "sync" 13 | "time" 14 | 15 | "github.com/sagernet/quic-go" 16 | "github.com/sagernet/sing-quic" 17 | hyCC "github.com/sagernet/sing-quic/hysteria/congestion" 18 | "github.com/sagernet/sing/common/baderror" 19 | "github.com/sagernet/sing/common/bufio" 20 | "github.com/sagernet/sing/common/debug" 21 | E "github.com/sagernet/sing/common/exceptions" 22 | "github.com/sagernet/sing/common/logger" 23 | M "github.com/sagernet/sing/common/metadata" 24 | N "github.com/sagernet/sing/common/network" 25 | aTLS "github.com/sagernet/sing/common/tls" 26 | ) 27 | 28 | type ClientOptions struct { 29 | Context context.Context 30 | Dialer N.Dialer 31 | Logger logger.Logger 32 | BrutalDebug bool 33 | ServerAddress M.Socksaddr 34 | ServerPorts []string 35 | HopInterval time.Duration 36 | SendBPS uint64 37 | ReceiveBPS uint64 38 | XPlusPassword string 39 | Password string 40 | TLSConfig aTLS.Config 41 | UDPDisabled bool 42 | 43 | // Legacy options 44 | 45 | ConnReceiveWindow uint64 46 | StreamReceiveWindow uint64 47 | DisableMTUDiscovery bool 48 | } 49 | 50 | type Client struct { 51 | ctx context.Context 52 | dialer N.Dialer 53 | logger logger.Logger 54 | brutalDebug bool 55 | serverAddr M.Socksaddr 56 | serverPorts []uint16 57 | hopInterval time.Duration 58 | sendBPS uint64 59 | receiveBPS uint64 60 | xplusPassword string 61 | password string 62 | tlsConfig aTLS.Config 63 | quicConfig *quic.Config 64 | udpDisabled bool 65 | 66 | connAccess sync.RWMutex 67 | conn *clientQUICConnection 68 | } 69 | 70 | func NewClient(options ClientOptions) (*Client, error) { 71 | quicConfig := &quic.Config{ 72 | DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), 73 | EnableDatagrams: true, 74 | InitialStreamReceiveWindow: DefaultStreamReceiveWindow, 75 | MaxStreamReceiveWindow: DefaultStreamReceiveWindow, 76 | InitialConnectionReceiveWindow: DefaultConnReceiveWindow, 77 | MaxConnectionReceiveWindow: DefaultConnReceiveWindow, 78 | MaxIdleTimeout: DefaultMaxIdleTimeout, 79 | KeepAlivePeriod: DefaultKeepAlivePeriod, 80 | } 81 | if options.StreamReceiveWindow != 0 { 82 | quicConfig.InitialStreamReceiveWindow = options.StreamReceiveWindow 83 | quicConfig.MaxStreamReceiveWindow = options.StreamReceiveWindow 84 | } 85 | if options.ConnReceiveWindow != 0 { 86 | quicConfig.InitialConnectionReceiveWindow = options.ConnReceiveWindow 87 | quicConfig.MaxConnectionReceiveWindow = options.ConnReceiveWindow 88 | } 89 | if options.DisableMTUDiscovery { 90 | quicConfig.DisablePathMTUDiscovery = true 91 | } 92 | if len(options.TLSConfig.NextProtos()) == 0 { 93 | options.TLSConfig.SetNextProtos([]string{DefaultALPN}) 94 | } 95 | if options.SendBPS == 0 { 96 | return nil, E.New("missing upload speed") 97 | } else if options.SendBPS < MinSpeedBPS { 98 | return nil, E.New("invalid upload speed") 99 | } 100 | if options.ReceiveBPS == 0 { 101 | return nil, E.New("missing download speed") 102 | } else if options.ReceiveBPS < MinSpeedBPS { 103 | return nil, E.New("invalid download speed") 104 | } 105 | var serverPorts []uint16 106 | if len(options.ServerPorts) > 0 { 107 | var err error 108 | serverPorts, err = ParsePorts(options.ServerPorts) 109 | if err != nil { 110 | return nil, err 111 | } 112 | } 113 | return &Client{ 114 | ctx: options.Context, 115 | dialer: options.Dialer, 116 | logger: options.Logger, 117 | brutalDebug: options.BrutalDebug, 118 | serverAddr: options.ServerAddress, 119 | serverPorts: serverPorts, 120 | hopInterval: options.HopInterval, 121 | sendBPS: options.SendBPS, 122 | receiveBPS: options.ReceiveBPS, 123 | xplusPassword: options.XPlusPassword, 124 | password: options.Password, 125 | tlsConfig: options.TLSConfig, 126 | quicConfig: quicConfig, 127 | udpDisabled: options.UDPDisabled, 128 | }, nil 129 | } 130 | 131 | func ParsePorts(serverPorts []string) ([]uint16, error) { 132 | var portList []uint16 133 | for _, portRange := range serverPorts { 134 | if !strings.Contains(portRange, ":") { 135 | return nil, E.New("bad port range: ", portRange) 136 | } 137 | subIndex := strings.Index(portRange, ":") 138 | var ( 139 | start, end uint64 140 | err error 141 | ) 142 | if subIndex > 0 { 143 | start, err = strconv.ParseUint(portRange[:subIndex], 10, 16) 144 | if err != nil { 145 | return nil, E.Cause(err, E.Cause(err, "bad port range: ", portRange)) 146 | } 147 | } 148 | if subIndex == len(portRange)-1 { 149 | end = math.MaxUint16 150 | } else { 151 | end, err = strconv.ParseUint(portRange[subIndex+1:], 10, 16) 152 | if err != nil { 153 | return nil, E.Cause(err, E.Cause(err, "bad port range: ", portRange)) 154 | } 155 | } 156 | for i := start; i <= end; i++ { 157 | portList = append(portList, uint16(i)) 158 | } 159 | } 160 | return portList, nil 161 | } 162 | 163 | func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { 164 | conn := c.conn 165 | if conn != nil && conn.active() { 166 | return conn, nil 167 | } 168 | c.connAccess.Lock() 169 | defer c.connAccess.Unlock() 170 | conn = c.conn 171 | if conn != nil && conn.active() { 172 | return conn, nil 173 | } 174 | conn, err := c.offerNew(ctx) 175 | if err != nil { 176 | return nil, err 177 | } 178 | return conn, nil 179 | } 180 | 181 | func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { 182 | dialFunc := func(serverAddr M.Socksaddr) (net.PacketConn, error) { 183 | udpConn, err := c.dialer.DialContext(c.ctx, "udp", serverAddr) 184 | if err != nil { 185 | return nil, err 186 | } 187 | var packetConn net.PacketConn 188 | packetConn = bufio.NewUnbindPacketConn(udpConn) 189 | if c.xplusPassword != "" { 190 | packetConn = NewXPlusPacketConn(packetConn, []byte(c.xplusPassword)) 191 | } 192 | return packetConn, nil 193 | } 194 | var ( 195 | packetConn net.PacketConn 196 | err error 197 | ) 198 | if len(c.serverPorts) == 0 { 199 | packetConn, err = dialFunc(c.serverAddr) 200 | } else { 201 | packetConn, err = NewHopPacketConn(dialFunc, c.serverAddr, c.serverPorts, c.hopInterval) 202 | } 203 | if err != nil { 204 | return nil, err 205 | } 206 | quicConn, err := qtls.Dial(c.ctx, packetConn, c.serverAddr, c.tlsConfig, c.quicConfig) 207 | if err != nil { 208 | packetConn.Close() 209 | return nil, err 210 | } 211 | controlStream, err := quicConn.OpenStreamSync(ctx) 212 | if err != nil { 213 | packetConn.Close() 214 | return nil, err 215 | } 216 | err = WriteClientHello(controlStream, ClientHello{ 217 | SendBPS: c.sendBPS, 218 | RecvBPS: c.receiveBPS, 219 | Auth: c.password, 220 | }) 221 | if err != nil { 222 | packetConn.Close() 223 | return nil, err 224 | } 225 | serverHello, err := ReadServerHello(controlStream) 226 | if err != nil { 227 | packetConn.Close() 228 | return nil, err 229 | } 230 | if !serverHello.OK { 231 | packetConn.Close() 232 | return nil, E.New("remote error: ", serverHello.Message) 233 | } 234 | quicConn.SetCongestionControl(hyCC.NewBrutalSender(uint64(math.Min(float64(serverHello.RecvBPS), float64(c.sendBPS))), c.brutalDebug, c.logger)) 235 | conn := &clientQUICConnection{ 236 | quicConn: quicConn, 237 | rawConn: packetConn, 238 | connDone: make(chan struct{}), 239 | udpDisabled: !quicConn.ConnectionState().SupportsDatagrams, 240 | udpConnMap: make(map[uint32]*udpPacketConn), 241 | } 242 | if !c.udpDisabled { 243 | go c.loopMessages(conn) 244 | } 245 | c.conn = conn 246 | return conn, nil 247 | } 248 | 249 | func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { 250 | conn, err := c.offer(ctx) 251 | if err != nil { 252 | return nil, err 253 | } 254 | stream, err := conn.quicConn.OpenStream() 255 | if err != nil { 256 | return nil, err 257 | } 258 | return &clientConn{ 259 | Stream: stream, 260 | destination: destination, 261 | }, nil 262 | } 263 | 264 | func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 265 | if c.udpDisabled { 266 | return nil, os.ErrInvalid 267 | } 268 | conn, err := c.offer(ctx) 269 | if err != nil { 270 | return nil, err 271 | } 272 | if conn.udpDisabled { 273 | return nil, E.New("UDP disabled by server") 274 | } 275 | stream, err := conn.quicConn.OpenStream() 276 | if err != nil { 277 | return nil, err 278 | } 279 | buffer := WriteClientRequest(ClientRequest{ 280 | UDP: true, 281 | Host: destination.AddrString(), 282 | Port: destination.Port, 283 | }, nil) 284 | _, err = stream.Write(buffer.Bytes()) 285 | buffer.Release() 286 | if err != nil { 287 | stream.Close() 288 | return nil, err 289 | } 290 | response, err := ReadServerResponse(stream) 291 | if err != nil { 292 | stream.Close() 293 | return nil, err 294 | } 295 | if !response.OK { 296 | stream.Close() 297 | return nil, E.New("remote error: ", response.Message) 298 | } 299 | clientPacketConn := newUDPPacketConn(c.ctx, conn.quicConn, func() { 300 | stream.CancelRead(0) 301 | stream.Close() 302 | conn.udpAccess.Lock() 303 | delete(conn.udpConnMap, response.UDPSessionID) 304 | conn.udpAccess.Unlock() 305 | }) 306 | conn.udpAccess.Lock() 307 | if debug.Enabled { 308 | if _, connExists := conn.udpConnMap[response.UDPSessionID]; connExists { 309 | stream.Close() 310 | return nil, E.New("udp session id duplicated") 311 | } 312 | } 313 | conn.udpConnMap[response.UDPSessionID] = clientPacketConn 314 | conn.udpAccess.Unlock() 315 | clientPacketConn.sessionID = response.UDPSessionID 316 | go func() { 317 | holdBuffer := make([]byte, 1024) 318 | for { 319 | _, hErr := stream.Read(holdBuffer) 320 | if hErr != nil { 321 | break 322 | } 323 | } 324 | clientPacketConn.closeWithError(E.Cause(net.ErrClosed, "hold stream closed")) 325 | }() 326 | return clientPacketConn, nil 327 | } 328 | 329 | func (c *Client) CloseWithError(err error) error { 330 | conn := c.conn 331 | if conn != nil { 332 | conn.closeWithError(err) 333 | } 334 | return nil 335 | } 336 | 337 | type clientQUICConnection struct { 338 | quicConn quic.Connection 339 | rawConn io.Closer 340 | closeOnce sync.Once 341 | connDone chan struct{} 342 | connErr error 343 | udpDisabled bool 344 | udpAccess sync.RWMutex 345 | udpConnMap map[uint32]*udpPacketConn 346 | } 347 | 348 | func (c *clientQUICConnection) active() bool { 349 | select { 350 | case <-c.quicConn.Context().Done(): 351 | return false 352 | default: 353 | } 354 | select { 355 | case <-c.connDone: 356 | return false 357 | default: 358 | } 359 | return true 360 | } 361 | 362 | func (c *clientQUICConnection) closeWithError(err error) { 363 | c.closeOnce.Do(func() { 364 | c.connErr = err 365 | close(c.connDone) 366 | _ = c.quicConn.CloseWithError(0, "") 367 | _ = c.rawConn.Close() 368 | }) 369 | } 370 | 371 | type clientConn struct { 372 | quic.Stream 373 | destination M.Socksaddr 374 | requestWritten bool 375 | responseRead bool 376 | } 377 | 378 | func (c *clientConn) NeedHandshake() bool { 379 | return !c.requestWritten 380 | } 381 | 382 | func (c *clientConn) Read(p []byte) (n int, err error) { 383 | if c.responseRead { 384 | n, err = c.Stream.Read(p) 385 | return n, baderror.WrapQUIC(err) 386 | } 387 | response, err := ReadServerResponse(c.Stream) 388 | if err != nil { 389 | return 0, baderror.WrapQUIC(err) 390 | } 391 | if !response.OK { 392 | err = E.New("remote error: ", response.Message) 393 | return 394 | } 395 | c.responseRead = true 396 | n, err = c.Stream.Read(p) 397 | return n, baderror.WrapQUIC(err) 398 | } 399 | 400 | func (c *clientConn) Write(p []byte) (n int, err error) { 401 | if !c.requestWritten { 402 | buffer := WriteClientRequest(ClientRequest{ 403 | UDP: false, 404 | Host: c.destination.AddrString(), 405 | Port: c.destination.Port, 406 | }, p) 407 | defer buffer.Release() 408 | _, err = c.Stream.Write(buffer.Bytes()) 409 | if err != nil { 410 | return 411 | } 412 | c.requestWritten = true 413 | return len(p), nil 414 | } 415 | n, err = c.Stream.Write(p) 416 | return n, baderror.WrapQUIC(err) 417 | } 418 | 419 | func (c *clientConn) LocalAddr() net.Addr { 420 | return M.Socksaddr{} 421 | } 422 | 423 | func (c *clientConn) RemoteAddr() net.Addr { 424 | return M.Socksaddr{} 425 | } 426 | 427 | func (c *clientConn) Close() error { 428 | c.Stream.CancelRead(0) 429 | return c.Stream.Close() 430 | } 431 | -------------------------------------------------------------------------------- /hysteria/client_packet.go: -------------------------------------------------------------------------------- 1 | package hysteria 2 | 3 | import E "github.com/sagernet/sing/common/exceptions" 4 | 5 | func (c *Client) loopMessages(conn *clientQUICConnection) { 6 | for { 7 | message, err := conn.quicConn.ReceiveDatagram(c.ctx) 8 | if err != nil { 9 | conn.closeWithError(E.Cause(err, "receive message")) 10 | return 11 | } 12 | go func() { 13 | hErr := c.handleMessage(conn, message) 14 | if hErr != nil { 15 | conn.closeWithError(E.Cause(hErr, "handle message")) 16 | } 17 | }() 18 | } 19 | } 20 | 21 | func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error { 22 | message := allocMessage() 23 | err := decodeUDPMessage(message, data) 24 | if err != nil { 25 | message.release() 26 | return E.Cause(err, "decode UDP message") 27 | } 28 | conn.handleUDPMessage(message) 29 | return nil 30 | } 31 | 32 | func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) { 33 | c.udpAccess.RLock() 34 | udpConn, loaded := c.udpConnMap[message.sessionID] 35 | c.udpAccess.RUnlock() 36 | if !loaded { 37 | message.releaseMessage() 38 | return 39 | } 40 | select { 41 | case <-udpConn.ctx.Done(): 42 | message.releaseMessage() 43 | return 44 | default: 45 | } 46 | udpConn.inputPacket(message) 47 | } 48 | -------------------------------------------------------------------------------- /hysteria/congestion/brutal.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/sagernet/quic-go/congestion" 8 | "github.com/sagernet/sing/common/logger" 9 | ) 10 | 11 | const ( 12 | initMaxDatagramSize = 1252 13 | 14 | pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample 15 | minSampleCount = 50 16 | minAckRate = 0.8 17 | congestionWindowMultiplier = 2 18 | debugPrintInterval = 2 19 | ) 20 | 21 | var _ congestion.CongestionControlEx = &BrutalSender{} 22 | 23 | type BrutalSender struct { 24 | rttStats congestion.RTTStatsProvider 25 | bps congestion.ByteCount 26 | maxDatagramSize congestion.ByteCount 27 | pacer *pacer 28 | 29 | pktInfoSlots [pktInfoSlotCount]pktInfo 30 | ackRate float64 31 | debug bool 32 | logger logger.Logger 33 | lastAckPrintTimestamp int64 34 | } 35 | 36 | type pktInfo struct { 37 | Timestamp int64 38 | AckCount uint64 39 | LossCount uint64 40 | } 41 | 42 | func NewBrutalSender(bps uint64, debug bool, logger logger.Logger) *BrutalSender { 43 | bs := &BrutalSender{ 44 | bps: congestion.ByteCount(bps), 45 | maxDatagramSize: initMaxDatagramSize, 46 | ackRate: 1, 47 | debug: debug, 48 | logger: logger, 49 | } 50 | bs.pacer = newPacer(func() congestion.ByteCount { 51 | return congestion.ByteCount(float64(bs.bps) / bs.ackRate) 52 | }) 53 | return bs 54 | } 55 | 56 | func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) { 57 | b.rttStats = rttStats 58 | } 59 | 60 | func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { 61 | return b.pacer.TimeUntilSend() 62 | } 63 | 64 | func (b *BrutalSender) HasPacingBudget(now time.Time) bool { 65 | return b.pacer.Budget(now) >= b.maxDatagramSize 66 | } 67 | 68 | func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool { 69 | return bytesInFlight < b.GetCongestionWindow() 70 | } 71 | 72 | func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { 73 | rtt := b.rttStats.SmoothedRTT() 74 | if rtt <= 0 { 75 | return 10240 76 | } 77 | return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate) 78 | } 79 | 80 | func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, 81 | packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool, 82 | ) { 83 | b.pacer.SentPacket(sentTime, bytes) 84 | } 85 | 86 | func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, 87 | priorInFlight congestion.ByteCount, eventTime time.Time, 88 | ) { 89 | // Stub 90 | } 91 | 92 | func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount, 93 | priorInFlight congestion.ByteCount, 94 | ) { 95 | // Stub 96 | } 97 | 98 | func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { 99 | currentTimestamp := eventTime.Unix() 100 | slot := currentTimestamp % pktInfoSlotCount 101 | if b.pktInfoSlots[slot].Timestamp == currentTimestamp { 102 | b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets)) 103 | b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets)) 104 | } else { 105 | // uninitialized slot or too old, reset 106 | b.pktInfoSlots[slot].Timestamp = currentTimestamp 107 | b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets)) 108 | b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets)) 109 | } 110 | b.updateAckRate(currentTimestamp) 111 | } 112 | 113 | func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { 114 | b.maxDatagramSize = size 115 | b.pacer.SetMaxDatagramSize(size) 116 | if b.debug { 117 | b.debugPrint("SetMaxDatagramSize: %d", size) 118 | } 119 | } 120 | 121 | func (b *BrutalSender) updateAckRate(currentTimestamp int64) { 122 | minTimestamp := currentTimestamp - pktInfoSlotCount 123 | var ackCount, lossCount uint64 124 | for _, info := range b.pktInfoSlots { 125 | if info.Timestamp < minTimestamp { 126 | continue 127 | } 128 | ackCount += info.AckCount 129 | lossCount += info.LossCount 130 | } 131 | if ackCount+lossCount < minSampleCount { 132 | b.ackRate = 1 133 | if b.canPrintAckRate(currentTimestamp) { 134 | b.lastAckPrintTimestamp = currentTimestamp 135 | b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)", 136 | ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) 137 | } 138 | return 139 | } 140 | rate := float64(ackCount) / float64(ackCount+lossCount) 141 | if rate < minAckRate { 142 | b.ackRate = minAckRate 143 | if b.canPrintAckRate(currentTimestamp) { 144 | b.lastAckPrintTimestamp = currentTimestamp 145 | b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", 146 | rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) 147 | } 148 | return 149 | } 150 | b.ackRate = rate 151 | if b.canPrintAckRate(currentTimestamp) { 152 | b.lastAckPrintTimestamp = currentTimestamp 153 | b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", 154 | rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) 155 | } 156 | } 157 | 158 | func (b *BrutalSender) InSlowStart() bool { 159 | return false 160 | } 161 | 162 | func (b *BrutalSender) InRecovery() bool { 163 | return false 164 | } 165 | 166 | func (b *BrutalSender) MaybeExitSlowStart() {} 167 | 168 | func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {} 169 | 170 | func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool { 171 | return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval 172 | } 173 | 174 | func (b *BrutalSender) debugPrint(format string, a ...any) { 175 | b.logger.Debug("[brutal] ", fmt.Sprintf(format, a...)) 176 | } 177 | 178 | func maxDuration(a, b time.Duration) time.Duration { 179 | if a > b { 180 | return a 181 | } 182 | return b 183 | } 184 | -------------------------------------------------------------------------------- /hysteria/congestion/pacer.go: -------------------------------------------------------------------------------- 1 | package congestion 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/sagernet/quic-go/congestion" 7 | ) 8 | 9 | const ( 10 | maxBurstPackets = 10 11 | maxBurstPacingDelayMultiplier = 4 12 | minPacingDelay = time.Millisecond 13 | ) 14 | 15 | // The pacer implements a token bucket pacing algorithm. 16 | type pacer struct { 17 | budgetAtLastSent congestion.ByteCount 18 | maxDatagramSize congestion.ByteCount 19 | lastSentTime time.Time 20 | getBandwidth func() congestion.ByteCount // in bytes/s 21 | } 22 | 23 | func newPacer(getBandwidth func() congestion.ByteCount) *pacer { 24 | p := &pacer{ 25 | budgetAtLastSent: maxBurstPackets * initMaxDatagramSize, 26 | maxDatagramSize: initMaxDatagramSize, 27 | getBandwidth: getBandwidth, 28 | } 29 | return p 30 | } 31 | 32 | func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { 33 | budget := p.Budget(sendTime) 34 | if size > budget { 35 | p.budgetAtLastSent = 0 36 | } else { 37 | p.budgetAtLastSent = budget - size 38 | } 39 | p.lastSentTime = sendTime 40 | } 41 | 42 | func (p *pacer) Budget(now time.Time) congestion.ByteCount { 43 | if p.lastSentTime.IsZero() { 44 | return p.maxBurstSize() 45 | } 46 | budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 47 | if budget < 0 { // protect against overflows 48 | budget = congestion.ByteCount(1<<62 - 1) 49 | } 50 | return minByteCount(p.maxBurstSize(), budget) 51 | } 52 | 53 | func (p *pacer) maxBurstSize() congestion.ByteCount { 54 | return maxByteCount( 55 | congestion.ByteCount((maxBurstPacingDelayMultiplier*minPacingDelay).Nanoseconds())*p.getBandwidth()/1e9, 56 | maxBurstPackets*p.maxDatagramSize, 57 | ) 58 | } 59 | 60 | // TimeUntilSend returns when the next packet should be sent. 61 | // It returns the zero value of time.Time if a packet can be sent immediately. 62 | func (p *pacer) TimeUntilSend() time.Time { 63 | if p.budgetAtLastSent >= p.maxDatagramSize { 64 | return time.Time{} 65 | } 66 | diff := 1e9 * uint64(p.maxDatagramSize-p.budgetAtLastSent) 67 | bw := uint64(p.getBandwidth()) 68 | // We might need to round up this value. 69 | // Otherwise, we might have a budget (slightly) smaller than the datagram size when the timer expires. 70 | d := diff / bw 71 | // this is effectively a math.Ceil, but using only integer math 72 | if diff%bw > 0 { 73 | d++ 74 | } 75 | return p.lastSentTime.Add(maxDuration(congestion.MinPacingDelay, time.Duration(d)*time.Nanosecond)) 76 | } 77 | 78 | func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) { 79 | p.maxDatagramSize = s 80 | } 81 | 82 | func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount { 83 | if a < b { 84 | return b 85 | } 86 | return a 87 | } 88 | 89 | func minByteCount(a, b congestion.ByteCount) congestion.ByteCount { 90 | if a < b { 91 | return a 92 | } 93 | return b 94 | } 95 | -------------------------------------------------------------------------------- /hysteria/hop.go: -------------------------------------------------------------------------------- 1 | package hysteria 2 | 3 | import ( 4 | "errors" 5 | "math/rand" 6 | "net" 7 | "os" 8 | "sync" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/sagernet/sing/common" 13 | "github.com/sagernet/sing/common/buf" 14 | M "github.com/sagernet/sing/common/metadata" 15 | ) 16 | 17 | const ( 18 | packetQueueSize = 1024 19 | udpBufferSize = 2048 20 | defaultHopInterval = 30 * time.Second 21 | ) 22 | 23 | type HopPacketConn struct { 24 | dialFunc func(M.Socksaddr) (net.PacketConn, error) 25 | destination M.Socksaddr 26 | ports []uint16 27 | interval time.Duration 28 | access sync.Mutex 29 | prevConn net.PacketConn 30 | currentConn net.PacketConn 31 | portIndex int 32 | readBufferSize int 33 | writeBufferSize int 34 | packetChan chan *buf.Buffer 35 | errChan chan error 36 | doneChan chan struct{} 37 | done bool 38 | } 39 | 40 | func NewHopPacketConn( 41 | dialFunc func(M.Socksaddr) (net.PacketConn, error), 42 | destination M.Socksaddr, 43 | ports []uint16, 44 | interval time.Duration, 45 | ) (*HopPacketConn, error) { 46 | if interval == 0 { 47 | interval = defaultHopInterval 48 | } 49 | hopConn := &HopPacketConn{ 50 | dialFunc: dialFunc, 51 | destination: destination, 52 | ports: ports, 53 | interval: interval, 54 | packetChan: make(chan *buf.Buffer, packetQueueSize), 55 | errChan: make(chan error), 56 | doneChan: make(chan struct{}), 57 | } 58 | currentConn, err := dialFunc(hopConn.nextAddr()) 59 | if err != nil { 60 | return nil, err 61 | } 62 | hopConn.currentConn = currentConn 63 | go hopConn.recvLoop(currentConn) 64 | go hopConn.hopLoop() 65 | return hopConn, nil 66 | } 67 | 68 | func (c *HopPacketConn) nextAddr() M.Socksaddr { 69 | c.portIndex = rand.Intn(len(c.ports)) 70 | return M.Socksaddr{ 71 | Addr: c.destination.Addr, 72 | Fqdn: c.destination.Fqdn, 73 | Port: c.ports[c.portIndex], 74 | } 75 | } 76 | 77 | func (c *HopPacketConn) recvLoop(conn net.PacketConn) { 78 | for { 79 | buffer := buf.NewSize(udpBufferSize) 80 | n, _, err := conn.ReadFrom(buffer.FreeBytes()) 81 | if err != nil { 82 | buffer.Release() 83 | var netErr net.Error 84 | if errors.As(err, &netErr) && netErr.Timeout() { 85 | // Only pass through timeout errors here, not permanent errors 86 | // like connection closed. Connection close is normal as we close 87 | // the old connection to exit this loop every time we hop. 88 | c.errChan <- netErr 89 | } 90 | return 91 | } 92 | buffer.Truncate(n) 93 | select { 94 | case c.packetChan <- buffer: 95 | default: 96 | buffer.Release() 97 | } 98 | } 99 | } 100 | 101 | func (c *HopPacketConn) hopLoop() { 102 | ticker := time.NewTicker(c.interval) 103 | defer ticker.Stop() 104 | for { 105 | select { 106 | case <-ticker.C: 107 | c.hop() 108 | case <-c.doneChan: 109 | return 110 | } 111 | } 112 | } 113 | 114 | func (c *HopPacketConn) hop() { 115 | c.access.Lock() 116 | defer c.access.Unlock() 117 | if c.done { 118 | return 119 | } 120 | nextAddr := c.nextAddr() 121 | newConn, err := c.dialFunc(nextAddr) 122 | if err != nil { 123 | return 124 | } 125 | if c.prevConn != nil { 126 | c.prevConn.Close() 127 | } 128 | c.prevConn = c.currentConn 129 | c.currentConn = newConn 130 | if c.readBufferSize > 0 { 131 | _ = trySetReadBuffer(newConn, c.readBufferSize) 132 | } 133 | if c.writeBufferSize > 0 { 134 | _ = trySetWriteBuffer(newConn, c.writeBufferSize) 135 | } 136 | go c.recvLoop(newConn) 137 | } 138 | 139 | func (c *HopPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { 140 | for { 141 | select { 142 | case packet := <-c.packetChan: 143 | n = copy(b, packet.Bytes()) 144 | packet.Release() 145 | return n, (*hopFakeAddr)(nil), nil 146 | case err = <-c.errChan: 147 | return 0, nil, err 148 | case <-c.doneChan: 149 | return 0, nil, net.ErrClosed 150 | } 151 | } 152 | } 153 | 154 | func (c *HopPacketConn) WriteTo(b []byte, _ net.Addr) (n int, err error) { 155 | c.access.Lock() 156 | defer c.access.Unlock() 157 | if c.done { 158 | return 0, net.ErrClosed 159 | } 160 | return c.currentConn.WriteTo(b, (*hopFakeAddr)(nil)) 161 | } 162 | 163 | func (c *HopPacketConn) Close() error { 164 | c.access.Lock() 165 | defer c.access.Unlock() 166 | if c.done { 167 | return nil 168 | } 169 | if c.prevConn != nil { 170 | _ = c.prevConn.Close() 171 | } 172 | err := c.currentConn.Close() 173 | close(c.doneChan) 174 | c.done = true 175 | return err 176 | } 177 | 178 | func (c *HopPacketConn) LocalAddr() net.Addr { 179 | c.access.Lock() 180 | defer c.access.Unlock() 181 | return c.currentConn.LocalAddr() 182 | } 183 | 184 | func (c *HopPacketConn) SetDeadline(t time.Time) error { 185 | c.access.Lock() 186 | defer c.access.Unlock() 187 | if c.prevConn != nil { 188 | _ = c.prevConn.SetDeadline(t) 189 | } 190 | return c.currentConn.SetDeadline(t) 191 | } 192 | 193 | func (c *HopPacketConn) SetReadDeadline(t time.Time) error { 194 | c.access.Lock() 195 | defer c.access.Unlock() 196 | if c.prevConn != nil { 197 | _ = c.prevConn.SetReadDeadline(t) 198 | } 199 | return c.currentConn.SetReadDeadline(t) 200 | } 201 | 202 | func (c *HopPacketConn) SetWriteDeadline(t time.Time) error { 203 | c.access.Lock() 204 | defer c.access.Unlock() 205 | if c.prevConn != nil { 206 | _ = c.prevConn.SetWriteDeadline(t) 207 | } 208 | return c.currentConn.SetWriteDeadline(t) 209 | } 210 | 211 | func (c *HopPacketConn) SetReadBuffer(bytes int) error { 212 | c.access.Lock() 213 | defer c.access.Unlock() 214 | c.readBufferSize = bytes 215 | if c.prevConn != nil { 216 | _ = trySetReadBuffer(c.prevConn, bytes) 217 | } 218 | return trySetReadBuffer(c.currentConn, bytes) 219 | } 220 | 221 | func (c *HopPacketConn) SetWriteBuffer(bytes int) error { 222 | c.access.Lock() 223 | defer c.access.Unlock() 224 | c.writeBufferSize = bytes 225 | if c.prevConn != nil { 226 | _ = trySetWriteBuffer(c.prevConn, bytes) 227 | } 228 | return trySetWriteBuffer(c.currentConn, bytes) 229 | } 230 | 231 | func (c *HopPacketConn) SyscallConn() (syscall.RawConn, error) { 232 | c.access.Lock() 233 | defer c.access.Unlock() 234 | rawConn, isRawConn := common.Cast[syscall.Conn](c.currentConn) 235 | if !isRawConn { 236 | return nil, os.ErrInvalid 237 | } 238 | return rawConn.SyscallConn() 239 | } 240 | 241 | func trySetReadBuffer(pc any, bytes int) error { 242 | udpConn, isUDPConn := common.Cast[interface { 243 | SetReadBuffer(bytes int) error 244 | }](pc) 245 | if !isUDPConn { 246 | return nil 247 | } 248 | return udpConn.SetReadBuffer(bytes) 249 | } 250 | 251 | func trySetWriteBuffer(pc any, bytes int) error { 252 | udpConn, isUDPConn := common.Cast[interface { 253 | SetWriteBuffer(bytes int) error 254 | }](pc) 255 | if !isUDPConn { 256 | return nil 257 | } 258 | return udpConn.SetWriteBuffer(bytes) 259 | } 260 | 261 | type hopFakeAddr struct{} 262 | 263 | func (a *hopFakeAddr) Network() string { 264 | return "udphop" 265 | } 266 | 267 | func (a *hopFakeAddr) String() string { 268 | return "" 269 | } 270 | -------------------------------------------------------------------------------- /hysteria/packet.go: -------------------------------------------------------------------------------- 1 | package hysteria 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/binary" 7 | "errors" 8 | "io" 9 | "math" 10 | "net" 11 | "os" 12 | "sync" 13 | "time" 14 | 15 | "github.com/sagernet/quic-go" 16 | "github.com/sagernet/sing/common" 17 | "github.com/sagernet/sing/common/atomic" 18 | "github.com/sagernet/sing/common/buf" 19 | "github.com/sagernet/sing/common/cache" 20 | E "github.com/sagernet/sing/common/exceptions" 21 | M "github.com/sagernet/sing/common/metadata" 22 | N "github.com/sagernet/sing/common/network" 23 | "github.com/sagernet/sing/common/pipe" 24 | ) 25 | 26 | var udpMessagePool = sync.Pool{ 27 | New: func() interface{} { 28 | return new(udpMessage) 29 | }, 30 | } 31 | 32 | func allocMessage() *udpMessage { 33 | message := udpMessagePool.Get().(*udpMessage) 34 | message.referenced = true 35 | return message 36 | } 37 | 38 | func releaseMessages(messages []*udpMessage) { 39 | for _, message := range messages { 40 | if message != nil { 41 | message.release() 42 | } 43 | } 44 | } 45 | 46 | type udpMessage struct { 47 | sessionID uint32 48 | packetID uint16 49 | fragmentID uint8 50 | fragmentTotal uint8 51 | host string 52 | port uint16 53 | data *buf.Buffer 54 | referenced bool 55 | } 56 | 57 | func (m *udpMessage) release() { 58 | if !m.referenced { 59 | return 60 | } 61 | *m = udpMessage{} 62 | udpMessagePool.Put(m) 63 | } 64 | 65 | func (m *udpMessage) releaseMessage() { 66 | m.data.Release() 67 | m.release() 68 | } 69 | 70 | func (m *udpMessage) pack() *buf.Buffer { 71 | buffer := buf.NewSize(m.headerSize() + m.data.Len()) 72 | common.Must( 73 | binary.Write(buffer, binary.BigEndian, m.sessionID), 74 | binary.Write(buffer, binary.BigEndian, uint16(len(m.host))), 75 | common.Error(buffer.WriteString(m.host)), 76 | binary.Write(buffer, binary.BigEndian, m.port), 77 | binary.Write(buffer, binary.BigEndian, m.packetID), 78 | binary.Write(buffer, binary.BigEndian, m.fragmentID), 79 | binary.Write(buffer, binary.BigEndian, m.fragmentTotal), 80 | binary.Write(buffer, binary.BigEndian, uint16(m.data.Len())), 81 | common.Error(buffer.Write(m.data.Bytes())), 82 | ) 83 | return buffer 84 | } 85 | 86 | func (m *udpMessage) headerSize() int { 87 | return 14 + len(m.host) 88 | } 89 | 90 | func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { 91 | udpMTU := maxPacketSize - message.headerSize() 92 | if message.data.Len() <= udpMTU { 93 | return []*udpMessage{message} 94 | } 95 | var fragments []*udpMessage 96 | originPacket := message.data.Bytes() 97 | for remaining := len(originPacket); remaining > 0; remaining -= udpMTU { 98 | fragment := allocMessage() 99 | *fragment = *message 100 | if remaining > udpMTU { 101 | fragment.data = buf.As(originPacket[:udpMTU]) 102 | originPacket = originPacket[udpMTU:] 103 | } else { 104 | fragment.data = buf.As(originPacket) 105 | originPacket = nil 106 | } 107 | fragments = append(fragments, fragment) 108 | } 109 | fragmentTotal := uint16(len(fragments)) 110 | for index, fragment := range fragments { 111 | fragment.fragmentID = uint8(index) 112 | fragment.fragmentTotal = uint8(fragmentTotal) 113 | /*if index > 0 { 114 | fragment.destination = "" 115 | // not work in hysteria 116 | }*/ 117 | } 118 | return fragments 119 | } 120 | 121 | type udpPacketConn struct { 122 | ctx context.Context 123 | cancel common.ContextCancelCauseFunc 124 | sessionID uint32 125 | quicConn quic.Connection 126 | data chan *udpMessage 127 | udpMTU int 128 | packetId atomic.Uint32 129 | closeOnce sync.Once 130 | defragger *udpDefragger 131 | onDestroy func() 132 | readWaitOptions N.ReadWaitOptions 133 | readDeadline pipe.Deadline 134 | } 135 | 136 | func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn { 137 | ctx, cancel := common.ContextWithCancelCause(ctx) 138 | return &udpPacketConn{ 139 | ctx: ctx, 140 | cancel: cancel, 141 | quicConn: quicConn, 142 | data: make(chan *udpMessage, 64), 143 | udpMTU: 1200 - 3, 144 | defragger: newUDPDefragger(), 145 | onDestroy: onDestroy, 146 | readDeadline: pipe.MakeDeadline(), 147 | } 148 | } 149 | 150 | func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 151 | select { 152 | case p := <-c.data: 153 | _, err = buffer.ReadOnceFrom(p.data) 154 | destination = M.ParseSocksaddrHostPort(p.host, p.port) 155 | p.releaseMessage() 156 | return 157 | case <-c.ctx.Done(): 158 | return M.Socksaddr{}, io.ErrClosedPipe 159 | case <-c.readDeadline.Wait(): 160 | return M.Socksaddr{}, os.ErrDeadlineExceeded 161 | } 162 | } 163 | 164 | func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 165 | select { 166 | case pkt := <-c.data: 167 | n = copy(p, pkt.data.Bytes()) 168 | destination := M.ParseSocksaddrHostPort(pkt.host, pkt.port) 169 | if destination.IsFqdn() { 170 | addr = destination 171 | } else { 172 | addr = destination.UDPAddr() 173 | } 174 | pkt.releaseMessage() 175 | return n, addr, nil 176 | case <-c.ctx.Done(): 177 | return 0, nil, io.ErrClosedPipe 178 | case <-c.readDeadline.Wait(): 179 | return 0, nil, os.ErrDeadlineExceeded 180 | } 181 | } 182 | 183 | func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 184 | defer buffer.Release() 185 | select { 186 | case <-c.ctx.Done(): 187 | return net.ErrClosed 188 | default: 189 | } 190 | if buffer.Len() > 0xffff { 191 | return &quic.DatagramTooLargeError{MaxDatagramPayloadSize: 0xffff} 192 | } 193 | packetId := uint16(c.packetId.Add(1) % math.MaxUint16) 194 | message := allocMessage() 195 | *message = udpMessage{ 196 | sessionID: c.sessionID, 197 | packetID: packetId, 198 | fragmentTotal: 1, 199 | host: destination.AddrString(), 200 | port: destination.Port, 201 | data: buffer, 202 | } 203 | defer message.releaseMessage() 204 | var err error 205 | if buffer.Len() > c.udpMTU-message.headerSize() { 206 | err = c.writePackets(fragUDPMessage(message, c.udpMTU)) 207 | } else { 208 | err = c.writePacket(message) 209 | } 210 | if err == nil { 211 | return nil 212 | } 213 | var tooLargeErr *quic.DatagramTooLargeError 214 | if !errors.As(err, &tooLargeErr) { 215 | return err 216 | } 217 | return c.writePackets(fragUDPMessage(message, int(tooLargeErr.MaxDatagramPayloadSize-3))) 218 | } 219 | 220 | func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 221 | select { 222 | case <-c.ctx.Done(): 223 | return 0, net.ErrClosed 224 | default: 225 | } 226 | if len(p) > 0xffff { 227 | return 0, &quic.DatagramTooLargeError{MaxDatagramPayloadSize: 0xffff} 228 | } 229 | packetId := uint16(c.packetId.Add(1) % math.MaxUint16) 230 | message := allocMessage() 231 | destination := M.SocksaddrFromNet(addr) 232 | *message = udpMessage{ 233 | sessionID: c.sessionID, 234 | packetID: packetId, 235 | fragmentTotal: 1, 236 | host: destination.AddrString(), 237 | port: destination.Port, 238 | data: buf.As(p), 239 | } 240 | if len(p) > c.udpMTU-message.headerSize() { 241 | err = c.writePackets(fragUDPMessage(message, c.udpMTU)) 242 | if err == nil { 243 | return len(p), nil 244 | } 245 | } else { 246 | err = c.writePacket(message) 247 | } 248 | if err == nil { 249 | return len(p), nil 250 | } 251 | var tooLargeErr *quic.DatagramTooLargeError 252 | if !errors.As(err, &tooLargeErr) { 253 | return 254 | } 255 | err = c.writePackets(fragUDPMessage(message, int(tooLargeErr.MaxDatagramPayloadSize-3))) 256 | if err == nil { 257 | return len(p), nil 258 | } 259 | return 260 | } 261 | 262 | func (c *udpPacketConn) inputPacket(message *udpMessage) { 263 | if message.fragmentTotal <= 1 { 264 | select { 265 | case c.data <- message: 266 | default: 267 | } 268 | } else { 269 | newMessage := c.defragger.feed(message) 270 | if newMessage != nil { 271 | select { 272 | case c.data <- newMessage: 273 | default: 274 | } 275 | } 276 | } 277 | } 278 | 279 | func (c *udpPacketConn) writePackets(messages []*udpMessage) error { 280 | defer releaseMessages(messages) 281 | for _, message := range messages { 282 | err := c.writePacket(message) 283 | if err != nil { 284 | return err 285 | } 286 | } 287 | return nil 288 | } 289 | 290 | func (c *udpPacketConn) writePacket(message *udpMessage) error { 291 | buffer := message.pack() 292 | defer buffer.Release() 293 | return c.quicConn.SendDatagram(buffer.Bytes()) 294 | } 295 | 296 | func (c *udpPacketConn) Close() error { 297 | c.closeWithError(os.ErrClosed) 298 | return nil 299 | } 300 | 301 | func (c *udpPacketConn) closeWithError(err error) { 302 | c.closeOnce.Do(func() { 303 | c.cancel(err) 304 | c.onDestroy() 305 | }) 306 | } 307 | 308 | func (c *udpPacketConn) LocalAddr() net.Addr { 309 | return c.quicConn.LocalAddr() 310 | } 311 | 312 | func (c *udpPacketConn) SetDeadline(t time.Time) error { 313 | return os.ErrInvalid 314 | } 315 | 316 | func (c *udpPacketConn) SetReadDeadline(t time.Time) error { 317 | c.readDeadline.Set(t) 318 | return nil 319 | } 320 | 321 | func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { 322 | return os.ErrInvalid 323 | } 324 | 325 | type udpDefragger struct { 326 | packetMap *cache.LruCache[uint16, *packetItem] 327 | } 328 | 329 | func newUDPDefragger() *udpDefragger { 330 | return &udpDefragger{ 331 | packetMap: cache.New( 332 | cache.WithAge[uint16, *packetItem](10), 333 | cache.WithUpdateAgeOnGet[uint16, *packetItem](), 334 | cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) { 335 | releaseMessages(value.messages) 336 | }), 337 | ), 338 | } 339 | } 340 | 341 | type packetItem struct { 342 | access sync.Mutex 343 | messages []*udpMessage 344 | count uint8 345 | } 346 | 347 | func (d *udpDefragger) feed(m *udpMessage) *udpMessage { 348 | if m.fragmentTotal <= 1 { 349 | return m 350 | } 351 | if m.fragmentID >= m.fragmentTotal { 352 | return nil 353 | } 354 | item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem) 355 | item.access.Lock() 356 | defer item.access.Unlock() 357 | if int(m.fragmentTotal) != len(item.messages) { 358 | releaseMessages(item.messages) 359 | item.messages = make([]*udpMessage, m.fragmentTotal) 360 | item.count = 1 361 | item.messages[m.fragmentID] = m 362 | return nil 363 | } 364 | if item.messages[m.fragmentID] != nil { 365 | return nil 366 | } 367 | item.messages[m.fragmentID] = m 368 | item.count++ 369 | if int(item.count) != len(item.messages) { 370 | return nil 371 | } 372 | newMessage := allocMessage() 373 | newMessage.sessionID = m.sessionID 374 | newMessage.packetID = m.packetID 375 | newMessage.host = item.messages[0].host 376 | newMessage.port = item.messages[0].port 377 | var finalLength int 378 | for _, message := range item.messages { 379 | finalLength += message.data.Len() 380 | } 381 | if finalLength > 0 { 382 | newMessage.data = buf.NewSize(finalLength) 383 | for _, message := range item.messages { 384 | newMessage.data.Write(message.data.Bytes()) 385 | message.releaseMessage() 386 | } 387 | item.messages = nil 388 | return newMessage 389 | } else { 390 | newMessage.release() 391 | for _, message := range item.messages { 392 | message.releaseMessage() 393 | } 394 | } 395 | item.messages = nil 396 | return nil 397 | } 398 | 399 | func newPacketItem() *packetItem { 400 | return new(packetItem) 401 | } 402 | 403 | func decodeUDPMessage(message *udpMessage, data []byte) error { 404 | reader := bytes.NewReader(data) 405 | err := binary.Read(reader, binary.BigEndian, &message.sessionID) 406 | if err != nil { 407 | return err 408 | } 409 | var hostLen uint16 410 | err = binary.Read(reader, binary.BigEndian, &hostLen) 411 | if err != nil { 412 | return err 413 | } 414 | hostBytes := make([]byte, hostLen) 415 | _, err = io.ReadFull(reader, hostBytes) 416 | if err != nil { 417 | return err 418 | } 419 | message.host = string(hostBytes) 420 | err = binary.Read(reader, binary.BigEndian, &message.port) 421 | if err != nil { 422 | return err 423 | } 424 | err = binary.Read(reader, binary.BigEndian, &message.packetID) 425 | if err != nil { 426 | return err 427 | } 428 | err = binary.Read(reader, binary.BigEndian, &message.fragmentID) 429 | if err != nil { 430 | return err 431 | } 432 | err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal) 433 | if err != nil { 434 | return err 435 | } 436 | var dataLen uint16 437 | err = binary.Read(reader, binary.BigEndian, &dataLen) 438 | if err != nil { 439 | return err 440 | } 441 | if reader.Len() != int(dataLen) { 442 | return E.New("invalid data length") 443 | } 444 | message.data = buf.As(data[len(data)-reader.Len():]) 445 | return nil 446 | } 447 | -------------------------------------------------------------------------------- /hysteria/packet_wait.go: -------------------------------------------------------------------------------- 1 | package hysteria 2 | 3 | import ( 4 | "io" 5 | "os" 6 | 7 | "github.com/sagernet/sing/common/buf" 8 | M "github.com/sagernet/sing/common/metadata" 9 | N "github.com/sagernet/sing/common/network" 10 | ) 11 | 12 | func (c *udpPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { 13 | c.readWaitOptions = options 14 | return options.NeedHeadroom() 15 | } 16 | 17 | func (c *udpPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { 18 | select { 19 | case p := <-c.data: 20 | destination = M.ParseSocksaddrHostPort(p.host, p.port) 21 | if c.readWaitOptions.NeedHeadroom() { 22 | buffer = c.readWaitOptions.NewPacketBuffer() 23 | _, err = buffer.Write(p.data.Bytes()) 24 | p.releaseMessage() 25 | if err != nil { 26 | buffer.Release() 27 | return 28 | } 29 | c.readWaitOptions.PostReturn(buffer) 30 | } else { 31 | buffer = p.data 32 | p.release() 33 | } 34 | return 35 | case <-c.ctx.Done(): 36 | return nil, M.Socksaddr{}, io.ErrClosedPipe 37 | case <-c.readDeadline.Wait(): 38 | return nil, M.Socksaddr{}, os.ErrDeadlineExceeded 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /hysteria/protocol.go: -------------------------------------------------------------------------------- 1 | package hysteria 2 | 3 | import ( 4 | "encoding/binary" 5 | "io" 6 | "time" 7 | 8 | "github.com/sagernet/quic-go" 9 | "github.com/sagernet/sing/common" 10 | "github.com/sagernet/sing/common/buf" 11 | E "github.com/sagernet/sing/common/exceptions" 12 | ) 13 | 14 | const ( 15 | MbpsToBps = 125000 16 | MinSpeedBPS = 16384 17 | DefaultALPN = "hysteria" 18 | DefaultStreamReceiveWindow = 8388608 // 8MB 19 | DefaultConnReceiveWindow = DefaultStreamReceiveWindow * 5 / 2 // 20MB 20 | DefaultMaxIdleTimeout = 30 * time.Second 21 | DefaultKeepAlivePeriod = 10 * time.Second 22 | ) 23 | 24 | const ( 25 | ProtocolVersion = 3 26 | ProtocolTimeout = 10 * time.Second 27 | ErrorCodeGeneric = 0 28 | ErrorCodeProtocolError = 1 29 | ErrorCodeAuthError = 2 30 | ) 31 | 32 | type ClientHello struct { 33 | SendBPS uint64 34 | RecvBPS uint64 35 | Auth string 36 | } 37 | 38 | func WriteClientHello(stream io.Writer, hello ClientHello) error { 39 | var requestLen int 40 | requestLen += 1 // version 41 | requestLen += 8 // sendBPS 42 | requestLen += 8 // recvBPS 43 | requestLen += 2 // auth len 44 | requestLen += len(hello.Auth) 45 | request := buf.NewSize(requestLen) 46 | defer request.Release() 47 | common.Must( 48 | request.WriteByte(ProtocolVersion), 49 | binary.Write(request, binary.BigEndian, hello.SendBPS), 50 | binary.Write(request, binary.BigEndian, hello.RecvBPS), 51 | binary.Write(request, binary.BigEndian, uint16(len(hello.Auth))), 52 | common.Error(request.WriteString(hello.Auth)), 53 | ) 54 | return common.Error(stream.Write(request.Bytes())) 55 | } 56 | 57 | func ReadClientHello(reader io.Reader) (*ClientHello, error) { 58 | var version uint8 59 | err := binary.Read(reader, binary.BigEndian, &version) 60 | if err != nil { 61 | return nil, err 62 | } 63 | if version != ProtocolVersion { 64 | return nil, E.New("unsupported client version: ", version) 65 | } 66 | var clientHello ClientHello 67 | err = binary.Read(reader, binary.BigEndian, &clientHello.SendBPS) 68 | if err != nil { 69 | return nil, err 70 | } 71 | err = binary.Read(reader, binary.BigEndian, &clientHello.RecvBPS) 72 | if err != nil { 73 | return nil, err 74 | } 75 | if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 { 76 | return nil, E.New("invalid rate from client") 77 | } 78 | var authLen uint16 79 | err = binary.Read(reader, binary.BigEndian, &authLen) 80 | if err != nil { 81 | return nil, err 82 | } 83 | authBytes := make([]byte, authLen) 84 | _, err = io.ReadFull(reader, authBytes) 85 | if err != nil { 86 | return nil, err 87 | } 88 | clientHello.Auth = string(authBytes) 89 | return &clientHello, nil 90 | } 91 | 92 | type ServerHello struct { 93 | OK bool 94 | SendBPS uint64 95 | RecvBPS uint64 96 | Message string 97 | } 98 | 99 | func ReadServerHello(stream io.Reader) (*ServerHello, error) { 100 | var responseLen int 101 | responseLen += 1 // ok 102 | responseLen += 8 // sendBPS 103 | responseLen += 8 // recvBPS 104 | responseLen += 2 // message len 105 | response := buf.NewSize(responseLen) 106 | defer response.Release() 107 | _, err := response.ReadFullFrom(stream, responseLen) 108 | if err != nil { 109 | return nil, err 110 | } 111 | var serverHello ServerHello 112 | serverHello.OK = response.Byte(0) == 1 113 | serverHello.SendBPS = binary.BigEndian.Uint64(response.Range(1, 9)) 114 | serverHello.RecvBPS = binary.BigEndian.Uint64(response.Range(9, 17)) 115 | messageLen := binary.BigEndian.Uint16(response.Range(17, 19)) 116 | if messageLen == 0 { 117 | return &serverHello, nil 118 | } 119 | message := make([]byte, messageLen) 120 | _, err = io.ReadFull(stream, message) 121 | if err != nil { 122 | return nil, err 123 | } 124 | serverHello.Message = string(message) 125 | return &serverHello, nil 126 | } 127 | 128 | func WriteServerHello(stream io.Writer, hello ServerHello) error { 129 | var responseLen int 130 | responseLen += 1 // ok 131 | responseLen += 8 // sendBPS 132 | responseLen += 8 // recvBPS 133 | responseLen += 2 // message len 134 | responseLen += len(hello.Message) 135 | response := buf.NewSize(responseLen) 136 | defer response.Release() 137 | if hello.OK { 138 | common.Must(response.WriteByte(1)) 139 | } else { 140 | common.Must(response.WriteByte(0)) 141 | } 142 | common.Must( 143 | binary.Write(response, binary.BigEndian, hello.SendBPS), 144 | binary.Write(response, binary.BigEndian, hello.RecvBPS), 145 | binary.Write(response, binary.BigEndian, uint16(len(hello.Message))), 146 | common.Error(response.WriteString(hello.Message)), 147 | ) 148 | return common.Error(stream.Write(response.Bytes())) 149 | } 150 | 151 | type ClientRequest struct { 152 | UDP bool 153 | Host string 154 | Port uint16 155 | } 156 | 157 | func ReadClientRequest(stream io.Reader) (*ClientRequest, error) { 158 | var clientRequest ClientRequest 159 | err := binary.Read(stream, binary.BigEndian, &clientRequest.UDP) 160 | if err != nil { 161 | return nil, err 162 | } 163 | var hostLen uint16 164 | err = binary.Read(stream, binary.BigEndian, &hostLen) 165 | if err != nil { 166 | return nil, err 167 | } 168 | host := make([]byte, hostLen) 169 | _, err = io.ReadFull(stream, host) 170 | if err != nil { 171 | return nil, err 172 | } 173 | clientRequest.Host = string(host) 174 | err = binary.Read(stream, binary.BigEndian, &clientRequest.Port) 175 | if err != nil { 176 | return nil, err 177 | } 178 | return &clientRequest, nil 179 | } 180 | 181 | func WriteClientRequest(request ClientRequest, payload []byte) *buf.Buffer { 182 | var requestLen int 183 | requestLen += 1 // udp 184 | requestLen += 2 // host len 185 | requestLen += len(request.Host) 186 | requestLen += 2 // port 187 | buffer := buf.NewSize(requestLen + len(payload)) 188 | if request.UDP { 189 | common.Must(buffer.WriteByte(1)) 190 | } else { 191 | common.Must(buffer.WriteByte(0)) 192 | } 193 | common.Must( 194 | binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))), 195 | common.Error(buffer.WriteString(request.Host)), 196 | binary.Write(buffer, binary.BigEndian, request.Port), 197 | common.Error(buffer.Write(payload)), 198 | ) 199 | return buffer 200 | } 201 | 202 | type ServerResponse struct { 203 | OK bool 204 | UDPSessionID uint32 205 | Message string 206 | } 207 | 208 | func ReadServerResponse(stream io.Reader) (*ServerResponse, error) { 209 | var responseLen int 210 | responseLen += 1 // ok 211 | responseLen += 4 // udp session id 212 | responseLen += 2 // message len 213 | response := buf.NewSize(responseLen) 214 | defer response.Release() 215 | _, err := response.ReadFullFrom(stream, responseLen) 216 | if err != nil { 217 | return nil, err 218 | } 219 | var serverResponse ServerResponse 220 | serverResponse.OK = response.Byte(0) == 1 221 | serverResponse.UDPSessionID = binary.BigEndian.Uint32(response.Range(1, 5)) 222 | messageLen := binary.BigEndian.Uint16(response.Range(5, 7)) 223 | if messageLen == 0 { 224 | return &serverResponse, nil 225 | } 226 | message := make([]byte, messageLen) 227 | _, err = io.ReadFull(stream, message) 228 | if err != nil { 229 | return nil, err 230 | } 231 | serverResponse.Message = string(message) 232 | return &serverResponse, nil 233 | } 234 | 235 | func WriteServerResponse(stream quic.Stream, response ServerResponse) error { 236 | var responseLen int 237 | responseLen += 1 // ok 238 | responseLen += 4 // udp session id 239 | responseLen += 2 // message len 240 | responseLen += len(response.Message) 241 | buffer := buf.NewSize(responseLen) 242 | defer buffer.Release() 243 | if response.OK { 244 | common.Must(buffer.WriteByte(1)) 245 | } else { 246 | common.Must(buffer.WriteByte(0)) 247 | } 248 | common.Must( 249 | binary.Write(buffer, binary.BigEndian, response.UDPSessionID), 250 | binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))), 251 | common.Error(buffer.WriteString(response.Message)), 252 | ) 253 | return common.Error(stream.Write(buffer.Bytes())) 254 | } 255 | -------------------------------------------------------------------------------- /hysteria/service.go: -------------------------------------------------------------------------------- 1 | package hysteria 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "math" 8 | "net" 9 | "os" 10 | "runtime" 11 | "sync" 12 | "time" 13 | 14 | "github.com/sagernet/quic-go" 15 | "github.com/sagernet/sing-quic" 16 | hyCC "github.com/sagernet/sing-quic/hysteria/congestion" 17 | "github.com/sagernet/sing/common" 18 | "github.com/sagernet/sing/common/auth" 19 | "github.com/sagernet/sing/common/baderror" 20 | "github.com/sagernet/sing/common/canceler" 21 | E "github.com/sagernet/sing/common/exceptions" 22 | "github.com/sagernet/sing/common/logger" 23 | M "github.com/sagernet/sing/common/metadata" 24 | N "github.com/sagernet/sing/common/network" 25 | aTLS "github.com/sagernet/sing/common/tls" 26 | ) 27 | 28 | type ServiceOptions struct { 29 | Context context.Context 30 | Logger logger.Logger 31 | BrutalDebug bool 32 | SendBPS uint64 33 | ReceiveBPS uint64 34 | XPlusPassword string 35 | TLSConfig aTLS.ServerConfig 36 | UDPDisabled bool 37 | UDPTimeout time.Duration 38 | Handler ServerHandler 39 | 40 | // Legacy options 41 | 42 | ConnReceiveWindow uint64 43 | StreamReceiveWindow uint64 44 | MaxIncomingStreams int64 45 | DisableMTUDiscovery bool 46 | } 47 | 48 | type ServerHandler interface { 49 | N.TCPConnectionHandlerEx 50 | N.UDPConnectionHandlerEx 51 | } 52 | 53 | type Service[U comparable] struct { 54 | ctx context.Context 55 | logger logger.Logger 56 | brutalDebug bool 57 | sendBPS uint64 58 | receiveBPS uint64 59 | xplusPassword string 60 | tlsConfig aTLS.ServerConfig 61 | quicConfig *quic.Config 62 | userMap map[string]U 63 | udpDisabled bool 64 | udpTimeout time.Duration 65 | handler ServerHandler 66 | quicListener io.Closer 67 | } 68 | 69 | func NewService[U comparable](options ServiceOptions) (*Service[U], error) { 70 | quicConfig := &quic.Config{ 71 | DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), 72 | EnableDatagrams: !options.UDPDisabled, 73 | MaxIncomingStreams: 1 << 60, 74 | InitialStreamReceiveWindow: DefaultStreamReceiveWindow, 75 | MaxStreamReceiveWindow: DefaultStreamReceiveWindow, 76 | InitialConnectionReceiveWindow: DefaultConnReceiveWindow, 77 | MaxConnectionReceiveWindow: DefaultConnReceiveWindow, 78 | MaxIdleTimeout: DefaultMaxIdleTimeout, 79 | KeepAlivePeriod: DefaultKeepAlivePeriod, 80 | } 81 | if options.StreamReceiveWindow != 0 { 82 | quicConfig.InitialStreamReceiveWindow = options.StreamReceiveWindow 83 | quicConfig.MaxStreamReceiveWindow = options.StreamReceiveWindow 84 | } 85 | if options.ConnReceiveWindow != 0 { 86 | quicConfig.InitialConnectionReceiveWindow = options.ConnReceiveWindow 87 | quicConfig.MaxConnectionReceiveWindow = options.ConnReceiveWindow 88 | } 89 | if options.MaxIncomingStreams > 0 { 90 | quicConfig.MaxIncomingStreams = int64(options.MaxIncomingStreams) 91 | } 92 | if options.DisableMTUDiscovery { 93 | quicConfig.DisablePathMTUDiscovery = true 94 | } 95 | if len(options.TLSConfig.NextProtos()) == 0 { 96 | options.TLSConfig.SetNextProtos([]string{DefaultALPN}) 97 | } 98 | if options.SendBPS == 0 { 99 | return nil, E.New("missing upload speed configuration") 100 | } 101 | if options.ReceiveBPS == 0 { 102 | return nil, E.New("missing download speed configuration") 103 | } 104 | return &Service[U]{ 105 | ctx: options.Context, 106 | logger: options.Logger, 107 | brutalDebug: options.BrutalDebug, 108 | sendBPS: options.SendBPS, 109 | receiveBPS: options.ReceiveBPS, 110 | xplusPassword: options.XPlusPassword, 111 | tlsConfig: options.TLSConfig, 112 | quicConfig: quicConfig, 113 | userMap: make(map[string]U), 114 | handler: options.Handler, 115 | udpDisabled: options.UDPDisabled, 116 | udpTimeout: options.UDPTimeout, 117 | }, nil 118 | } 119 | 120 | func (s *Service[U]) UpdateUsers(userList []U, passwordList []string) { 121 | userMap := make(map[string]U) 122 | for i, user := range userList { 123 | userMap[passwordList[i]] = user 124 | } 125 | s.userMap = userMap 126 | } 127 | 128 | func (s *Service[U]) Start(conn net.PacketConn) error { 129 | if s.xplusPassword != "" { 130 | conn = NewXPlusPacketConn(conn, []byte(s.xplusPassword)) 131 | } 132 | listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig) 133 | if err != nil { 134 | return err 135 | } 136 | s.quicListener = listener 137 | go s.loopConnections(listener) 138 | return nil 139 | } 140 | 141 | func (s *Service[U]) Close() error { 142 | return common.Close( 143 | s.quicListener, 144 | ) 145 | } 146 | 147 | func (s *Service[U]) loopConnections(listener qtls.Listener) { 148 | for { 149 | connection, err := listener.Accept(s.ctx) 150 | if err != nil { 151 | if E.IsClosedOrCanceled(err) || errors.Is(err, quic.ErrServerClosed) { 152 | s.logger.Debug(E.Cause(err, "listener closed")) 153 | } else { 154 | s.logger.Error(E.Cause(err, "listener closed")) 155 | } 156 | return 157 | } 158 | session := &serverSession[U]{ 159 | Service: s, 160 | ctx: s.ctx, 161 | quicConn: connection, 162 | source: M.SocksaddrFromNet(connection.RemoteAddr()).Unwrap(), 163 | connDone: make(chan struct{}), 164 | udpConnMap: make(map[uint32]*udpPacketConn), 165 | } 166 | go session.handleConnection() 167 | } 168 | } 169 | 170 | type serverSession[U comparable] struct { 171 | *Service[U] 172 | ctx context.Context 173 | quicConn quic.Connection 174 | source M.Socksaddr 175 | connAccess sync.Mutex 176 | connDone chan struct{} 177 | connErr error 178 | authUser U 179 | udpAccess sync.RWMutex 180 | udpConnMap map[uint32]*udpPacketConn 181 | udpSessionID uint32 182 | } 183 | 184 | func (s *serverSession[U]) handleConnection() { 185 | ctx, cancel := context.WithTimeout(s.ctx, ProtocolTimeout) 186 | controlStream, err := s.quicConn.AcceptStream(ctx) 187 | cancel() 188 | if err != nil { 189 | s.closeWithError0(ErrorCodeProtocolError, err) 190 | return 191 | } 192 | clientHello, err := ReadClientHello(controlStream) 193 | if err != nil { 194 | s.closeWithError0(ErrorCodeProtocolError, E.Cause(err, "read client hello")) 195 | return 196 | } 197 | user, loaded := s.userMap[clientHello.Auth] 198 | if !loaded { 199 | WriteServerHello(controlStream, ServerHello{ 200 | OK: false, 201 | Message: "Wrong password", 202 | }) 203 | s.closeWithError0(ErrorCodeAuthError, E.New("authentication failed, auth_str=", clientHello.Auth)) 204 | return 205 | } 206 | err = WriteServerHello(controlStream, ServerHello{ 207 | OK: true, 208 | SendBPS: s.sendBPS, 209 | RecvBPS: s.receiveBPS, 210 | }) 211 | if err != nil { 212 | s.closeWithError(err) 213 | return 214 | } 215 | s.authUser = user 216 | s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(uint64(math.Min(float64(s.sendBPS), float64(clientHello.RecvBPS))), s.brutalDebug, s.logger)) 217 | if !s.udpDisabled { 218 | go s.loopMessages() 219 | } 220 | s.loopStreams() 221 | } 222 | 223 | func (s *serverSession[U]) loopStreams() { 224 | for { 225 | stream, err := s.quicConn.AcceptStream(s.ctx) 226 | if err != nil { 227 | return 228 | } 229 | go func() { 230 | err = s.handleStream(stream) 231 | if err != nil { 232 | stream.CancelRead(0) 233 | stream.Close() 234 | s.logger.Error(E.Cause(err, "handle stream request")) 235 | } 236 | }() 237 | } 238 | } 239 | 240 | func (s *serverSession[U]) handleStream(stream quic.Stream) error { 241 | request, err := ReadClientRequest(stream) 242 | if err != nil { 243 | return E.New("read TCP request") 244 | } 245 | ctx := auth.ContextWithUser(s.ctx, s.authUser) 246 | if !request.UDP { 247 | s.handler.NewConnectionEx(ctx, &serverConn{Stream: stream}, s.source, M.ParseSocksaddrHostPort(request.Host, request.Port), nil) 248 | } else { 249 | if s.udpDisabled { 250 | return WriteServerResponse(stream, ServerResponse{ 251 | OK: false, 252 | Message: "UDP disabled by server", 253 | }) 254 | } 255 | var sessionID uint32 256 | udpConn := newUDPPacketConn(ctx, s.quicConn, func() { 257 | stream.CancelRead(0) 258 | stream.Close() 259 | s.udpAccess.Lock() 260 | delete(s.udpConnMap, sessionID) 261 | s.udpAccess.Unlock() 262 | }) 263 | s.udpAccess.Lock() 264 | s.udpSessionID++ 265 | sessionID = s.udpSessionID 266 | udpConn.sessionID = sessionID 267 | s.udpConnMap[sessionID] = udpConn 268 | s.udpAccess.Unlock() 269 | err = WriteServerResponse(stream, ServerResponse{ 270 | OK: true, 271 | UDPSessionID: sessionID, 272 | }) 273 | if err != nil { 274 | udpConn.closeWithError(E.Cause(err, "write server response")) 275 | return err 276 | } 277 | newCtx, newConn := canceler.NewPacketConn(udpConn.ctx, udpConn, s.udpTimeout) 278 | go s.handler.NewPacketConnectionEx(newCtx, newConn, s.source, M.ParseSocksaddrHostPort(request.Host, request.Port), nil) 279 | holdBuffer := make([]byte, 1024) 280 | for { 281 | _, hErr := stream.Read(holdBuffer) 282 | if hErr != nil { 283 | break 284 | } 285 | } 286 | udpConn.closeWithError(E.Cause(net.ErrClosed, "hold stream closed")) 287 | } 288 | return nil 289 | } 290 | 291 | func (s *serverSession[U]) closeWithError(err error) { 292 | s.closeWithError0(ErrorCodeGeneric, err) 293 | } 294 | 295 | func (s *serverSession[U]) closeWithError0(errorCode int, err error) { 296 | s.connAccess.Lock() 297 | defer s.connAccess.Unlock() 298 | select { 299 | case <-s.connDone: 300 | return 301 | default: 302 | s.connErr = err 303 | close(s.connDone) 304 | } 305 | if E.IsClosedOrCanceled(err) { 306 | s.logger.Debug(E.Cause(err, "connection failed")) 307 | } else { 308 | s.logger.Error(E.Cause(err, "connection failed")) 309 | } 310 | switch errorCode { 311 | case ErrorCodeProtocolError: 312 | _ = s.quicConn.CloseWithError(quic.ApplicationErrorCode(errorCode), "protocol error") 313 | case ErrorCodeAuthError: 314 | _ = s.quicConn.CloseWithError(quic.ApplicationErrorCode(errorCode), "auth error") 315 | default: 316 | _ = s.quicConn.CloseWithError(quic.ApplicationErrorCode(errorCode), "") 317 | } 318 | } 319 | 320 | type serverConn struct { 321 | quic.Stream 322 | responseWritten bool 323 | } 324 | 325 | func (c *serverConn) HandshakeFailure(err error) error { 326 | if c.responseWritten { 327 | return os.ErrInvalid 328 | } 329 | c.responseWritten = true 330 | return WriteServerResponse(c.Stream, ServerResponse{ 331 | OK: false, 332 | Message: err.Error(), 333 | }) 334 | } 335 | 336 | func (c *serverConn) HandshakeSuccess() error { 337 | if c.responseWritten { 338 | return nil 339 | } 340 | c.responseWritten = true 341 | return WriteServerResponse(c.Stream, ServerResponse{ 342 | OK: true, 343 | }) 344 | } 345 | 346 | func (c *serverConn) Read(p []byte) (n int, err error) { 347 | n, err = c.Stream.Read(p) 348 | return n, baderror.WrapQUIC(err) 349 | } 350 | 351 | func (c *serverConn) Write(p []byte) (n int, err error) { 352 | if !c.responseWritten { 353 | c.responseWritten = true 354 | err = WriteServerResponse(c.Stream, ServerResponse{ 355 | OK: true, 356 | }) 357 | if err != nil { 358 | return 0, baderror.WrapQUIC(err) 359 | } 360 | } 361 | n, err = c.Stream.Write(p) 362 | return n, baderror.WrapQUIC(err) 363 | } 364 | 365 | func (c *serverConn) LocalAddr() net.Addr { 366 | return M.Socksaddr{} 367 | } 368 | 369 | func (c *serverConn) RemoteAddr() net.Addr { 370 | return M.Socksaddr{} 371 | } 372 | 373 | func (c *serverConn) Close() error { 374 | c.Stream.CancelRead(0) 375 | return c.Stream.Close() 376 | } 377 | -------------------------------------------------------------------------------- /hysteria/service_packet.go: -------------------------------------------------------------------------------- 1 | package hysteria 2 | 3 | import ( 4 | "github.com/sagernet/sing/common" 5 | E "github.com/sagernet/sing/common/exceptions" 6 | ) 7 | 8 | func (s *serverSession[U]) loopMessages() { 9 | for { 10 | message, err := s.quicConn.ReceiveDatagram(s.ctx) 11 | if err != nil { 12 | s.closeWithError(E.Cause(err, "receive message")) 13 | return 14 | } 15 | hErr := s.handleMessage(message) 16 | if hErr != nil { 17 | s.closeWithError(E.Cause(hErr, "handle message")) 18 | return 19 | } 20 | } 21 | } 22 | 23 | func (s *serverSession[U]) handleMessage(data []byte) error { 24 | message := allocMessage() 25 | err := decodeUDPMessage(message, data) 26 | if err != nil { 27 | message.release() 28 | return E.Cause(err, "decode UDP message") 29 | } 30 | return s.handleUDPMessage(message) 31 | } 32 | 33 | func (s *serverSession[U]) handleUDPMessage(message *udpMessage) error { 34 | s.udpAccess.RLock() 35 | udpConn, loaded := s.udpConnMap[message.sessionID] 36 | s.udpAccess.RUnlock() 37 | if !loaded || common.Done(udpConn.ctx) { 38 | message.release() 39 | return E.New("unknown session iD: ", message.sessionID) 40 | } 41 | udpConn.inputPacket(message) 42 | return nil 43 | } 44 | -------------------------------------------------------------------------------- /hysteria/xplus.go: -------------------------------------------------------------------------------- 1 | package hysteria 2 | 3 | import ( 4 | "crypto/sha256" 5 | "math/rand" 6 | "net" 7 | "sync" 8 | "time" 9 | 10 | "github.com/sagernet/sing/common" 11 | "github.com/sagernet/sing/common/buf" 12 | "github.com/sagernet/sing/common/bufio" 13 | M "github.com/sagernet/sing/common/metadata" 14 | N "github.com/sagernet/sing/common/network" 15 | ) 16 | 17 | const xplusSaltLen = 16 18 | 19 | func NewXPlusPacketConn(conn net.PacketConn, key []byte) net.PacketConn { 20 | vectorisedWriter, isVectorised := bufio.CreateVectorisedPacketWriter(conn) 21 | if isVectorised { 22 | return &VectorisedXPlusConn{ 23 | XPlusPacketConn: XPlusPacketConn{ 24 | PacketConn: conn, 25 | key: key, 26 | rand: rand.New(rand.NewSource(time.Now().UnixNano())), 27 | }, 28 | writer: vectorisedWriter, 29 | } 30 | } else { 31 | return &XPlusPacketConn{ 32 | PacketConn: conn, 33 | key: key, 34 | rand: rand.New(rand.NewSource(time.Now().UnixNano())), 35 | } 36 | } 37 | } 38 | 39 | type XPlusPacketConn struct { 40 | net.PacketConn 41 | key []byte 42 | randAccess sync.Mutex 43 | rand *rand.Rand 44 | } 45 | 46 | func (c *XPlusPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 47 | n, addr, err = c.PacketConn.ReadFrom(p) 48 | if err != nil { 49 | return 50 | } else if n < xplusSaltLen { 51 | n = 0 52 | return 53 | } 54 | key := sha256.Sum256(append(c.key, p[:xplusSaltLen]...)) 55 | for i := range p[xplusSaltLen:] { 56 | p[i] = p[xplusSaltLen+i] ^ key[i%sha256.Size] 57 | } 58 | n -= xplusSaltLen 59 | return 60 | } 61 | 62 | func (c *XPlusPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 63 | // can't use unsafe buffer on WriteTo 64 | buffer := buf.NewSize(len(p) + xplusSaltLen) 65 | defer buffer.Release() 66 | salt := buffer.Extend(xplusSaltLen) 67 | c.randAccess.Lock() 68 | _, _ = c.rand.Read(salt) 69 | c.randAccess.Unlock() 70 | key := sha256.Sum256(append(c.key, salt...)) 71 | for i := range p { 72 | common.Must(buffer.WriteByte(p[i] ^ key[i%sha256.Size])) 73 | } 74 | return c.PacketConn.WriteTo(buffer.Bytes(), addr) 75 | } 76 | 77 | func (c *XPlusPacketConn) Upstream() any { 78 | return c.PacketConn 79 | } 80 | 81 | type VectorisedXPlusConn struct { 82 | XPlusPacketConn 83 | writer N.VectorisedPacketWriter 84 | } 85 | 86 | func (c *VectorisedXPlusConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 87 | header := buf.NewSize(xplusSaltLen) 88 | defer header.Release() 89 | salt := header.Extend(xplusSaltLen) 90 | c.randAccess.Lock() 91 | _, _ = c.rand.Read(salt) 92 | c.randAccess.Unlock() 93 | key := sha256.Sum256(append(c.key, salt...)) 94 | for i := range p { 95 | p[i] ^= key[i%sha256.Size] 96 | } 97 | return bufio.WriteVectorisedPacket(c.writer, [][]byte{header.Bytes(), p}, M.SocksaddrFromNet(addr)) 98 | } 99 | 100 | func (c *VectorisedXPlusConn) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { 101 | header := buf.NewSize(xplusSaltLen) 102 | defer header.Release() 103 | salt := header.Extend(xplusSaltLen) 104 | c.randAccess.Lock() 105 | _, _ = c.rand.Read(salt) 106 | c.randAccess.Unlock() 107 | key := sha256.Sum256(append(c.key, salt...)) 108 | var index int 109 | for _, buffer := range buffers { 110 | data := buffer.Bytes() 111 | for i := range data { 112 | data[i] ^= key[index%sha256.Size] 113 | index++ 114 | } 115 | } 116 | buffers = append([]*buf.Buffer{header}, buffers...) 117 | return c.writer.WriteVectorisedPacket(buffers, destination) 118 | } 119 | -------------------------------------------------------------------------------- /hysteria2/client.go: -------------------------------------------------------------------------------- 1 | package hysteria2 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net" 7 | "net/http" 8 | "net/url" 9 | "os" 10 | "runtime" 11 | "sync" 12 | "time" 13 | 14 | "github.com/sagernet/quic-go" 15 | "github.com/sagernet/quic-go/congestion" 16 | "github.com/sagernet/quic-go/http3" 17 | "github.com/sagernet/sing-quic" 18 | congestion_meta1 "github.com/sagernet/sing-quic/congestion_meta1" 19 | congestion_meta2 "github.com/sagernet/sing-quic/congestion_meta2" 20 | "github.com/sagernet/sing-quic/hysteria" 21 | hyCC "github.com/sagernet/sing-quic/hysteria/congestion" 22 | "github.com/sagernet/sing-quic/hysteria2/internal/protocol" 23 | "github.com/sagernet/sing/common/baderror" 24 | "github.com/sagernet/sing/common/bufio" 25 | E "github.com/sagernet/sing/common/exceptions" 26 | "github.com/sagernet/sing/common/logger" 27 | M "github.com/sagernet/sing/common/metadata" 28 | N "github.com/sagernet/sing/common/network" 29 | "github.com/sagernet/sing/common/ntp" 30 | aTLS "github.com/sagernet/sing/common/tls" 31 | ) 32 | 33 | type ClientOptions struct { 34 | Context context.Context 35 | Dialer N.Dialer 36 | Logger logger.Logger 37 | BrutalDebug bool 38 | ServerAddress M.Socksaddr 39 | ServerPorts []string 40 | HopInterval time.Duration 41 | SendBPS uint64 42 | ReceiveBPS uint64 43 | SalamanderPassword string 44 | Password string 45 | TLSConfig aTLS.Config 46 | UDPDisabled bool 47 | } 48 | 49 | type Client struct { 50 | ctx context.Context 51 | dialer N.Dialer 52 | logger logger.Logger 53 | brutalDebug bool 54 | serverAddr M.Socksaddr 55 | serverPorts []uint16 56 | hopInterval time.Duration 57 | sendBPS uint64 58 | receiveBPS uint64 59 | salamanderPassword string 60 | password string 61 | tlsConfig aTLS.Config 62 | quicConfig *quic.Config 63 | udpDisabled bool 64 | 65 | connAccess sync.RWMutex 66 | conn *clientQUICConnection 67 | } 68 | 69 | func NewClient(options ClientOptions) (*Client, error) { 70 | quicConfig := &quic.Config{ 71 | DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), 72 | EnableDatagrams: !options.UDPDisabled, 73 | InitialStreamReceiveWindow: hysteria.DefaultStreamReceiveWindow, 74 | MaxStreamReceiveWindow: hysteria.DefaultStreamReceiveWindow, 75 | InitialConnectionReceiveWindow: hysteria.DefaultConnReceiveWindow, 76 | MaxConnectionReceiveWindow: hysteria.DefaultConnReceiveWindow, 77 | MaxIdleTimeout: hysteria.DefaultMaxIdleTimeout, 78 | KeepAlivePeriod: hysteria.DefaultKeepAlivePeriod, 79 | } 80 | if len(options.TLSConfig.NextProtos()) == 0 { 81 | options.TLSConfig.SetNextProtos([]string{http3.NextProtoH3}) 82 | } 83 | var serverPorts []uint16 84 | if len(options.ServerPorts) > 0 { 85 | var err error 86 | serverPorts, err = hysteria.ParsePorts(options.ServerPorts) 87 | if err != nil { 88 | return nil, err 89 | } 90 | } 91 | return &Client{ 92 | ctx: options.Context, 93 | dialer: options.Dialer, 94 | logger: options.Logger, 95 | brutalDebug: options.BrutalDebug, 96 | serverAddr: options.ServerAddress, 97 | serverPorts: serverPorts, 98 | hopInterval: options.HopInterval, 99 | sendBPS: options.SendBPS, 100 | receiveBPS: options.ReceiveBPS, 101 | salamanderPassword: options.SalamanderPassword, 102 | password: options.Password, 103 | tlsConfig: options.TLSConfig, 104 | quicConfig: quicConfig, 105 | udpDisabled: options.UDPDisabled, 106 | }, nil 107 | } 108 | 109 | func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { 110 | conn := c.conn 111 | if conn != nil && conn.active() { 112 | return conn, nil 113 | } 114 | c.connAccess.Lock() 115 | defer c.connAccess.Unlock() 116 | conn = c.conn 117 | if conn != nil && conn.active() { 118 | return conn, nil 119 | } 120 | conn, err := c.offerNew(ctx) 121 | if err != nil { 122 | return nil, err 123 | } 124 | return conn, nil 125 | } 126 | 127 | func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { 128 | dialFunc := func(serverAddr M.Socksaddr) (net.PacketConn, error) { 129 | udpConn, err := c.dialer.DialContext(c.ctx, "udp", serverAddr) 130 | if err != nil { 131 | return nil, err 132 | } 133 | var packetConn net.PacketConn 134 | packetConn = bufio.NewUnbindPacketConn(udpConn) 135 | if c.salamanderPassword != "" { 136 | packetConn = NewSalamanderConn(packetConn, []byte(c.salamanderPassword)) 137 | } 138 | return packetConn, nil 139 | } 140 | var ( 141 | packetConn net.PacketConn 142 | err error 143 | ) 144 | if len(c.serverPorts) == 0 { 145 | packetConn, err = dialFunc(c.serverAddr) 146 | } else { 147 | packetConn, err = hysteria.NewHopPacketConn(dialFunc, c.serverAddr, c.serverPorts, c.hopInterval) 148 | } 149 | if err != nil { 150 | return nil, err 151 | } 152 | var quicConn quic.EarlyConnection 153 | http3Transport, err := qtls.CreateTransport(packetConn, &quicConn, c.serverAddr, c.tlsConfig, c.quicConfig) 154 | if err != nil { 155 | packetConn.Close() 156 | return nil, err 157 | } 158 | request := &http.Request{ 159 | Method: http.MethodPost, 160 | URL: &url.URL{ 161 | Scheme: "https", 162 | Host: protocol.URLHost, 163 | Path: protocol.URLPath, 164 | }, 165 | Header: make(http.Header), 166 | } 167 | protocol.AuthRequestToHeader(request.Header, protocol.AuthRequest{Auth: c.password, Rx: c.receiveBPS}) 168 | response, err := http3Transport.RoundTrip(request.WithContext(ctx)) 169 | if err != nil { 170 | if quicConn != nil { 171 | quicConn.CloseWithError(0, "") 172 | } 173 | packetConn.Close() 174 | return nil, err 175 | } 176 | response.Body.Close() 177 | if response.StatusCode != protocol.StatusAuthOK { 178 | if quicConn != nil { 179 | quicConn.CloseWithError(0, "") 180 | } 181 | packetConn.Close() 182 | return nil, E.New("authentication failed, status code: ", response.StatusCode) 183 | } 184 | authResponse := protocol.AuthResponseFromHeader(response.Header) 185 | actualTx := authResponse.Rx 186 | if actualTx == 0 || actualTx > c.sendBPS { 187 | actualTx = c.sendBPS 188 | } 189 | if !authResponse.RxAuto && actualTx > 0 { 190 | quicConn.SetCongestionControl(hyCC.NewBrutalSender(actualTx, c.brutalDebug, c.logger)) 191 | } else { 192 | timeFunc := ntp.TimeFuncFromContext(c.ctx) 193 | if timeFunc == nil { 194 | timeFunc = time.Now 195 | } 196 | quicConn.SetCongestionControl(congestion_meta2.NewBbrSender( 197 | congestion_meta2.DefaultClock{TimeFunc: timeFunc}, 198 | congestion.ByteCount(quicConn.Config().InitialPacketSize), 199 | congestion.ByteCount(congestion_meta1.InitialCongestionWindow), 200 | )) 201 | } 202 | conn := &clientQUICConnection{ 203 | quicConn: quicConn, 204 | rawConn: packetConn, 205 | connDone: make(chan struct{}), 206 | udpDisabled: !authResponse.UDPEnabled, 207 | udpConnMap: make(map[uint32]*udpPacketConn), 208 | } 209 | if !c.udpDisabled { 210 | go c.loopMessages(conn) 211 | } 212 | c.conn = conn 213 | return conn, nil 214 | } 215 | 216 | func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { 217 | conn, err := c.offer(ctx) 218 | if err != nil { 219 | return nil, err 220 | } 221 | stream, err := conn.quicConn.OpenStream() 222 | if err != nil { 223 | return nil, err 224 | } 225 | return &clientConn{ 226 | Stream: stream, 227 | destination: destination, 228 | }, nil 229 | } 230 | 231 | func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) { 232 | if c.udpDisabled { 233 | return nil, os.ErrInvalid 234 | } 235 | conn, err := c.offer(ctx) 236 | if err != nil { 237 | return nil, err 238 | } 239 | if conn.udpDisabled { 240 | return nil, E.New("UDP disabled by server") 241 | } 242 | var sessionID uint32 243 | clientPacketConn := newUDPPacketConn(c.ctx, conn.quicConn, func() { 244 | conn.udpAccess.Lock() 245 | delete(conn.udpConnMap, sessionID) 246 | conn.udpAccess.Unlock() 247 | }) 248 | conn.udpAccess.Lock() 249 | sessionID = conn.udpSessionID 250 | conn.udpSessionID++ 251 | conn.udpConnMap[sessionID] = clientPacketConn 252 | conn.udpAccess.Unlock() 253 | clientPacketConn.sessionID = sessionID 254 | return clientPacketConn, nil 255 | } 256 | 257 | func (c *Client) CloseWithError(err error) error { 258 | conn := c.conn 259 | if conn != nil { 260 | conn.closeWithError(err) 261 | } 262 | return nil 263 | } 264 | 265 | type clientQUICConnection struct { 266 | quicConn quic.Connection 267 | rawConn io.Closer 268 | closeOnce sync.Once 269 | connDone chan struct{} 270 | connErr error 271 | udpDisabled bool 272 | udpAccess sync.RWMutex 273 | udpConnMap map[uint32]*udpPacketConn 274 | udpSessionID uint32 275 | } 276 | 277 | func (c *clientQUICConnection) active() bool { 278 | select { 279 | case <-c.quicConn.Context().Done(): 280 | return false 281 | default: 282 | } 283 | select { 284 | case <-c.connDone: 285 | return false 286 | default: 287 | } 288 | return true 289 | } 290 | 291 | func (c *clientQUICConnection) closeWithError(err error) { 292 | c.closeOnce.Do(func() { 293 | c.connErr = err 294 | close(c.connDone) 295 | _ = c.quicConn.CloseWithError(0, "") 296 | _ = c.rawConn.Close() 297 | }) 298 | } 299 | 300 | type clientConn struct { 301 | quic.Stream 302 | destination M.Socksaddr 303 | requestWritten bool 304 | responseRead bool 305 | } 306 | 307 | func (c *clientConn) NeedHandshake() bool { 308 | return !c.requestWritten 309 | } 310 | 311 | func (c *clientConn) Read(p []byte) (n int, err error) { 312 | if c.responseRead { 313 | n, err = c.Stream.Read(p) 314 | return n, baderror.WrapQUIC(err) 315 | } 316 | status, errorMessage, err := protocol.ReadTCPResponse(c.Stream) 317 | if err != nil { 318 | return 0, baderror.WrapQUIC(err) 319 | } 320 | if !status { 321 | err = E.New("remote error: ", errorMessage) 322 | return 323 | } 324 | c.responseRead = true 325 | n, err = c.Stream.Read(p) 326 | return n, baderror.WrapQUIC(err) 327 | } 328 | 329 | func (c *clientConn) Write(p []byte) (n int, err error) { 330 | if !c.requestWritten { 331 | buffer := protocol.WriteTCPRequest(c.destination.String(), p) 332 | defer buffer.Release() 333 | _, err = c.Stream.Write(buffer.Bytes()) 334 | if err != nil { 335 | return 336 | } 337 | c.requestWritten = true 338 | return len(p), nil 339 | } 340 | n, err = c.Stream.Write(p) 341 | return n, baderror.WrapQUIC(err) 342 | } 343 | 344 | func (c *clientConn) LocalAddr() net.Addr { 345 | return M.Socksaddr{} 346 | } 347 | 348 | func (c *clientConn) RemoteAddr() net.Addr { 349 | return M.Socksaddr{} 350 | } 351 | 352 | func (c *clientConn) Close() error { 353 | c.Stream.CancelRead(0) 354 | return c.Stream.Close() 355 | } 356 | -------------------------------------------------------------------------------- /hysteria2/client_packet.go: -------------------------------------------------------------------------------- 1 | package hysteria2 2 | 3 | import E "github.com/sagernet/sing/common/exceptions" 4 | 5 | func (c *Client) loopMessages(conn *clientQUICConnection) { 6 | for { 7 | message, err := conn.quicConn.ReceiveDatagram(c.ctx) 8 | if err != nil { 9 | conn.closeWithError(E.Cause(err, "receive message")) 10 | return 11 | } 12 | go func() { 13 | hErr := c.handleMessage(conn, message) 14 | if hErr != nil { 15 | conn.closeWithError(E.Cause(hErr, "handle message")) 16 | } 17 | }() 18 | } 19 | } 20 | 21 | func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error { 22 | message := allocMessage() 23 | err := decodeUDPMessage(message, data) 24 | if err != nil { 25 | message.release() 26 | return E.Cause(err, "decode UDP message") 27 | } 28 | conn.handleUDPMessage(message) 29 | return nil 30 | } 31 | 32 | func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) { 33 | c.udpAccess.RLock() 34 | udpConn, loaded := c.udpConnMap[message.sessionID] 35 | c.udpAccess.RUnlock() 36 | if !loaded { 37 | message.releaseMessage() 38 | return 39 | } 40 | select { 41 | case <-udpConn.ctx.Done(): 42 | message.releaseMessage() 43 | return 44 | default: 45 | } 46 | udpConn.inputPacket(message) 47 | } 48 | -------------------------------------------------------------------------------- /hysteria2/internal/protocol/http.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | ) 7 | 8 | const ( 9 | URLHost = "hysteria" 10 | URLPath = "/auth" 11 | 12 | RequestHeaderAuth = "Hysteria-Auth" 13 | ResponseHeaderUDPEnabled = "Hysteria-UDP" 14 | CommonHeaderCCRX = "Hysteria-CC-RX" 15 | CommonHeaderPadding = "Hysteria-Padding" 16 | 17 | StatusAuthOK = 233 18 | ) 19 | 20 | // AuthRequest is what client sends to server for authentication. 21 | type AuthRequest struct { 22 | Auth string 23 | Rx uint64 // 0 = unknown, client asks server to use bandwidth detection 24 | } 25 | 26 | // AuthResponse is what server sends to client when authentication is passed. 27 | type AuthResponse struct { 28 | UDPEnabled bool 29 | Rx uint64 // 0 = unlimited 30 | RxAuto bool // true = server asks client to use bandwidth detection 31 | } 32 | 33 | func AuthRequestFromHeader(h http.Header) AuthRequest { 34 | rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) 35 | return AuthRequest{ 36 | Auth: h.Get(RequestHeaderAuth), 37 | Rx: rx, 38 | } 39 | } 40 | 41 | func AuthRequestToHeader(h http.Header, req AuthRequest) { 42 | h.Set(RequestHeaderAuth, req.Auth) 43 | h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10)) 44 | h.Set(CommonHeaderPadding, authRequestPadding.String()) 45 | } 46 | 47 | func AuthResponseFromHeader(h http.Header) AuthResponse { 48 | resp := AuthResponse{} 49 | resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled)) 50 | rxStr := h.Get(CommonHeaderCCRX) 51 | if rxStr == "auto" { 52 | // Special case for server requesting client to use bandwidth detection 53 | resp.RxAuto = true 54 | } else { 55 | resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64) 56 | } 57 | return resp 58 | } 59 | 60 | func AuthResponseToHeader(h http.Header, resp AuthResponse) { 61 | h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled)) 62 | if resp.RxAuto { 63 | h.Set(CommonHeaderCCRX, "auto") 64 | } else { 65 | h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10)) 66 | } 67 | h.Set(CommonHeaderPadding, authResponsePadding.String()) 68 | } 69 | -------------------------------------------------------------------------------- /hysteria2/internal/protocol/padding.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "math/rand" 5 | ) 6 | 7 | const ( 8 | paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" 9 | ) 10 | 11 | // padding specifies a half-open range [Min, Max). 12 | type padding struct { 13 | Min int 14 | Max int 15 | } 16 | 17 | func (p padding) String() string { 18 | n := p.Min + rand.Intn(p.Max-p.Min) 19 | bs := make([]byte, n) 20 | for i := range bs { 21 | bs[i] = paddingChars[rand.Intn(len(paddingChars))] 22 | } 23 | return string(bs) 24 | } 25 | 26 | var ( 27 | authRequestPadding = padding{Min: 256, Max: 2048} 28 | authResponsePadding = padding{Min: 256, Max: 2048} 29 | tcpRequestPadding = padding{Min: 64, Max: 512} 30 | tcpResponsePadding = padding{Min: 128, Max: 1024} 31 | ) 32 | -------------------------------------------------------------------------------- /hysteria2/internal/protocol/proxy.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "fmt" 7 | "io" 8 | 9 | "github.com/sagernet/quic-go/quicvarint" 10 | "github.com/sagernet/sing/common" 11 | "github.com/sagernet/sing/common/buf" 12 | E "github.com/sagernet/sing/common/exceptions" 13 | ) 14 | 15 | const ( 16 | FrameTypeTCPRequest = 0x401 17 | 18 | // Max length values are for preventing DoS attacks 19 | 20 | MaxAddressLength = 2048 21 | MaxMessageLength = 2048 22 | MaxPaddingLength = 4096 23 | 24 | MaxUDPSize = 4096 25 | 26 | maxVarInt1 = 63 27 | maxVarInt2 = 16383 28 | maxVarInt4 = 1073741823 29 | maxVarInt8 = 4611686018427387903 30 | ) 31 | 32 | // TCPRequest format: 33 | // 0x401 (QUIC varint) 34 | // Address length (QUIC varint) 35 | // Address (bytes) 36 | // Padding length (QUIC varint) 37 | // Padding (bytes) 38 | 39 | func ReadTCPRequest(r io.Reader) (string, error) { 40 | bReader := quicvarint.NewReader(r) 41 | addrLen, err := quicvarint.Read(bReader) 42 | if err != nil { 43 | return "", err 44 | } 45 | if addrLen == 0 || addrLen > MaxAddressLength { 46 | return "", E.New("invalid address length") 47 | } 48 | addrBuf := make([]byte, addrLen) 49 | _, err = io.ReadFull(r, addrBuf) 50 | if err != nil { 51 | return "", err 52 | } 53 | paddingLen, err := quicvarint.Read(bReader) 54 | if err != nil { 55 | return "", err 56 | } 57 | if paddingLen > MaxPaddingLength { 58 | return "", E.New("invalid padding length") 59 | } 60 | if paddingLen > 0 { 61 | _, err = io.CopyN(io.Discard, r, int64(paddingLen)) 62 | if err != nil { 63 | return "", err 64 | } 65 | } 66 | return string(addrBuf), nil 67 | } 68 | 69 | func WriteTCPRequest(addr string, payload []byte) *buf.Buffer { 70 | padding := tcpRequestPadding.String() 71 | paddingLen := len(padding) 72 | addrLen := len(addr) 73 | sz := int(quicvarint.Len(FrameTypeTCPRequest)) + 74 | int(quicvarint.Len(uint64(addrLen))) + addrLen + 75 | int(quicvarint.Len(uint64(paddingLen))) + paddingLen 76 | buffer := buf.NewSize(sz + len(payload)) 77 | bufferContent := buffer.Extend(sz) 78 | i := varintPut(bufferContent, FrameTypeTCPRequest) 79 | i += varintPut(bufferContent[i:], uint64(addrLen)) 80 | i += copy(bufferContent[i:], addr) 81 | i += varintPut(bufferContent[i:], uint64(paddingLen)) 82 | copy(bufferContent[i:], padding) 83 | buffer.Write(payload) 84 | return buffer 85 | } 86 | 87 | // TCPResponse format: 88 | // Status (byte, 0=ok, 1=error) 89 | // Message length (QUIC varint) 90 | // Message (bytes) 91 | // Padding length (QUIC varint) 92 | // Padding (bytes) 93 | 94 | func ReadTCPResponse(r io.Reader) (ok bool, message string, err error) { 95 | var status [1]byte 96 | _, err = io.ReadFull(r, status[:]) 97 | if err != nil { 98 | return 99 | } 100 | ok = status[0] == 0 101 | bReader := quicvarint.NewReader(r) 102 | messageLen, err := quicvarint.Read(bReader) 103 | if err != nil { 104 | return 105 | } 106 | if messageLen > MaxMessageLength { 107 | return false, "", E.New("invalid message length") 108 | } 109 | messageBytes := make([]byte, messageLen) 110 | _, err = io.ReadFull(r, messageBytes) 111 | if err != nil { 112 | return 113 | } 114 | message = string(messageBytes) 115 | paddingLen, err := quicvarint.Read(bReader) 116 | if err != nil { 117 | return 118 | } 119 | if paddingLen > MaxPaddingLength { 120 | return false, "", E.New("invalid padding length") 121 | } 122 | if paddingLen > 0 { 123 | _, err = io.CopyN(io.Discard, r, int64(paddingLen)) 124 | if err != nil { 125 | return 126 | } 127 | } 128 | return 129 | } 130 | 131 | func WriteTCPResponse(ok bool, msg string, payload []byte) *buf.Buffer { 132 | padding := tcpResponsePadding.String() 133 | paddingLen := len(padding) 134 | msgLen := len(msg) 135 | if msgLen > MaxMessageLength { 136 | msgLen = MaxMessageLength 137 | } 138 | sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + 139 | int(quicvarint.Len(uint64(paddingLen))) + paddingLen 140 | buffer := buf.NewSize(sz + len(payload)) 141 | if ok { 142 | buffer.WriteByte(0) 143 | } else { 144 | buffer.WriteByte(1) 145 | } 146 | WriteVString(buffer, msg) 147 | WriteUVariant(buffer, uint64(paddingLen)) 148 | buffer.Extend(paddingLen) 149 | buffer.Write(payload) 150 | return buffer 151 | } 152 | 153 | // UDPMessage format: 154 | // Session ID (uint32 BE) 155 | // Packet ID (uint16 BE) 156 | // Fragment ID (uint8) 157 | // Fragment count (uint8) 158 | // Address length (QUIC varint) 159 | // Address (bytes) 160 | // Data... 161 | 162 | type UDPMessage struct { 163 | SessionID uint32 // 4 164 | PacketID uint16 // 2 165 | FragID uint8 // 1 166 | FragCount uint8 // 1 167 | Addr string // varint + bytes 168 | Data []byte 169 | } 170 | 171 | func (m *UDPMessage) HeaderSize() int { 172 | lAddr := len(m.Addr) 173 | return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr 174 | } 175 | 176 | func (m *UDPMessage) Size() int { 177 | return m.HeaderSize() + len(m.Data) 178 | } 179 | 180 | func (m *UDPMessage) Serialize(buf []byte) int { 181 | // Make sure the buffer is big enough 182 | if len(buf) < m.Size() { 183 | return -1 184 | } 185 | binary.BigEndian.PutUint32(buf, m.SessionID) 186 | binary.BigEndian.PutUint16(buf[4:], m.PacketID) 187 | buf[6] = m.FragID 188 | buf[7] = m.FragCount 189 | i := varintPut(buf[8:], uint64(len(m.Addr))) 190 | i += copy(buf[8+i:], m.Addr) 191 | i += copy(buf[8+i:], m.Data) 192 | return 8 + i 193 | } 194 | 195 | func ParseUDPMessage(msg []byte) (*UDPMessage, error) { 196 | m := &UDPMessage{} 197 | buf := bytes.NewBuffer(msg) 198 | if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil { 199 | return nil, err 200 | } 201 | if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil { 202 | return nil, err 203 | } 204 | if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil { 205 | return nil, err 206 | } 207 | if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil { 208 | return nil, err 209 | } 210 | lAddr, err := quicvarint.Read(buf) 211 | if err != nil { 212 | return nil, err 213 | } 214 | if lAddr == 0 || lAddr > MaxAddressLength { 215 | return nil, E.New("invalid address length") 216 | } 217 | bs := buf.Bytes() 218 | m.Addr = string(bs[:lAddr]) 219 | m.Data = bs[lAddr:] 220 | return m, nil 221 | } 222 | 223 | func ReadVString(reader io.Reader) (string, error) { 224 | length, err := quicvarint.Read(quicvarint.NewReader(reader)) 225 | if err != nil { 226 | return "", err 227 | } 228 | if length > MaxAddressLength { 229 | return "", E.New("invalid address length") 230 | } 231 | stringBytes := make([]byte, length) 232 | _, err = io.ReadFull(reader, stringBytes) 233 | if err != nil { 234 | return "", err 235 | } 236 | return string(stringBytes), nil 237 | } 238 | 239 | func WriteVString(writer io.Writer, value string) error { 240 | err := WriteUVariant(writer, uint64(len(value))) 241 | if err != nil { 242 | return err 243 | } 244 | return common.Error(writer.Write([]byte(value))) 245 | } 246 | 247 | func WriteUVariant(writer io.Writer, value uint64) error { 248 | var b [8]byte 249 | return common.Error(writer.Write(b[:varintPut(b[:], value)])) 250 | } 251 | 252 | // varintPut is like quicvarint.Append, but instead of appending to a slice, 253 | // it writes to a fixed-size buffer. Returns the number of bytes written. 254 | func varintPut(b []byte, i uint64) int { 255 | if i <= maxVarInt1 { 256 | b[0] = uint8(i) 257 | return 1 258 | } 259 | if i <= maxVarInt2 { 260 | b[0] = uint8(i>>8) | 0x40 261 | b[1] = uint8(i) 262 | return 2 263 | } 264 | if i <= maxVarInt4 { 265 | b[0] = uint8(i>>24) | 0x80 266 | b[1] = uint8(i >> 16) 267 | b[2] = uint8(i >> 8) 268 | b[3] = uint8(i) 269 | return 4 270 | } 271 | if i <= maxVarInt8 { 272 | b[0] = uint8(i>>56) | 0xc0 273 | b[1] = uint8(i >> 48) 274 | b[2] = uint8(i >> 40) 275 | b[3] = uint8(i >> 32) 276 | b[4] = uint8(i >> 24) 277 | b[5] = uint8(i >> 16) 278 | b[6] = uint8(i >> 8) 279 | b[7] = uint8(i) 280 | return 8 281 | } 282 | panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) 283 | } 284 | -------------------------------------------------------------------------------- /hysteria2/packet.go: -------------------------------------------------------------------------------- 1 | package hysteria2 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/binary" 7 | "errors" 8 | "io" 9 | "math" 10 | "net" 11 | "os" 12 | "sync" 13 | "time" 14 | 15 | "github.com/sagernet/quic-go" 16 | "github.com/sagernet/quic-go/quicvarint" 17 | "github.com/sagernet/sing-quic/hysteria2/internal/protocol" 18 | "github.com/sagernet/sing/common" 19 | "github.com/sagernet/sing/common/atomic" 20 | "github.com/sagernet/sing/common/buf" 21 | "github.com/sagernet/sing/common/cache" 22 | M "github.com/sagernet/sing/common/metadata" 23 | N "github.com/sagernet/sing/common/network" 24 | "github.com/sagernet/sing/common/pipe" 25 | ) 26 | 27 | var udpMessagePool = sync.Pool{ 28 | New: func() interface{} { 29 | return new(udpMessage) 30 | }, 31 | } 32 | 33 | func allocMessage() *udpMessage { 34 | message := udpMessagePool.Get().(*udpMessage) 35 | message.referenced = true 36 | return message 37 | } 38 | 39 | func releaseMessages(messages []*udpMessage) { 40 | for _, message := range messages { 41 | if message != nil { 42 | message.release() 43 | } 44 | } 45 | } 46 | 47 | type udpMessage struct { 48 | sessionID uint32 49 | packetID uint16 50 | fragmentID uint8 51 | fragmentTotal uint8 52 | destination string 53 | data *buf.Buffer 54 | referenced bool 55 | } 56 | 57 | func (m *udpMessage) release() { 58 | if !m.referenced { 59 | return 60 | } 61 | *m = udpMessage{} 62 | udpMessagePool.Put(m) 63 | } 64 | 65 | func (m *udpMessage) releaseMessage() { 66 | m.data.Release() 67 | m.release() 68 | } 69 | 70 | func (m *udpMessage) pack() *buf.Buffer { 71 | buffer := buf.NewSize(m.headerSize() + m.data.Len()) 72 | common.Must( 73 | binary.Write(buffer, binary.BigEndian, m.sessionID), 74 | binary.Write(buffer, binary.BigEndian, m.packetID), 75 | binary.Write(buffer, binary.BigEndian, m.fragmentID), 76 | binary.Write(buffer, binary.BigEndian, m.fragmentTotal), 77 | protocol.WriteVString(buffer, m.destination), 78 | common.Error(buffer.Write(m.data.Bytes())), 79 | ) 80 | return buffer 81 | } 82 | 83 | func (m *udpMessage) headerSize() int { 84 | return 8 + int(quicvarint.Len(uint64(len(m.destination)))) + len(m.destination) 85 | } 86 | 87 | func fragUDPMessage(message *udpMessage, maxPacketSize int) []*udpMessage { 88 | udpMTU := maxPacketSize - message.headerSize() 89 | if message.data.Len() <= udpMTU { 90 | return []*udpMessage{message} 91 | } 92 | var fragments []*udpMessage 93 | originPacket := message.data.Bytes() 94 | for remaining := len(originPacket); remaining > 0; remaining -= udpMTU { 95 | fragment := allocMessage() 96 | *fragment = *message 97 | if remaining > udpMTU { 98 | fragment.data = buf.As(originPacket[:udpMTU]) 99 | originPacket = originPacket[udpMTU:] 100 | } else { 101 | fragment.data = buf.As(originPacket) 102 | originPacket = nil 103 | } 104 | fragments = append(fragments, fragment) 105 | } 106 | fragmentTotal := uint16(len(fragments)) 107 | for index, fragment := range fragments { 108 | fragment.fragmentID = uint8(index) 109 | fragment.fragmentTotal = uint8(fragmentTotal) 110 | /*if index > 0 { 111 | fragment.destination = "" 112 | // not work in hysteria 113 | }*/ 114 | } 115 | return fragments 116 | } 117 | 118 | type udpPacketConn struct { 119 | ctx context.Context 120 | cancel common.ContextCancelCauseFunc 121 | sessionID uint32 122 | quicConn quic.Connection 123 | data chan *udpMessage 124 | udpMTU int 125 | packetId atomic.Uint32 126 | closeOnce sync.Once 127 | defragger *udpDefragger 128 | onDestroy func() 129 | readWaitOptions N.ReadWaitOptions 130 | readDeadline pipe.Deadline 131 | } 132 | 133 | func newUDPPacketConn(ctx context.Context, quicConn quic.Connection, onDestroy func()) *udpPacketConn { 134 | ctx, cancel := common.ContextWithCancelCause(ctx) 135 | return &udpPacketConn{ 136 | ctx: ctx, 137 | cancel: cancel, 138 | quicConn: quicConn, 139 | data: make(chan *udpMessage, 64), 140 | udpMTU: 1200 - 3, 141 | defragger: newUDPDefragger(), 142 | onDestroy: onDestroy, 143 | readDeadline: pipe.MakeDeadline(), 144 | } 145 | } 146 | 147 | func (c *udpPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 148 | select { 149 | case p := <-c.data: 150 | _, err = buffer.ReadOnceFrom(p.data) 151 | destination = M.ParseSocksaddr(p.destination) 152 | p.releaseMessage() 153 | return 154 | case <-c.ctx.Done(): 155 | return M.Socksaddr{}, io.ErrClosedPipe 156 | case <-c.readDeadline.Wait(): 157 | return M.Socksaddr{}, os.ErrDeadlineExceeded 158 | } 159 | } 160 | 161 | func (c *udpPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 162 | select { 163 | case pkt := <-c.data: 164 | n = copy(p, pkt.data.Bytes()) 165 | destination := M.ParseSocksaddr(pkt.destination) 166 | if destination.IsFqdn() { 167 | addr = destination 168 | } else { 169 | addr = destination.UDPAddr() 170 | } 171 | pkt.releaseMessage() 172 | return n, addr, nil 173 | case <-c.ctx.Done(): 174 | return 0, nil, io.ErrClosedPipe 175 | case <-c.readDeadline.Wait(): 176 | return 0, nil, os.ErrDeadlineExceeded 177 | } 178 | } 179 | 180 | func (c *udpPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 181 | defer buffer.Release() 182 | select { 183 | case <-c.ctx.Done(): 184 | return net.ErrClosed 185 | default: 186 | } 187 | if buffer.Len() > protocol.MaxUDPSize { 188 | return &quic.DatagramTooLargeError{MaxDatagramPayloadSize: protocol.MaxUDPSize} 189 | } 190 | packetId := uint16(c.packetId.Add(1) % math.MaxUint16) 191 | message := allocMessage() 192 | *message = udpMessage{ 193 | sessionID: c.sessionID, 194 | packetID: packetId, 195 | fragmentTotal: 1, 196 | destination: destination.String(), 197 | data: buffer, 198 | } 199 | defer message.releaseMessage() 200 | var err error 201 | if buffer.Len() > c.udpMTU-message.headerSize() { 202 | err = c.writePackets(fragUDPMessage(message, c.udpMTU)) 203 | } else { 204 | err = c.writePacket(message) 205 | } 206 | if err == nil { 207 | return nil 208 | } 209 | var tooLargeErr *quic.DatagramTooLargeError 210 | if !errors.As(err, &tooLargeErr) { 211 | return err 212 | } 213 | return c.writePackets(fragUDPMessage(message, int(tooLargeErr.MaxDatagramPayloadSize-3))) 214 | } 215 | 216 | func (c *udpPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 217 | select { 218 | case <-c.ctx.Done(): 219 | return 0, net.ErrClosed 220 | default: 221 | } 222 | if len(p) > protocol.MaxUDPSize { 223 | return 0, &quic.DatagramTooLargeError{MaxDatagramPayloadSize: protocol.MaxUDPSize} 224 | } 225 | packetId := uint16(c.packetId.Add(1) % math.MaxUint16) 226 | message := allocMessage() 227 | *message = udpMessage{ 228 | sessionID: c.sessionID, 229 | packetID: packetId, 230 | fragmentTotal: 1, 231 | destination: addr.String(), 232 | data: buf.As(p), 233 | } 234 | if len(p) > c.udpMTU-message.headerSize() { 235 | err = c.writePackets(fragUDPMessage(message, c.udpMTU)) 236 | if err == nil { 237 | return len(p), nil 238 | } 239 | } else { 240 | err = c.writePacket(message) 241 | } 242 | if err == nil { 243 | return len(p), nil 244 | } 245 | var tooLargeErr *quic.DatagramTooLargeError 246 | if !errors.As(err, &tooLargeErr) { 247 | return 248 | } 249 | err = c.writePackets(fragUDPMessage(message, int(tooLargeErr.MaxDatagramPayloadSize-3))) 250 | if err == nil { 251 | return len(p), nil 252 | } 253 | return 254 | } 255 | 256 | func (c *udpPacketConn) inputPacket(message *udpMessage) { 257 | if message.fragmentTotal <= 1 { 258 | select { 259 | case c.data <- message: 260 | default: 261 | } 262 | } else { 263 | newMessage := c.defragger.feed(message) 264 | if newMessage != nil { 265 | select { 266 | case c.data <- newMessage: 267 | default: 268 | } 269 | } 270 | } 271 | } 272 | 273 | func (c *udpPacketConn) writePackets(messages []*udpMessage) error { 274 | defer releaseMessages(messages) 275 | for _, message := range messages { 276 | err := c.writePacket(message) 277 | if err != nil { 278 | return err 279 | } 280 | } 281 | return nil 282 | } 283 | 284 | func (c *udpPacketConn) writePacket(message *udpMessage) error { 285 | buffer := message.pack() 286 | defer buffer.Release() 287 | return c.quicConn.SendDatagram(buffer.Bytes()) 288 | } 289 | 290 | func (c *udpPacketConn) Close() error { 291 | c.closeOnce.Do(func() { 292 | c.closeWithError(os.ErrClosed) 293 | c.onDestroy() 294 | }) 295 | return nil 296 | } 297 | 298 | func (c *udpPacketConn) closeWithError(err error) { 299 | c.cancel(err) 300 | } 301 | 302 | func (c *udpPacketConn) LocalAddr() net.Addr { 303 | return c.quicConn.LocalAddr() 304 | } 305 | 306 | func (c *udpPacketConn) SetDeadline(t time.Time) error { 307 | return os.ErrInvalid 308 | } 309 | 310 | func (c *udpPacketConn) SetReadDeadline(t time.Time) error { 311 | c.readDeadline.Set(t) 312 | return nil 313 | } 314 | 315 | func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { 316 | return os.ErrInvalid 317 | } 318 | 319 | func (c *udpPacketConn) ReaderMTU() int { 320 | return protocol.MaxUDPSize 321 | } 322 | 323 | func (c *udpPacketConn) WriterMTU() int { 324 | return protocol.MaxUDPSize 325 | } 326 | 327 | type udpDefragger struct { 328 | packetMap *cache.LruCache[uint16, *packetItem] 329 | } 330 | 331 | func newUDPDefragger() *udpDefragger { 332 | return &udpDefragger{ 333 | packetMap: cache.New( 334 | cache.WithAge[uint16, *packetItem](10), 335 | cache.WithUpdateAgeOnGet[uint16, *packetItem](), 336 | cache.WithEvict[uint16, *packetItem](func(key uint16, value *packetItem) { 337 | releaseMessages(value.messages) 338 | }), 339 | ), 340 | } 341 | } 342 | 343 | type packetItem struct { 344 | access sync.Mutex 345 | messages []*udpMessage 346 | count uint8 347 | } 348 | 349 | func (d *udpDefragger) feed(m *udpMessage) *udpMessage { 350 | if m.fragmentTotal <= 1 { 351 | return m 352 | } 353 | if m.fragmentID >= m.fragmentTotal { 354 | return nil 355 | } 356 | item, _ := d.packetMap.LoadOrStore(m.packetID, newPacketItem) 357 | item.access.Lock() 358 | defer item.access.Unlock() 359 | if int(m.fragmentTotal) != len(item.messages) { 360 | releaseMessages(item.messages) 361 | item.messages = make([]*udpMessage, m.fragmentTotal) 362 | item.count = 1 363 | item.messages[m.fragmentID] = m 364 | return nil 365 | } 366 | if item.messages[m.fragmentID] != nil { 367 | return nil 368 | } 369 | item.messages[m.fragmentID] = m 370 | item.count++ 371 | if int(item.count) != len(item.messages) { 372 | return nil 373 | } 374 | newMessage := allocMessage() 375 | newMessage.sessionID = m.sessionID 376 | newMessage.packetID = m.packetID 377 | newMessage.destination = item.messages[0].destination 378 | var finalLength int 379 | for _, message := range item.messages { 380 | finalLength += message.data.Len() 381 | } 382 | if finalLength > 0 { 383 | newMessage.data = buf.NewSize(finalLength) 384 | for _, message := range item.messages { 385 | newMessage.data.Write(message.data.Bytes()) 386 | message.releaseMessage() 387 | } 388 | item.messages = nil 389 | return newMessage 390 | } else { 391 | newMessage.releaseMessage() 392 | for _, message := range item.messages { 393 | message.releaseMessage() 394 | } 395 | } 396 | item.messages = nil 397 | return nil 398 | } 399 | 400 | func newPacketItem() *packetItem { 401 | return new(packetItem) 402 | } 403 | 404 | func decodeUDPMessage(message *udpMessage, data []byte) error { 405 | reader := bytes.NewReader(data) 406 | err := binary.Read(reader, binary.BigEndian, &message.sessionID) 407 | if err != nil { 408 | return err 409 | } 410 | err = binary.Read(reader, binary.BigEndian, &message.packetID) 411 | if err != nil { 412 | return err 413 | } 414 | err = binary.Read(reader, binary.BigEndian, &message.fragmentID) 415 | if err != nil { 416 | return err 417 | } 418 | err = binary.Read(reader, binary.BigEndian, &message.fragmentTotal) 419 | if err != nil { 420 | return err 421 | } 422 | message.destination, err = protocol.ReadVString(reader) 423 | if err != nil { 424 | return err 425 | } 426 | message.data = buf.As(data[len(data)-reader.Len():]) 427 | return nil 428 | } 429 | -------------------------------------------------------------------------------- /hysteria2/packet_wait.go: -------------------------------------------------------------------------------- 1 | package hysteria2 2 | 3 | import ( 4 | "io" 5 | "os" 6 | 7 | "github.com/sagernet/sing/common/buf" 8 | M "github.com/sagernet/sing/common/metadata" 9 | N "github.com/sagernet/sing/common/network" 10 | ) 11 | 12 | func (c *udpPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { 13 | c.readWaitOptions = options 14 | return options.NeedHeadroom() 15 | } 16 | 17 | func (c *udpPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { 18 | select { 19 | case p := <-c.data: 20 | destination = M.ParseSocksaddr(p.destination) 21 | if c.readWaitOptions.NeedHeadroom() { 22 | buffer = c.readWaitOptions.NewPacketBuffer() 23 | _, err = buffer.Write(p.data.Bytes()) 24 | p.releaseMessage() 25 | if err != nil { 26 | buffer.Release() 27 | return 28 | } 29 | c.readWaitOptions.PostReturn(buffer) 30 | } else { 31 | buffer = p.data 32 | p.release() 33 | } 34 | return 35 | case <-c.ctx.Done(): 36 | return nil, M.Socksaddr{}, io.ErrClosedPipe 37 | case <-c.readDeadline.Wait(): 38 | return nil, M.Socksaddr{}, os.ErrDeadlineExceeded 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /hysteria2/salamander.go: -------------------------------------------------------------------------------- 1 | package hysteria2 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/sagernet/sing/common" 7 | "github.com/sagernet/sing/common/buf" 8 | "github.com/sagernet/sing/common/bufio" 9 | M "github.com/sagernet/sing/common/metadata" 10 | N "github.com/sagernet/sing/common/network" 11 | 12 | "golang.org/x/crypto/blake2b" 13 | ) 14 | 15 | const salamanderSaltLen = 8 16 | 17 | const ObfsTypeSalamander = "salamander" 18 | 19 | type SalamanderPacketConn struct { 20 | net.PacketConn 21 | password []byte 22 | } 23 | 24 | func NewSalamanderConn(conn net.PacketConn, password []byte) net.PacketConn { 25 | writer, isVectorised := bufio.CreateVectorisedPacketWriter(conn) 26 | if isVectorised { 27 | return &VectorisedSalamanderPacketConn{ 28 | SalamanderPacketConn: SalamanderPacketConn{ 29 | PacketConn: conn, 30 | password: password, 31 | }, 32 | writer: writer, 33 | } 34 | } else { 35 | return &SalamanderPacketConn{ 36 | PacketConn: conn, 37 | password: password, 38 | } 39 | } 40 | } 41 | 42 | func (s *SalamanderPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 43 | n, addr, err = s.PacketConn.ReadFrom(p) 44 | if err != nil { 45 | return 46 | } 47 | if n <= salamanderSaltLen { 48 | return 49 | } 50 | key := blake2b.Sum256(append(s.password, p[:salamanderSaltLen]...)) 51 | for index, c := range p[salamanderSaltLen:n] { 52 | p[index] = c ^ key[index%blake2b.Size256] 53 | } 54 | return n - salamanderSaltLen, addr, nil 55 | } 56 | 57 | func (s *SalamanderPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 58 | buffer := buf.NewSize(len(p) + salamanderSaltLen) 59 | defer buffer.Release() 60 | buffer.WriteRandom(salamanderSaltLen) 61 | key := blake2b.Sum256(append(s.password, buffer.Bytes()...)) 62 | for index, c := range p { 63 | common.Must(buffer.WriteByte(c ^ key[index%blake2b.Size256])) 64 | } 65 | _, err = s.PacketConn.WriteTo(buffer.Bytes(), addr) 66 | if err != nil { 67 | return 68 | } 69 | return len(p), nil 70 | } 71 | 72 | func (s *SalamanderPacketConn) Upstream() any { 73 | return s.PacketConn 74 | } 75 | 76 | type VectorisedSalamanderPacketConn struct { 77 | SalamanderPacketConn 78 | writer N.VectorisedPacketWriter 79 | } 80 | 81 | func (s *VectorisedSalamanderPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 82 | buffer := buf.NewSize(salamanderSaltLen) 83 | buffer.WriteRandom(salamanderSaltLen) 84 | key := blake2b.Sum256(append(s.password, buffer.Bytes()...)) 85 | for i := range p { 86 | p[i] ^= key[i%blake2b.Size256] 87 | } 88 | err = s.writer.WriteVectorisedPacket([]*buf.Buffer{buffer, buf.As(p)}, M.SocksaddrFromNet(addr)) 89 | if err != nil { 90 | return 91 | } 92 | return len(p), nil 93 | } 94 | 95 | func (s *VectorisedSalamanderPacketConn) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { 96 | header := buf.NewSize(salamanderSaltLen) 97 | defer header.Release() 98 | header.WriteRandom(salamanderSaltLen) 99 | key := blake2b.Sum256(append(s.password, header.Bytes()...)) 100 | var bufferIndex int 101 | for _, buffer := range buffers { 102 | content := buffer.Bytes() 103 | for index, c := range content { 104 | content[bufferIndex+index] = c ^ key[bufferIndex+index%blake2b.Size256] 105 | } 106 | bufferIndex += len(content) 107 | } 108 | return s.writer.WriteVectorisedPacket(append([]*buf.Buffer{header}, buffers...), destination) 109 | } 110 | -------------------------------------------------------------------------------- /hysteria2/service.go: -------------------------------------------------------------------------------- 1 | package hysteria2 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net" 8 | "net/http" 9 | "os" 10 | "runtime" 11 | "sync" 12 | "time" 13 | 14 | "github.com/sagernet/quic-go" 15 | "github.com/sagernet/quic-go/congestion" 16 | "github.com/sagernet/quic-go/http3" 17 | "github.com/sagernet/sing-quic" 18 | congestion_meta1 "github.com/sagernet/sing-quic/congestion_meta1" 19 | congestion_meta2 "github.com/sagernet/sing-quic/congestion_meta2" 20 | "github.com/sagernet/sing-quic/hysteria" 21 | hyCC "github.com/sagernet/sing-quic/hysteria/congestion" 22 | "github.com/sagernet/sing-quic/hysteria2/internal/protocol" 23 | "github.com/sagernet/sing/common" 24 | "github.com/sagernet/sing/common/auth" 25 | "github.com/sagernet/sing/common/baderror" 26 | E "github.com/sagernet/sing/common/exceptions" 27 | "github.com/sagernet/sing/common/logger" 28 | M "github.com/sagernet/sing/common/metadata" 29 | N "github.com/sagernet/sing/common/network" 30 | "github.com/sagernet/sing/common/ntp" 31 | aTLS "github.com/sagernet/sing/common/tls" 32 | ) 33 | 34 | type ServiceOptions struct { 35 | Context context.Context 36 | Logger logger.Logger 37 | BrutalDebug bool 38 | SendBPS uint64 39 | ReceiveBPS uint64 40 | IgnoreClientBandwidth bool 41 | SalamanderPassword string 42 | TLSConfig aTLS.ServerConfig 43 | UDPDisabled bool 44 | UDPTimeout time.Duration 45 | Handler ServerHandler 46 | MasqueradeHandler http.Handler 47 | } 48 | 49 | type ServerHandler interface { 50 | N.TCPConnectionHandlerEx 51 | N.UDPConnectionHandlerEx 52 | } 53 | 54 | type Service[U comparable] struct { 55 | ctx context.Context 56 | logger logger.Logger 57 | brutalDebug bool 58 | sendBPS uint64 59 | receiveBPS uint64 60 | ignoreClientBandwidth bool 61 | salamanderPassword string 62 | tlsConfig aTLS.ServerConfig 63 | quicConfig *quic.Config 64 | userMap map[string]U 65 | udpDisabled bool 66 | udpTimeout time.Duration 67 | handler ServerHandler 68 | masqueradeHandler http.Handler 69 | quicListener io.Closer 70 | } 71 | 72 | func NewService[U comparable](options ServiceOptions) (*Service[U], error) { 73 | quicConfig := &quic.Config{ 74 | DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), 75 | EnableDatagrams: !options.UDPDisabled, 76 | MaxIncomingStreams: 1 << 60, 77 | InitialStreamReceiveWindow: hysteria.DefaultStreamReceiveWindow, 78 | MaxStreamReceiveWindow: hysteria.DefaultStreamReceiveWindow, 79 | InitialConnectionReceiveWindow: hysteria.DefaultConnReceiveWindow, 80 | MaxConnectionReceiveWindow: hysteria.DefaultConnReceiveWindow, 81 | MaxIdleTimeout: hysteria.DefaultMaxIdleTimeout, 82 | KeepAlivePeriod: hysteria.DefaultKeepAlivePeriod, 83 | } 84 | if options.MasqueradeHandler == nil { 85 | options.MasqueradeHandler = http.NotFoundHandler() 86 | } 87 | if len(options.TLSConfig.NextProtos()) == 0 { 88 | options.TLSConfig.SetNextProtos([]string{http3.NextProtoH3}) 89 | } 90 | return &Service[U]{ 91 | ctx: options.Context, 92 | logger: options.Logger, 93 | brutalDebug: options.BrutalDebug, 94 | sendBPS: options.SendBPS, 95 | receiveBPS: options.ReceiveBPS, 96 | ignoreClientBandwidth: options.IgnoreClientBandwidth, 97 | salamanderPassword: options.SalamanderPassword, 98 | tlsConfig: options.TLSConfig, 99 | quicConfig: quicConfig, 100 | userMap: make(map[string]U), 101 | udpDisabled: options.UDPDisabled, 102 | udpTimeout: options.UDPTimeout, 103 | handler: options.Handler, 104 | masqueradeHandler: options.MasqueradeHandler, 105 | }, nil 106 | } 107 | 108 | func (s *Service[U]) UpdateUsers(userList []U, passwordList []string) { 109 | userMap := make(map[string]U) 110 | for i, user := range userList { 111 | userMap[passwordList[i]] = user 112 | } 113 | s.userMap = userMap 114 | } 115 | 116 | func (s *Service[U]) Start(conn net.PacketConn) error { 117 | if s.salamanderPassword != "" { 118 | conn = NewSalamanderConn(conn, []byte(s.salamanderPassword)) 119 | } 120 | err := qtls.ConfigureHTTP3(s.tlsConfig) 121 | if err != nil { 122 | return err 123 | } 124 | listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig) 125 | if err != nil { 126 | return err 127 | } 128 | s.quicListener = listener 129 | go s.loopConnections(listener) 130 | return nil 131 | } 132 | 133 | func (s *Service[U]) Close() error { 134 | return common.Close( 135 | s.quicListener, 136 | ) 137 | } 138 | 139 | func (s *Service[U]) loopConnections(listener qtls.Listener) { 140 | for { 141 | connection, err := listener.Accept(s.ctx) 142 | if err != nil { 143 | if E.IsClosedOrCanceled(err) || errors.Is(err, quic.ErrServerClosed) { 144 | s.logger.Debug(E.Cause(err, "listener closed")) 145 | } else { 146 | s.logger.Error(E.Cause(err, "listener closed")) 147 | } 148 | return 149 | } 150 | go s.handleConnection(connection) 151 | } 152 | } 153 | 154 | func (s *Service[U]) handleConnection(connection quic.Connection) { 155 | session := &serverSession[U]{ 156 | Service: s, 157 | ctx: s.ctx, 158 | quicConn: connection, 159 | source: M.SocksaddrFromNet(connection.RemoteAddr()).Unwrap(), 160 | connDone: make(chan struct{}), 161 | udpConnMap: make(map[uint32]*udpPacketConn), 162 | } 163 | httpServer := http3.Server{ 164 | Handler: session, 165 | StreamHijacker: session.handleStream0, 166 | } 167 | _ = httpServer.ServeQUICConn(connection) 168 | _ = connection.CloseWithError(0, "") 169 | } 170 | 171 | type serverSession[U comparable] struct { 172 | *Service[U] 173 | ctx context.Context 174 | quicConn quic.Connection 175 | source M.Socksaddr 176 | connAccess sync.Mutex 177 | connDone chan struct{} 178 | connErr error 179 | authenticated bool 180 | authUser U 181 | udpAccess sync.RWMutex 182 | udpConnMap map[uint32]*udpPacketConn 183 | } 184 | 185 | func (s *serverSession[U]) ServeHTTP(w http.ResponseWriter, r *http.Request) { 186 | if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath { 187 | if s.authenticated { 188 | protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ 189 | UDPEnabled: !s.udpDisabled, 190 | Rx: s.receiveBPS, 191 | RxAuto: s.receiveBPS == 0 && s.ignoreClientBandwidth, 192 | }) 193 | w.WriteHeader(protocol.StatusAuthOK) 194 | return 195 | } 196 | request := protocol.AuthRequestFromHeader(r.Header) 197 | user, loaded := s.userMap[request.Auth] 198 | if !loaded { 199 | s.masqueradeHandler.ServeHTTP(w, r) 200 | return 201 | } 202 | s.authUser = user 203 | s.authenticated = true 204 | var rxAuto bool 205 | if s.receiveBPS > 0 && s.ignoreClientBandwidth && request.Rx == 0 { 206 | s.logger.Debug("process connection from ", r.RemoteAddr, ": BBR disabled by server") 207 | s.masqueradeHandler.ServeHTTP(w, r) 208 | return 209 | } else if !(s.receiveBPS == 0 && s.ignoreClientBandwidth) && request.Rx > 0 { 210 | rx := request.Rx 211 | if s.sendBPS > 0 && rx > s.sendBPS { 212 | rx = s.sendBPS 213 | } 214 | s.quicConn.SetCongestionControl(hyCC.NewBrutalSender(rx, s.brutalDebug, s.logger)) 215 | } else { 216 | timeFunc := ntp.TimeFuncFromContext(s.ctx) 217 | if timeFunc == nil { 218 | timeFunc = time.Now 219 | } 220 | s.quicConn.SetCongestionControl(congestion_meta2.NewBbrSender( 221 | congestion_meta2.DefaultClock{TimeFunc: timeFunc}, 222 | congestion.ByteCount(s.quicConn.Config().InitialPacketSize), 223 | congestion.ByteCount(congestion_meta1.InitialCongestionWindow), 224 | )) 225 | rxAuto = true 226 | } 227 | protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ 228 | UDPEnabled: !s.udpDisabled, 229 | Rx: s.receiveBPS, 230 | RxAuto: rxAuto, 231 | }) 232 | w.WriteHeader(protocol.StatusAuthOK) 233 | if s.ctx.Done() != nil { 234 | go func() { 235 | select { 236 | case <-s.ctx.Done(): 237 | s.closeWithError(s.ctx.Err()) 238 | case <-s.connDone: 239 | } 240 | }() 241 | } 242 | if !s.udpDisabled { 243 | go s.loopMessages() 244 | } 245 | } else { 246 | s.masqueradeHandler.ServeHTTP(w, r) 247 | } 248 | } 249 | 250 | func (s *serverSession[U]) handleStream0(frameType http3.FrameType, id quic.ConnectionTracingID, stream quic.Stream, err error) (bool, error) { 251 | if !s.authenticated || err != nil { 252 | return false, nil 253 | } 254 | if frameType != protocol.FrameTypeTCPRequest { 255 | return false, nil 256 | } 257 | go func() { 258 | hErr := s.handleStream(stream) 259 | if hErr != nil { 260 | stream.CancelRead(0) 261 | stream.Close() 262 | s.logger.Error(E.Cause(hErr, "handle stream request")) 263 | } 264 | }() 265 | return true, nil 266 | } 267 | 268 | func (s *serverSession[U]) handleStream(stream quic.Stream) error { 269 | destinationString, err := protocol.ReadTCPRequest(stream) 270 | if err != nil { 271 | return E.New("read TCP request") 272 | } 273 | s.handler.NewConnectionEx(auth.ContextWithUser(s.ctx, s.authUser), &serverConn{Stream: stream}, s.source, M.ParseSocksaddr(destinationString), nil) 274 | return nil 275 | } 276 | 277 | func (s *serverSession[U]) closeWithError(err error) { 278 | s.connAccess.Lock() 279 | defer s.connAccess.Unlock() 280 | select { 281 | case <-s.connDone: 282 | return 283 | default: 284 | s.connErr = err 285 | close(s.connDone) 286 | } 287 | if E.IsClosedOrCanceled(err) { 288 | s.logger.Debug(E.Cause(err, "connection failed")) 289 | } else { 290 | s.logger.Error(E.Cause(err, "connection failed")) 291 | } 292 | _ = s.quicConn.CloseWithError(0, "") 293 | } 294 | 295 | type serverConn struct { 296 | quic.Stream 297 | responseWritten bool 298 | } 299 | 300 | func (c *serverConn) HandshakeFailure(err error) error { 301 | if c.responseWritten { 302 | return os.ErrInvalid 303 | } 304 | c.responseWritten = true 305 | buffer := protocol.WriteTCPResponse(false, err.Error(), nil) 306 | defer buffer.Release() 307 | return common.Error(c.Stream.Write(buffer.Bytes())) 308 | } 309 | 310 | func (c *serverConn) HandshakeSuccess() error { 311 | if c.responseWritten { 312 | return nil 313 | } 314 | c.responseWritten = true 315 | buffer := protocol.WriteTCPResponse(true, "", nil) 316 | defer buffer.Release() 317 | return common.Error(c.Stream.Write(buffer.Bytes())) 318 | } 319 | 320 | func (c *serverConn) Read(p []byte) (n int, err error) { 321 | n, err = c.Stream.Read(p) 322 | return n, baderror.WrapQUIC(err) 323 | } 324 | 325 | func (c *serverConn) Write(p []byte) (n int, err error) { 326 | if !c.responseWritten { 327 | c.responseWritten = true 328 | buffer := protocol.WriteTCPResponse(true, "", p) 329 | defer buffer.Release() 330 | _, err = c.Stream.Write(buffer.Bytes()) 331 | if err != nil { 332 | return 0, baderror.WrapQUIC(err) 333 | } 334 | return len(p), nil 335 | } 336 | n, err = c.Stream.Write(p) 337 | return n, baderror.WrapQUIC(err) 338 | } 339 | 340 | func (c *serverConn) LocalAddr() net.Addr { 341 | return M.Socksaddr{} 342 | } 343 | 344 | func (c *serverConn) RemoteAddr() net.Addr { 345 | return M.Socksaddr{} 346 | } 347 | 348 | func (c *serverConn) Close() error { 349 | c.Stream.CancelRead(0) 350 | return c.Stream.Close() 351 | } 352 | -------------------------------------------------------------------------------- /hysteria2/service_packet.go: -------------------------------------------------------------------------------- 1 | package hysteria2 2 | 3 | import ( 4 | "github.com/sagernet/sing/common" 5 | "github.com/sagernet/sing/common/auth" 6 | "github.com/sagernet/sing/common/canceler" 7 | E "github.com/sagernet/sing/common/exceptions" 8 | M "github.com/sagernet/sing/common/metadata" 9 | ) 10 | 11 | func (s *serverSession[U]) loopMessages() { 12 | for { 13 | message, err := s.quicConn.ReceiveDatagram(s.ctx) 14 | if err != nil { 15 | s.closeWithError(E.Cause(err, "receive message")) 16 | return 17 | } 18 | hErr := s.handleMessage(message) 19 | if hErr != nil { 20 | s.closeWithError(E.Cause(hErr, "handle message")) 21 | return 22 | } 23 | } 24 | } 25 | 26 | func (s *serverSession[U]) handleMessage(data []byte) error { 27 | message := allocMessage() 28 | err := decodeUDPMessage(message, data) 29 | if err != nil { 30 | message.release() 31 | return E.Cause(err, "decode UDP message") 32 | } 33 | s.handleUDPMessage(message) 34 | return nil 35 | } 36 | 37 | func (s *serverSession[U]) handleUDPMessage(message *udpMessage) { 38 | s.udpAccess.RLock() 39 | udpConn, loaded := s.udpConnMap[message.sessionID] 40 | s.udpAccess.RUnlock() 41 | if !loaded || common.Done(udpConn.ctx) { 42 | udpConn = newUDPPacketConn(auth.ContextWithUser(s.ctx, s.authUser), s.quicConn, func() { 43 | s.udpAccess.Lock() 44 | delete(s.udpConnMap, message.sessionID) 45 | s.udpAccess.Unlock() 46 | }) 47 | udpConn.sessionID = message.sessionID 48 | s.udpAccess.Lock() 49 | s.udpConnMap[message.sessionID] = udpConn 50 | s.udpAccess.Unlock() 51 | newCtx, newConn := canceler.NewPacketConn(udpConn.ctx, udpConn, s.udpTimeout) 52 | go s.handler.NewPacketConnectionEx(newCtx, newConn, s.source, M.ParseSocksaddr(message.destination), nil) 53 | } 54 | udpConn.inputPacket(message) 55 | } 56 | -------------------------------------------------------------------------------- /quic.go: -------------------------------------------------------------------------------- 1 | package qtls 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "net" 7 | "net/http" 8 | 9 | "github.com/sagernet/quic-go" 10 | "github.com/sagernet/quic-go/http3" 11 | M "github.com/sagernet/sing/common/metadata" 12 | aTLS "github.com/sagernet/sing/common/tls" 13 | ) 14 | 15 | type Config interface { 16 | Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.Connection, error) 17 | DialEarly(ctx context.Context, conn net.PacketConn, addr net.Addr, config *quic.Config) (quic.EarlyConnection, error) 18 | CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr M.Socksaddr, quicConfig *quic.Config) http.RoundTripper 19 | } 20 | 21 | type ServerConfig interface { 22 | Listen(conn net.PacketConn, config *quic.Config) (Listener, error) 23 | ListenEarly(conn net.PacketConn, config *quic.Config) (EarlyListener, error) 24 | ConfigureHTTP3() 25 | } 26 | 27 | type Listener interface { 28 | Accept(ctx context.Context) (quic.Connection, error) 29 | Close() error 30 | Addr() net.Addr 31 | } 32 | 33 | type EarlyListener interface { 34 | Accept(ctx context.Context) (quic.EarlyConnection, error) 35 | Close() error 36 | Addr() net.Addr 37 | } 38 | 39 | func Dial(ctx context.Context, conn net.PacketConn, addr net.Addr, config aTLS.Config, quicConfig *quic.Config) (quic.Connection, error) { 40 | if quicTLSConfig, isQUICConfig := config.(Config); isQUICConfig { 41 | return quicTLSConfig.Dial(ctx, conn, addr, quicConfig) 42 | } 43 | tlsConfig, err := config.Config() 44 | if err != nil { 45 | return nil, err 46 | } 47 | return quic.Dial(ctx, conn, addr, tlsConfig, quicConfig) 48 | } 49 | 50 | func DialEarly(ctx context.Context, conn net.PacketConn, addr net.Addr, config aTLS.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { 51 | if quicTLSConfig, isQUICConfig := config.(Config); isQUICConfig { 52 | return quicTLSConfig.DialEarly(ctx, conn, addr, quicConfig) 53 | } 54 | tlsConfig, err := config.Config() 55 | if err != nil { 56 | return nil, err 57 | } 58 | return quic.DialEarly(ctx, conn, addr, tlsConfig, quicConfig) 59 | } 60 | 61 | func CreateTransport(conn net.PacketConn, quicConnPtr *quic.EarlyConnection, serverAddr M.Socksaddr, config aTLS.Config, quicConfig *quic.Config) (http.RoundTripper, error) { 62 | if quicTLSConfig, isQUICConfig := config.(Config); isQUICConfig { 63 | return quicTLSConfig.CreateTransport(conn, quicConnPtr, serverAddr, quicConfig), nil 64 | } 65 | tlsConfig, err := config.Config() 66 | if err != nil { 67 | return nil, err 68 | } 69 | return &http3.Transport{ 70 | TLSClientConfig: tlsConfig, 71 | QUICConfig: quicConfig, 72 | Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 73 | quicConn, err := quic.DialEarly(ctx, conn, serverAddr.UDPAddr(), tlsCfg, cfg) 74 | if err != nil { 75 | return nil, err 76 | } 77 | *quicConnPtr = quicConn 78 | return quicConn, nil 79 | }, 80 | }, nil 81 | } 82 | 83 | func Listen(conn net.PacketConn, config aTLS.ServerConfig, quicConfig *quic.Config) (Listener, error) { 84 | if quicTLSConfig, isQUICConfig := config.(ServerConfig); isQUICConfig { 85 | return quicTLSConfig.Listen(conn, quicConfig) 86 | } 87 | tlsConfig, err := config.Config() 88 | if err != nil { 89 | return nil, err 90 | } 91 | return quic.Listen(conn, tlsConfig, quicConfig) 92 | } 93 | 94 | func ListenEarly(conn net.PacketConn, config aTLS.ServerConfig, quicConfig *quic.Config) (EarlyListener, error) { 95 | if quicTLSConfig, isQUICConfig := config.(ServerConfig); isQUICConfig { 96 | return quicTLSConfig.ListenEarly(conn, quicConfig) 97 | } 98 | tlsConfig, err := config.Config() 99 | if err != nil { 100 | return nil, err 101 | } 102 | return quic.ListenEarly(conn, tlsConfig, quicConfig) 103 | } 104 | 105 | func ConfigureHTTP3(config aTLS.ServerConfig) error { 106 | if len(config.NextProtos()) == 0 { 107 | config.SetNextProtos([]string{http3.NextProtoH3}) 108 | } 109 | if quicTLSConfig, isQUICConfig := config.(ServerConfig); isQUICConfig { 110 | quicTLSConfig.ConfigureHTTP3() 111 | return nil 112 | } 113 | tlsConfig, err := config.Config() 114 | if err != nil { 115 | return err 116 | } 117 | http3.ConfigureTLSConfig(tlsConfig) 118 | return nil 119 | } 120 | -------------------------------------------------------------------------------- /tuic/address.go: -------------------------------------------------------------------------------- 1 | package tuic 2 | 3 | import M "github.com/sagernet/sing/common/metadata" 4 | 5 | var AddressSerializer = M.NewSerializer( 6 | M.AddressFamilyByte(0x00, M.AddressFamilyFqdn), 7 | M.AddressFamilyByte(0x01, M.AddressFamilyIPv4), 8 | M.AddressFamilyByte(0x02, M.AddressFamilyIPv6), 9 | M.AddressFamilyByte(0xff, M.AddressFamilyEmpty), 10 | ) 11 | -------------------------------------------------------------------------------- /tuic/client.go: -------------------------------------------------------------------------------- 1 | package tuic 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net" 7 | "runtime" 8 | "sync" 9 | "time" 10 | 11 | "github.com/sagernet/quic-go" 12 | "github.com/sagernet/sing-quic" 13 | "github.com/sagernet/sing/common" 14 | "github.com/sagernet/sing/common/baderror" 15 | "github.com/sagernet/sing/common/buf" 16 | "github.com/sagernet/sing/common/bufio" 17 | E "github.com/sagernet/sing/common/exceptions" 18 | M "github.com/sagernet/sing/common/metadata" 19 | N "github.com/sagernet/sing/common/network" 20 | aTLS "github.com/sagernet/sing/common/tls" 21 | ) 22 | 23 | type ClientOptions struct { 24 | Context context.Context 25 | Dialer N.Dialer 26 | ServerAddress M.Socksaddr 27 | TLSConfig aTLS.Config 28 | UUID [16]byte 29 | Password string 30 | CongestionControl string 31 | UDPStream bool 32 | ZeroRTTHandshake bool 33 | Heartbeat time.Duration 34 | } 35 | 36 | type Client struct { 37 | ctx context.Context 38 | dialer N.Dialer 39 | serverAddr M.Socksaddr 40 | tlsConfig aTLS.Config 41 | quicConfig *quic.Config 42 | uuid [16]byte 43 | password string 44 | congestionControl string 45 | udpStream bool 46 | zeroRTTHandshake bool 47 | heartbeat time.Duration 48 | 49 | connAccess sync.RWMutex 50 | conn *clientQUICConnection 51 | } 52 | 53 | func NewClient(options ClientOptions) (*Client, error) { 54 | if options.Heartbeat == 0 { 55 | options.Heartbeat = 10 * time.Second 56 | } 57 | quicConfig := &quic.Config{ 58 | DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), 59 | EnableDatagrams: true, 60 | MaxIncomingUniStreams: 1 << 60, 61 | } 62 | switch options.CongestionControl { 63 | case "": 64 | options.CongestionControl = "cubic" 65 | case "cubic", "new_reno", "bbr": 66 | default: 67 | return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl) 68 | } 69 | return &Client{ 70 | ctx: options.Context, 71 | dialer: options.Dialer, 72 | serverAddr: options.ServerAddress, 73 | tlsConfig: options.TLSConfig, 74 | quicConfig: quicConfig, 75 | uuid: options.UUID, 76 | password: options.Password, 77 | congestionControl: options.CongestionControl, 78 | udpStream: options.UDPStream, 79 | zeroRTTHandshake: options.ZeroRTTHandshake, 80 | heartbeat: options.Heartbeat, 81 | }, nil 82 | } 83 | 84 | func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { 85 | conn := c.conn 86 | if conn != nil && conn.active() { 87 | return conn, nil 88 | } 89 | c.connAccess.Lock() 90 | defer c.connAccess.Unlock() 91 | conn = c.conn 92 | if conn != nil && conn.active() { 93 | return conn, nil 94 | } 95 | conn, err := c.offerNew(ctx) 96 | if err != nil { 97 | return nil, err 98 | } 99 | return conn, nil 100 | } 101 | 102 | func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { 103 | udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.serverAddr) 104 | if err != nil { 105 | return nil, err 106 | } 107 | var quicConn quic.Connection 108 | if c.zeroRTTHandshake { 109 | quicConn, err = qtls.DialEarly(c.ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig) 110 | } else { 111 | quicConn, err = qtls.Dial(c.ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig) 112 | } 113 | if err != nil { 114 | udpConn.Close() 115 | return nil, E.Cause(err, "open connection") 116 | } 117 | setCongestion(c.ctx, quicConn, c.congestionControl) 118 | conn := &clientQUICConnection{ 119 | quicConn: quicConn, 120 | rawConn: udpConn, 121 | connDone: make(chan struct{}), 122 | udpConnMap: make(map[uint16]*udpPacketConn), 123 | } 124 | go func() { 125 | hErr := c.clientHandshake(quicConn) 126 | if hErr != nil { 127 | conn.closeWithError(hErr) 128 | } 129 | }() 130 | if c.udpStream { 131 | go c.loopUniStreams(conn) 132 | } 133 | go c.loopMessages(conn) 134 | go c.loopHeartbeats(conn) 135 | c.conn = conn 136 | return conn, nil 137 | } 138 | 139 | func (c *Client) clientHandshake(conn quic.Connection) error { 140 | authStream, err := conn.OpenUniStream() 141 | if err != nil { 142 | return E.Cause(err, "open handshake stream") 143 | } 144 | defer authStream.Close() 145 | handshakeState := conn.ConnectionState() 146 | tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32) 147 | if err != nil { 148 | return E.Cause(err, "export keying material") 149 | } 150 | authRequest := buf.NewSize(AuthenticateLen) 151 | authRequest.WriteByte(Version) 152 | authRequest.WriteByte(CommandAuthenticate) 153 | authRequest.Write(c.uuid[:]) 154 | authRequest.Write(tuicAuthToken) 155 | return common.Error(authStream.Write(authRequest.Bytes())) 156 | } 157 | 158 | func (c *Client) loopHeartbeats(conn *clientQUICConnection) { 159 | ticker := time.NewTicker(c.heartbeat) 160 | defer ticker.Stop() 161 | for { 162 | select { 163 | case <-conn.connDone: 164 | return 165 | case <-ticker.C: 166 | err := conn.quicConn.SendDatagram([]byte{Version, CommandHeartbeat}) 167 | if err != nil { 168 | conn.closeWithError(E.Cause(err, "send heartbeat")) 169 | } 170 | } 171 | } 172 | } 173 | 174 | func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { 175 | conn, err := c.offer(ctx) 176 | if err != nil { 177 | return nil, err 178 | } 179 | stream, err := conn.quicConn.OpenStream() 180 | if err != nil { 181 | return nil, err 182 | } 183 | return &clientConn{ 184 | Stream: stream, 185 | parent: conn, 186 | destination: destination, 187 | }, nil 188 | } 189 | 190 | func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) { 191 | conn, err := c.offer(ctx) 192 | if err != nil { 193 | return nil, err 194 | } 195 | var sessionID uint16 196 | clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() { 197 | conn.udpAccess.Lock() 198 | delete(conn.udpConnMap, sessionID) 199 | conn.udpAccess.Unlock() 200 | }) 201 | conn.udpAccess.Lock() 202 | sessionID = conn.udpSessionID 203 | conn.udpSessionID++ 204 | conn.udpConnMap[sessionID] = clientPacketConn 205 | conn.udpAccess.Unlock() 206 | clientPacketConn.sessionID = sessionID 207 | return clientPacketConn, nil 208 | } 209 | 210 | func (c *Client) CloseWithError(err error) error { 211 | conn := c.conn 212 | if conn != nil { 213 | conn.closeWithError(err) 214 | } 215 | return nil 216 | } 217 | 218 | type clientQUICConnection struct { 219 | quicConn quic.Connection 220 | rawConn io.Closer 221 | closeOnce sync.Once 222 | connDone chan struct{} 223 | connErr error 224 | udpAccess sync.RWMutex 225 | udpConnMap map[uint16]*udpPacketConn 226 | udpSessionID uint16 227 | } 228 | 229 | func (c *clientQUICConnection) active() bool { 230 | select { 231 | case <-c.quicConn.Context().Done(): 232 | return false 233 | default: 234 | } 235 | select { 236 | case <-c.connDone: 237 | return false 238 | default: 239 | } 240 | return true 241 | } 242 | 243 | func (c *clientQUICConnection) closeWithError(err error) { 244 | c.closeOnce.Do(func() { 245 | c.connErr = err 246 | close(c.connDone) 247 | _ = c.quicConn.CloseWithError(0, "") 248 | _ = c.rawConn.Close() 249 | }) 250 | } 251 | 252 | type clientConn struct { 253 | quic.Stream 254 | parent *clientQUICConnection 255 | destination M.Socksaddr 256 | requestWritten bool 257 | } 258 | 259 | func (c *clientConn) NeedHandshake() bool { 260 | return !c.requestWritten 261 | } 262 | 263 | func (c *clientConn) Read(b []byte) (n int, err error) { 264 | n, err = c.Stream.Read(b) 265 | return n, baderror.WrapQUIC(err) 266 | } 267 | 268 | func (c *clientConn) Write(b []byte) (n int, err error) { 269 | if !c.requestWritten { 270 | request := buf.NewSize(2 + AddressSerializer.AddrPortLen(c.destination) + len(b)) 271 | defer request.Release() 272 | request.WriteByte(Version) 273 | request.WriteByte(CommandConnect) 274 | err = AddressSerializer.WriteAddrPort(request, c.destination) 275 | if err != nil { 276 | return 277 | } 278 | request.Write(b) 279 | _, err = c.Stream.Write(request.Bytes()) 280 | if err != nil { 281 | c.parent.closeWithError(E.Cause(err, "create new connection")) 282 | return 0, baderror.WrapQUIC(err) 283 | } 284 | c.requestWritten = true 285 | return len(b), nil 286 | } 287 | n, err = c.Stream.Write(b) 288 | return n, baderror.WrapQUIC(err) 289 | } 290 | 291 | func (c *clientConn) Close() error { 292 | c.Stream.CancelRead(0) 293 | return c.Stream.Close() 294 | } 295 | 296 | func (c *clientConn) LocalAddr() net.Addr { 297 | return M.Socksaddr{} 298 | } 299 | 300 | func (c *clientConn) RemoteAddr() net.Addr { 301 | return c.destination 302 | } 303 | -------------------------------------------------------------------------------- /tuic/client_packet.go: -------------------------------------------------------------------------------- 1 | package tuic 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/sagernet/quic-go" 7 | "github.com/sagernet/sing/common/buf" 8 | "github.com/sagernet/sing/common/bufio" 9 | E "github.com/sagernet/sing/common/exceptions" 10 | ) 11 | 12 | func (c *Client) loopMessages(conn *clientQUICConnection) { 13 | for { 14 | message, err := conn.quicConn.ReceiveDatagram(c.ctx) 15 | if err != nil { 16 | conn.closeWithError(E.Cause(err, "receive message")) 17 | return 18 | } 19 | go func() { 20 | hErr := c.handleMessage(conn, message) 21 | if hErr != nil { 22 | conn.closeWithError(E.Cause(hErr, "handle message")) 23 | } 24 | }() 25 | } 26 | } 27 | 28 | func (c *Client) handleMessage(conn *clientQUICConnection, data []byte) error { 29 | if len(data) < 2 { 30 | return E.New("invalid message") 31 | } 32 | if data[0] != Version { 33 | return E.New("unknown version ", data[0]) 34 | } 35 | switch data[1] { 36 | case CommandPacket: 37 | message := allocMessage() 38 | err := decodeUDPMessage(message, data[2:]) 39 | if err != nil { 40 | message.release() 41 | return E.Cause(err, "decode UDP message") 42 | } 43 | conn.handleUDPMessage(message) 44 | return nil 45 | case CommandHeartbeat: 46 | return nil 47 | default: 48 | return E.New("unknown command ", data[0]) 49 | } 50 | } 51 | 52 | func (c *Client) loopUniStreams(conn *clientQUICConnection) { 53 | for { 54 | stream, err := conn.quicConn.AcceptUniStream(c.ctx) 55 | if err != nil { 56 | conn.closeWithError(E.Cause(err, "handle uni stream")) 57 | return 58 | } 59 | go func() { 60 | hErr := c.handleUniStream(conn, stream) 61 | if hErr != nil { 62 | conn.closeWithError(hErr) 63 | } 64 | }() 65 | } 66 | } 67 | 68 | func (c *Client) handleUniStream(conn *clientQUICConnection, stream quic.ReceiveStream) error { 69 | defer stream.CancelRead(0) 70 | buffer := buf.NewPacket() 71 | defer buffer.Release() 72 | _, err := buffer.ReadAtLeastFrom(stream, 2) 73 | if err != nil { 74 | return err 75 | } 76 | version, _ := buffer.ReadByte() 77 | if version != Version { 78 | return E.New("unknown version ", version) 79 | } 80 | command, _ := buffer.ReadByte() 81 | if command != CommandPacket { 82 | return E.New("unknown command ", command) 83 | } 84 | reader := io.MultiReader(bufio.NewCachedReader(stream, buffer), stream) 85 | message := allocMessage() 86 | err = readUDPMessage(message, reader) 87 | if err != nil { 88 | message.release() 89 | return err 90 | } 91 | conn.handleUDPMessage(message) 92 | return nil 93 | } 94 | 95 | func (c *clientQUICConnection) handleUDPMessage(message *udpMessage) { 96 | c.udpAccess.RLock() 97 | udpConn, loaded := c.udpConnMap[message.sessionID] 98 | c.udpAccess.RUnlock() 99 | if !loaded { 100 | message.releaseMessage() 101 | return 102 | } 103 | select { 104 | case <-udpConn.ctx.Done(): 105 | message.releaseMessage() 106 | return 107 | default: 108 | } 109 | udpConn.inputPacket(message) 110 | } 111 | -------------------------------------------------------------------------------- /tuic/congestion.go: -------------------------------------------------------------------------------- 1 | package tuic 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/sagernet/quic-go" 8 | "github.com/sagernet/quic-go/congestion" 9 | congestion_meta1 "github.com/sagernet/sing-quic/congestion_meta1" 10 | congestion_meta2 "github.com/sagernet/sing-quic/congestion_meta2" 11 | "github.com/sagernet/sing/common/ntp" 12 | ) 13 | 14 | func setCongestion(ctx context.Context, connection quic.Connection, congestionName string) { 15 | timeFunc := ntp.TimeFuncFromContext(ctx) 16 | if timeFunc == nil { 17 | timeFunc = time.Now 18 | } 19 | switch congestionName { 20 | case "cubic": 21 | connection.SetCongestionControl( 22 | congestion_meta1.NewCubicSender( 23 | congestion_meta1.DefaultClock{TimeFunc: timeFunc}, 24 | congestion.ByteCount(connection.Config().InitialPacketSize), 25 | false, 26 | nil, 27 | ), 28 | ) 29 | case "new_reno": 30 | connection.SetCongestionControl( 31 | congestion_meta1.NewCubicSender( 32 | congestion_meta1.DefaultClock{TimeFunc: timeFunc}, 33 | congestion.ByteCount(connection.Config().InitialPacketSize), 34 | true, 35 | nil, 36 | ), 37 | ) 38 | case "bbr_meta_v1": 39 | connection.SetCongestionControl(congestion_meta1.NewBBRSender( 40 | congestion_meta1.DefaultClock{TimeFunc: timeFunc}, 41 | congestion.ByteCount(connection.Config().InitialPacketSize), 42 | congestion_meta1.InitialCongestionWindow*congestion_meta1.InitialMaxDatagramSize, 43 | congestion_meta1.DefaultBBRMaxCongestionWindow*congestion_meta1.InitialMaxDatagramSize, 44 | )) 45 | case "bbr": 46 | connection.SetCongestionControl(congestion_meta2.NewBbrSender( 47 | congestion_meta2.DefaultClock{TimeFunc: timeFunc}, 48 | congestion.ByteCount(connection.Config().InitialPacketSize), 49 | congestion.ByteCount(congestion_meta1.InitialCongestionWindow), 50 | )) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /tuic/packet_wait.go: -------------------------------------------------------------------------------- 1 | package tuic 2 | 3 | import ( 4 | "io" 5 | "os" 6 | 7 | "github.com/sagernet/sing/common/buf" 8 | M "github.com/sagernet/sing/common/metadata" 9 | N "github.com/sagernet/sing/common/network" 10 | ) 11 | 12 | func (c *udpPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { 13 | c.readWaitOptions = options 14 | return options.NeedHeadroom() 15 | } 16 | 17 | func (c *udpPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { 18 | select { 19 | case p := <-c.data: 20 | destination = p.destination 21 | if c.readWaitOptions.NeedHeadroom() { 22 | buffer = c.readWaitOptions.NewPacketBuffer() 23 | p.releaseMessage() 24 | _, err = buffer.Write(p.data.Bytes()) 25 | if err != nil { 26 | buffer.Release() 27 | return 28 | } 29 | c.readWaitOptions.PostReturn(buffer) 30 | } else { 31 | buffer = p.data 32 | p.release() 33 | } 34 | return 35 | case <-c.ctx.Done(): 36 | return nil, M.Socksaddr{}, io.ErrClosedPipe 37 | case <-c.readDeadline.Wait(): 38 | return nil, M.Socksaddr{}, os.ErrDeadlineExceeded 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /tuic/protocol.go: -------------------------------------------------------------------------------- 1 | package tuic 2 | 3 | const ( 4 | Version = 5 5 | ) 6 | 7 | const ( 8 | CommandAuthenticate = iota 9 | CommandConnect 10 | CommandPacket 11 | CommandDissociate 12 | CommandHeartbeat 13 | ) 14 | 15 | const AuthenticateLen = 2 + 16 + 32 16 | -------------------------------------------------------------------------------- /tuic/service.go: -------------------------------------------------------------------------------- 1 | package tuic 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/binary" 7 | "errors" 8 | "io" 9 | "net" 10 | "runtime" 11 | "sync" 12 | "time" 13 | 14 | "github.com/sagernet/quic-go" 15 | "github.com/sagernet/sing-quic" 16 | "github.com/sagernet/sing/common" 17 | "github.com/sagernet/sing/common/auth" 18 | "github.com/sagernet/sing/common/baderror" 19 | "github.com/sagernet/sing/common/buf" 20 | "github.com/sagernet/sing/common/bufio" 21 | E "github.com/sagernet/sing/common/exceptions" 22 | "github.com/sagernet/sing/common/logger" 23 | M "github.com/sagernet/sing/common/metadata" 24 | N "github.com/sagernet/sing/common/network" 25 | aTLS "github.com/sagernet/sing/common/tls" 26 | 27 | "github.com/gofrs/uuid/v5" 28 | ) 29 | 30 | type ServiceOptions struct { 31 | Context context.Context 32 | Logger logger.Logger 33 | TLSConfig aTLS.ServerConfig 34 | CongestionControl string 35 | AuthTimeout time.Duration 36 | ZeroRTTHandshake bool 37 | Heartbeat time.Duration 38 | UDPTimeout time.Duration 39 | Handler ServiceHandler 40 | } 41 | 42 | type ServiceHandler interface { 43 | N.TCPConnectionHandlerEx 44 | N.UDPConnectionHandlerEx 45 | } 46 | 47 | type Service[U comparable] struct { 48 | ctx context.Context 49 | logger logger.Logger 50 | tlsConfig aTLS.ServerConfig 51 | heartbeat time.Duration 52 | quicConfig *quic.Config 53 | userMap map[[16]byte]U 54 | passwordMap map[U]string 55 | congestionControl string 56 | authTimeout time.Duration 57 | udpTimeout time.Duration 58 | handler ServiceHandler 59 | 60 | quicListener io.Closer 61 | } 62 | 63 | func NewService[U comparable](options ServiceOptions) (*Service[U], error) { 64 | if options.AuthTimeout == 0 { 65 | options.AuthTimeout = 3 * time.Second 66 | } 67 | if options.Heartbeat == 0 { 68 | options.Heartbeat = 10 * time.Second 69 | } 70 | quicConfig := &quic.Config{ 71 | DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), 72 | EnableDatagrams: true, 73 | Allow0RTT: options.ZeroRTTHandshake, 74 | MaxIncomingStreams: 1 << 60, 75 | MaxIncomingUniStreams: 1 << 60, 76 | } 77 | switch options.CongestionControl { 78 | case "": 79 | options.CongestionControl = "cubic" 80 | case "cubic", "new_reno", "bbr": 81 | default: 82 | return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl) 83 | } 84 | return &Service[U]{ 85 | ctx: options.Context, 86 | logger: options.Logger, 87 | tlsConfig: options.TLSConfig, 88 | heartbeat: options.Heartbeat, 89 | quicConfig: quicConfig, 90 | userMap: make(map[[16]byte]U), 91 | congestionControl: options.CongestionControl, 92 | authTimeout: options.AuthTimeout, 93 | udpTimeout: options.UDPTimeout, 94 | handler: options.Handler, 95 | }, nil 96 | } 97 | 98 | func (s *Service[U]) UpdateUsers(userList []U, uuidList [][16]byte, passwordList []string) { 99 | userMap := make(map[[16]byte]U) 100 | passwordMap := make(map[U]string) 101 | for index := range userList { 102 | userMap[uuidList[index]] = userList[index] 103 | passwordMap[userList[index]] = passwordList[index] 104 | } 105 | s.userMap = userMap 106 | s.passwordMap = passwordMap 107 | } 108 | 109 | func (s *Service[U]) Start(conn net.PacketConn) error { 110 | if !s.quicConfig.Allow0RTT { 111 | listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig) 112 | if err != nil { 113 | return err 114 | } 115 | s.quicListener = listener 116 | go func() { 117 | for { 118 | connection, hErr := listener.Accept(s.ctx) 119 | if hErr != nil { 120 | if E.IsClosedOrCanceled(hErr) || errors.Is(hErr, quic.ErrServerClosed) { 121 | s.logger.Debug(E.Cause(hErr, "listener closed")) 122 | } else { 123 | s.logger.Error(E.Cause(hErr, "listener closed")) 124 | } 125 | return 126 | } 127 | go s.handleConnection(connection) 128 | } 129 | }() 130 | } else { 131 | listener, err := qtls.ListenEarly(conn, s.tlsConfig, s.quicConfig) 132 | if err != nil { 133 | return err 134 | } 135 | s.quicListener = listener 136 | go func() { 137 | for { 138 | connection, hErr := listener.Accept(s.ctx) 139 | if hErr != nil { 140 | if E.IsClosedOrCanceled(hErr) || errors.Is(hErr, quic.ErrServerClosed) { 141 | s.logger.Debug(E.Cause(hErr, "listener closed")) 142 | } else { 143 | s.logger.Error(E.Cause(hErr, "listener closed")) 144 | } 145 | return 146 | } 147 | go s.handleConnection(connection) 148 | } 149 | }() 150 | } 151 | return nil 152 | } 153 | 154 | func (s *Service[U]) Close() error { 155 | return common.Close( 156 | s.quicListener, 157 | ) 158 | } 159 | 160 | func (s *Service[U]) handleConnection(connection quic.Connection) { 161 | setCongestion(s.ctx, connection, s.congestionControl) 162 | session := &serverSession[U]{ 163 | Service: s, 164 | ctx: s.ctx, 165 | quicConn: connection, 166 | source: M.SocksaddrFromNet(connection.RemoteAddr()).Unwrap(), 167 | connDone: make(chan struct{}), 168 | authDone: make(chan struct{}), 169 | udpConnMap: make(map[uint16]*udpPacketConn), 170 | } 171 | session.handle() 172 | } 173 | 174 | type serverSession[U comparable] struct { 175 | *Service[U] 176 | ctx context.Context 177 | quicConn quic.Connection 178 | source M.Socksaddr 179 | connAccess sync.Mutex 180 | connDone chan struct{} 181 | connErr error 182 | authDone chan struct{} 183 | authUser U 184 | udpAccess sync.RWMutex 185 | udpConnMap map[uint16]*udpPacketConn 186 | } 187 | 188 | func (s *serverSession[U]) handle() { 189 | if s.ctx.Done() != nil { 190 | go func() { 191 | select { 192 | case <-s.ctx.Done(): 193 | s.closeWithError(s.ctx.Err()) 194 | case <-s.connDone: 195 | } 196 | }() 197 | } 198 | go s.loopUniStreams() 199 | go s.loopStreams() 200 | go s.loopMessages() 201 | go s.handleAuthTimeout() 202 | go s.loopHeartbeats() 203 | } 204 | 205 | func (s *serverSession[U]) loopUniStreams() { 206 | for { 207 | uniStream, err := s.quicConn.AcceptUniStream(s.ctx) 208 | if err != nil { 209 | return 210 | } 211 | go func() { 212 | err = s.handleUniStream(uniStream) 213 | if err != nil { 214 | s.closeWithError(E.Cause(err, "handle uni stream")) 215 | } 216 | }() 217 | } 218 | } 219 | 220 | func (s *serverSession[U]) handleUniStream(stream quic.ReceiveStream) error { 221 | defer stream.CancelRead(0) 222 | buffer := buf.New() 223 | defer buffer.Release() 224 | _, err := buffer.ReadAtLeastFrom(stream, 2) 225 | if err != nil { 226 | return E.Cause(err, "read request") 227 | } 228 | version := buffer.Byte(0) 229 | if version != Version { 230 | return E.New("unknown version ", buffer.Byte(0)) 231 | } 232 | command := buffer.Byte(1) 233 | switch command { 234 | case CommandAuthenticate: 235 | select { 236 | case <-s.authDone: 237 | return E.New("authentication: multiple authentication requests") 238 | default: 239 | } 240 | if buffer.Len() < AuthenticateLen { 241 | _, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len()) 242 | if err != nil { 243 | return E.Cause(err, "authentication: read request") 244 | } 245 | } 246 | var userUUID [16]byte 247 | copy(userUUID[:], buffer.Range(2, 2+16)) 248 | user, loaded := s.userMap[userUUID] 249 | if !loaded { 250 | return E.New("authentication: unknown user ", uuid.UUID(userUUID)) 251 | } 252 | handshakeState := s.quicConn.ConnectionState() 253 | tuicToken, err := handshakeState.ExportKeyingMaterial(string(userUUID[:]), []byte(s.passwordMap[user]), 32) 254 | if err != nil { 255 | return E.Cause(err, "authentication: export keying material") 256 | } 257 | if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) { 258 | return E.New("authentication: token mismatch") 259 | } 260 | s.authUser = user 261 | close(s.authDone) 262 | return nil 263 | case CommandPacket: 264 | select { 265 | case <-s.connDone: 266 | return s.connErr 267 | case <-s.authDone: 268 | } 269 | message := allocMessage() 270 | err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream)) 271 | if err != nil { 272 | message.release() 273 | return err 274 | } 275 | s.handleUDPMessage(message, true) 276 | return nil 277 | case CommandDissociate: 278 | select { 279 | case <-s.connDone: 280 | return s.connErr 281 | case <-s.authDone: 282 | } 283 | if buffer.Len() > 4 { 284 | return E.New("invalid dissociate message") 285 | } 286 | var sessionID uint16 287 | err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID) 288 | if err != nil { 289 | return err 290 | } 291 | s.udpAccess.RLock() 292 | udpConn, loaded := s.udpConnMap[sessionID] 293 | s.udpAccess.RUnlock() 294 | if loaded { 295 | udpConn.closeWithError(E.New("remote closed")) 296 | s.udpAccess.Lock() 297 | delete(s.udpConnMap, sessionID) 298 | s.udpAccess.Unlock() 299 | } 300 | return nil 301 | default: 302 | return E.New("unknown command ", command) 303 | } 304 | } 305 | 306 | func (s *serverSession[U]) handleAuthTimeout() { 307 | select { 308 | case <-s.connDone: 309 | case <-s.authDone: 310 | case <-time.After(s.authTimeout): 311 | s.closeWithError(E.New("authentication timeout")) 312 | } 313 | } 314 | 315 | func (s *serverSession[U]) loopStreams() { 316 | for { 317 | stream, err := s.quicConn.AcceptStream(s.ctx) 318 | if err != nil { 319 | return 320 | } 321 | go func() { 322 | err = s.handleStream(stream) 323 | if err != nil { 324 | stream.CancelRead(0) 325 | stream.Close() 326 | s.logger.Error(E.Cause(err, "handle stream request")) 327 | } 328 | }() 329 | } 330 | } 331 | 332 | func (s *serverSession[U]) handleStream(stream quic.Stream) error { 333 | buffer := buf.NewSize(2 + M.MaxSocksaddrLength) 334 | defer buffer.Release() 335 | _, err := buffer.ReadAtLeastFrom(stream, 2) 336 | if err != nil { 337 | return E.Cause(err, "read request") 338 | } 339 | version, _ := buffer.ReadByte() 340 | if version != Version { 341 | return E.New("unknown version ", buffer.Byte(0)) 342 | } 343 | command, _ := buffer.ReadByte() 344 | if command != CommandConnect { 345 | return E.New("unsupported stream command ", command) 346 | } 347 | destination, err := AddressSerializer.ReadAddrPort(io.MultiReader(buffer, stream)) 348 | if err != nil { 349 | return E.Cause(err, "read request destination") 350 | } 351 | select { 352 | case <-s.connDone: 353 | return s.connErr 354 | case <-s.authDone: 355 | } 356 | var conn net.Conn = &serverConn{ 357 | Stream: stream, 358 | destination: destination, 359 | } 360 | if buffer.IsEmpty() { 361 | buffer.Release() 362 | } else { 363 | conn = bufio.NewCachedConn(conn, buffer) 364 | } 365 | s.handler.NewConnectionEx(auth.ContextWithUser(s.ctx, s.authUser), conn, s.source, destination, nil) 366 | return nil 367 | } 368 | 369 | func (s *serverSession[U]) loopHeartbeats() { 370 | ticker := time.NewTicker(s.heartbeat) 371 | defer ticker.Stop() 372 | for { 373 | select { 374 | case <-s.connDone: 375 | return 376 | case <-ticker.C: 377 | err := s.quicConn.SendDatagram([]byte{Version, CommandHeartbeat}) 378 | if err != nil { 379 | s.closeWithError(E.Cause(err, "send heartbeat")) 380 | } 381 | } 382 | } 383 | } 384 | 385 | func (s *serverSession[U]) closeWithError(err error) { 386 | s.connAccess.Lock() 387 | defer s.connAccess.Unlock() 388 | select { 389 | case <-s.connDone: 390 | return 391 | default: 392 | s.connErr = err 393 | close(s.connDone) 394 | } 395 | if E.IsClosedOrCanceled(err) { 396 | s.logger.Debug(E.Cause(err, "connection failed")) 397 | } else { 398 | s.logger.Error(E.Cause(err, "connection failed")) 399 | } 400 | _ = s.quicConn.CloseWithError(0, "") 401 | } 402 | 403 | type serverConn struct { 404 | quic.Stream 405 | destination M.Socksaddr 406 | } 407 | 408 | func (c *serverConn) Read(p []byte) (n int, err error) { 409 | n, err = c.Stream.Read(p) 410 | return n, baderror.WrapQUIC(err) 411 | } 412 | 413 | func (c *serverConn) Write(p []byte) (n int, err error) { 414 | n, err = c.Stream.Write(p) 415 | return n, baderror.WrapQUIC(err) 416 | } 417 | 418 | func (c *serverConn) LocalAddr() net.Addr { 419 | return c.destination 420 | } 421 | 422 | func (c *serverConn) RemoteAddr() net.Addr { 423 | return M.Socksaddr{} 424 | } 425 | 426 | func (c *serverConn) Close() error { 427 | c.Stream.CancelRead(0) 428 | return c.Stream.Close() 429 | } 430 | -------------------------------------------------------------------------------- /tuic/service_packet.go: -------------------------------------------------------------------------------- 1 | package tuic 2 | 3 | import ( 4 | "github.com/sagernet/sing/common" 5 | "github.com/sagernet/sing/common/auth" 6 | "github.com/sagernet/sing/common/canceler" 7 | E "github.com/sagernet/sing/common/exceptions" 8 | ) 9 | 10 | func (s *serverSession[U]) loopMessages() { 11 | select { 12 | case <-s.connDone: 13 | return 14 | case <-s.authDone: 15 | } 16 | for { 17 | message, err := s.quicConn.ReceiveDatagram(s.ctx) 18 | if err != nil { 19 | s.closeWithError(E.Cause(err, "receive message")) 20 | return 21 | } 22 | hErr := s.handleMessage(message) 23 | if hErr != nil { 24 | s.closeWithError(E.Cause(hErr, "handle message")) 25 | return 26 | } 27 | } 28 | } 29 | 30 | func (s *serverSession[U]) handleMessage(data []byte) error { 31 | if len(data) < 2 { 32 | return E.New("invalid message") 33 | } 34 | if data[0] != Version { 35 | return E.New("unknown version ", data[0]) 36 | } 37 | switch data[1] { 38 | case CommandPacket: 39 | message := allocMessage() 40 | err := decodeUDPMessage(message, data[2:]) 41 | if err != nil { 42 | message.release() 43 | return E.Cause(err, "decode UDP message") 44 | } 45 | s.handleUDPMessage(message, false) 46 | return nil 47 | case CommandHeartbeat: 48 | return nil 49 | default: 50 | return E.New("unknown command ", data[0]) 51 | } 52 | } 53 | 54 | func (s *serverSession[U]) handleUDPMessage(message *udpMessage, udpStream bool) { 55 | s.udpAccess.RLock() 56 | udpConn, loaded := s.udpConnMap[message.sessionID] 57 | s.udpAccess.RUnlock() 58 | if !loaded || common.Done(udpConn.ctx) { 59 | udpConn = newUDPPacketConn(auth.ContextWithUser(s.ctx, s.authUser), s.quicConn, udpStream, true, func() { 60 | s.udpAccess.Lock() 61 | delete(s.udpConnMap, message.sessionID) 62 | s.udpAccess.Unlock() 63 | }) 64 | udpConn.sessionID = message.sessionID 65 | s.udpAccess.Lock() 66 | s.udpConnMap[message.sessionID] = udpConn 67 | s.udpAccess.Unlock() 68 | newCtx, newConn := canceler.NewPacketConn(udpConn.ctx, udpConn, s.udpTimeout) 69 | go s.handler.NewPacketConnectionEx(newCtx, newConn, s.source, message.destination, nil) 70 | } 71 | udpConn.inputPacket(message) 72 | } 73 | -------------------------------------------------------------------------------- /workflows/debug.yml: -------------------------------------------------------------------------------- 1 | name: Debug build 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | paths-ignore: 9 | - '**.md' 10 | - '.github/**' 11 | - '!.github/workflows/debug.yml' 12 | pull_request: 13 | branches: 14 | - main 15 | - dev 16 | 17 | jobs: 18 | build: 19 | name: Linux Debug build 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Checkout 23 | uses: actions/checkout@v4 24 | with: 25 | fetch-depth: 0 26 | - name: Setup Go 27 | uses: actions/setup-go@v4 28 | with: 29 | go-version: ^1.22 30 | - name: Build 31 | run: | 32 | make test 33 | build_go120: 34 | name: Linux Debug build (Go 1.20) 35 | runs-on: ubuntu-latest 36 | steps: 37 | - name: Checkout 38 | uses: actions/checkout@v4 39 | with: 40 | fetch-depth: 0 41 | - name: Setup Go 42 | uses: actions/setup-go@v4 43 | with: 44 | go-version: ~1.20 45 | continue-on-error: true 46 | - name: Build 47 | run: | 48 | make test 49 | build_go121: 50 | name: Linux Debug build (Go 1.21) 51 | runs-on: ubuntu-latest 52 | steps: 53 | - name: Checkout 54 | uses: actions/checkout@v4 55 | with: 56 | fetch-depth: 0 57 | - name: Setup Go 58 | uses: actions/setup-go@v4 59 | with: 60 | go-version: ~1.21 61 | continue-on-error: true 62 | - name: Build 63 | run: | 64 | make test 65 | build__windows: 66 | name: Windows Debug build 67 | runs-on: windows-latest 68 | steps: 69 | - name: Checkout 70 | uses: actions/checkout@v4 71 | with: 72 | fetch-depth: 0 73 | - name: Setup Go 74 | uses: actions/setup-go@v4 75 | with: 76 | go-version: ^1.22 77 | continue-on-error: true 78 | - name: Build 79 | run: | 80 | make test 81 | build_darwin: 82 | name: macOS Debug build 83 | runs-on: macos-latest 84 | steps: 85 | - name: Checkout 86 | uses: actions/checkout@v4 87 | with: 88 | fetch-depth: 0 89 | - name: Setup Go 90 | uses: actions/setup-go@v4 91 | with: 92 | go-version: ^1.22 93 | continue-on-error: true 94 | - name: Build 95 | run: | 96 | make test --------------------------------------------------------------------------------