├── vnet ├── .gitignore ├── vnet.go ├── errors.go ├── udpproxy_direct.go ├── chunk_queue.go ├── chunk_queue_test.go ├── queue.go ├── tbf_queue_test.go ├── resolver.go ├── resolver_test.go ├── tbf_queue.go ├── queue_test.go ├── tbf.go ├── conn_map.go ├── delay_filter.go ├── chunk_test.go ├── stress_test.go ├── tbf_test.go ├── delay_filter_test.go ├── udpproxy.go ├── chunk.go ├── conn_test.go └── conn_map_test.go ├── .github ├── .gitignore ├── workflows │ ├── api.yaml │ ├── lint.yaml │ ├── reuse.yml │ ├── release.yml │ ├── renovate-go-sum-fix.yaml │ ├── tidy-check.yaml │ ├── codeql-analysis.yml │ ├── fuzz.yaml │ └── test.yaml ├── install-hooks.sh └── fetch-scripts.sh ├── .goreleaser.yml ├── renovate.json ├── packetio ├── hardlimit.go ├── no_hardlimit.go └── errors.go ├── test ├── util_nowasm.go ├── test.go ├── util_wasm.go ├── connctx.go ├── test_test.go ├── rand.go ├── util_test.go ├── stress.go └── util.go ├── deadline ├── timer.go ├── timer_generic.go ├── timer_js.go ├── deadline.go └── deadline_test.go ├── netctx ├── pipe.go ├── conn.go ├── packetconn.go ├── conn_test.go └── packetconn_test.go ├── connctx ├── pipe.go ├── connctx.go └── connctx_test.go ├── go.mod ├── .gitignore ├── codecov.yml ├── utils └── xor │ ├── xor_generic.go │ ├── xor_arm.go │ ├── xor_test.go │ ├── xor_old.go │ └── xor_arm.s ├── .reuse └── dep5 ├── LICENSES ├── MIT.txt └── BSD-3-Clause.txt ├── LICENSE ├── examples └── vnet-udpproxy │ ├── README.md │ └── main.go ├── go.sum ├── replaydetector ├── fixedbig_test.go ├── fixedbig.go └── replaydetector.go ├── README.md ├── dpipe ├── dpipe_test.go └── dpipe.go ├── stdnet ├── net.go └── net_test.go └── udp ├── batchconn_test.go └── batchconn.go /vnet/.gitignore: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2023 The Pion community 2 | # SPDX-License-Identifier: MIT 3 | 4 | *.sw[poe] 5 | -------------------------------------------------------------------------------- /.github/.gitignore: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2023 The Pion community 2 | # SPDX-License-Identifier: MIT 3 | 4 | .goassets 5 | -------------------------------------------------------------------------------- /.goreleaser.yml: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2023 The Pion community 2 | # SPDX-License-Identifier: MIT 3 | 4 | builds: 5 | - skip: true 6 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": [ 4 | "github>pion/renovate-config" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /vnet/vnet.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | // Package vnet provides a virtual network layer for pion 5 | package vnet 6 | -------------------------------------------------------------------------------- /packetio/hardlimit.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | //go:build packetioSizeHardlimit 5 | // +build packetioSizeHardlimit 6 | 7 | package packetio 8 | 9 | const sizeHardLimit = true 10 | -------------------------------------------------------------------------------- /test/util_nowasm.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | //go:build !wasm 5 | // +build !wasm 6 | 7 | package test 8 | 9 | func filterRoutineWASM(string) bool { 10 | return false 11 | } 12 | -------------------------------------------------------------------------------- /deadline/timer.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package deadline 5 | 6 | import ( 7 | "time" 8 | ) 9 | 10 | type timer interface { 11 | Stop() bool 12 | Reset(time.Duration) bool 13 | } 14 | -------------------------------------------------------------------------------- /packetio/no_hardlimit.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | //go:build !packetioSizeHardlimit 5 | // +build !packetioSizeHardlimit 6 | 7 | package packetio 8 | 9 | const sizeHardLimit = false 10 | -------------------------------------------------------------------------------- /netctx/pipe.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package netctx 5 | 6 | import ( 7 | "net" 8 | ) 9 | 10 | // Pipe creates piped pair of Conn. 11 | func Pipe() (Conn, Conn) { 12 | ca, cb := net.Pipe() 13 | 14 | return NewConn(ca), NewConn(cb) 15 | } 16 | -------------------------------------------------------------------------------- /connctx/pipe.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package connctx 5 | 6 | import ( 7 | "net" 8 | ) 9 | 10 | // Pipe creates piped pair of ConnCtx. 11 | func Pipe() (ConnCtx, ConnCtx) { 12 | ca, cb := net.Pipe() 13 | 14 | return New(ca), New(cb) 15 | } 16 | -------------------------------------------------------------------------------- /deadline/timer_generic.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | //go:build !js 5 | // +build !js 6 | 7 | package deadline 8 | 9 | import ( 10 | "time" 11 | ) 12 | 13 | func afterFunc(d time.Duration, f func()) timer { 14 | return time.AfterFunc(d, f) 15 | } 16 | -------------------------------------------------------------------------------- /test/test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | // Package test provides helpers to test the various pion transports implementations. 5 | // The tests are standardized around the io.ReadWriteCloser interface. 6 | // This package is meant to be used in addition to golang.org/x/net/nettest. 7 | package test 8 | -------------------------------------------------------------------------------- /test/util_wasm.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package test 5 | 6 | import ( 7 | "strings" 8 | ) 9 | 10 | func filterRoutineWASM(stack string) bool { 11 | // Nested t.Run on Go 1.14-1.21 and go1.22 WASM have these routines 12 | return strings.Contains(stack, "runtime.goexit()") || 13 | strings.Contains(stack, "runtime.goexit({})") 14 | } 15 | -------------------------------------------------------------------------------- /vnet/errors.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | type timeoutError struct { 7 | msg string 8 | } 9 | 10 | func newTimeoutError(msg string) error { 11 | return &timeoutError{ 12 | msg: msg, 13 | } 14 | } 15 | 16 | func (e *timeoutError) Error() string { 17 | return e.msg 18 | } 19 | 20 | func (e *timeoutError) Timeout() bool { 21 | return true 22 | } 23 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/pion/transport/v3 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/pion/logging v0.2.4 7 | github.com/stretchr/testify v1.11.1 8 | github.com/wlynxg/anet v0.0.5 9 | golang.org/x/net v0.34.0 10 | golang.org/x/sys v0.29.0 11 | golang.org/x/time v0.10.0 12 | ) 13 | 14 | require ( 15 | github.com/davecgh/go-spew v1.1.1 // indirect 16 | github.com/pmezard/go-difflib v1.0.0 // indirect 17 | gopkg.in/yaml.v3 v3.0.1 // indirect 18 | ) 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2023 The Pion community 2 | # SPDX-License-Identifier: MIT 3 | 4 | ### JetBrains IDE ### 5 | ##################### 6 | .idea/ 7 | 8 | ### Emacs Temporary Files ### 9 | ############################# 10 | *~ 11 | 12 | ### Folders ### 13 | ############### 14 | bin/ 15 | vendor/ 16 | node_modules/ 17 | 18 | ### Files ### 19 | ############# 20 | *.ivf 21 | *.ogg 22 | tags 23 | cover.out 24 | *.sw[poe] 25 | *.wasm 26 | examples/sfu-ws/cert.pem 27 | examples/sfu-ws/key.pem 28 | wasm_exec.js 29 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # 2 | # DO NOT EDIT THIS FILE 3 | # 4 | # It is automatically copied from https://github.com/pion/.goassets repository. 5 | # 6 | # SPDX-FileCopyrightText: 2023 The Pion community 7 | # SPDX-License-Identifier: MIT 8 | 9 | coverage: 10 | status: 11 | project: 12 | default: 13 | # Allow decreasing 2% of total coverage to avoid noise. 14 | threshold: 2% 15 | patch: 16 | default: 17 | target: 70% 18 | only_pulls: true 19 | 20 | ignore: 21 | - "examples/*" 22 | - "examples/**/*" 23 | -------------------------------------------------------------------------------- /utils/xor/xor_generic.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2013 The Go Authors. All rights reserved. 2 | // SPDX-License-Identifier: BSD-3-Clause 3 | // SPDX-FileCopyrightText: 2024 The Pion community 4 | // SPDX-License-Identifier: MIT 5 | 6 | //go:build go1.20 && !arm && !gccgo 7 | 8 | // Package xor provides the XorBytes function. 9 | package xor 10 | 11 | import ( 12 | "crypto/subtle" 13 | ) 14 | 15 | // XorBytes calls [crypto/suble.XORBytes]. 16 | // 17 | //revive:disable-next-line 18 | func XorBytes(dst, a, b []byte) int { 19 | return subtle.XORBytes(dst, a, b) 20 | } 21 | -------------------------------------------------------------------------------- /.reuse/dep5: -------------------------------------------------------------------------------- 1 | Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ 2 | Upstream-Name: Pion 3 | Source: https://github.com/pion/ 4 | 5 | Files: README.md DESIGN.md **/README.md AUTHORS.txt renovate.json go.mod go.sum **/go.mod **/go.sum .eslintrc.json package.json examples.json sfu-ws/flutter/.gitignore sfu-ws/flutter/pubspec.yaml c-data-channels/webrtc.h examples/examples.json yarn.lock 6 | Copyright: 2023 The Pion community 7 | License: MIT 8 | 9 | Files: testdata/seed/* testdata/fuzz/* **/testdata/fuzz/* api/*.txt 10 | Copyright: 2023 The Pion community 11 | License: CC0-1.0 12 | -------------------------------------------------------------------------------- /.github/workflows/api.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # DO NOT EDIT THIS FILE 3 | # 4 | # It is automatically copied from https://github.com/pion/.goassets repository. 5 | # If this repository should have package specific CI config, 6 | # remove the repository name from .goassets/.github/workflows/assets-sync.yml. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | name: API 15 | on: 16 | pull_request: 17 | 18 | jobs: 19 | check: 20 | uses: pion/.goassets/.github/workflows/api.reusable.yml@master 21 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # DO NOT EDIT THIS FILE 3 | # 4 | # It is automatically copied from https://github.com/pion/.goassets repository. 5 | # If this repository should have package specific CI config, 6 | # remove the repository name from .goassets/.github/workflows/assets-sync.yml. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | name: Lint 15 | on: 16 | pull_request: 17 | 18 | jobs: 19 | lint: 20 | uses: pion/.goassets/.github/workflows/lint.reusable.yml@master 21 | -------------------------------------------------------------------------------- /.github/workflows/reuse.yml: -------------------------------------------------------------------------------- 1 | # 2 | # DO NOT EDIT THIS FILE 3 | # 4 | # It is automatically copied from https://github.com/pion/.goassets repository. 5 | # If this repository should have package specific CI config, 6 | # remove the repository name from .goassets/.github/workflows/assets-sync.yml. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | name: REUSE Compliance Check 15 | 16 | on: 17 | push: 18 | pull_request: 19 | 20 | jobs: 21 | lint: 22 | uses: pion/.goassets/.github/workflows/reuse.reusable.yml@master 23 | -------------------------------------------------------------------------------- /packetio/errors.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package packetio 5 | 6 | import ( 7 | "errors" 8 | ) 9 | 10 | // netError implements net.Error. 11 | type netError struct { 12 | error 13 | timeout, temporary bool 14 | } 15 | 16 | func (e *netError) Timeout() bool { 17 | return e.timeout 18 | } 19 | 20 | func (e *netError) Temporary() bool { 21 | return e.temporary 22 | } 23 | 24 | var ( 25 | // ErrFull is returned when the buffer has hit the configured limits. 26 | ErrFull = errors.New("packetio.Buffer is full, discarding write") 27 | 28 | // ErrTimeout is returned when a deadline has expired. 29 | ErrTimeout = errors.New("i/o timeout") 30 | ) 31 | -------------------------------------------------------------------------------- /.github/install-hooks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # 4 | # DO NOT EDIT THIS FILE 5 | # 6 | # It is automatically copied from https://github.com/pion/.goassets repository. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | SCRIPT_PATH="$(realpath "$(dirname "$0")")" 15 | 16 | . ${SCRIPT_PATH}/fetch-scripts.sh 17 | 18 | cp "${GOASSETS_PATH}/hooks/commit-msg.sh" "${SCRIPT_PATH}/../.git/hooks/commit-msg" 19 | cp "${GOASSETS_PATH}/hooks/pre-commit.sh" "${SCRIPT_PATH}/../.git/hooks/pre-commit" 20 | cp "${GOASSETS_PATH}/hooks/pre-push.sh" "${SCRIPT_PATH}/../.git/hooks/pre-push" 21 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # 2 | # DO NOT EDIT THIS FILE 3 | # 4 | # It is automatically copied from https://github.com/pion/.goassets repository. 5 | # If this repository should have package specific CI config, 6 | # remove the repository name from .goassets/.github/workflows/assets-sync.yml. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | name: Release 15 | on: 16 | push: 17 | tags: 18 | - 'v*' 19 | 20 | jobs: 21 | release: 22 | uses: pion/.goassets/.github/workflows/release.reusable.yml@master 23 | with: 24 | go-version: "1.25" # auto-update/latest-go-version 25 | -------------------------------------------------------------------------------- /.github/workflows/renovate-go-sum-fix.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # DO NOT EDIT THIS FILE 3 | # 4 | # It is automatically copied from https://github.com/pion/.goassets repository. 5 | # If this repository should have package specific CI config, 6 | # remove the repository name from .goassets/.github/workflows/assets-sync.yml. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | name: Fix go.sum 15 | on: 16 | push: 17 | branches: 18 | - renovate/* 19 | 20 | jobs: 21 | fix: 22 | uses: pion/.goassets/.github/workflows/renovate-go-sum-fix.reusable.yml@master 23 | secrets: 24 | token: ${{ secrets.PIONBOT_PRIVATE_KEY }} 25 | -------------------------------------------------------------------------------- /.github/workflows/tidy-check.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # DO NOT EDIT THIS FILE 3 | # 4 | # It is automatically copied from https://github.com/pion/.goassets repository. 5 | # If this repository should have package specific CI config, 6 | # remove the repository name from .goassets/.github/workflows/assets-sync.yml. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | name: Go mod tidy 15 | on: 16 | pull_request: 17 | push: 18 | branches: 19 | - master 20 | 21 | jobs: 22 | tidy: 23 | uses: pion/.goassets/.github/workflows/tidy-check.reusable.yml@master 24 | with: 25 | go-version: "1.25" # auto-update/latest-go-version 26 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # 2 | # DO NOT EDIT THIS FILE 3 | # 4 | # It is automatically copied from https://github.com/pion/.goassets repository. 5 | # If this repository should have package specific CI config, 6 | # remove the repository name from .goassets/.github/workflows/assets-sync.yml. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | name: CodeQL 15 | 16 | on: 17 | workflow_dispatch: 18 | schedule: 19 | - cron: '23 5 * * 0' 20 | pull_request: 21 | branches: 22 | - master 23 | paths: 24 | - '**.go' 25 | 26 | jobs: 27 | analyze: 28 | uses: pion/.goassets/.github/workflows/codeql-analysis.reusable.yml@master 29 | -------------------------------------------------------------------------------- /.github/workflows/fuzz.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # DO NOT EDIT THIS FILE 3 | # 4 | # It is automatically copied from https://github.com/pion/.goassets repository. 5 | # If this repository should have package specific CI config, 6 | # remove the repository name from .goassets/.github/workflows/assets-sync.yml. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | name: Fuzz 15 | on: 16 | push: 17 | branches: 18 | - master 19 | schedule: 20 | - cron: "0 */8 * * *" 21 | 22 | jobs: 23 | fuzz: 24 | uses: pion/.goassets/.github/workflows/fuzz.reusable.yml@master 25 | with: 26 | go-version: "1.25" # auto-update/latest-go-version 27 | fuzz-time: "60s" 28 | -------------------------------------------------------------------------------- /test/connctx.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package test 5 | 6 | import ( 7 | "context" 8 | "io" 9 | ) 10 | 11 | type wrappedReader struct { 12 | io.Reader 13 | } 14 | 15 | func (r *wrappedReader) ReadContext(_ context.Context, b []byte) (int, error) { 16 | return r.Reader.Read(b) 17 | } 18 | 19 | type wrappedWriter struct { 20 | io.Writer 21 | } 22 | 23 | func (r *wrappedWriter) WriteContext(_ context.Context, b []byte) (int, error) { 24 | return r.Writer.Write(b) 25 | } 26 | 27 | type wrappedReadWriter struct { 28 | io.ReadWriter 29 | } 30 | 31 | func (r *wrappedReadWriter) ReadContext(_ context.Context, b []byte) (int, error) { 32 | return r.ReadWriter.Read(b) 33 | } 34 | 35 | func (r *wrappedReadWriter) WriteContext(_ context.Context, b []byte) (int, error) { 36 | return r.ReadWriter.Write(b) 37 | } 38 | -------------------------------------------------------------------------------- /test/test_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package test 5 | 6 | import ( 7 | "io" 8 | "net" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestStressIOPipe(t *testing.T) { 15 | r, w := io.Pipe() 16 | 17 | opt := Options{ 18 | MsgSize: 2048, 19 | MsgCount: 100, 20 | } 21 | 22 | assert.NoError(t, Stress(w, r, opt)) 23 | } 24 | 25 | func TestStressDuplexNetPipe(t *testing.T) { 26 | ca, cb := net.Pipe() 27 | 28 | opt := Options{ 29 | MsgSize: 2048, 30 | MsgCount: 100, 31 | } 32 | 33 | assert.NoError(t, StressDuplex(ca, cb, opt)) 34 | } 35 | 36 | func BenchmarkPipe(b *testing.B) { 37 | ca, cb := net.Pipe() 38 | 39 | b.ResetTimer() 40 | 41 | opt := Options{ 42 | MsgSize: 2048, 43 | MsgCount: b.N, 44 | } 45 | 46 | assert.NoError(b, Stress(ca, cb, opt)) 47 | } 48 | -------------------------------------------------------------------------------- /test/rand.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package test 5 | 6 | import ( 7 | crand "crypto/rand" 8 | "errors" 9 | "fmt" 10 | mrand "math/rand" 11 | ) 12 | 13 | var errRequestTooLargeBuffer = errors.New("requested too large buffer") 14 | 15 | type randomizer struct { 16 | randomness []byte 17 | } 18 | 19 | func initRand() randomizer { 20 | // read 1MB of randomness 21 | randomness := make([]byte, 1<<20) 22 | if _, err := crand.Read(randomness); err != nil { 23 | fmt.Println("Failed to initiate randomness:", err) // nolint 24 | } 25 | 26 | return randomizer{ 27 | randomness: randomness, 28 | } 29 | } 30 | 31 | func (r *randomizer) randBuf(size int) ([]byte, error) { 32 | n := len(r.randomness) - size 33 | if n < 1 { 34 | return nil, fmt.Errorf("%w (%d). max is %d", errRequestTooLargeBuffer, size, len(r.randomness)) 35 | } 36 | 37 | start := mrand.Intn(n) //nolint:gosec 38 | 39 | return r.randomness[start : start+size], nil 40 | } 41 | -------------------------------------------------------------------------------- /.github/fetch-scripts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # 4 | # DO NOT EDIT THIS FILE 5 | # 6 | # It is automatically copied from https://github.com/pion/.goassets repository. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | set -eu 15 | 16 | SCRIPT_PATH="$(realpath "$(dirname "$0")")" 17 | GOASSETS_PATH="${SCRIPT_PATH}/.goassets" 18 | 19 | GOASSETS_REF=${GOASSETS_REF:-master} 20 | 21 | if [ -d "${GOASSETS_PATH}" ]; then 22 | if ! git -C "${GOASSETS_PATH}" diff --exit-code; then 23 | echo "${GOASSETS_PATH} has uncommitted changes" >&2 24 | exit 1 25 | fi 26 | git -C "${GOASSETS_PATH}" fetch origin 27 | git -C "${GOASSETS_PATH}" checkout ${GOASSETS_REF} 28 | git -C "${GOASSETS_PATH}" reset --hard origin/${GOASSETS_REF} 29 | else 30 | git clone -b ${GOASSETS_REF} https://github.com/pion/.goassets.git "${GOASSETS_PATH}" 31 | fi 32 | -------------------------------------------------------------------------------- /LICENSES/MIT.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 The Pion community 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /deadline/timer_js.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | //go:build js 5 | // +build js 6 | 7 | package deadline 8 | 9 | import ( 10 | "sync" 11 | "time" 12 | ) 13 | 14 | // jsTimer is a timer utility for wasm with a working Reset function. 15 | type jsTimer struct { 16 | f func() 17 | mu sync.Mutex 18 | timer *time.Timer 19 | version uint64 20 | started bool 21 | } 22 | 23 | func afterFunc(d time.Duration, f func()) timer { 24 | t := &jsTimer{f: f} 25 | t.Reset(d) 26 | return t 27 | } 28 | 29 | func (t *jsTimer) Stop() bool { 30 | t.mu.Lock() 31 | defer t.mu.Unlock() 32 | 33 | t.version++ 34 | t.timer.Stop() 35 | 36 | started := t.started 37 | t.started = false 38 | return started 39 | } 40 | 41 | func (t *jsTimer) Reset(d time.Duration) bool { 42 | t.mu.Lock() 43 | defer t.mu.Unlock() 44 | 45 | if t.timer != nil { 46 | t.timer.Stop() 47 | } 48 | 49 | t.version++ 50 | version := t.version 51 | t.timer = time.AfterFunc(d, func() { 52 | t.mu.Lock() 53 | if version != t.version { 54 | t.mu.Unlock() 55 | return 56 | } 57 | 58 | t.started = false 59 | t.mu.Unlock() 60 | 61 | t.f() 62 | }) 63 | 64 | started := t.started 65 | t.started = true 66 | return started 67 | } 68 | -------------------------------------------------------------------------------- /examples/vnet-udpproxy/README.md: -------------------------------------------------------------------------------- 1 | # vnet-udpproxy 2 | 3 | This example demonstrates how VNet can be used to communicate with non-VNet addresses using UDPProxy. 4 | 5 | In this example we listen map the VNet Address `10.0.0.11` to a real address of our choice. We then 6 | send to our real address from three different VNet addresses. 7 | 8 | If you pass `-address 192.168.1.3:8000` the traffic will be the following 9 | 10 | ``` 11 | vnet(10.0.0.11:5787) => proxy => 192.168.1.3:8000 12 | vnet(10.0.0.11:5788) => proxy => 192.168.1.3:8000 13 | vnet(10.0.0.11:5789) => proxy => 192.168.1.3:8000 14 | ``` 15 | 16 | ## Running 17 | ``` 18 | go run main.go -address 192.168.1.3:8000 19 | ``` 20 | 21 | You should see the following in tcpdump 22 | ``` 23 | sean@SeanLaptop:~/go/src/github.com/pion/transport/examples$ sudo tcpdump -i any udp and port 8000 24 | tcpdump: data link type LINUX_SLL2 25 | tcpdump: verbose output suppressed, use -v[v]... for full protocol decode 26 | listening on any, link-type LINUX_SLL2 (Linux cooked v2), snapshot length 262144 bytes 27 | 13:21:18.239943 lo In IP 192.168.1.7.40574 > 192.168.1.7.8000: UDP, length 5 28 | 13:21:18.240105 lo In IP 192.168.1.7.40647 > 192.168.1.7.8000: UDP, length 5 29 | 13:21:18.240304 lo In IP 192.168.1.7.57744 > 192.168.1.7.8000: UDP, length 5 30 | ``` 31 | -------------------------------------------------------------------------------- /test/util_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package test 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestCheckRoutines(t *testing.T) { 15 | // Limit runtime in case of deadlocks 16 | lim := TimeOut(time.Second * 20) 17 | defer lim.Stop() 18 | 19 | // Check for leaking routines 20 | report := CheckRoutines(t) 21 | defer report() 22 | 23 | go func() { 24 | time.Sleep(1 * time.Second) 25 | }() 26 | } 27 | 28 | func TestCheckRoutinesStrict(t *testing.T) { 29 | mock := &tbMock{TB: t} 30 | 31 | // Limit runtime in case of deadlocks 32 | lim := TimeOut(time.Second * 20) 33 | defer lim.Stop() 34 | 35 | // Check for leaking routines 36 | report := CheckRoutinesStrict(mock) 37 | defer func() { 38 | report() 39 | assert.NotEmpty(t, mock.fatalfCalled, "expected Fatalf to be called") 40 | assert.Contains(t, mock.fatalfCalled[0], "Unexpected routines") 41 | }() 42 | 43 | go func() { 44 | time.Sleep(1 * time.Second) 45 | }() 46 | } 47 | 48 | type tbMock struct { 49 | testing.TB 50 | 51 | fatalfCalled []string 52 | } 53 | 54 | func (m *tbMock) Fatalf(format string, args ...any) { 55 | m.fatalfCalled = append(m.fatalfCalled, fmt.Sprintf(format, args...)) 56 | } 57 | -------------------------------------------------------------------------------- /LICENSES/BSD-3-Clause.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) . 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /vnet/udpproxy_direct.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "fmt" 8 | "net" 9 | ) 10 | 11 | // Deliver directly send packet to vnet or real-server. 12 | // For example, we can use this API to simulate the REPLAY ATTACK. 13 | func (v *UDPProxy) Deliver(sourceAddr, destAddr net.Addr, b []byte) (nn int, err error) { 14 | v.workers.Range(func(_, value any) bool { 15 | worker, ok := value.(*aUDPProxyWorker) 16 | if !ok { 17 | return false 18 | } 19 | 20 | if nn, err = worker.Deliver(sourceAddr, destAddr, b); err != nil { 21 | return false // Fail, abort. 22 | } else if nn == len(b) { 23 | return false // Done. 24 | } 25 | 26 | return true // Deliver by next worker. 27 | }) 28 | 29 | return 30 | } 31 | 32 | func (v *aUDPProxyWorker) Deliver(sourceAddr, _ net.Addr, b []byte) (nn int, err error) { 33 | addr, ok := sourceAddr.(*net.UDPAddr) 34 | if !ok { 35 | return 0, fmt.Errorf("invalid addr %v", sourceAddr) // nolint:err113 36 | } 37 | 38 | // nolint:godox // TODO: Support deliver packet from real server to vnet. 39 | // If packet is from vnet, proxy to real server. 40 | var realSocket *net.UDPConn 41 | value, ok := v.endpoints.Load(addr.String()) 42 | if !ok { 43 | return 0, nil 44 | } 45 | 46 | realSocket = value.(*net.UDPConn) // nolint:forcetypeassert 47 | 48 | // Send to real server. 49 | if _, err := realSocket.Write(b); err != nil { 50 | return 0, err 51 | } 52 | 53 | return len(b), nil 54 | } 55 | -------------------------------------------------------------------------------- /vnet/chunk_queue.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "sync" 8 | ) 9 | 10 | type chunkQueue struct { 11 | chunks []Chunk 12 | maxSize int // 0 or negative value: unlimited 13 | maxBytes int // 0 or negative value: unlimited 14 | currentBytes int 15 | mutex sync.RWMutex 16 | } 17 | 18 | func newChunkQueue(maxSize int, maxBytes int) *chunkQueue { 19 | return &chunkQueue{ 20 | chunks: []Chunk{}, 21 | maxSize: maxSize, 22 | maxBytes: maxBytes, 23 | currentBytes: 0, 24 | mutex: sync.RWMutex{}, 25 | } 26 | } 27 | 28 | func (q *chunkQueue) push(c Chunk) bool { 29 | q.mutex.Lock() 30 | defer q.mutex.Unlock() 31 | 32 | if q.maxSize > 0 && len(q.chunks) >= q.maxSize { 33 | return false // dropped 34 | } 35 | if q.maxBytes > 0 && q.currentBytes+len(c.UserData()) >= q.maxBytes { 36 | return false 37 | } 38 | 39 | q.currentBytes += len(c.UserData()) 40 | q.chunks = append(q.chunks, c) 41 | 42 | return true 43 | } 44 | 45 | func (q *chunkQueue) pop() (Chunk, bool) { 46 | q.mutex.Lock() 47 | defer q.mutex.Unlock() 48 | 49 | if len(q.chunks) == 0 { 50 | return nil, false 51 | } 52 | 53 | c := q.chunks[0] 54 | q.chunks = q.chunks[1:] 55 | q.currentBytes -= len(c.UserData()) 56 | 57 | return c, true 58 | } 59 | 60 | func (q *chunkQueue) peek() Chunk { 61 | q.mutex.RLock() 62 | defer q.mutex.RUnlock() 63 | 64 | if len(q.chunks) == 0 { 65 | return nil 66 | } 67 | 68 | return q.chunks[0] 69 | } 70 | -------------------------------------------------------------------------------- /vnet/chunk_queue_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "net" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestChunkQueue(t *testing.T) { 14 | chunk := newChunkUDP(&net.UDPAddr{ 15 | IP: net.ParseIP("192.188.0.2"), 16 | Port: 1234, 17 | }, &net.UDPAddr{ 18 | IP: net.ParseIP(demoIP), 19 | Port: 5678, 20 | }) 21 | chunk.userData = make([]byte, 1200) 22 | 23 | var ok bool 24 | var queue *chunkQueue 25 | var chunk2 Chunk 26 | 27 | queue = newChunkQueue(0, 0) 28 | 29 | chunk2 = queue.peek() 30 | assert.Nil(t, chunk2, "should return nil") 31 | 32 | ok = queue.push(chunk) 33 | assert.True(t, ok, "should succeed") 34 | 35 | chunk2, ok = queue.pop() 36 | assert.True(t, ok, "should succeed") 37 | assert.Equal(t, chunk, chunk2, "should be the same") 38 | 39 | chunk2, ok = queue.pop() 40 | assert.False(t, ok, "should fail") 41 | assert.Nil(t, chunk2, "should be nil") 42 | 43 | queue = newChunkQueue(1, 0) 44 | ok = queue.push(chunk) 45 | assert.True(t, ok, "should succeed") 46 | 47 | ok = queue.push(chunk) 48 | assert.False(t, ok, "should fail") 49 | 50 | chunk2 = queue.peek() 51 | assert.Equal(t, chunk, chunk2, "should be the same") 52 | 53 | queue = newChunkQueue(0, 1500) 54 | ok = queue.push(chunk) 55 | assert.True(t, ok, "should succeed") 56 | 57 | ok = queue.push(chunk) 58 | assert.False(t, ok, "should fail") 59 | 60 | chunk2 = queue.peek() 61 | assert.Equal(t, chunk, chunk2, "should be the same") 62 | } 63 | -------------------------------------------------------------------------------- /utils/xor/xor_arm.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2022 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | //go:build !gccgo 5 | // +build !gccgo 6 | 7 | // Package xor provides utility functions used by other Pion 8 | // packages. ARM arch. 9 | package xor 10 | 11 | import ( 12 | "unsafe" 13 | 14 | "golang.org/x/sys/cpu" 15 | ) 16 | 17 | const wordSize = int(unsafe.Sizeof(uintptr(0))) // nolint:gosec 18 | var hasNEON = cpu.ARM.HasNEON // nolint:gochecknoglobals 19 | 20 | func isAligned(a *byte) bool { 21 | return uintptr(unsafe.Pointer(a))%uintptr(wordSize) == 0 22 | } 23 | 24 | // XorBytes xors the bytes in a and b. The destination should have enough 25 | // space, otherwise xorBytes will panic. Returns the number of bytes xor'd. 26 | // 27 | //revive:disable-next-line 28 | func XorBytes(dst, a, b []byte) int { 29 | n := len(a) 30 | if len(b) < n { 31 | n = len(b) 32 | } 33 | if n == 0 { 34 | return 0 35 | } 36 | // make sure dst has enough space 37 | _ = dst[n-1] 38 | 39 | if hasNEON { 40 | xorBytesNEON32(&dst[0], &a[0], &b[0], n) 41 | } else if isAligned(&dst[0]) && isAligned(&a[0]) && isAligned(&b[0]) { 42 | xorBytesARM32(&dst[0], &a[0], &b[0], n) 43 | } else { 44 | safeXORBytes(dst, a, b, n) 45 | } 46 | return n 47 | } 48 | 49 | // n needs to be smaller or equal than the length of a and b. 50 | func safeXORBytes(dst, a, b []byte, n int) { 51 | for i := 0; i < n; i++ { 52 | dst[i] = a[i] ^ b[i] 53 | } 54 | } 55 | 56 | //go:noescape 57 | func xorBytesARM32(dst, a, b *byte, n int) 58 | 59 | //go:noescape 60 | func xorBytesNEON32(dst, a, b *byte, n int) 61 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= 4 | github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= 5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 7 | github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= 8 | github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= 9 | github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= 10 | github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= 11 | golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= 12 | golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= 13 | golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= 14 | golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 15 | golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= 16 | golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 17 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 18 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 19 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 20 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 21 | -------------------------------------------------------------------------------- /replaydetector/fixedbig_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package replaydetector 5 | 6 | import ( 7 | "fmt" 8 | ) 9 | 10 | func Example_fixedBigInt_SetBit() { 11 | bi := newFixedBigInt(224) 12 | 13 | bi.SetBit(0) 14 | fmt.Println(bi.String()) 15 | bi.Lsh(1) 16 | fmt.Println(bi.String()) 17 | 18 | bi.Lsh(0) 19 | fmt.Println(bi.String()) 20 | 21 | bi.SetBit(10) 22 | fmt.Println(bi.String()) 23 | bi.Lsh(20) 24 | fmt.Println(bi.String()) 25 | 26 | bi.SetBit(80) 27 | fmt.Println(bi.String()) 28 | bi.Lsh(4) 29 | fmt.Println(bi.String()) 30 | 31 | bi.SetBit(130) 32 | fmt.Println(bi.String()) 33 | bi.Lsh(64) 34 | fmt.Println(bi.String()) 35 | 36 | bi.SetBit(7) 37 | fmt.Println(bi.String()) 38 | 39 | bi.Lsh(129) 40 | fmt.Println(bi.String()) 41 | 42 | for i := 0; i < 256; i++ { 43 | bi.Lsh(1) 44 | bi.SetBit(0) 45 | } 46 | fmt.Println(bi.String()) 47 | 48 | // output: 49 | // 0000000000000000000000000000000000000000000000000000000000000001 50 | // 0000000000000000000000000000000000000000000000000000000000000002 51 | // 0000000000000000000000000000000000000000000000000000000000000002 52 | // 0000000000000000000000000000000000000000000000000000000000000402 53 | // 0000000000000000000000000000000000000000000000000000000040200000 54 | // 0000000000000000000000000000000000000000000100000000000040200000 55 | // 0000000000000000000000000000000000000000001000000000000402000000 56 | // 0000000000000000000000000000000400000000001000000000000402000000 57 | // 0000000000000004000000000010000000000004020000000000000000000000 58 | // 0000000000000004000000000010000000000004020000000000000000000080 59 | // 0000000004000000000000000000010000000000000000000000000000000000 60 | // 00000000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF 61 | } 62 | -------------------------------------------------------------------------------- /replaydetector/fixedbig.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package replaydetector 5 | 6 | import ( 7 | "fmt" 8 | ) 9 | 10 | // fixedBigInt is the fix-sized multi-word integer. 11 | type fixedBigInt struct { 12 | bits []uint64 13 | n uint 14 | msbMask uint64 15 | } 16 | 17 | // newFixedBigInt creates a new fix-sized multi-word int. 18 | func newFixedBigInt(n uint) *fixedBigInt { 19 | chunkSize := (n + 63) / 64 20 | if chunkSize == 0 { 21 | chunkSize = 1 22 | } 23 | 24 | return &fixedBigInt{ 25 | bits: make([]uint64, chunkSize), 26 | n: n, 27 | msbMask: (1 << (64 - n%64)) - 1, 28 | } 29 | } 30 | 31 | // Lsh is the left shift operation. 32 | func (s *fixedBigInt) Lsh(n uint) { //nolint:varnamelen 33 | if n == 0 { 34 | return 35 | } 36 | nChunk := int(n / 64) //nolint:gosec 37 | nN := n % 64 38 | 39 | for i := len(s.bits) - 1; i >= 0; i-- { 40 | var carry uint64 41 | if i-nChunk >= 0 { 42 | carry = s.bits[i-nChunk] << nN 43 | if i-nChunk-1 >= 0 { 44 | carry |= s.bits[i-nChunk-1] >> (64 - nN) 45 | } 46 | } 47 | s.bits[i] = (s.bits[i] << n) | carry 48 | } 49 | s.bits[len(s.bits)-1] &= s.msbMask 50 | } 51 | 52 | // Bit returns i-th bit of the fixedBigInt. 53 | func (s *fixedBigInt) Bit(i uint) uint { 54 | if i >= s.n { 55 | return 0 56 | } 57 | chunk := i / 64 58 | pos := i % 64 59 | if s.bits[chunk]&(1<= s.n { 69 | return 70 | } 71 | chunk := i / 64 72 | pos := i % 64 73 | s.bits[chunk] |= 1 << pos 74 | } 75 | 76 | // String returns string representation of fixedBigInt. 77 | func (s *fixedBigInt) String() string { 78 | var out string 79 | for i := len(s.bits) - 1; i >= 0; i-- { 80 | out += fmt.Sprintf("%016X", s.bits[i]) 81 | } 82 | 83 | return out 84 | } 85 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # DO NOT EDIT THIS FILE 3 | # 4 | # It is automatically copied from https://github.com/pion/.goassets repository. 5 | # If this repository should have package specific CI config, 6 | # remove the repository name from .goassets/.github/workflows/assets-sync.yml. 7 | # 8 | # If you want to update the shared CI config, send a PR to 9 | # https://github.com/pion/.goassets instead of this repository. 10 | # 11 | # SPDX-FileCopyrightText: 2023 The Pion community 12 | # SPDX-License-Identifier: MIT 13 | 14 | name: Test 15 | on: 16 | push: 17 | branches: 18 | - master 19 | pull_request: 20 | 21 | jobs: 22 | test: 23 | uses: pion/.goassets/.github/workflows/test.reusable.yml@master 24 | strategy: 25 | matrix: 26 | go: ["1.25", "1.24"] # auto-update/supported-go-version-list 27 | fail-fast: false 28 | with: 29 | go-version: ${{ matrix.go }} 30 | secrets: inherit 31 | 32 | test-i386: 33 | uses: pion/.goassets/.github/workflows/test-i386.reusable.yml@master 34 | strategy: 35 | matrix: 36 | go: ["1.25", "1.24"] # auto-update/supported-go-version-list 37 | fail-fast: false 38 | with: 39 | go-version: ${{ matrix.go }} 40 | 41 | test-windows: 42 | uses: pion/.goassets/.github/workflows/test-windows.reusable.yml@master 43 | strategy: 44 | matrix: 45 | go: ["1.25", "1.24"] # auto-update/supported-go-version-list 46 | fail-fast: false 47 | with: 48 | go-version: ${{ matrix.go }} 49 | 50 | test-macos: 51 | uses: pion/.goassets/.github/workflows/test-macos.reusable.yml@master 52 | strategy: 53 | matrix: 54 | go: ["1.25", "1.24"] # auto-update/supported-go-version-list 55 | fail-fast: false 56 | with: 57 | go-version: ${{ matrix.go }} 58 | 59 | test-wasm: 60 | uses: pion/.goassets/.github/workflows/test-wasm.reusable.yml@master 61 | with: 62 | go-version: "1.25" # auto-update/latest-go-version 63 | secrets: inherit 64 | -------------------------------------------------------------------------------- /vnet/queue.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2025 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "sync" 8 | "time" 9 | ) 10 | 11 | type Discipline interface { 12 | push(Chunk) 13 | pop() Chunk 14 | empty() bool 15 | next() time.Time 16 | } 17 | 18 | type Queue struct { 19 | NIC 20 | data Discipline 21 | chunkCh chan Chunk 22 | closed bool 23 | close chan struct{} 24 | wg sync.WaitGroup 25 | lock sync.Mutex 26 | } 27 | 28 | func NewQueue(n NIC, d Discipline) (*Queue, error) { 29 | q := &Queue{ 30 | NIC: n, 31 | data: d, 32 | chunkCh: make(chan Chunk), 33 | closed: false, 34 | close: make(chan struct{}), 35 | wg: sync.WaitGroup{}, 36 | lock: sync.Mutex{}, 37 | } 38 | q.wg.Add(1) 39 | go q.run() 40 | 41 | return q, nil 42 | } 43 | 44 | func (q *Queue) onInboundChunk(c Chunk) { 45 | select { 46 | case q.chunkCh <- c: 47 | case <-q.close: 48 | 49 | return 50 | } 51 | } 52 | 53 | func (q *Queue) run() { 54 | defer q.wg.Done() 55 | for { 56 | if !q.schedule() { 57 | return 58 | } 59 | } 60 | } 61 | 62 | func (q *Queue) schedule() bool { 63 | q.lock.Lock() 64 | if q.closed { 65 | q.lock.Unlock() 66 | 67 | return false 68 | } 69 | q.lock.Unlock() 70 | 71 | var timer <-chan time.Time 72 | 73 | if !q.data.empty() { 74 | next := q.data.next() 75 | timer = time.After(time.Until(next)) 76 | } 77 | 78 | select { 79 | case chunk := <-q.chunkCh: 80 | q.data.push(chunk) 81 | case <-timer: 82 | chunk := q.data.pop() 83 | if chunk != nil { 84 | q.NIC.onInboundChunk(chunk) 85 | } 86 | case <-q.close: 87 | return false 88 | } 89 | 90 | return true 91 | } 92 | 93 | func (q *Queue) Close() error { 94 | defer q.wg.Wait() 95 | q.lock.Lock() 96 | defer q.lock.Unlock() 97 | if q.closed { 98 | return nil 99 | } 100 | q.closed = true 101 | close(q.close) 102 | 103 | return nil 104 | } 105 | -------------------------------------------------------------------------------- /vnet/tbf_queue_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2025 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestTBFQueue(t *testing.T) { 14 | t.Run("enqueue-dequeue", func(t *testing.T) { 15 | q := NewTBFQueue(1_000_000, 10_000, 15*1500) 16 | chunk := &chunkUDP{ 17 | userData: make([]byte, 1300), 18 | } 19 | q.push(chunk) 20 | res := q.pop() 21 | assert.Equal(t, chunk, res) 22 | }) 23 | t.Run("drop-packets", func(t *testing.T) { 24 | q := NewTBFQueue(1_000_000, 10_000, 15*1500) 25 | chunk := &chunkUDP{ 26 | userData: make([]byte, 1500), 27 | } 28 | for i := 0; i < 20; i++ { 29 | q.push(chunk) 30 | } 31 | assert.Len(t, q.chunks, 15) 32 | }) 33 | t.Run("burst", func(t *testing.T) { 34 | queue := NewTBFQueue(150_000, 10*1500, 15*1500) 35 | chunk := &chunkUDP{ 36 | userData: make([]byte, 1500), 37 | } 38 | for i := 0; i < 15; i++ { 39 | queue.push(chunk) 40 | } 41 | // queue size is 15, burst is 10 so we should be allowed to burst 10 42 | // packets immediately 43 | for i := 0; i < 10; i++ { 44 | assert.Equal(t, chunk, queue.pop()) 45 | } 46 | // But no more than 10 47 | assert.Nil(t, queue.pop()) 48 | }) 49 | t.Run("rate", func(t *testing.T) { 50 | queue := NewTBFQueue(8*15_000, 1500, 30_000) 51 | chunk := &chunkUDP{ 52 | userData: make([]byte, 1000), 53 | } 54 | for i := 0; i < 30; i++ { 55 | queue.push(chunk) 56 | } 57 | now := time.Now() 58 | received := 0 59 | // dequeue for one second 60 | for time.Since(now) < time.Second { 61 | res := queue.pop() 62 | if res == nil { 63 | next := queue.next() 64 | if next.IsZero() { 65 | continue 66 | } 67 | time.Sleep(time.Until(queue.next())) 68 | 69 | continue 70 | } 71 | assert.NotNil(t, res, "received nil chunk after receiving %v bytes", received) 72 | received += len(res.UserData()) 73 | } 74 | // Initial burst of 1000+1s*15_000 byte/s = 16_000 bytes 75 | assert.Equal(t, 16_000, received) 76 | }) 77 | } 78 | -------------------------------------------------------------------------------- /vnet/resolver.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | "net" 10 | "sync" 11 | 12 | "github.com/pion/logging" 13 | ) 14 | 15 | var ( 16 | errHostnameEmpty = errors.New("host name must not be empty") 17 | errFailedToParseIPAddr = errors.New("failed to parse IP address") 18 | ) 19 | 20 | type resolverConfig struct { 21 | LoggerFactory logging.LoggerFactory 22 | } 23 | 24 | type resolver struct { 25 | parent *resolver // read-only 26 | hosts map[string]net.IP // requires mutex 27 | mutex sync.RWMutex // thread-safe 28 | log logging.LeveledLogger // read-only 29 | } 30 | 31 | func newResolver(config *resolverConfig) *resolver { 32 | r := &resolver{ 33 | hosts: map[string]net.IP{}, 34 | log: config.LoggerFactory.NewLogger("vnet"), 35 | } 36 | 37 | if err := r.addHost("localhost", "127.0.0.1"); err != nil { 38 | r.log.Warn("failed to add localhost to resolver") 39 | } 40 | 41 | return r 42 | } 43 | 44 | func (r *resolver) setParent(parent *resolver) { 45 | r.mutex.Lock() 46 | defer r.mutex.Unlock() 47 | 48 | r.parent = parent 49 | } 50 | 51 | func (r *resolver) addHost(name string, ipAddr string) error { 52 | r.mutex.Lock() 53 | defer r.mutex.Unlock() 54 | 55 | if len(name) == 0 { 56 | return errHostnameEmpty 57 | } 58 | ip := net.ParseIP(ipAddr) 59 | if ip == nil { 60 | return fmt.Errorf("%w \"%s\"", errFailedToParseIPAddr, ipAddr) 61 | } 62 | r.hosts[name] = ip 63 | 64 | return nil 65 | } 66 | 67 | func (r *resolver) lookUp(hostName string) (net.IP, error) { 68 | ip := func() net.IP { 69 | r.mutex.RLock() 70 | defer r.mutex.RUnlock() 71 | 72 | if ip2, ok := r.hosts[hostName]; ok { 73 | return ip2 74 | } 75 | 76 | return nil 77 | }() 78 | if ip != nil { 79 | return ip, nil 80 | } 81 | 82 | // mutex must be unlocked before calling into parent resolver 83 | 84 | if r.parent != nil { 85 | return r.parent.lookUp(hostName) 86 | } 87 | 88 | return nil, &net.DNSError{ 89 | Err: "host not found", 90 | Name: hostName, 91 | Server: "vnet resolver", 92 | IsTimeout: false, 93 | IsTemporary: false, 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /vnet/resolver_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "net" 8 | "testing" 9 | 10 | "github.com/pion/logging" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestResolver(t *testing.T) { 15 | loggerFactory := logging.NewDefaultLoggerFactory() 16 | log := loggerFactory.NewLogger("test") 17 | 18 | t.Run("Standalone", func(t *testing.T) { 19 | resolver := newResolver(&resolverConfig{ 20 | LoggerFactory: loggerFactory, 21 | }) 22 | 23 | // should have localhost by default 24 | name := "localhost" 25 | ipAddr := "127.0.0.1" 26 | ip := net.ParseIP(ipAddr) 27 | 28 | resolved, err := resolver.lookUp(name) 29 | assert.NoError(t, err, "should succeed") 30 | assert.True(t, resolved.Equal(ip), "should match") 31 | 32 | name = "abc.com" 33 | ipAddr = demoIP 34 | ip = net.ParseIP(ipAddr) 35 | log.Debugf("adding %s %s", name, ipAddr) 36 | 37 | err = resolver.addHost(name, ipAddr) 38 | assert.NoError(t, err, "should succeed") 39 | 40 | resolved, err = resolver.lookUp(name) 41 | assert.NoError(t, err, "should succeed") 42 | assert.True(t, resolved.Equal(ip), "should match") 43 | }) 44 | 45 | t.Run("Cascaded", func(t *testing.T) { 46 | r0 := newResolver(&resolverConfig{ 47 | LoggerFactory: loggerFactory, 48 | }) 49 | name0 := "abc.com" 50 | ipAddr0 := demoIP 51 | ip0 := net.ParseIP(ipAddr0) 52 | err := r0.addHost(name0, ipAddr0) 53 | assert.NoError(t, err, "should succeed") 54 | 55 | r1 := newResolver(&resolverConfig{ 56 | LoggerFactory: loggerFactory, 57 | }) 58 | name1 := "myserver.local" 59 | ipAddr1 := "10.1.2.5" 60 | ip1 := net.ParseIP(ipAddr1) 61 | err = r1.addHost(name1, ipAddr1) 62 | assert.NoError(t, err, "should succeed") 63 | r1.setParent(r0) 64 | 65 | resolved, err := r1.lookUp(name0) 66 | assert.NoError(t, err, "should succeed") 67 | assert.True(t, resolved.Equal(ip0), "should match") 68 | 69 | resolved, err = r1.lookUp(name1) 70 | assert.NoError(t, err, "should succeed") 71 | assert.True(t, resolved.Equal(ip1), "should match") 72 | 73 | // should fail if the name does not exist 74 | _, err = r1.lookUp("bad.com") 75 | assert.NotNil(t, err, "should fail") 76 | }) 77 | } 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | Pion Transport 4 |
5 |

6 |

Transport testing for Pion

7 |

8 | Pion transport 9 | join us on Discord Follow us on Bluesky 10 |
11 | GitHub Workflow Status 12 | Go Reference 13 | Coverage Status 14 | Go Report Card 15 | License: MIT 16 |

17 |
18 | 19 | ### Roadmap 20 | The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. 21 | 22 | ### Community 23 | Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). 24 | 25 | Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. 26 | 27 | We are always looking to support **your projects**. Please reach out if you have something to build! 28 | If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) 29 | 30 | ### Contributing 31 | Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible 32 | 33 | ### License 34 | MIT License - see [LICENSE](LICENSE) for full text 35 | -------------------------------------------------------------------------------- /utils/xor/xor_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2013 The Go Authors. All rights reserved. 2 | // SPDX-License-Identifier: BSD-3-Clause 3 | 4 | package xor 5 | 6 | import ( 7 | "crypto/rand" 8 | "fmt" 9 | "io" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestXOR(t *testing.T) { //nolint:cyclop 16 | for j := 1; j <= 1024; j++ { //nolint:varnamelen 17 | if testing.Short() && j > 16 { 18 | break 19 | } 20 | for alignP := 0; alignP < 2; alignP++ { 21 | for alignQ := 0; alignQ < 2; alignQ++ { 22 | for alignD := 0; alignD < 2; alignD++ { 23 | p := make([]byte, j)[alignP:] //nolint:varnamelen 24 | q := make([]byte, j)[alignQ:] //nolint:varnamelen 25 | d0 := make([]byte, j+alignD+1) 26 | d0[j+alignD] = 42 27 | d1 := d0[alignD : j+alignD] 28 | d2 := make([]byte, j+alignD)[alignD:] 29 | _, err := io.ReadFull(rand.Reader, p) 30 | assert.NoError(t, err) 31 | _, err = io.ReadFull(rand.Reader, q) 32 | assert.NoError(t, err) 33 | XorBytes(d1, p, q) 34 | n := minInt(p, q) 35 | for i := 0; i < n; i++ { 36 | d2[i] = p[i] ^ q[i] 37 | } 38 | assert.Equalf(t, d1, d2, "p: %#v, q: %#v", p, q) 39 | assert.Equal(t, byte(42), d0[j+alignD], "guard overwritten") 40 | } 41 | } 42 | } 43 | } 44 | } 45 | 46 | func minInt(a, b []byte) int { 47 | n := len(a) 48 | if len(b) < n { 49 | n = len(b) 50 | } 51 | 52 | return n 53 | } 54 | 55 | func BenchmarkXORAligned(b *testing.B) { 56 | dst := make([]byte, 1<<15) 57 | data0 := make([]byte, 1<<15) 58 | data1 := make([]byte, 1<<15) 59 | sizes := []int64{1 << 3, 1 << 7, 1 << 11, 1 << 15} 60 | for _, size := range sizes { 61 | b.Run(fmt.Sprintf("%dBytes", size), func(b *testing.B) { 62 | s0 := data0[:size] 63 | s1 := data1[:size] 64 | b.SetBytes(size) 65 | for i := 0; i < b.N; i++ { 66 | XorBytes(dst, s0, s1) 67 | } 68 | }) 69 | } 70 | } 71 | 72 | func BenchmarkXORUnalignedDst(b *testing.B) { 73 | dst := make([]byte, 1<<15+1) 74 | data0 := make([]byte, 1<<15) 75 | data1 := make([]byte, 1<<15) 76 | sizes := []int64{1 << 3, 1 << 7, 1 << 11, 1 << 15} 77 | for _, size := range sizes { 78 | b.Run(fmt.Sprintf("%dBytes", size), func(b *testing.B) { 79 | s0 := data0[:size] 80 | s1 := data1[:size] 81 | b.SetBytes(size) 82 | for i := 0; i < b.N; i++ { 83 | XorBytes(dst[1:], s0, s1) 84 | } 85 | }) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /vnet/tbf_queue.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2025 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "sync/atomic" 8 | "time" 9 | 10 | "golang.org/x/time/rate" 11 | ) 12 | 13 | var _ Discipline = (*TBFQueue)(nil) 14 | 15 | type TBFQueue struct { 16 | limiter *rate.Limiter 17 | chunks []Chunk 18 | maxSize atomic.Int64 19 | currentSize int 20 | } 21 | 22 | // NewTBFQueue creates a new Token Bucket Filter queue with initial rate r in 23 | // bit per second, burst size b in bytes and queue size s in bytes. 24 | func NewTBFQueue(r int, b int, s int64) *TBFQueue { 25 | q := &TBFQueue{ 26 | limiter: rate.NewLimiter(rate.Limit(r), b*8), 27 | chunks: []Chunk{}, 28 | maxSize: atomic.Int64{}, 29 | currentSize: 0, 30 | } 31 | q.maxSize.Store(s) 32 | 33 | return q 34 | } 35 | 36 | // SetRate updates the rate to r bit per second. 37 | func (t *TBFQueue) SetRate(r int) { 38 | t.limiter.SetLimit(rate.Limit(r)) 39 | } 40 | 41 | // SetBurst updates the max burst size to b bytes. 42 | func (t *TBFQueue) SetBurst(b int) { 43 | t.limiter.SetBurst(b * 8) 44 | } 45 | 46 | func (t *TBFQueue) SetSize(s int64) { 47 | t.maxSize.Store(s) 48 | } 49 | 50 | // empty implements discipline. 51 | func (t *TBFQueue) empty() bool { 52 | return len(t.chunks) == 0 53 | } 54 | 55 | // next implements discipline. 56 | func (t *TBFQueue) next() time.Time { 57 | if t.empty() { 58 | return time.Time{} 59 | } 60 | now := time.Now() 61 | if t.limiter.TokensAt(now) > 8*float64(len(t.chunks[0].UserData())) { 62 | return now 63 | } 64 | res := t.limiter.ReserveN(now, 8*len(t.chunks[0].UserData())) 65 | delay := res.Delay() 66 | res.Cancel() 67 | 68 | return now.Add(delay) 69 | } 70 | 71 | // pop implements discipline. 72 | func (t *TBFQueue) pop() (chunk Chunk) { 73 | if t.empty() { 74 | return nil 75 | } 76 | if !t.limiter.AllowN(time.Now(), 8*len(t.chunks[0].UserData())) { 77 | return nil 78 | } 79 | chunk, t.chunks = t.chunks[0], t.chunks[1:] 80 | t.currentSize -= len(chunk.UserData()) 81 | 82 | return chunk 83 | } 84 | 85 | // push implements discipline. 86 | func (t *TBFQueue) push(chunk Chunk) { 87 | maxSize := int(t.maxSize.Load()) 88 | if t.currentSize+len(chunk.UserData()) > maxSize { 89 | // drop chunk because queue is full 90 | return 91 | } 92 | t.currentSize += len(chunk.UserData()) 93 | t.chunks = append(t.chunks, chunk) 94 | } 95 | -------------------------------------------------------------------------------- /utils/xor/xor_old.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2013 The Go Authors. All rights reserved. 2 | // SPDX-License-Identifier: BSD-3-Clause 3 | // SPDX-FileCopyrightText: 2022 The Pion community 4 | // SPDX-License-Identifier: MIT 5 | 6 | //go:build (!go1.20 && !arm) || gccgo 7 | 8 | // Package xor provides the XorBytes function. 9 | // This version is only used on Go up to version 1.19. 10 | package xor 11 | 12 | import ( 13 | "runtime" 14 | "unsafe" 15 | ) 16 | 17 | const ( 18 | wordSize = int(unsafe.Sizeof(uintptr(0))) // nolint:gosec 19 | supportsUnaligned = runtime.GOARCH == "386" || runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" || runtime.GOARCH == "ppc64" || runtime.GOARCH == "ppc64le" || runtime.GOARCH == "s390x" // nolint:gochecknoglobals 20 | ) 21 | 22 | func isAligned(a *byte) bool { 23 | return uintptr(unsafe.Pointer(a))%uintptr(wordSize) == 0 24 | } 25 | 26 | // XorBytes xors the bytes in a and b. The destination should have enough 27 | // space, otherwise xorBytes will panic. Returns the number of bytes xor'd. 28 | // 29 | //revive:disable-next-line 30 | func XorBytes(dst, a, b []byte) int { 31 | n := len(a) 32 | if len(b) < n { 33 | n = len(b) 34 | } 35 | if n == 0 { 36 | return 0 37 | } 38 | 39 | switch { 40 | case supportsUnaligned: 41 | fastXORBytes(dst, a, b, n) 42 | case isAligned(&dst[0]) && isAligned(&a[0]) && isAligned(&b[0]): 43 | fastXORBytes(dst, a, b, n) 44 | default: 45 | safeXORBytes(dst, a, b, n) 46 | } 47 | return n 48 | } 49 | 50 | // fastXORBytes xors in bulk. It only works on architectures that 51 | // support unaligned read/writes. 52 | // n needs to be smaller or equal than the length of a and b. 53 | func fastXORBytes(dst, a, b []byte, n int) { 54 | // Assert dst has enough space 55 | _ = dst[n-1] 56 | 57 | w := n / wordSize 58 | if w > 0 { 59 | dw := *(*[]uintptr)(unsafe.Pointer(&dst)) // nolint:gosec 60 | aw := *(*[]uintptr)(unsafe.Pointer(&a)) // nolint:gosec 61 | bw := *(*[]uintptr)(unsafe.Pointer(&b)) // nolint:gosec 62 | for i := 0; i < w; i++ { 63 | dw[i] = aw[i] ^ bw[i] 64 | } 65 | } 66 | 67 | for i := (n - n%wordSize); i < n; i++ { 68 | dst[i] = a[i] ^ b[i] 69 | } 70 | } 71 | 72 | // n needs to be smaller or equal than the length of a and b. 73 | func safeXORBytes(dst, a, b []byte, n int) { 74 | for i := 0; i < n; i++ { 75 | dst[i] = a[i] ^ b[i] 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /utils/xor/xor_arm.s: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2022 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | // go:build !gccgo 5 | // +build !gccgo 6 | 7 | #include "textflag.h" 8 | 9 | // func xorBytesARM32(dst, a, b *byte, n int) 10 | TEXT ·xorBytesARM32(SB), NOSPLIT|NOFRAME, $0 11 | MOVW dst+0(FP), R0 12 | MOVW a+4(FP), R1 13 | MOVW b+8(FP), R2 14 | MOVW n+12(FP), R3 15 | CMP $4, R3 16 | BLT less_than4 17 | 18 | loop_4: 19 | MOVW.P 4(R1), R4 20 | MOVW.P 4(R2), R5 21 | EOR R4, R5, R5 22 | MOVW.P R5, 4(R0) 23 | 24 | SUB $4, R3 25 | CMP $4, R3 26 | BGE loop_4 27 | 28 | less_than4: 29 | CMP $2, R3 30 | BLT less_than2 31 | MOVH.P 2(R1), R4 32 | MOVH.P 2(R2), R5 33 | EOR R4, R5, R5 34 | MOVH.P R5, 2(R0) 35 | 36 | SUB $2, R3 37 | 38 | less_than2: 39 | CMP $0, R3 40 | BEQ end 41 | MOVB (R1), R4 42 | MOVB (R2), R5 43 | EOR R4, R5, R5 44 | MOVB R5, (R0) 45 | end: 46 | RET 47 | 48 | // func xorBytesNEON32(dst, a, b *byte, n int) 49 | TEXT ·xorBytesNEON32(SB), NOSPLIT|NOFRAME, $0 50 | MOVW dst+0(FP), R0 51 | MOVW a+4(FP), R1 52 | MOVW b+8(FP), R2 53 | MOVW n+12(FP), R3 54 | CMP $32, R3 55 | BLT less_than32 56 | 57 | loop_32: 58 | WORD $0xF421020D // vld1.u8 {q0, q1}, [r1]! 59 | WORD $0xF422420D // vld1.u8 {q2, q3}, [r2]! 60 | WORD $0xF3004154 // veor q2, q0, q2 61 | WORD $0xF3026156 // veor q3, q1, q3 62 | WORD $0xF400420D // vst1.u8 {q2, q3}, [r0]! 63 | 64 | SUB $32, R3 65 | CMP $32, R3 66 | BGE loop_32 67 | 68 | less_than32: 69 | CMP $16, R3 70 | BLT less_than16 71 | WORD $0xF4210A0D // vld1.u8 q0, [r1]! 72 | WORD $0xF4222A0D // vld1.u8 q1, [r2]! 73 | WORD $0xF3002152 // veor q1, q0, q1 74 | WORD $0xF4002A0D // vst1.u8 {q1}, [r0]! 75 | 76 | SUB $16, R3 77 | 78 | less_than16: 79 | CMP $8, R3 80 | BLT less_than8 81 | WORD $0xF421070D // vld1.u8 d0, [r1]! 82 | WORD $0xF422170D // vld1.u8 d1, [r2]! 83 | WORD $0xF3001111 // veor d1, d0, d1 84 | WORD $0xF400170D // vst1.u8 {d1}, [r0]! 85 | 86 | SUB $8, R3 87 | 88 | less_than8: 89 | CMP $4, R3 90 | BLT less_than4 91 | MOVW.P 4(R1), R4 92 | MOVW.P 4(R2), R5 93 | EOR R4, R5, R5 94 | MOVW.P R5, 4(R0) 95 | 96 | SUB $4, R3 97 | 98 | less_than4: 99 | CMP $2, R3 100 | BLT less_than2 101 | MOVH.P 2(R1), R4 102 | MOVH.P 2(R2), R5 103 | EOR R4, R5, R5 104 | MOVH.P R5, 2(R0) 105 | 106 | SUB $2, R3 107 | 108 | less_than2: 109 | CMP $0, R3 110 | BEQ end 111 | MOVB (R1), R4 112 | MOVB (R2), R5 113 | EOR R4, R5, R5 114 | MOVB R5, (R0) 115 | end: 116 | RET 117 | -------------------------------------------------------------------------------- /examples/vnet-udpproxy/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | // Package main implements an example for the virtual Net 5 | // UDP proxy. 6 | package main 7 | 8 | import ( 9 | "flag" 10 | "net" 11 | "time" 12 | 13 | "github.com/pion/logging" 14 | "github.com/pion/transport/v3/vnet" 15 | ) 16 | 17 | func main() { //nolint:cyclop 18 | address := flag.String("address", "", "Destination address that three separate vnet clients will send too") 19 | flag.Parse() 20 | 21 | // Create vnet WAN with one endpoint 22 | // See the following docs for more information 23 | // https://github.com/pion/transport/tree/master/vnet#example-wan-with-one-endpoint-vnet 24 | router, err := vnet.NewRouter(&vnet.RouterConfig{ 25 | CIDR: "0.0.0.0/0", 26 | LoggerFactory: logging.NewDefaultLoggerFactory(), 27 | }) 28 | if err != nil { 29 | panic(err) 30 | } 31 | 32 | // Create a network and add to router, for example, for client. 33 | clientNetwork, err := vnet.NewNet(&vnet.NetConfig{ 34 | StaticIP: "10.0.0.11", 35 | }) 36 | if err != nil { 37 | panic(err) 38 | } 39 | 40 | if err = router.AddNet(clientNetwork); err != nil { 41 | panic(err) 42 | } 43 | 44 | if err = router.Start(); err != nil { 45 | panic(err) 46 | } 47 | defer router.Stop() // nolint:errcheck 48 | 49 | // Create a proxy, bind to the router. 50 | proxy, err := vnet.NewProxy(router) 51 | if err != nil { 52 | panic(err) 53 | } 54 | defer proxy.Close() // nolint:errcheck 55 | 56 | serverAddr, err := net.ResolveUDPAddr("udp4", *address) 57 | if err != nil { 58 | panic(err) 59 | } 60 | 61 | // Start to proxy some addresses, clientNetwork is a hit for proxy, 62 | // that the client in vnet is from this network. 63 | if err = proxy.Proxy(clientNetwork, serverAddr); err != nil { 64 | panic(err) 65 | } 66 | 67 | // Now, all packets from client, will be proxy to real server, vice versa. 68 | client0, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787") 69 | if err != nil { 70 | panic(err) 71 | } 72 | _, _ = client0.WriteTo([]byte("Hello"), serverAddr) 73 | 74 | client1, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5788") 75 | if err != nil { 76 | panic(err) 77 | } 78 | _, _ = client1.WriteTo([]byte("Hello"), serverAddr) 79 | 80 | client2, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5789") 81 | if err != nil { 82 | panic(err) 83 | } 84 | _, _ = client2.WriteTo([]byte("Hello"), serverAddr) 85 | 86 | // Packets are delivered by a goroutine so WriteTo 87 | // return doesn't mean success. This may improve in 88 | // the future. 89 | time.Sleep(time.Second * 3) 90 | } 91 | -------------------------------------------------------------------------------- /dpipe/dpipe_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | //go:build !js 5 | // +build !js 6 | 7 | package dpipe 8 | 9 | import ( 10 | "fmt" 11 | "io" 12 | "net" 13 | "testing" 14 | "time" 15 | 16 | "github.com/stretchr/testify/assert" 17 | "golang.org/x/net/nettest" 18 | ) 19 | 20 | var errFailedToCast = fmt.Errorf("failed to cast net.Conn to conn") 21 | 22 | func TestNetTest(t *testing.T) { 23 | nettest.TestConn(t, func() (net.Conn, net.Conn, func(), error) { 24 | ca, cb := Pipe() 25 | caConn, ok := ca.(*conn) 26 | if !ok { 27 | return nil, nil, nil, errFailedToCast 28 | } 29 | 30 | cbConn, ok := cb.(*conn) 31 | if !ok { 32 | return nil, nil, nil, errFailedToCast 33 | } 34 | 35 | return &closePropagator{caConn, cbConn}, 36 | &closePropagator{cbConn, caConn}, 37 | func() { 38 | _ = ca.Close() 39 | _ = cb.Close() 40 | }, nil 41 | }) 42 | } 43 | 44 | type closePropagator struct { 45 | *conn 46 | otherEnd *conn 47 | } 48 | 49 | func (c *closePropagator) Close() error { 50 | close(c.otherEnd.closing) 51 | 52 | return c.conn.Close() 53 | } 54 | 55 | func TestPipe(t *testing.T) { //nolint:cyclop 56 | ca, cb := Pipe() 57 | 58 | testData := []byte{0x01, 0x02} 59 | 60 | for name, cond := range map[string]struct { 61 | ca net.Conn 62 | cb net.Conn 63 | }{ 64 | "AtoB": {ca, cb}, 65 | "BtoA": {cb, ca}, 66 | } { 67 | c0 := cond.ca 68 | c1 := cond.cb 69 | t.Run(name, func(t *testing.T) { 70 | n, err := c0.Write(testData) 71 | assert.NoError(t, err) 72 | assert.Equal(t, len(testData), n) 73 | 74 | readData := make([]byte, 4) 75 | n, err = c1.Read(readData) 76 | assert.NoError(t, err) 77 | assert.Len(t, testData, n) 78 | assert.Equal(t, testData, readData[:n]) 79 | }) 80 | } 81 | 82 | assert.NoError(t, ca.Close()) 83 | _, err := ca.Write(testData) 84 | assert.ErrorIs(t, err, io.ErrClosedPipe, "Write to closed conn should fail") 85 | 86 | // Other side should be writable. 87 | _, err = cb.Write(testData) 88 | assert.NoError(t, err) 89 | 90 | readData := make([]byte, 4) 91 | _, err = ca.Read(readData) 92 | assert.ErrorIs(t, err, io.EOF, "Read from closed conn should fail with io.EOF") 93 | 94 | // Other side should be readable. 95 | readDone := make(chan struct{}) 96 | go func() { 97 | readData := make([]byte, 4) 98 | n, err := cb.Read(readData) 99 | assert.Errorf(t, err, "Unexpected data %v was arrived to orphaned conn", readData[:n]) 100 | close(readDone) 101 | }() 102 | select { 103 | case <-readDone: 104 | assert.Fail(t, "Read should be blocked if the other side is closed") 105 | case <-time.After(10 * time.Millisecond): 106 | } 107 | assert.NoError(t, cb.Close()) 108 | } 109 | -------------------------------------------------------------------------------- /deadline/deadline.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | // Package deadline provides deadline timer used to implement 5 | // net.Conn compatible connection 6 | package deadline 7 | 8 | import ( 9 | "context" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | type deadlineState uint8 15 | 16 | const ( 17 | deadlineStopped deadlineState = iota 18 | deadlineStarted 19 | deadlineExceeded 20 | ) 21 | 22 | var _ context.Context = (*Deadline)(nil) 23 | 24 | // Deadline signals updatable deadline timer. 25 | // Also, it implements context.Context. 26 | type Deadline struct { 27 | mu sync.RWMutex 28 | timer timer 29 | done chan struct{} 30 | deadline time.Time 31 | state deadlineState 32 | pending uint8 33 | } 34 | 35 | // New creates new deadline timer. 36 | func New() *Deadline { 37 | return &Deadline{ 38 | done: make(chan struct{}), 39 | } 40 | } 41 | 42 | func (d *Deadline) timeout() { 43 | d.mu.Lock() 44 | if d.pending--; d.pending != 0 || d.state != deadlineStarted { 45 | d.mu.Unlock() 46 | 47 | return 48 | } 49 | 50 | d.state = deadlineExceeded 51 | done := d.done 52 | d.mu.Unlock() 53 | 54 | close(done) 55 | } 56 | 57 | // Set new deadline. Zero value means no deadline. 58 | func (d *Deadline) Set(setTo time.Time) { 59 | d.mu.Lock() 60 | defer d.mu.Unlock() 61 | 62 | if d.state == deadlineStarted && d.timer.Stop() { 63 | d.pending-- 64 | } 65 | 66 | d.deadline = setTo 67 | d.pending++ 68 | 69 | if d.state == deadlineExceeded { 70 | d.done = make(chan struct{}) 71 | } 72 | 73 | if setTo.IsZero() { 74 | d.pending-- 75 | d.state = deadlineStopped 76 | 77 | return 78 | } 79 | 80 | if dur := time.Until(setTo); dur > 0 { 81 | d.state = deadlineStarted 82 | if d.timer == nil { 83 | d.timer = afterFunc(dur, d.timeout) 84 | } else { 85 | d.timer.Reset(dur) 86 | } 87 | 88 | return 89 | } 90 | 91 | d.pending-- 92 | d.state = deadlineExceeded 93 | close(d.done) 94 | } 95 | 96 | // Done receives deadline signal. 97 | func (d *Deadline) Done() <-chan struct{} { 98 | d.mu.RLock() 99 | defer d.mu.RUnlock() 100 | 101 | return d.done 102 | } 103 | 104 | // Err returns context.DeadlineExceeded if the deadline is exceeded. 105 | // Otherwise, it returns nil. 106 | func (d *Deadline) Err() error { 107 | d.mu.RLock() 108 | defer d.mu.RUnlock() 109 | if d.state == deadlineExceeded { 110 | return context.DeadlineExceeded 111 | } 112 | 113 | return nil 114 | } 115 | 116 | // Deadline returns current deadline. 117 | func (d *Deadline) Deadline() (time.Time, bool) { 118 | d.mu.RLock() 119 | defer d.mu.RUnlock() 120 | if d.deadline.IsZero() { 121 | return d.deadline, false 122 | } 123 | 124 | return d.deadline, true 125 | } 126 | 127 | // Value returns nil. 128 | func (d *Deadline) Value(any) any { 129 | return nil 130 | } 131 | -------------------------------------------------------------------------------- /vnet/queue_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2025 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | type mockDiscipline struct { 14 | mockPush func(Chunk) 15 | mockPop func() Chunk 16 | mockEmpty func() bool 17 | mockNext func() time.Time 18 | } 19 | 20 | // empty implements Discipline. 21 | func (m *mockDiscipline) empty() bool { 22 | return m.mockEmpty() 23 | } 24 | 25 | // next implements Discipline. 26 | func (m *mockDiscipline) next() time.Time { 27 | return m.mockNext() 28 | } 29 | 30 | // pop implements Discipline. 31 | func (m *mockDiscipline) pop() Chunk { 32 | return m.mockPop() 33 | } 34 | 35 | // push implements Discipline. 36 | func (m *mockDiscipline) push(c Chunk) { 37 | m.mockPush(c) 38 | } 39 | 40 | func newMockDiscipline(t *testing.T) *mockDiscipline { 41 | t.Helper() 42 | 43 | return &mockDiscipline{ 44 | mockPush: func(Chunk) { 45 | assert.Fail(t, "unexpected call to push") 46 | }, 47 | mockPop: func() Chunk { 48 | assert.Fail(t, "unexpected call to pop") 49 | 50 | return nil 51 | }, 52 | mockEmpty: func() bool { 53 | assert.Fail(t, "unexpected call to empty") 54 | 55 | return false 56 | }, 57 | mockNext: func() time.Time { 58 | assert.Fail(t, "unexpected call to next") 59 | 60 | return time.Time{} 61 | }, 62 | } 63 | } 64 | 65 | func TestQueue(t *testing.T) { 66 | t.Run("enqueue-chunk", func(t *testing.T) { 67 | mnic := newMockNIC(t) 68 | md := newMockDiscipline(t) 69 | pushCh := make(chan struct{}) 70 | chunk := &chunkUDP{ 71 | userData: make([]byte, 1300), 72 | } 73 | md.mockPush = func(c Chunk) { 74 | assert.Equal(t, chunk, c) 75 | close(pushCh) 76 | } 77 | md.mockEmpty = func() bool { 78 | return true 79 | } 80 | q, err := NewQueue(mnic, md) 81 | assert.NoError(t, err) 82 | 83 | q.onInboundChunk(chunk) 84 | 85 | select { 86 | case <-pushCh: 87 | case <-time.After(10 * time.Millisecond): 88 | assert.Fail(t, "timeout before chunk was pushed") 89 | } 90 | assert.NoError(t, q.Close()) 91 | }) 92 | t.Run("dequeue-chunk", func(t *testing.T) { 93 | mnic := newMockNIC(t) 94 | 95 | data := []Chunk{ 96 | &chunkUDP{ 97 | userData: make([]byte, 1300), 98 | }, 99 | } 100 | pushCh := make(chan struct{}) 101 | mnic.mockOnInboundChunk = func(c Chunk) { 102 | close(pushCh) 103 | } 104 | md := newMockDiscipline(t) 105 | 106 | md.mockEmpty = func() bool { 107 | return len(data) == 0 108 | } 109 | md.mockPop = func() Chunk { 110 | var next Chunk 111 | next, data = data[0], data[1:] 112 | 113 | return next 114 | } 115 | // return false the first time and only dequeue after 5ms 116 | nextCalled := false 117 | md.mockNext = func() time.Time { 118 | if nextCalled { 119 | return time.Time{} 120 | } 121 | nextCalled = true 122 | 123 | return time.Now().Add(5 * time.Millisecond) 124 | } 125 | 126 | q, err := NewQueue(mnic, md) 127 | assert.NoError(t, err) 128 | 129 | select { 130 | case <-pushCh: 131 | case <-time.After(10 * time.Millisecond): 132 | assert.Fail(t, "timeout before chunk was pushed") 133 | } 134 | 135 | assert.NoError(t, q.Close()) 136 | }) 137 | } 138 | -------------------------------------------------------------------------------- /test/stress.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package test 5 | 6 | import ( 7 | "bytes" 8 | "context" 9 | "errors" 10 | "fmt" 11 | "io" 12 | "sync" 13 | 14 | "github.com/pion/transport/v3/netctx" 15 | ) 16 | 17 | var errByteSequenceChanged = errors.New("byte sequence changed") 18 | 19 | // Options represents the configuration of the stress test. 20 | type Options struct { 21 | MsgSize int 22 | MsgCount int 23 | } 24 | 25 | // Stress enables stress testing of a io.ReadWriter. 26 | // It checks that packets are received correctly and in order. 27 | func Stress(ca io.Writer, cb io.Reader, opt Options) error { 28 | return StressContext(context.Background(), &wrappedWriter{ca}, &wrappedReader{cb}, opt) 29 | } 30 | 31 | // StressContext enables stress testing of a io.ReadWriter. 32 | // It checks that packets are received correctly and in order. 33 | func StressContext(ctx context.Context, ca netctx.Writer, cb netctx.Reader, opt Options) error { 34 | bufs := make(chan []byte, opt.MsgCount) 35 | errCh := make(chan error) 36 | // Write 37 | go func() { 38 | err := write(ctx, ca, bufs, opt) 39 | errCh <- err 40 | close(bufs) 41 | }() 42 | 43 | // Read 44 | go func() { 45 | result := make([]byte, opt.MsgSize) 46 | 47 | for original := range bufs { 48 | err := read(ctx, cb, original, result) 49 | if err != nil { 50 | errCh <- err 51 | } 52 | } 53 | 54 | close(errCh) 55 | }() 56 | 57 | return FlattenErrs(GatherErrs(errCh)) 58 | } 59 | 60 | func read(ctx context.Context, r netctx.Reader, original, result []byte) error { 61 | n, err := r.ReadContext(ctx, result) 62 | if err != nil { 63 | return err 64 | } 65 | if !bytes.Equal(original, result[:n]) { 66 | return fmt.Errorf("%w %#v != %#v", errByteSequenceChanged, original, result) 67 | } 68 | 69 | return nil 70 | } 71 | 72 | // StressDuplex enables duplex stress testing of a io.ReadWriter. 73 | // It checks that packets are received correctly and in order. 74 | func StressDuplex(ca io.ReadWriter, cb io.ReadWriter, opt Options) error { 75 | return StressDuplexContext(context.Background(), &wrappedReadWriter{ca}, &wrappedReadWriter{cb}, opt) 76 | } 77 | 78 | // StressDuplexContext enables duplex stress testing of a io.ReadWriter. 79 | // It checks that packets are received correctly and in order. 80 | func StressDuplexContext(ctx context.Context, ca netctx.ReadWriter, cb netctx.ReadWriter, opt Options) error { 81 | errCh := make(chan error) 82 | 83 | var wg sync.WaitGroup 84 | wg.Add(2) 85 | 86 | go func() { 87 | defer wg.Done() 88 | errCh <- StressContext(ctx, ca, cb, opt) 89 | }() 90 | 91 | go func() { 92 | defer wg.Done() 93 | errCh <- StressContext(ctx, cb, ca, opt) 94 | }() 95 | 96 | go func() { 97 | wg.Wait() 98 | close(errCh) 99 | }() 100 | 101 | return FlattenErrs(GatherErrs(errCh)) 102 | } 103 | 104 | func write(ctx context.Context, c netctx.Writer, bufs chan []byte, opt Options) error { 105 | randomizer := initRand() 106 | for i := 0; i < opt.MsgCount; i++ { 107 | buf, err := randomizer.randBuf(opt.MsgSize) 108 | if err != nil { 109 | return err 110 | } 111 | bufs <- buf 112 | if _, err = c.WriteContext(ctx, buf); err != nil { 113 | return err 114 | } 115 | } 116 | 117 | return nil 118 | } 119 | -------------------------------------------------------------------------------- /vnet/tbf.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import "sync" 7 | 8 | const ( 9 | // Bit is a single bit. 10 | Bit = 1 11 | // KBit is a kilobit. 12 | KBit = 1000 * Bit 13 | // MBit is a Megabit. 14 | MBit = 1000 * KBit 15 | ) 16 | 17 | // TokenBucketFilter implements a token bucket rate limit algorithm. 18 | // 19 | // Deprecated: TokenBucketFilter is now a wrapper around Queue with a TBF 20 | // discipline. Use Queue directly. 21 | type TokenBucketFilter struct { 22 | NIC 23 | queue *Queue 24 | tbf *TBFQueue 25 | 26 | lock sync.Mutex 27 | maxSize int64 28 | rate int 29 | burst int 30 | } 31 | 32 | // TBFOption is the option type to configure a TokenBucketFilter. 33 | // 34 | // Deprecated: TokenBucketFilter is deprecated, use Queue instead. 35 | type TBFOption func(*TokenBucketFilter) TBFOption 36 | 37 | // TBFQueueSizeInBytes sets the max number of bytes waiting in the queue. Can 38 | // only be set in constructor before using the TBF. 39 | // 40 | // Deprecated: TokenBucketFilter is deprecated, use Queue instead. 41 | func TBFQueueSizeInBytes(bytes int) TBFOption { 42 | return func(t *TokenBucketFilter) TBFOption { 43 | t.lock.Lock() 44 | defer t.lock.Unlock() 45 | prev := t.maxSize 46 | t.maxSize = int64(bytes) 47 | t.tbf.SetSize(t.maxSize) 48 | 49 | return TBFQueueSizeInBytes(int(prev)) 50 | } 51 | } 52 | 53 | // TBFRate sets the bit rate of a TokenBucketFilter. 54 | // 55 | // Deprecated: TokenBucketFilter is deprecated, use Queue instead. 56 | func TBFRate(rate int) TBFOption { 57 | return func(t *TokenBucketFilter) TBFOption { 58 | t.lock.Lock() 59 | defer t.lock.Unlock() 60 | prev := t.rate 61 | t.rate = rate 62 | t.tbf.SetRate(t.rate) 63 | 64 | return TBFRate(prev) 65 | } 66 | } 67 | 68 | // TBFMaxBurst sets the bucket size of the token bucket filter. This is the 69 | // maximum size that can instantly leave the filter, if the bucket is full. 70 | // 71 | // Deprecated: TokenBucketFilter is deprecated, use Queue instead. 72 | func TBFMaxBurst(size int) TBFOption { 73 | return func(t *TokenBucketFilter) TBFOption { 74 | t.lock.Lock() 75 | defer t.lock.Unlock() 76 | prev := t.burst 77 | t.burst = size 78 | t.tbf.SetBurst(t.burst) 79 | 80 | return TBFMaxBurst(prev) 81 | } 82 | } 83 | 84 | // Set updates a setting on the token bucket filter. 85 | // 86 | // Deprecated: TokenBucketFilter is deprecated, use Queue instead. 87 | func (t *TokenBucketFilter) Set(opts ...TBFOption) (previous TBFOption) { 88 | for _, opt := range opts { 89 | previous = opt(t) 90 | } 91 | 92 | return previous 93 | } 94 | 95 | // NewTokenBucketFilter creates and starts a new TokenBucketFilter. 96 | // 97 | // Deprecated: TokenBucketFilter is deprecated, use Queue instead. 98 | func NewTokenBucketFilter(n NIC, opts ...TBFOption) (*TokenBucketFilter, error) { 99 | tbfQueue := NewTBFQueue(1*MBit, 8*KBit, int64(50_000)) 100 | q, err := NewQueue(n, tbfQueue) 101 | if err != nil { 102 | return nil, err 103 | } 104 | tbf := &TokenBucketFilter{ 105 | NIC: q.NIC, 106 | tbf: tbfQueue, 107 | queue: q, 108 | } 109 | tbf.Set(opts...) 110 | 111 | return tbf, nil 112 | } 113 | 114 | func (t *TokenBucketFilter) onInboundChunk(c Chunk) { 115 | t.queue.onInboundChunk(c) 116 | } 117 | 118 | // Close closes and stops the token bucket filter queue. 119 | // 120 | // Deprecated: TokenBucketFilter is deprecated, use Queue instead. 121 | func (t *TokenBucketFilter) Close() error { 122 | return t.queue.Close() 123 | } 124 | -------------------------------------------------------------------------------- /vnet/conn_map.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "errors" 8 | "net" 9 | "sync" 10 | ) 11 | 12 | var ( 13 | errAddressAlreadyInUse = errors.New("address already in use") 14 | errNoSuchUDPConn = errors.New("no such UDPConn") 15 | errCannotRemoveUnspecifiedIP = errors.New("cannot remove unspecified IP by the specified IP") 16 | ) 17 | 18 | type udpConnMap struct { 19 | portMap map[int][]*UDPConn 20 | mutex sync.RWMutex 21 | } 22 | 23 | func newUDPConnMap() *udpConnMap { 24 | return &udpConnMap{ 25 | portMap: map[int][]*UDPConn{}, 26 | } 27 | } 28 | 29 | func (m *udpConnMap) insert(conn *UDPConn) error { 30 | m.mutex.Lock() 31 | defer m.mutex.Unlock() 32 | 33 | udpAddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert 34 | 35 | // check if the port has a listener 36 | conns, ok := m.portMap[udpAddr.Port] 37 | if ok { 38 | if udpAddr.IP.IsUnspecified() { 39 | return errAddressAlreadyInUse 40 | } 41 | 42 | for _, conn := range conns { 43 | laddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert 44 | if laddr.IP.IsUnspecified() || laddr.IP.Equal(udpAddr.IP) { 45 | return errAddressAlreadyInUse 46 | } 47 | } 48 | 49 | conns = append(conns, conn) 50 | } else { 51 | conns = []*UDPConn{conn} 52 | } 53 | 54 | m.portMap[udpAddr.Port] = conns 55 | 56 | return nil 57 | } 58 | 59 | func (m *udpConnMap) find(addr net.Addr) (*UDPConn, bool) { 60 | m.mutex.Lock() // could be RLock, but we have delete() op 61 | defer m.mutex.Unlock() 62 | 63 | udpAddr := addr.(*net.UDPAddr) //nolint:forcetypeassert 64 | 65 | if conns, ok := m.portMap[udpAddr.Port]; ok { 66 | if udpAddr.IP.IsUnspecified() { 67 | // pick the first one appears in the iteration 68 | if len(conns) == 0 { 69 | // This can't happen! 70 | delete(m.portMap, udpAddr.Port) 71 | 72 | return nil, false 73 | } 74 | 75 | return conns[0], true 76 | } 77 | 78 | for _, conn := range conns { 79 | laddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert 80 | if laddr.IP.IsUnspecified() || laddr.IP.Equal(udpAddr.IP) { 81 | return conn, ok 82 | } 83 | } 84 | } 85 | 86 | return nil, false 87 | } 88 | 89 | func (m *udpConnMap) delete(addr net.Addr) error { 90 | m.mutex.Lock() 91 | defer m.mutex.Unlock() 92 | 93 | udpAddr := addr.(*net.UDPAddr) //nolint:forcetypeassert 94 | 95 | conns, ok := m.portMap[udpAddr.Port] 96 | if !ok { 97 | return errNoSuchUDPConn 98 | } 99 | 100 | if udpAddr.IP.IsUnspecified() { 101 | // remove all from this port 102 | delete(m.portMap, udpAddr.Port) 103 | 104 | return nil 105 | } 106 | 107 | newConns := []*UDPConn{} 108 | 109 | for _, conn := range conns { 110 | laddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert 111 | if laddr.IP.IsUnspecified() { 112 | // This can't happen! 113 | return errCannotRemoveUnspecifiedIP 114 | } 115 | 116 | if laddr.IP.Equal(udpAddr.IP) { 117 | continue 118 | } 119 | 120 | newConns = append(newConns, conn) 121 | } 122 | 123 | if len(newConns) == 0 { 124 | delete(m.portMap, udpAddr.Port) 125 | } else { 126 | m.portMap[udpAddr.Port] = newConns 127 | } 128 | 129 | return nil 130 | } 131 | 132 | // size returns the number of UDPConns (UDP listeners). 133 | func (m *udpConnMap) size() int { 134 | m.mutex.RLock() 135 | defer m.mutex.RUnlock() 136 | 137 | n := 0 138 | for _, conns := range m.portMap { 139 | n += len(conns) 140 | } 141 | 142 | return n 143 | } 144 | -------------------------------------------------------------------------------- /dpipe/dpipe.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | // Package dpipe provides the pipe works like datagram protocol on memory. 5 | // 6 | // This package is mainly intended for testing and not for production! 7 | package dpipe 8 | 9 | import ( 10 | "context" 11 | "io" 12 | "net" 13 | "sync" 14 | "time" 15 | 16 | "github.com/pion/transport/v3/deadline" 17 | ) 18 | 19 | // Pipe creates pair of non-stream conn on memory. 20 | // Close of the one end doesn't make effect to the other end. 21 | func Pipe() (net.Conn, net.Conn) { 22 | ch0 := make(chan []byte, 1000) 23 | ch1 := make(chan []byte, 1000) 24 | 25 | return &conn{ 26 | rCh: ch0, 27 | wCh: ch1, 28 | closed: make(chan struct{}), 29 | closing: make(chan struct{}), 30 | readDeadline: deadline.New(), 31 | writeDeadline: deadline.New(), 32 | }, &conn{ 33 | rCh: ch1, 34 | wCh: ch0, 35 | closed: make(chan struct{}), 36 | closing: make(chan struct{}), 37 | readDeadline: deadline.New(), 38 | writeDeadline: deadline.New(), 39 | } 40 | } 41 | 42 | type pipeAddr struct{} 43 | 44 | func (pipeAddr) Network() string { return "pipe" } 45 | func (pipeAddr) String() string { return ":1" } 46 | 47 | type conn struct { 48 | rCh chan []byte 49 | wCh chan []byte 50 | closed chan struct{} 51 | closing chan struct{} 52 | closeOnce sync.Once 53 | 54 | readDeadline *deadline.Deadline 55 | writeDeadline *deadline.Deadline 56 | } 57 | 58 | func (*conn) LocalAddr() net.Addr { return pipeAddr{} } 59 | func (*conn) RemoteAddr() net.Addr { return pipeAddr{} } 60 | 61 | func (c *conn) SetDeadline(t time.Time) error { 62 | c.readDeadline.Set(t) 63 | c.writeDeadline.Set(t) 64 | 65 | return nil 66 | } 67 | 68 | func (c *conn) SetReadDeadline(t time.Time) error { 69 | c.readDeadline.Set(t) 70 | 71 | return nil 72 | } 73 | 74 | func (c *conn) SetWriteDeadline(t time.Time) error { 75 | c.writeDeadline.Set(t) 76 | 77 | return nil 78 | } 79 | 80 | func (c *conn) Read(data []byte) (n int, err error) { //nolint:cyclop 81 | select { 82 | case <-c.closed: 83 | return 0, io.EOF 84 | case <-c.closing: 85 | if len(c.rCh) == 0 { 86 | return 0, io.EOF 87 | } 88 | case <-c.readDeadline.Done(): 89 | return 0, context.DeadlineExceeded 90 | default: 91 | } 92 | 93 | for { 94 | select { 95 | case d := <-c.rCh: 96 | if len(d) <= len(data) { 97 | copy(data, d) 98 | 99 | return len(d), nil 100 | } 101 | copy(data, d[:len(data)]) 102 | 103 | return len(data), nil 104 | case <-c.closed: 105 | return 0, io.EOF 106 | case <-c.closing: 107 | if len(c.rCh) == 0 { 108 | return 0, io.EOF 109 | } 110 | case <-c.readDeadline.Done(): 111 | return 0, context.DeadlineExceeded 112 | } 113 | } 114 | } 115 | 116 | func (c *conn) cleanWriteBuffer() { 117 | for { 118 | select { 119 | case <-c.wCh: 120 | default: 121 | return 122 | } 123 | } 124 | } 125 | 126 | func (c *conn) Write(data []byte) (n int, err error) { 127 | select { 128 | case <-c.closed: 129 | return 0, io.ErrClosedPipe 130 | case <-c.writeDeadline.Done(): 131 | c.cleanWriteBuffer() 132 | 133 | return 0, context.DeadlineExceeded 134 | default: 135 | } 136 | 137 | cData := make([]byte, len(data)) 138 | copy(cData, data) 139 | 140 | select { 141 | case <-c.closed: 142 | return 0, io.ErrClosedPipe 143 | case <-c.writeDeadline.Done(): 144 | c.cleanWriteBuffer() 145 | 146 | return 0, context.DeadlineExceeded 147 | case c.wCh <- cData: 148 | return len(cData), nil 149 | } 150 | } 151 | 152 | func (c *conn) Close() error { 153 | c.closeOnce.Do(func() { 154 | close(c.closed) 155 | }) 156 | 157 | return nil 158 | } 159 | -------------------------------------------------------------------------------- /replaydetector/replaydetector.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | // Package replaydetector provides packet replay detection algorithm. 5 | package replaydetector 6 | 7 | // ReplayDetector is the interface of sequence replay detector. 8 | type ReplayDetector interface { 9 | // Check returns true if given sequence number is not replayed. 10 | // Call accept() to mark the packet is received properly. 11 | // The return value of accept() indicates whether the accepted packet is 12 | // has the latest observed sequence number. 13 | Check(seq uint64) (accept func() bool, ok bool) 14 | } 15 | 16 | // nop is a no-op func that is returned in the case that Check() fails. 17 | func nop() bool { 18 | return false 19 | } 20 | 21 | type slidingWindowDetector struct { 22 | latestSeq uint64 23 | maxSeq uint64 24 | windowSize uint 25 | mask *fixedBigInt 26 | } 27 | 28 | // New creates ReplayDetector. 29 | // Created ReplayDetector doesn't allow wrapping. 30 | // It can handle monotonically increasing sequence number up to 31 | // full 64bit number. It is suitable for DTLS replay protection. 32 | func New(windowSize uint, maxSeq uint64) ReplayDetector { 33 | return &slidingWindowDetector{ 34 | maxSeq: maxSeq, 35 | windowSize: windowSize, 36 | mask: newFixedBigInt(windowSize), 37 | } 38 | } 39 | 40 | func (d *slidingWindowDetector) Check(seq uint64) (func() bool, bool) { 41 | if seq > d.maxSeq { 42 | // Exceeded upper limit. 43 | return nop, false 44 | } 45 | 46 | if seq <= d.latestSeq { 47 | if d.latestSeq >= uint64(d.windowSize)+seq { 48 | return nop, false 49 | } 50 | if d.mask.Bit(uint(d.latestSeq-seq)) != 0 { 51 | // The sequence number is duplicated. 52 | return nop, false 53 | } 54 | } 55 | 56 | return func() bool { 57 | latest := seq == 0 58 | if seq > d.latestSeq { 59 | // Update the head of the window. 60 | d.mask.Lsh(uint(seq - d.latestSeq)) 61 | d.latestSeq = seq 62 | latest = true 63 | } 64 | diff := (d.latestSeq - seq) % d.maxSeq 65 | d.mask.SetBit(uint(diff)) 66 | 67 | return latest 68 | }, true 69 | } 70 | 71 | // WithWrap creates ReplayDetector allowing sequence wrapping. 72 | // This is suitable for short bit width counter like SRTP and SRTCP. 73 | func WithWrap(windowSize uint, maxSeq uint64) ReplayDetector { 74 | return &wrappedSlidingWindowDetector{ 75 | maxSeq: maxSeq, 76 | windowSize: windowSize, 77 | mask: newFixedBigInt(windowSize), 78 | } 79 | } 80 | 81 | type wrappedSlidingWindowDetector struct { 82 | latestSeq uint64 83 | maxSeq uint64 84 | windowSize uint 85 | mask *fixedBigInt 86 | init bool 87 | } 88 | 89 | func (d *wrappedSlidingWindowDetector) Check(seq uint64) (func() bool, bool) { 90 | if seq > d.maxSeq { 91 | // Exceeded upper limit. 92 | return nop, false 93 | } 94 | if !d.init { 95 | if seq != 0 { 96 | d.latestSeq = seq - 1 97 | } else { 98 | d.latestSeq = d.maxSeq 99 | } 100 | d.init = true 101 | } 102 | 103 | diff := int64(d.latestSeq) - int64(seq) //nolint:gosec // GG115 TODO check 104 | // Wrap the number. 105 | if diff > int64(d.maxSeq)/2 { //nolint:gosec // GG115 TODO check 106 | diff -= int64(d.maxSeq + 1) //nolint:gosec // GG115 TODO check 107 | } else if diff <= -int64(d.maxSeq)/2 { //nolint:gosec // GG115 TODO check 108 | diff += int64(d.maxSeq + 1) //nolint:gosec // GG115 TODO check 109 | } 110 | 111 | if diff >= int64(d.windowSize) { //nolint:gosec // GG115 TODO check 112 | // Too old. 113 | return nop, false 114 | } 115 | if diff >= 0 { 116 | if d.mask.Bit(uint(diff)) != 0 { 117 | // The sequence number is duplicated. 118 | return nop, false 119 | } 120 | } 121 | 122 | return func() bool { 123 | latest := false 124 | if diff < 0 { 125 | // Update the head of the window. 126 | d.mask.Lsh(uint(-diff)) 127 | d.latestSeq = seq 128 | latest = true 129 | d.mask.SetBit(0) 130 | } else { 131 | d.mask.SetBit(uint(diff)) 132 | } 133 | 134 | return latest 135 | }, true 136 | } 137 | -------------------------------------------------------------------------------- /vnet/delay_filter.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "context" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | ) 12 | 13 | // DelayFilter delays inbound packets by the given delay. Automatically starts 14 | // processing when created and runs until Close() is called. 15 | type DelayFilter struct { 16 | NIC 17 | delay atomic.Int64 // atomic field - stores time.Duration as int64 18 | push chan struct{} 19 | queue *chunkQueue 20 | done chan struct{} 21 | wg sync.WaitGroup 22 | } 23 | 24 | type timedChunk struct { 25 | Chunk 26 | deadline time.Time 27 | } 28 | 29 | // NewDelayFilter creates and starts a new DelayFilter with the given nic and delay. 30 | func NewDelayFilter(nic NIC, delay time.Duration) (*DelayFilter, error) { 31 | delayFilter := &DelayFilter{ 32 | NIC: nic, 33 | push: make(chan struct{}), 34 | queue: newChunkQueue(0, 0), 35 | done: make(chan struct{}), 36 | } 37 | 38 | delayFilter.delay.Store(int64(delay)) 39 | 40 | // Start processing automatically 41 | delayFilter.wg.Add(1) 42 | go delayFilter.run() 43 | 44 | return delayFilter, nil 45 | } 46 | 47 | // SetDelay atomically updates the delay. 48 | func (f *DelayFilter) SetDelay(newDelay time.Duration) { 49 | f.delay.Store(int64(newDelay)) 50 | } 51 | 52 | func (f *DelayFilter) getDelay() time.Duration { 53 | return time.Duration(f.delay.Load()) 54 | } 55 | 56 | func (f *DelayFilter) onInboundChunk(c Chunk) { 57 | f.queue.push(timedChunk{ 58 | Chunk: c, 59 | deadline: time.Now().Add(f.getDelay()), 60 | }) 61 | f.push <- struct{}{} 62 | } 63 | 64 | // run processes the delayed packets queue until Close() is called. 65 | func (f *DelayFilter) run() { 66 | defer f.wg.Done() 67 | 68 | timer := time.NewTimer(0) 69 | defer timer.Stop() 70 | 71 | for { 72 | select { 73 | case <-f.done: 74 | f.drainRemainingPackets() 75 | 76 | return 77 | 78 | case <-f.push: 79 | f.updateTimerForNextPacket(timer) 80 | 81 | case now := <-timer.C: 82 | f.processReadyPackets(now) 83 | f.scheduleNextPacketTimer(timer) 84 | } 85 | } 86 | } 87 | 88 | // drainRemainingPackets sends all remaining packets immediately during shutdown. 89 | func (f *DelayFilter) drainRemainingPackets() { 90 | for { 91 | next, ok := f.queue.pop() 92 | if !ok { 93 | break 94 | } 95 | if chunk, ok := next.(timedChunk); ok { 96 | f.NIC.onInboundChunk(chunk.Chunk) 97 | } 98 | } 99 | } 100 | 101 | // updateTimerForNextPacket updates the timer when a new packet arrives. 102 | func (f *DelayFilter) updateTimerForNextPacket(timer *time.Timer) { 103 | next := f.queue.peek() 104 | if next != nil { 105 | if chunk, ok := next.(timedChunk); ok { 106 | if !timer.Stop() { 107 | <-timer.C 108 | } 109 | timer.Reset(time.Until(chunk.deadline)) 110 | } 111 | } 112 | } 113 | 114 | // processReadyPackets processes all packets that are ready to be sent. 115 | func (f *DelayFilter) processReadyPackets(now time.Time) { 116 | for { 117 | next := f.queue.peek() 118 | if next == nil { 119 | break 120 | } 121 | if chunk, ok := next.(timedChunk); ok && !chunk.deadline.After(now) { 122 | _, _ = f.queue.pop() // We already have the item from peek() 123 | f.NIC.onInboundChunk(chunk.Chunk) 124 | } else { 125 | break 126 | } 127 | } 128 | } 129 | 130 | // scheduleNextPacketTimer schedules the timer for the next packet to be processed. 131 | func (f *DelayFilter) scheduleNextPacketTimer(timer *time.Timer) { 132 | next := f.queue.peek() 133 | if next == nil { 134 | timer.Reset(time.Minute) // Long timeout when queue is empty 135 | } else if chunk, ok := next.(timedChunk); ok { 136 | timer.Reset(time.Until(chunk.deadline)) 137 | } 138 | } 139 | 140 | // Run is provided for backward compatibility. The DelayFilter now starts 141 | // automatically when created, so this method is a no-op. 142 | func (f *DelayFilter) Run(_ context.Context) { 143 | // DelayFilter now starts automatically in NewDelayFilter, so this is a no-op 144 | } 145 | 146 | // Close stops the DelayFilter and waits for graceful shutdown. 147 | func (f *DelayFilter) Close() error { 148 | close(f.done) 149 | f.wg.Wait() 150 | 151 | return nil 152 | } 153 | -------------------------------------------------------------------------------- /connctx/connctx.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | // Package connctx wraps net.Conn using context.Context. 5 | // 6 | // Deprecated: use netctx instead. 7 | package connctx 8 | 9 | import ( 10 | "context" 11 | "errors" 12 | "io" 13 | "net" 14 | "sync" 15 | "sync/atomic" 16 | "time" 17 | ) 18 | 19 | // ErrClosing is returned on Write to closed connection. 20 | var ErrClosing = errors.New("use of closed network connection") 21 | 22 | // Reader is an interface for context controlled reader. 23 | type Reader interface { 24 | ReadContext(context.Context, []byte) (int, error) 25 | } 26 | 27 | // Writer is an interface for context controlled writer. 28 | type Writer interface { 29 | WriteContext(context.Context, []byte) (int, error) 30 | } 31 | 32 | // ReadWriter is a composite of ReadWriter. 33 | type ReadWriter interface { 34 | Reader 35 | Writer 36 | } 37 | 38 | // ConnCtx is a wrapper of net.Conn using context.Context. 39 | type ConnCtx interface { 40 | Reader 41 | Writer 42 | io.Closer 43 | LocalAddr() net.Addr 44 | RemoteAddr() net.Addr 45 | Conn() net.Conn 46 | } 47 | 48 | type connCtx struct { 49 | nextConn net.Conn 50 | closed chan struct{} 51 | closeOnce sync.Once 52 | readMu sync.Mutex 53 | writeMu sync.Mutex 54 | } 55 | 56 | var veryOld = time.Unix(0, 1) //nolint:gochecknoglobals 57 | 58 | // New creates a new ConnCtx wrapping given net.Conn. 59 | func New(conn net.Conn) ConnCtx { 60 | c := &connCtx{ 61 | nextConn: conn, 62 | closed: make(chan struct{}), 63 | } 64 | 65 | return c 66 | } 67 | 68 | func (c *connCtx) ReadContext(ctx context.Context, b []byte) (int, error) { //nolint:cyclop 69 | c.readMu.Lock() 70 | defer c.readMu.Unlock() 71 | 72 | select { 73 | case <-c.closed: 74 | return 0, io.EOF 75 | default: 76 | } 77 | 78 | done := make(chan struct{}) 79 | var wg sync.WaitGroup 80 | var errSetDeadline atomic.Value 81 | wg.Add(1) 82 | go func() { 83 | defer wg.Done() 84 | select { 85 | case <-ctx.Done(): 86 | // context canceled 87 | if err := c.nextConn.SetReadDeadline(veryOld); err != nil { 88 | errSetDeadline.Store(err) 89 | 90 | return 91 | } 92 | <-done 93 | if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil { 94 | errSetDeadline.Store(err) 95 | } 96 | case <-done: 97 | } 98 | }() 99 | 100 | n, err := c.nextConn.Read(b) 101 | 102 | close(done) 103 | wg.Wait() 104 | if e := ctx.Err(); e != nil && n == 0 { 105 | err = e 106 | } 107 | if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { 108 | err = err2 109 | } 110 | 111 | return n, err 112 | } 113 | 114 | func (c *connCtx) WriteContext(ctx context.Context, b []byte) (int, error) { //nolint:cyclop 115 | c.writeMu.Lock() 116 | defer c.writeMu.Unlock() 117 | 118 | select { 119 | case <-c.closed: 120 | return 0, ErrClosing 121 | default: 122 | } 123 | 124 | done := make(chan struct{}) 125 | var wg sync.WaitGroup 126 | var errSetDeadline atomic.Value 127 | wg.Add(1) 128 | go func() { 129 | defer wg.Done() 130 | select { 131 | case <-ctx.Done(): 132 | // context canceled 133 | if err := c.nextConn.SetWriteDeadline(veryOld); err != nil { 134 | errSetDeadline.Store(err) 135 | 136 | return 137 | } 138 | <-done 139 | if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil { 140 | errSetDeadline.Store(err) 141 | } 142 | case <-done: 143 | } 144 | }() 145 | 146 | n, err := c.nextConn.Write(b) 147 | 148 | close(done) 149 | wg.Wait() 150 | if e := ctx.Err(); e != nil && n == 0 { 151 | err = e 152 | } 153 | if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { 154 | err = err2 155 | } 156 | 157 | return n, err 158 | } 159 | 160 | func (c *connCtx) Close() error { 161 | err := c.nextConn.Close() 162 | c.closeOnce.Do(func() { 163 | c.writeMu.Lock() 164 | c.readMu.Lock() 165 | close(c.closed) 166 | c.readMu.Unlock() 167 | c.writeMu.Unlock() 168 | }) 169 | 170 | return err 171 | } 172 | 173 | func (c *connCtx) LocalAddr() net.Addr { 174 | return c.nextConn.LocalAddr() 175 | } 176 | 177 | func (c *connCtx) RemoteAddr() net.Addr { 178 | return c.nextConn.RemoteAddr() 179 | } 180 | 181 | func (c *connCtx) Conn() net.Conn { 182 | return c.nextConn 183 | } 184 | -------------------------------------------------------------------------------- /test/util.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package test 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | "os" 10 | "runtime" 11 | "runtime/pprof" 12 | "strings" 13 | "testing" 14 | "time" 15 | ) 16 | 17 | var errFlattenErrs = errors.New("") 18 | 19 | // TimeOut is used to panic if a test takes to long. 20 | // It will print the current goroutines and panic. 21 | // It is meant as an aid in debugging deadlocks. 22 | func TimeOut(t time.Duration) *time.Timer { 23 | return time.AfterFunc(t, func() { 24 | if err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1); err != nil { 25 | fmt.Printf("failed to print goroutines: %v \n", err) // nolint 26 | } 27 | panic("timeout") // nolint 28 | }) 29 | } 30 | 31 | func tryCheckRoutinesLoop(tb testing.TB, failMessage string) { 32 | tb.Helper() 33 | 34 | try := 0 35 | ticker := time.NewTicker(200 * time.Millisecond) 36 | defer ticker.Stop() 37 | for range ticker.C { 38 | runtime.GC() 39 | routines := getRoutines() 40 | if len(routines) == 0 { 41 | return 42 | } 43 | if try >= 50 { 44 | tb.Fatalf("%s: \n%s", failMessage, strings.Join(routines, "\n\n")) // nolint 45 | } 46 | try++ 47 | } 48 | } 49 | 50 | // CheckRoutines is used to check for leaked go-routines. 51 | func CheckRoutines(t *testing.T) func() { 52 | t.Helper() 53 | tryCheckRoutinesLoop(t, "Unexpected routines on test startup") 54 | 55 | return func() { 56 | tryCheckRoutinesLoop(t, "Unexpected routines on test end") 57 | } 58 | } 59 | 60 | // CheckRoutinesStrict is used to check for leaked go-routines. 61 | // It differs from CheckRoutines in that it has very little tolerance 62 | // for lingering goroutines. This is helpful for tests that need 63 | // to ensure clean closure of resources. 64 | // Checking the state of goroutines exactly is tricky. As users writing 65 | // goroutines, we tend to clean up gracefully using some synchronization 66 | // pattern. When used correctly, we won't leak goroutines, but we cannot 67 | // guarantee *when* the goroutines will end. This is the nature of waiting 68 | // on the runtime's goexit1 being called which is the final subroutine 69 | // called, which is after any user written code. This small, but possible 70 | // chance to have a thread (not goroutine) be preempted before this is 71 | // called, can have our goroutine stack be not quite correct yet. The 72 | // best we can do is sleep a little bit and try to encourage the runtime 73 | // to run that goroutine (G) on the machine (M) it belongs to. 74 | func CheckRoutinesStrict(tb testing.TB) func() { 75 | tb.Helper() 76 | 77 | tryCheckRoutinesLoop(tb, "Unexpected routines on test startup") 78 | 79 | return func() { 80 | runtime.Gosched() 81 | runtime.GC() 82 | routines := getRoutines() 83 | if len(routines) == 0 { 84 | return 85 | } 86 | // arbitrarily short choice to allow the runtime to cleanup any 87 | // goroutines that really aren't doing anything but haven't yet 88 | // completed. 89 | time.Sleep(time.Millisecond) 90 | runtime.Gosched() 91 | runtime.GC() 92 | routines = getRoutines() 93 | if len(routines) == 0 { 94 | return 95 | } 96 | 97 | tb.Fatalf("%s: \n%s", "Unexpected routines on test end", strings.Join(routines, "\n\n")) // nolint 98 | } 99 | } 100 | 101 | func getRoutines() []string { 102 | buf := make([]byte, 2<<20) 103 | buf = buf[:runtime.Stack(buf, true)] 104 | 105 | return filterRoutines(strings.Split(string(buf), "\n\n")) 106 | } 107 | 108 | func filterRoutines(routines []string) []string { 109 | result := []string{} 110 | for _, stack := range routines { 111 | if stack == "" || // Empty 112 | filterRoutineWASM(stack) || // WASM specific exception 113 | strings.Contains(stack, "testing.Main(") || // Tests 114 | strings.Contains(stack, "testing.(*T).Run(") || // Test run 115 | strings.Contains(stack, "test.getRoutines(") { // This routine 116 | continue 117 | } 118 | result = append(result, stack) 119 | } 120 | 121 | return result 122 | } 123 | 124 | // GatherErrs gathers all errors returned by a channel. 125 | // It blocks until the channel is closed. 126 | func GatherErrs(c chan error) []error { 127 | errs := make([]error, 0) 128 | 129 | for err := range c { 130 | errs = append(errs, err) 131 | } 132 | 133 | return errs 134 | } 135 | 136 | // FlattenErrs flattens a slice of errors into a single error. 137 | func FlattenErrs(errs []error) error { 138 | var errStrings []string 139 | 140 | for _, err := range errs { 141 | if err != nil { 142 | errStrings = append(errStrings, err.Error()) 143 | } 144 | } 145 | 146 | if len(errStrings) == 0 { 147 | return nil 148 | } 149 | 150 | return fmt.Errorf("%w %s", errFlattenErrs, strings.Join(errStrings, "\n")) 151 | } 152 | -------------------------------------------------------------------------------- /netctx/conn.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | // Package netctx wraps common net interfaces using context.Context. 5 | package netctx 6 | 7 | import ( 8 | "context" 9 | "errors" 10 | "io" 11 | "net" 12 | "sync" 13 | "sync/atomic" 14 | "time" 15 | ) 16 | 17 | // ErrClosing is returned on Write to closed connection. 18 | var ErrClosing = errors.New("use of closed network connection") 19 | 20 | // Reader is an interface for context controlled reader. 21 | type Reader interface { 22 | ReadContext(context.Context, []byte) (int, error) 23 | } 24 | 25 | // Writer is an interface for context controlled writer. 26 | type Writer interface { 27 | WriteContext(context.Context, []byte) (int, error) 28 | } 29 | 30 | // ReadWriter is a composite of ReadWriter. 31 | type ReadWriter interface { 32 | Reader 33 | Writer 34 | } 35 | 36 | // Conn is a wrapper of net.Conn using context.Context. 37 | type Conn interface { 38 | Reader 39 | Writer 40 | io.Closer 41 | LocalAddr() net.Addr 42 | RemoteAddr() net.Addr 43 | Conn() net.Conn 44 | } 45 | 46 | type conn struct { 47 | nextConn net.Conn 48 | closed chan struct{} 49 | closeOnce sync.Once 50 | readMu sync.Mutex 51 | writeMu sync.Mutex 52 | } 53 | 54 | var veryOld = time.Unix(0, 1) //nolint:gochecknoglobals 55 | 56 | // NewConn creates a new Conn wrapping given net.Conn. 57 | func NewConn(netConn net.Conn) Conn { 58 | c := &conn{ 59 | nextConn: netConn, 60 | closed: make(chan struct{}), 61 | } 62 | 63 | return c 64 | } 65 | 66 | // ReadContext reads data from the connection. 67 | // Unlike net.Conn.Read(), the provided context is used to control timeout. 68 | func (c *conn) ReadContext(ctx context.Context, b []byte) (int, error) { //nolint:cyclop 69 | c.readMu.Lock() 70 | defer c.readMu.Unlock() 71 | 72 | select { 73 | case <-c.closed: 74 | return 0, net.ErrClosed 75 | default: 76 | } 77 | 78 | done := make(chan struct{}) 79 | var wg sync.WaitGroup 80 | var errSetDeadline atomic.Value 81 | wg.Add(1) 82 | go func() { 83 | defer wg.Done() 84 | select { 85 | case <-ctx.Done(): 86 | // context canceled 87 | if err := c.nextConn.SetReadDeadline(veryOld); err != nil { 88 | errSetDeadline.Store(err) 89 | 90 | return 91 | } 92 | <-done 93 | if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil { 94 | errSetDeadline.Store(err) 95 | } 96 | case <-done: 97 | } 98 | }() 99 | 100 | n, err := c.nextConn.Read(b) 101 | 102 | close(done) 103 | wg.Wait() 104 | if e := ctx.Err(); e != nil && n == 0 { 105 | err = e 106 | } 107 | if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { 108 | err = err2 109 | } 110 | 111 | return n, err 112 | } 113 | 114 | // WriteContext writes data to the connection. 115 | // Unlike net.Conn.Write(), the provided context is used to control timeout. 116 | func (c *conn) WriteContext(ctx context.Context, b []byte) (int, error) { //nolint:cyclop 117 | c.writeMu.Lock() 118 | defer c.writeMu.Unlock() 119 | 120 | select { 121 | case <-c.closed: 122 | return 0, ErrClosing 123 | default: 124 | } 125 | 126 | done := make(chan struct{}) 127 | var wg sync.WaitGroup 128 | var errSetDeadline atomic.Value 129 | wg.Add(1) 130 | go func() { 131 | defer wg.Done() 132 | select { 133 | case <-ctx.Done(): 134 | // context canceled 135 | if err := c.nextConn.SetWriteDeadline(veryOld); err != nil { 136 | errSetDeadline.Store(err) 137 | 138 | return 139 | } 140 | <-done 141 | if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil { 142 | errSetDeadline.Store(err) 143 | } 144 | case <-done: 145 | } 146 | }() 147 | 148 | n, err := c.nextConn.Write(b) 149 | 150 | close(done) 151 | wg.Wait() 152 | if e := ctx.Err(); e != nil && n == 0 { 153 | err = e 154 | } 155 | if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { 156 | err = err2 157 | } 158 | 159 | return n, err 160 | } 161 | 162 | // Close closes the connection. 163 | // Any blocked ReadContext or WriteContext operations will be unblocked and 164 | // return errors. 165 | func (c *conn) Close() error { 166 | err := c.nextConn.Close() 167 | c.closeOnce.Do(func() { 168 | c.writeMu.Lock() 169 | c.readMu.Lock() 170 | close(c.closed) 171 | c.readMu.Unlock() 172 | c.writeMu.Unlock() 173 | }) 174 | 175 | return err 176 | } 177 | 178 | // LocalAddr returns the local network address, if known. 179 | func (c *conn) LocalAddr() net.Addr { 180 | return c.nextConn.LocalAddr() 181 | } 182 | 183 | // LocalAddr returns the local network address, if known. 184 | func (c *conn) RemoteAddr() net.Addr { 185 | return c.nextConn.RemoteAddr() 186 | } 187 | 188 | // Conn returns the underlying net.Conn. 189 | func (c *conn) Conn() net.Conn { 190 | return c.nextConn 191 | } 192 | -------------------------------------------------------------------------------- /netctx/packetconn.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package netctx 5 | 6 | import ( 7 | "context" 8 | "io" 9 | "net" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | ) 14 | 15 | // ReaderFrom is an interface for context controlled packet reader. 16 | type ReaderFrom interface { 17 | ReadFromContext(context.Context, []byte) (int, net.Addr, error) 18 | } 19 | 20 | // WriterTo is an interface for context controlled packet writer. 21 | type WriterTo interface { 22 | WriteToContext(context.Context, []byte, net.Addr) (int, error) 23 | } 24 | 25 | // PacketConn is a wrapper of net.PacketConn using context.Context. 26 | type PacketConn interface { 27 | ReaderFrom 28 | WriterTo 29 | io.Closer 30 | LocalAddr() net.Addr 31 | Conn() net.PacketConn 32 | } 33 | 34 | type packetConn struct { 35 | nextConn net.PacketConn 36 | closed chan struct{} 37 | closeOnce sync.Once 38 | readMu sync.Mutex 39 | writeMu sync.Mutex 40 | } 41 | 42 | // NewPacketConn creates a new PacketConn wrapping the given net.PacketConn. 43 | func NewPacketConn(pconn net.PacketConn) PacketConn { 44 | p := &packetConn{ 45 | nextConn: pconn, 46 | closed: make(chan struct{}), 47 | } 48 | 49 | return p 50 | } 51 | 52 | // ReadFromContext reads a packet from the connection, 53 | // copying the payload into p. It returns the number of 54 | // bytes copied into p and the return address that 55 | // was on the packet. 56 | // It returns the number of bytes read (0 <= n <= len(p)) 57 | // and any error encountered. Callers should always process 58 | // the n > 0 bytes returned before considering the error err. 59 | // Unlike net.PacketConn.ReadFrom(), the provided context is 60 | // used to control timeout. 61 | func (p *packetConn) ReadFromContext(ctx context.Context, b []byte) (int, net.Addr, error) { //nolint:cyclop 62 | p.readMu.Lock() 63 | defer p.readMu.Unlock() 64 | 65 | select { 66 | case <-p.closed: 67 | return 0, nil, net.ErrClosed 68 | default: 69 | } 70 | 71 | done := make(chan struct{}) 72 | var wg sync.WaitGroup 73 | var errSetDeadline atomic.Value 74 | wg.Add(1) 75 | go func() { 76 | defer wg.Done() 77 | select { 78 | case <-ctx.Done(): 79 | // context canceled 80 | if err := p.nextConn.SetReadDeadline(veryOld); err != nil { 81 | errSetDeadline.Store(err) 82 | 83 | return 84 | } 85 | <-done 86 | if err := p.nextConn.SetReadDeadline(time.Time{}); err != nil { 87 | errSetDeadline.Store(err) 88 | } 89 | case <-done: 90 | } 91 | }() 92 | 93 | n, raddr, err := p.nextConn.ReadFrom(b) 94 | 95 | close(done) 96 | wg.Wait() 97 | if e := ctx.Err(); e != nil && n == 0 { 98 | err = e 99 | } 100 | if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { 101 | err = err2 102 | } 103 | 104 | return n, raddr, err 105 | } 106 | 107 | // WriteToContext writes a packet with payload p to addr. 108 | // Unlike net.PacketConn.WriteTo(), the provided context 109 | // is used to control timeout. 110 | // On packet-oriented connections, write timeouts are rare. 111 | func (p *packetConn) WriteToContext(ctx context.Context, b []byte, raddr net.Addr) (int, error) { //nolint:cyclop 112 | p.writeMu.Lock() 113 | defer p.writeMu.Unlock() 114 | 115 | select { 116 | case <-p.closed: 117 | return 0, ErrClosing 118 | default: 119 | } 120 | 121 | done := make(chan struct{}) 122 | var wg sync.WaitGroup 123 | var errSetDeadline atomic.Value 124 | wg.Add(1) 125 | go func() { 126 | defer wg.Done() 127 | select { 128 | case <-ctx.Done(): 129 | // context canceled 130 | if err := p.nextConn.SetWriteDeadline(veryOld); err != nil { 131 | errSetDeadline.Store(err) 132 | 133 | return 134 | } 135 | <-done 136 | if err := p.nextConn.SetWriteDeadline(time.Time{}); err != nil { 137 | errSetDeadline.Store(err) 138 | } 139 | case <-done: 140 | } 141 | }() 142 | 143 | n, err := p.nextConn.WriteTo(b, raddr) 144 | 145 | close(done) 146 | wg.Wait() 147 | if e := ctx.Err(); e != nil && n == 0 { 148 | err = e 149 | } 150 | if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { 151 | err = err2 152 | } 153 | 154 | return n, err 155 | } 156 | 157 | // Close closes the connection. 158 | // Any blocked ReadFromContext or WriteToContext operations will be unblocked 159 | // and return errors. 160 | func (p *packetConn) Close() error { 161 | err := p.nextConn.Close() 162 | p.closeOnce.Do(func() { 163 | p.writeMu.Lock() 164 | p.readMu.Lock() 165 | close(p.closed) 166 | p.readMu.Unlock() 167 | p.writeMu.Unlock() 168 | }) 169 | 170 | return err 171 | } 172 | 173 | // LocalAddr returns the local network address, if known. 174 | func (p *packetConn) LocalAddr() net.Addr { 175 | return p.nextConn.LocalAddr() 176 | } 177 | 178 | // Conn returns the underlying net.PacketConn. 179 | func (p *packetConn) Conn() net.PacketConn { 180 | return p.nextConn 181 | } 182 | -------------------------------------------------------------------------------- /stdnet/net.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | // Package stdnet implements the transport.Net interface 5 | // using methods from Go's standard net package. 6 | package stdnet 7 | 8 | import ( 9 | "fmt" 10 | "net" 11 | 12 | "github.com/pion/transport/v3" 13 | "github.com/wlynxg/anet" 14 | ) 15 | 16 | const ( 17 | lo0String = "lo0String" 18 | udpString = "udp" 19 | ) 20 | 21 | // Net is an implementation of the net.Net interface 22 | // based on functions of the standard net package. 23 | type Net struct { 24 | interfaces []*transport.Interface 25 | } 26 | 27 | // NewNet creates a new StdNet instance. 28 | func NewNet() (*Net, error) { 29 | n := &Net{} 30 | 31 | return n, n.UpdateInterfaces() 32 | } 33 | 34 | // Compile-time assertion. 35 | var _ transport.Net = &Net{} 36 | 37 | // UpdateInterfaces updates the internal list of network interfaces 38 | // and associated addresses. 39 | func (n *Net) UpdateInterfaces() error { 40 | ifs := []*transport.Interface{} 41 | 42 | oifs, err := anet.Interfaces() 43 | if err != nil { 44 | return err 45 | } 46 | 47 | for i := range oifs { 48 | ifc := transport.NewInterface(oifs[i]) 49 | 50 | addrs, err := anet.InterfaceAddrsByInterface(&oifs[i]) 51 | if err != nil { 52 | return err 53 | } 54 | 55 | for _, addr := range addrs { 56 | ifc.AddAddress(addr) 57 | } 58 | 59 | ifs = append(ifs, ifc) 60 | } 61 | 62 | n.interfaces = ifs 63 | 64 | return nil 65 | } 66 | 67 | // Interfaces returns a slice of interfaces which are available on the 68 | // system. 69 | func (n *Net) Interfaces() ([]*transport.Interface, error) { 70 | return n.interfaces, nil 71 | } 72 | 73 | // InterfaceByIndex returns the interface specified by index. 74 | // 75 | // On Solaris, it returns one of the logical network interfaces 76 | // sharing the logical data link; for more precision use 77 | // InterfaceByName. 78 | func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) { 79 | for _, ifc := range n.interfaces { 80 | if ifc.Index == index { 81 | return ifc, nil 82 | } 83 | } 84 | 85 | return nil, fmt.Errorf("%w: index=%d", transport.ErrInterfaceNotFound, index) 86 | } 87 | 88 | // InterfaceByName returns the interface specified by name. 89 | func (n *Net) InterfaceByName(name string) (*transport.Interface, error) { 90 | for _, ifc := range n.interfaces { 91 | if ifc.Name == name { 92 | return ifc, nil 93 | } 94 | } 95 | 96 | return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name) 97 | } 98 | 99 | // ListenPacket announces on the local network address. 100 | func (n *Net) ListenPacket(network string, address string) (net.PacketConn, error) { 101 | return net.ListenPacket(network, address) //nolint: noctx 102 | } 103 | 104 | // ListenUDP acts like ListenPacket for UDP networks. 105 | func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { 106 | return net.ListenUDP(network, locAddr) 107 | } 108 | 109 | // Dial connects to the address on the named network. 110 | func (n *Net) Dial(network, address string) (net.Conn, error) { 111 | return net.Dial(network, address) //nolint: noctx 112 | } 113 | 114 | // DialUDP acts like Dial for UDP networks. 115 | func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { 116 | return net.DialUDP(network, laddr, raddr) 117 | } 118 | 119 | // ResolveIPAddr returns an address of IP end point. 120 | func (n *Net) ResolveIPAddr(network, address string) (*net.IPAddr, error) { 121 | return net.ResolveIPAddr(network, address) 122 | } 123 | 124 | // ResolveUDPAddr returns an address of UDP end point. 125 | func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { 126 | return net.ResolveUDPAddr(network, address) 127 | } 128 | 129 | // ResolveTCPAddr returns an address of TCP end point. 130 | func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { 131 | return net.ResolveTCPAddr(network, address) 132 | } 133 | 134 | // DialTCP acts like Dial for TCP networks. 135 | func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { 136 | return net.DialTCP(network, laddr, raddr) 137 | } 138 | 139 | // ListenTCP acts like Listen for TCP networks. 140 | func (n *Net) ListenTCP(network string, laddr *net.TCPAddr) (transport.TCPListener, error) { 141 | l, err := net.ListenTCP(network, laddr) 142 | if err != nil { 143 | return nil, err 144 | } 145 | 146 | return tcpListener{l}, nil 147 | } 148 | 149 | type tcpListener struct { 150 | *net.TCPListener 151 | } 152 | 153 | func (l tcpListener) AcceptTCP() (transport.TCPConn, error) { 154 | return l.TCPListener.AcceptTCP() 155 | } 156 | 157 | type stdDialer struct { 158 | *net.Dialer 159 | } 160 | 161 | func (d stdDialer) Dial(network, address string) (net.Conn, error) { 162 | return d.Dialer.Dial(network, address) 163 | } 164 | 165 | // CreateDialer creates an instance of vnet.Dialer. 166 | func (n *Net) CreateDialer(d *net.Dialer) transport.Dialer { 167 | return stdDialer{d} 168 | } 169 | -------------------------------------------------------------------------------- /deadline/deadline_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package deadline 5 | 6 | import ( 7 | "context" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestDeadline(t *testing.T) { 15 | ctx, cancel := context.WithCancel(context.Background()) 16 | defer cancel() 17 | 18 | t.Run("Deadline", func(t *testing.T) { 19 | now := time.Now() 20 | 21 | ctx0, cancel0 := context.WithDeadline(ctx, now.Add(40*time.Millisecond)) 22 | defer cancel0() 23 | ctx1, cancel1 := context.WithDeadline(ctx, now.Add(60*time.Millisecond)) 24 | defer cancel1() 25 | d := New() 26 | d.Set(now.Add(50 * time.Millisecond)) 27 | 28 | ch := make(chan byte) 29 | go sendOnDone(ctx, ctx0.Done(), ch, 0) 30 | go sendOnDone(ctx, ctx1.Done(), ch, 1) 31 | go sendOnDone(ctx, d.Done(), ch, 2) 32 | 33 | calls := collectCh(ch, 3, 100*time.Millisecond) 34 | expectedCalls := []byte{0, 2, 1} 35 | assert.Equal(t, expectedCalls, calls, "Wrong order of deadline signal") 36 | }) 37 | 38 | t.Run("DeadlineExtend", func(t *testing.T) { //nolint:dupl 39 | now := time.Now() 40 | 41 | ctx0, cancel0 := context.WithDeadline(ctx, now.Add(40*time.Millisecond)) 42 | defer cancel0() 43 | ctx1, cancel1 := context.WithDeadline(ctx, now.Add(60*time.Millisecond)) 44 | defer cancel1() 45 | d := New() 46 | d.Set(now.Add(50 * time.Millisecond)) 47 | d.Set(now.Add(70 * time.Millisecond)) 48 | 49 | ch := make(chan byte) 50 | go sendOnDone(ctx, ctx0.Done(), ch, 0) 51 | go sendOnDone(ctx, ctx1.Done(), ch, 1) 52 | go sendOnDone(ctx, d.Done(), ch, 2) 53 | 54 | calls := collectCh(ch, 3, 100*time.Millisecond) 55 | expectedCalls := []byte{0, 1, 2} 56 | assert.Equal(t, expectedCalls, calls, "Wrong order of deadline signal") 57 | }) 58 | 59 | t.Run("DeadlinePretend", func(t *testing.T) { //nolint:dupl 60 | now := time.Now() 61 | 62 | ctx0, cancel0 := context.WithDeadline(ctx, now.Add(40*time.Millisecond)) 63 | defer cancel0() 64 | ctx1, cancel1 := context.WithDeadline(ctx, now.Add(60*time.Millisecond)) 65 | defer cancel1() 66 | d := New() 67 | d.Set(now.Add(50 * time.Millisecond)) 68 | d.Set(now.Add(30 * time.Millisecond)) 69 | 70 | ch := make(chan byte) 71 | go sendOnDone(ctx, ctx0.Done(), ch, 0) 72 | go sendOnDone(ctx, ctx1.Done(), ch, 1) 73 | go sendOnDone(ctx, d.Done(), ch, 2) 74 | 75 | calls := collectCh(ch, 3, 100*time.Millisecond) 76 | expectedCalls := []byte{2, 0, 1} 77 | assert.Equal(t, expectedCalls, calls, "Wrong order of deadline signal") 78 | }) 79 | 80 | t.Run("DeadlineCancel", func(t *testing.T) { 81 | now := time.Now() 82 | 83 | ctx0, cancel0 := context.WithDeadline(ctx, now.Add(40*time.Millisecond)) 84 | defer cancel0() 85 | d := New() 86 | d.Set(now.Add(50 * time.Millisecond)) 87 | d.Set(time.Time{}) 88 | 89 | ch := make(chan byte) 90 | go sendOnDone(ctx, ctx0.Done(), ch, 0) 91 | go sendOnDone(ctx, d.Done(), ch, 1) 92 | 93 | calls := collectCh(ch, 2, 60*time.Millisecond) 94 | expectedCalls := []byte{0} 95 | assert.Equal(t, expectedCalls, calls, "Wrong order of deadline signal") 96 | }) 97 | } 98 | 99 | func sendOnDone(ctx context.Context, done <-chan struct{}, dest chan byte, val byte) { 100 | select { 101 | case <-done: 102 | case <-ctx.Done(): 103 | return 104 | } 105 | dest <- val 106 | } 107 | 108 | func collectCh(ch <-chan byte, n int, timeout time.Duration) []byte { 109 | a := time.After(timeout) 110 | var calls []byte 111 | for len(calls) < n { 112 | select { 113 | case call := <-ch: 114 | calls = append(calls, call) 115 | case <-a: 116 | return calls 117 | } 118 | } 119 | 120 | return calls 121 | } 122 | 123 | func TestContext(t *testing.T) { //nolint:cyclop 124 | t.Run("Cancel", func(t *testing.T) { 125 | deadline := New() 126 | 127 | select { 128 | case <-deadline.Done(): 129 | assert.Fail(t, "Deadline unexpectedly done") 130 | case <-time.After(50 * time.Millisecond): 131 | } 132 | assert.NoError(t, deadline.Err()) 133 | deadline.Set(time.Unix(0, 1)) // exceeded 134 | select { 135 | case <-deadline.Done(): 136 | case <-time.After(50 * time.Millisecond): 137 | assert.Fail(t, "Timeout") 138 | } 139 | assert.ErrorIs(t, deadline.Err(), context.DeadlineExceeded) 140 | }) 141 | t.Run("Deadline", func(t *testing.T) { 142 | d := New() 143 | t0, expired0 := d.Deadline() 144 | assert.True(t, t0.IsZero(), "Initial Deadline is expected to be 0") 145 | assert.False(t, expired0, "Deadline is not expected to be expired at initial state") 146 | 147 | dl := time.Unix(12345, 0) 148 | d.Set(dl) // exceeded 149 | 150 | t1, expired1 := d.Deadline() 151 | assert.True(t, t1.Equal(dl), "Initial Deadline is expected to be %v, got %v", dl, t1) 152 | assert.True(t, expired1, "Deadline is expected to be expired") 153 | }) 154 | } 155 | 156 | func BenchmarkDeadline(b *testing.B) { 157 | b.Run("Set", func(b *testing.B) { 158 | d := New() 159 | t := time.Now().Add(time.Minute) 160 | for i := 0; i < b.N; i++ { 161 | d.Set(t) 162 | } 163 | }) 164 | } 165 | -------------------------------------------------------------------------------- /vnet/chunk_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "net" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/pion/logging" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestTCPFragString(t *testing.T) { 16 | f := tcpFIN 17 | assert.Equal(t, "FIN", f.String(), "should match") 18 | f = tcpSYN 19 | assert.Equal(t, "SYN", f.String(), "should match") 20 | f = tcpRST 21 | assert.Equal(t, "RST", f.String(), "should match") 22 | f = tcpPSH 23 | assert.Equal(t, "PSH", f.String(), "should match") 24 | f = tcpACK 25 | assert.Equal(t, "ACK", f.String(), "should match") 26 | f = tcpSYN | tcpACK 27 | assert.Equal(t, "SYN-ACK", f.String(), "should match") 28 | } 29 | 30 | func TestChunk(t *testing.T) { 31 | loggerFactory := logging.NewDefaultLoggerFactory() 32 | log := loggerFactory.NewLogger("test") 33 | 34 | t.Run("ChunkUDP", func(t *testing.T) { 35 | src := &net.UDPAddr{ 36 | IP: net.ParseIP("192.168.0.2"), 37 | Port: 1234, 38 | } 39 | dst := &net.UDPAddr{ 40 | IP: net.ParseIP(demoIP), 41 | Port: 5678, 42 | } 43 | 44 | var chunk Chunk = newChunkUDP(src, dst) 45 | str := chunk.String() 46 | log.Debugf("chunk: %s", str) 47 | assert.Equal(t, udp, chunk.Network(), "should match") 48 | assert.True(t, strings.Contains(str, src.Network()), "should include network type") 49 | assert.True(t, strings.Contains(str, src.String()), "should include address") 50 | assert.True(t, strings.Contains(str, dst.String()), "should include address") 51 | assert.True(t, chunk.getSourceIP().Equal(src.IP), "ip should match") 52 | assert.True(t, chunk.getDestinationIP().Equal(dst.IP), "ip should match") 53 | 54 | // Test timestamp 55 | ts := chunk.setTimestamp() 56 | assert.Equal(t, ts, chunk.getTimestamp(), "timestamp should match") 57 | 58 | uc := chunk.(*chunkUDP) //nolint:forcetypeassert 59 | uc.userData = []byte("Hello") 60 | 61 | cloned := chunk.Clone().(*chunkUDP) //nolint:forcetypeassert 62 | 63 | // Test setSourceAddr 64 | err := uc.setSourceAddr("2.3.4.5:4000") 65 | assert.Nil(t, err, "should succeed") 66 | assert.Equal(t, "2.3.4.5:4000", uc.SourceAddr().String()) 67 | 68 | // Test Tag() 69 | assert.True(t, len(uc.tag) > 0, "should not be empty") 70 | assert.Equal(t, uc.tag, uc.Tag(), "should match") 71 | 72 | // Verify cloned chunk was not affected by the changes to original chunk 73 | uc.userData[0] = []byte("!")[0] // original: "Hello" -> "Hell!" 74 | assert.Equal(t, "Hello", string(cloned.userData), "should match") 75 | assert.Equal(t, "192.168.0.2:1234", cloned.SourceAddr().String()) 76 | assert.True(t, cloned.getSourceIP().Equal(src.IP), "ip should match") 77 | assert.True(t, cloned.getDestinationIP().Equal(dst.IP), "ip should match") 78 | }) 79 | 80 | t.Run("ChunkTCP", func(t *testing.T) { 81 | src := &net.TCPAddr{ 82 | IP: net.ParseIP("192.168.0.2"), 83 | Port: 1234, 84 | } 85 | dst := &net.TCPAddr{ 86 | IP: net.ParseIP(demoIP), 87 | Port: 5678, 88 | } 89 | 90 | var chunk Chunk = newChunkTCP(src, dst, tcpSYN) 91 | str := chunk.String() 92 | log.Debugf("chunk: %s", str) 93 | assert.Equal(t, "tcp", chunk.Network(), "should match") 94 | assert.True(t, strings.Contains(str, src.Network()), "should include network type") 95 | assert.True(t, strings.Contains(str, src.String()), "should include address") 96 | assert.True(t, strings.Contains(str, dst.String()), "should include address") 97 | assert.True(t, chunk.getSourceIP().Equal(src.IP), "ip should match") 98 | assert.True(t, chunk.getDestinationIP().Equal(dst.IP), "ip should match") 99 | 100 | tcp, ok := chunk.(*chunkTCP) 101 | assert.True(t, ok, "type should match") 102 | assert.Equal(t, tcp.flags, tcpSYN, "flags should match") 103 | 104 | // Test timestamp 105 | ts := chunk.setTimestamp() 106 | assert.Equal(t, ts, chunk.getTimestamp(), "timestamp should match") 107 | 108 | tc := chunk.(*chunkTCP) //nolint:forcetypeassert 109 | tc.userData = []byte("Hello") 110 | 111 | cloned := chunk.Clone().(*chunkTCP) //nolint:forcetypeassert 112 | 113 | // Test setSourceAddr 114 | err := tc.setSourceAddr("2.3.4.5:4000") 115 | assert.Nil(t, err, "should succeed") 116 | assert.Equal(t, "2.3.4.5:4000", tc.SourceAddr().String()) 117 | 118 | // Test Tag() 119 | assert.True(t, len(tc.tag) > 0, "should not be empty") 120 | assert.Equal(t, tc.tag, tc.Tag(), "should match") 121 | 122 | // Verify cloned chunk was not affected by the changes to original chunk 123 | tc.userData[0] = []byte("!")[0] // original: "Hello" -> "Hell!" 124 | assert.Equal(t, "Hello", string(cloned.userData), "should match") 125 | assert.Equal(t, "192.168.0.2:1234", cloned.SourceAddr().String()) 126 | assert.True(t, cloned.getSourceIP().Equal(src.IP), "ip should match") 127 | assert.True(t, cloned.getDestinationIP().Equal(dst.IP), "ip should match") 128 | 129 | // Test setDestinationAddr 130 | err = tc.setDestinationAddr("3.4.5.6:7000") 131 | assert.Nil(t, err, "should succeed") 132 | assert.Equal(t, "3.4.5.6:7000", tc.DestinationAddr().String()) 133 | }) 134 | } 135 | -------------------------------------------------------------------------------- /vnet/stress_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "fmt" 8 | "net" 9 | "sync" 10 | "testing" 11 | "time" 12 | 13 | "github.com/pion/logging" 14 | "github.com/pion/transport/v3/test" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | func TestStressTestUDP(t *testing.T) { //nolint:cyclop 19 | loggerFactory := logging.NewDefaultLoggerFactory() 20 | log := loggerFactory.NewLogger("test") 21 | 22 | t.Run("lan to wan", func(t *testing.T) { 23 | tt := test.TimeOut(30 * time.Second) 24 | defer tt.Stop() 25 | 26 | // WAN with a nic (net0) 27 | wan, err := NewRouter(&RouterConfig{ 28 | CIDR: "1.2.3.0/24", 29 | QueueSize: 1000, 30 | LoggerFactory: loggerFactory, 31 | }) 32 | assert.NoError(t, err, "should succeed") 33 | 34 | net0, err := NewNet(&NetConfig{ 35 | StaticIPs: []string{demoIP}, 36 | }) 37 | if !assert.NoError(t, err, "should succeed") { 38 | return 39 | } 40 | 41 | err = wan.AddNet(net0) 42 | assert.NoError(t, err, "should succeed") 43 | 44 | // LAN with a nic (net1) 45 | lan, err := NewRouter(&RouterConfig{ 46 | CIDR: "192.168.0.0/24", 47 | QueueSize: 1000, 48 | LoggerFactory: loggerFactory, 49 | }) 50 | assert.NoError(t, err, "should succeed") 51 | 52 | net1, err := NewNet(&NetConfig{}) 53 | if !assert.NoError(t, err, "should succeed") { 54 | return 55 | } 56 | 57 | err = lan.AddNet(net1) 58 | assert.NoError(t, err, "should succeed") 59 | 60 | err = wan.AddRouter(lan) 61 | assert.NoError(t, err, "should succeed") 62 | 63 | err = wan.Start() 64 | assert.NoError(t, err, "should succeed") 65 | defer func() { 66 | err = wan.Stop() 67 | assert.NoError(t, err, "should succeed") 68 | }() 69 | 70 | // Find IP address for net0 71 | ifs, err := net0.Interfaces() 72 | if !assert.NoError(t, err, "should succeed") { 73 | return 74 | } 75 | log.Debugf("num ifs: %d", len(ifs)) 76 | 77 | var echoServerIP net.IP 78 | loop: 79 | for _, ifc := range ifs { 80 | log.Debugf("flags: %v", ifc.Flags) 81 | if ifc.Flags&net.FlagUp == 0 { 82 | continue 83 | } 84 | if ifc.Flags&net.FlagLoopback != 0 { 85 | continue 86 | } 87 | 88 | addrs, err2 := ifc.Addrs() 89 | if !assert.NoError(t, err2, "should succeed") { 90 | return 91 | } 92 | log.Debugf("num addrs: %d", len(addrs)) 93 | for _, addr := range addrs { 94 | log.Debugf("addr: %s", addr.String()) 95 | switch addr := addr.(type) { 96 | case *net.IPNet: 97 | echoServerIP = addr.IP 98 | 99 | break loop 100 | case *net.IPAddr: 101 | echoServerIP = addr.IP 102 | 103 | break loop 104 | } 105 | } 106 | } 107 | if !assert.NotNil(t, echoServerIP, "should have IP address") { 108 | return 109 | } 110 | 111 | log.Debugf("echo server IP: %s", echoServerIP.String()) 112 | 113 | // Set up an echo server on WAN 114 | conn0, err := net0.ListenPacket( 115 | "udp4", fmt.Sprintf("%s:0", echoServerIP)) 116 | if !assert.NoError(t, err, "should succeed") { 117 | return 118 | } 119 | 120 | doneCh0 := make(chan struct{}) 121 | go func() { 122 | buf := make([]byte, 1500) 123 | for { 124 | n, from, err2 := conn0.ReadFrom(buf) 125 | if err2 != nil { 126 | break 127 | } 128 | // echo back 129 | _, err2 = conn0.WriteTo(buf[:n], from) 130 | if err2 != nil { 131 | break 132 | } 133 | } 134 | close(doneCh0) 135 | }() 136 | 137 | var wg sync.WaitGroup 138 | 139 | runEchoTest := func() { 140 | // Set up a client 141 | var numRecvd int 142 | const numToSend int = 400 143 | const pktSize int = 1200 144 | conn1, err2 := net0.ListenPacket("udp4", "0.0.0.0:0") 145 | if !assert.NoError(t, err2, "should succeed") { 146 | return 147 | } 148 | 149 | doneCh1 := make(chan struct{}) 150 | go func() { 151 | buf := make([]byte, 1500) 152 | for { 153 | n, _, err3 := conn1.ReadFrom(buf) 154 | if err3 != nil { 155 | break 156 | } 157 | 158 | if n != pktSize { 159 | break 160 | } 161 | 162 | numRecvd++ 163 | } 164 | close(doneCh1) 165 | }() 166 | 167 | buf := make([]byte, pktSize) 168 | to := conn0.LocalAddr() 169 | for i := 0; i < numToSend; i++ { 170 | _, err3 := conn1.WriteTo(buf, to) 171 | assert.NoError(t, err3, "should succeed") 172 | time.Sleep(10 * time.Millisecond) 173 | } 174 | 175 | time.Sleep(time.Second) 176 | 177 | err2 = conn1.Close() 178 | assert.NoError(t, err2, "should succeed") 179 | 180 | <-doneCh1 181 | 182 | // allow some packet loss 183 | assert.True(t, numRecvd >= numToSend*8/10, "majority should received") 184 | if numRecvd < numToSend { 185 | log.Infof("lost %d packets", numToSend-numRecvd) 186 | } 187 | 188 | wg.Done() 189 | } 190 | 191 | // Run echo tests concurrently 192 | for i := 0; i < 20; i++ { 193 | wg.Add(1) 194 | go runEchoTest() 195 | } 196 | wg.Wait() 197 | 198 | err = conn0.Close() 199 | assert.NoError(t, err, "should succeed") 200 | }) 201 | } 202 | -------------------------------------------------------------------------------- /vnet/tbf_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | //go:build !wasm 5 | // +build !wasm 6 | 7 | package vnet 8 | 9 | import ( 10 | "context" 11 | "runtime" 12 | "sync" 13 | "sync/atomic" 14 | "testing" 15 | "time" 16 | 17 | "github.com/pion/logging" 18 | "github.com/stretchr/testify/assert" 19 | ) 20 | 21 | func TestTokenBucketFilter(t *testing.T) { 22 | t.Run("bitrateBelowCapacity", func(t *testing.T) { 23 | mnic := newMockNIC(t) 24 | 25 | const payloadSize = 1200 26 | sent := 100 27 | if runtime.GOOS == "windows" { 28 | // stay under the default queue size 29 | // to avoid drops on the slower windows schedulers. 30 | sent = 40 31 | } 32 | 33 | tbf, err := NewTokenBucketFilter(mnic, TBFRate(10*MBit), TBFMaxBurst(10*MBit)) 34 | assert.NoError(t, err, "should succeed") 35 | 36 | var received atomic.Int32 37 | mnic.mockOnInboundChunk = func(Chunk) { 38 | received.Add(1) 39 | } 40 | 41 | time.Sleep(1 * time.Second) 42 | 43 | for i := 0; i < sent; i++ { 44 | tbf.onInboundChunk(&chunkUDP{ 45 | userData: make([]byte, payloadSize), 46 | }) 47 | } 48 | 49 | runtime.Gosched() 50 | assert.Eventually( 51 | t, 52 | func() bool { 53 | return int(received.Load()) == sent 54 | }, 55 | time.Second, 56 | 5*time.Millisecond, 57 | ) 58 | 59 | assert.NoError(t, tbf.Close()) 60 | 61 | assert.Equal(t, sent, int(received.Load())) 62 | }) 63 | 64 | subTest := func(t *testing.T, capacity int, maxBurst int, duration time.Duration) { 65 | t.Helper() 66 | 67 | log := logging.NewDefaultLoggerFactory().NewLogger("test") 68 | 69 | mnic := newMockNIC(t) 70 | 71 | tbf, err := NewTokenBucketFilter(mnic, TBFRate(capacity), TBFMaxBurst(maxBurst)) 72 | assert.NoError(t, err, "should succeed") 73 | 74 | chunkChan := make(chan Chunk) 75 | mnic.mockOnInboundChunk = func(c Chunk) { 76 | chunkChan <- c 77 | } 78 | 79 | var wg sync.WaitGroup 80 | wg.Add(1) 81 | 82 | ctx, cancel := context.WithCancel(context.Background()) 83 | 84 | go func() { 85 | defer wg.Done() 86 | 87 | totalBytesReceived := 0 88 | totalPacketsReceived := 0 89 | bytesReceived := 0 90 | packetsReceived := 0 91 | start := time.Now() 92 | last := time.Now() 93 | 94 | ticker := time.NewTicker(time.Second) 95 | defer ticker.Stop() 96 | 97 | for { 98 | select { 99 | case <-ctx.Done(): 100 | bits := float64(totalBytesReceived) * 8.0 101 | rate := bits / time.Since(start).Seconds() 102 | mBitPerSecond := rate / float64(MBit) 103 | // Allow 5% more than capacity due to max bursts 104 | assert.Less(t, rate, 1.05*float64(capacity)) 105 | // Allow for timing variations on slow/contended CI runners 106 | assert.Greater(t, rate, 0.75*float64(capacity)) 107 | 108 | log.Infof( 109 | "duration=%v, bytesReceived=%v, packetsReceived=%v throughput=%.2f Mb/s", 110 | time.Since(start), 111 | bytesReceived, 112 | packetsReceived, 113 | mBitPerSecond, 114 | ) 115 | 116 | return 117 | case now := <-ticker.C: 118 | delta := now.Sub(last) 119 | last = now 120 | bits := float64(bytesReceived) * 8.0 121 | rate := bits / delta.Seconds() 122 | mBitPerSecond := rate / float64(MBit) 123 | log.Infof( 124 | "duration=%v, bytesReceived=%v, packetsReceived=%v throughput=%.2f Mb/s", 125 | delta, 126 | bytesReceived, 127 | packetsReceived, 128 | mBitPerSecond, 129 | ) 130 | // Allow 10% more than capacity due to max bursts 131 | assert.Less(t, rate, 1.10*float64(capacity)) 132 | // Be tolerant of per-second fluctuations on slower CI 133 | assert.Greater(t, rate, 0.60*float64(capacity)) 134 | bytesReceived = 0 135 | packetsReceived = 0 136 | 137 | case c := <-chunkChan: 138 | bytesReceived += len(c.UserData()) 139 | packetsReceived++ 140 | totalBytesReceived += len(c.UserData()) 141 | totalPacketsReceived++ 142 | } 143 | } 144 | }() 145 | 146 | go func() { 147 | defer cancel() 148 | bytesSent := 0 149 | packetsSent := 0 150 | var start time.Time 151 | for start = time.Now(); time.Since(start) < duration; { 152 | c := &chunkUDP{ 153 | userData: make([]byte, 1200), 154 | } 155 | tbf.onInboundChunk(c) 156 | bytesSent += len(c.UserData()) 157 | packetsSent++ 158 | runtime.Gosched() 159 | } 160 | bits := float64(bytesSent) * 8.0 161 | rate := bits / time.Since(start).Seconds() 162 | mBitPerSecond := rate / float64(MBit) 163 | log.Infof( 164 | "duration=%v, bytesSent=%v, packetsSent=%v throughput=%.2f Mb/s", 165 | time.Since(start), 166 | bytesSent, 167 | packetsSent, 168 | mBitPerSecond, 169 | ) 170 | 171 | assert.NoError(t, tbf.Close()) 172 | }() 173 | 174 | wg.Wait() 175 | } 176 | 177 | t.Run("500Kbit-s", func(t *testing.T) { 178 | subTest(t, 500*KBit, 10*KBit, 15*time.Second) 179 | }) 180 | 181 | t.Run("1Mbit-s", func(t *testing.T) { 182 | subTest(t, 1*MBit, 20*KBit, 10*time.Second) 183 | }) 184 | 185 | t.Run("2Mbit-s", func(t *testing.T) { 186 | subTest(t, 2*MBit, 40*KBit, 10*time.Second) 187 | }) 188 | 189 | t.Run("8Mbit-s", func(t *testing.T) { 190 | subTest(t, 8*MBit, 160*KBit, 10*time.Second) 191 | }) 192 | } 193 | -------------------------------------------------------------------------------- /vnet/delay_filter_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | type TimestampedChunk struct { 14 | ts time.Time 15 | c Chunk 16 | } 17 | 18 | func initTest(t *testing.T) (*DelayFilter, chan TimestampedChunk) { 19 | t.Helper() 20 | nic := newMockNIC(t) 21 | delayFilter, err := NewDelayFilter(nic, 0) 22 | if !assert.NoError(t, err, "should succeed") { 23 | return nil, nil 24 | } 25 | 26 | receiveCh := make(chan TimestampedChunk) 27 | 28 | nic.mockOnInboundChunk = func(c Chunk) { 29 | receivedAt := time.Now() 30 | receiveCh <- TimestampedChunk{ 31 | ts: receivedAt, 32 | c: c, 33 | } 34 | } 35 | 36 | return delayFilter, receiveCh 37 | } 38 | 39 | func scheduleOnePacketAtATime( 40 | t *testing.T, 41 | delayFilter *DelayFilter, 42 | receiveCh chan TimestampedChunk, 43 | delay time.Duration, 44 | nrPackets int, 45 | ) { 46 | t.Helper() 47 | delayFilter.SetDelay(delay) 48 | lastNr := -1 49 | for i := 0; i < nrPackets; i++ { 50 | sent := time.Now() 51 | delayFilter.onInboundChunk(&chunkUDP{ 52 | chunkIP: chunkIP{timestamp: sent}, 53 | userData: []byte{byte(i)}, 54 | }) 55 | 56 | select { 57 | case chunk := <-receiveCh: 58 | nr := int(chunk.c.UserData()[0]) 59 | 60 | assert.Greater(t, nr, lastNr) 61 | lastNr = nr 62 | 63 | assert.Greater(t, chunk.ts.Sub(sent), delay) 64 | // Use generous timing tolerance for CI environments with high system load 65 | // and virtualization overhead. Function call overhead from DelayFilter 66 | // refactoring also contributes to timing variability. 67 | assert.Less(t, chunk.ts.Sub(sent), delay+200*time.Millisecond) 68 | case <-time.After(time.Second): 69 | assert.Fail(t, "expected to receive next chunk") 70 | } 71 | } 72 | } 73 | 74 | func scheduleManyPackets( 75 | t *testing.T, 76 | delayFilter *DelayFilter, 77 | receiveCh chan TimestampedChunk, 78 | delay time.Duration, 79 | nrPackets int, //nolint:unparam 80 | ) { 81 | t.Helper() 82 | delayFilter.SetDelay(delay) 83 | sent := time.Now() 84 | 85 | for i := 0; i < nrPackets; i++ { 86 | delayFilter.onInboundChunk(&chunkUDP{ 87 | chunkIP: chunkIP{timestamp: sent}, 88 | userData: []byte{byte(i)}, 89 | }) 90 | } 91 | 92 | // receive nrPackets chunks with a minimum delay 93 | for i := 0; i < nrPackets; i++ { 94 | select { 95 | case chunk := <-receiveCh: 96 | nr := int(chunk.c.UserData()[0]) 97 | assert.Equal(t, i, nr) 98 | assert.Greater(t, chunk.ts.Sub(sent), delay) 99 | assert.Less(t, chunk.ts.Sub(sent), delay+200*time.Millisecond) 100 | case <-time.After(time.Second): 101 | assert.Fail(t, "expected to receive next chunk") 102 | } 103 | } 104 | } 105 | 106 | func TestDelayFilter(t *testing.T) { 107 | t.Run("schedulesOnePacketAtATime", func(t *testing.T) { 108 | delayFilter, receiveCh := initTest(t) 109 | if delayFilter == nil { 110 | return 111 | } 112 | 113 | scheduleOnePacketAtATime(t, delayFilter, receiveCh, 10*time.Millisecond, 100) 114 | assert.NoError(t, delayFilter.Close()) 115 | }) 116 | 117 | t.Run("schedulesSubsequentManyPackets", func(t *testing.T) { 118 | delayFilter, receiveCh := initTest(t) 119 | if delayFilter == nil { 120 | return 121 | } 122 | 123 | scheduleManyPackets(t, delayFilter, receiveCh, 10*time.Millisecond, 100) 124 | assert.NoError(t, delayFilter.Close()) 125 | }) 126 | 127 | t.Run("scheduleIncreasingDelayOnePacketAtATime", func(t *testing.T) { 128 | delayFilter, receiveCh := initTest(t) 129 | if delayFilter == nil { 130 | return 131 | } 132 | 133 | scheduleOnePacketAtATime(t, delayFilter, receiveCh, 10*time.Millisecond, 10) 134 | scheduleOnePacketAtATime(t, delayFilter, receiveCh, 50*time.Millisecond, 10) 135 | scheduleOnePacketAtATime(t, delayFilter, receiveCh, 100*time.Millisecond, 10) 136 | assert.NoError(t, delayFilter.Close()) 137 | }) 138 | 139 | t.Run("scheduleDecreasingDelayOnePacketAtATime", func(t *testing.T) { 140 | delayFilter, receiveCh := initTest(t) 141 | if delayFilter == nil { 142 | return 143 | } 144 | 145 | scheduleOnePacketAtATime(t, delayFilter, receiveCh, 100*time.Millisecond, 10) 146 | scheduleOnePacketAtATime(t, delayFilter, receiveCh, 50*time.Millisecond, 10) 147 | scheduleOnePacketAtATime(t, delayFilter, receiveCh, 10*time.Millisecond, 10) 148 | assert.NoError(t, delayFilter.Close()) 149 | }) 150 | 151 | t.Run("scheduleIncreasingDelayManyPackets", func(t *testing.T) { 152 | delayFilter, receiveCh := initTest(t) 153 | if delayFilter == nil { 154 | return 155 | } 156 | 157 | scheduleManyPackets(t, delayFilter, receiveCh, 10*time.Millisecond, 100) 158 | scheduleManyPackets(t, delayFilter, receiveCh, 50*time.Millisecond, 100) 159 | scheduleManyPackets(t, delayFilter, receiveCh, 100*time.Millisecond, 100) 160 | assert.NoError(t, delayFilter.Close()) 161 | }) 162 | 163 | t.Run("scheduleDecreasingDelayManyPackets", func(t *testing.T) { 164 | delayFilter, receiveCh := initTest(t) 165 | if delayFilter == nil { 166 | return 167 | } 168 | 169 | scheduleManyPackets(t, delayFilter, receiveCh, 100*time.Millisecond, 100) 170 | scheduleManyPackets(t, delayFilter, receiveCh, 50*time.Millisecond, 100) 171 | scheduleManyPackets(t, delayFilter, receiveCh, 10*time.Millisecond, 100) 172 | assert.NoError(t, delayFilter.Close()) 173 | }) 174 | } 175 | -------------------------------------------------------------------------------- /stdnet/net_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | //go:build !js 5 | // +build !js 6 | 7 | package stdnet 8 | 9 | import ( 10 | "net" 11 | "testing" 12 | 13 | "github.com/pion/logging" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func TestStdNet(t *testing.T) { //nolint:cyclop 18 | log := logging.NewDefaultLoggerFactory().NewLogger("test") 19 | 20 | t.Run("Interfaces", func(t *testing.T) { 21 | nw, err := NewNet() 22 | if !assert.Nil(t, err, "should succeed") { 23 | return 24 | } 25 | 26 | interfaces, err := nw.Interfaces() 27 | if !assert.NoError(t, err, "should succeed") { 28 | return 29 | } 30 | 31 | log.Debugf("interfaces: %+v", interfaces) 32 | for _, ifc := range interfaces { 33 | if ifc.Name == lo0String { 34 | _, err := ifc.Addrs() 35 | if !assert.NoError(t, err, "should succeed") { 36 | return 37 | } 38 | } 39 | 40 | if addrs, err := ifc.Addrs(); err == nil { 41 | for _, addr := range addrs { 42 | log.Debugf("[%d] %s:%s", 43 | ifc.Index, 44 | addr.Network(), 45 | addr.String()) 46 | } 47 | } 48 | } 49 | }) 50 | 51 | t.Run("ResolveUDPAddr", func(t *testing.T) { 52 | nw, err := NewNet() 53 | if !assert.Nil(t, err, "should succeed") { 54 | return 55 | } 56 | 57 | udpAddr, err := nw.ResolveUDPAddr(udpString, "localhost:1234") 58 | if !assert.NoError(t, err, "should succeed") { 59 | return 60 | } 61 | assert.Contains(t, []string{"127.0.0.1", "127.0.1.1"}, udpAddr.IP.String(), "should match") 62 | assert.Equal(t, 1234, udpAddr.Port, "should match") 63 | }) 64 | 65 | t.Run("ListenPacket", func(t *testing.T) { 66 | nw, err := NewNet() 67 | if !assert.Nil(t, err, "should succeed") { 68 | return 69 | } 70 | 71 | conn, err := nw.ListenPacket(udpString, "127.0.0.1:0") 72 | if !assert.NoError(t, err, "should succeed") { 73 | return 74 | } 75 | 76 | udpConn, ok := conn.(*net.UDPConn) 77 | assert.True(t, ok, "should succeed") 78 | log.Debugf("udpConn: %+v", udpConn) 79 | 80 | laddr := conn.LocalAddr().String() 81 | log.Debugf("laddr: %s", laddr) 82 | }) 83 | 84 | t.Run("ListenUDP random port", func(t *testing.T) { 85 | nw, err := NewNet() 86 | if !assert.Nil(t, err, "should succeed") { 87 | return 88 | } 89 | 90 | srcAddr := &net.UDPAddr{ 91 | IP: net.ParseIP("127.0.0.1"), 92 | } 93 | conn, err := nw.ListenUDP(udpString, srcAddr) 94 | assert.NoError(t, err, "should succeed") 95 | 96 | laddr := conn.LocalAddr().String() 97 | log.Debugf("laddr: %s", laddr) 98 | 99 | assert.NoError(t, conn.Close(), "should succeed") 100 | }) 101 | 102 | t.Run("Dial (UDP)", func(t *testing.T) { 103 | nw, err := NewNet() 104 | assert.Nil(t, err, "should succeed") 105 | 106 | conn, err := nw.Dial(udpString, "127.0.0.1:1234") 107 | assert.NoError(t, err, "should succeed") 108 | 109 | laddr := conn.LocalAddr() 110 | log.Debugf("laddr: %s", laddr.String()) 111 | 112 | raddr := conn.RemoteAddr() 113 | log.Debugf("raddr: %s", raddr.String()) 114 | 115 | assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert 116 | assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert 117 | assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") 118 | assert.NoError(t, conn.Close(), "should succeed") 119 | }) 120 | 121 | t.Run("DialUDP", func(t *testing.T) { 122 | nw, err := NewNet() 123 | assert.Nil(t, err, "should succeed") 124 | 125 | locAddr := &net.UDPAddr{ 126 | IP: net.IPv4(127, 0, 0, 1), 127 | Port: 0, 128 | } 129 | 130 | remAddr := &net.UDPAddr{ 131 | IP: net.IPv4(127, 0, 0, 1), 132 | Port: 1234, 133 | } 134 | 135 | conn, err := nw.DialUDP(udpString, locAddr, remAddr) 136 | assert.NoError(t, err, "should succeed") 137 | 138 | laddr := conn.LocalAddr() 139 | log.Debugf("laddr: %s", laddr.String()) 140 | 141 | raddr := conn.RemoteAddr() 142 | log.Debugf("raddr: %s", raddr.String()) 143 | 144 | assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert 145 | assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert 146 | assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") 147 | assert.NoError(t, conn.Close(), "should succeed") 148 | }) 149 | 150 | t.Run("UDPLoopback", func(t *testing.T) { 151 | nw, err := NewNet() 152 | assert.Nil(t, err, "should succeed") 153 | 154 | conn, err := nw.ListenPacket(udpString, "127.0.0.1:0") 155 | assert.NoError(t, err, "should succeed") 156 | laddr := conn.LocalAddr() 157 | msg := "PING!" 158 | n, err := conn.WriteTo([]byte(msg), laddr) 159 | assert.NoError(t, err, "should succeed") 160 | assert.Equal(t, len(msg), n, "should match") 161 | 162 | buf := make([]byte, 1000) 163 | n, addr, err := conn.ReadFrom(buf) 164 | assert.NoError(t, err, "should succeed") 165 | assert.Equal(t, len(msg), n, "should match") 166 | assert.Equal(t, msg, string(buf[:n]), "should match") 167 | assert.Equal(t, laddr.(*net.UDPAddr).String(), addr.(*net.UDPAddr).String(), "should match") //nolint:forcetypeassert 168 | assert.NoError(t, conn.Close(), "should succeed") 169 | }) 170 | 171 | t.Run("Dialer", func(t *testing.T) { 172 | nw, err := NewNet() 173 | assert.Nil(t, err, "should succeed") 174 | 175 | dialer := nw.CreateDialer(&net.Dialer{ 176 | LocalAddr: &net.UDPAddr{ 177 | IP: net.ParseIP("127.0.0.1"), 178 | Port: 0, 179 | }, 180 | }) 181 | 182 | conn, err := dialer.Dial(udpString, "127.0.0.1:1234") 183 | assert.NoError(t, err, "should succeed") 184 | 185 | laddr := conn.LocalAddr() 186 | log.Debugf("laddr: %s", laddr.String()) 187 | 188 | raddr := conn.RemoteAddr() 189 | log.Debugf("raddr: %s", raddr.String()) 190 | 191 | assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert 192 | assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert 193 | assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") 194 | assert.NoError(t, conn.Close(), "should succeed") 195 | }) 196 | 197 | t.Run("Unexpected operations", func(t *testing.T) { 198 | // For portability of test, find a name of loopback interface name first 199 | var loName string 200 | ifs, err := net.Interfaces() 201 | assert.NoError(t, err, "should succeed") 202 | for _, ifc := range ifs { 203 | if ifc.Flags&net.FlagLoopback != 0 { 204 | loName = ifc.Name 205 | 206 | break 207 | } 208 | } 209 | 210 | nw, err := NewNet() 211 | assert.Nil(t, err, "should succeed") 212 | 213 | if len(loName) > 0 { 214 | // InterfaceByName 215 | ifc, err2 := nw.InterfaceByName(loName) 216 | assert.NoError(t, err2, "should succeed") 217 | assert.Equal(t, loName, ifc.Name, "should match") 218 | } 219 | 220 | _, err = nw.InterfaceByName("foo0") 221 | assert.Error(t, err, "should fail") 222 | }) 223 | } 224 | -------------------------------------------------------------------------------- /vnet/udpproxy.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "context" 8 | "net" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | // UDPProxy is a proxy between real server(net.UDPConn) and vnet.UDPConn. 14 | // 15 | // High level design: 16 | // 17 | // .............................................. 18 | // : Virtual Network (vnet) : 19 | // : : 20 | // +-------+ * 1 +----+ +--------+ : 21 | // | :App |------------>|:Net|--o<-----|:Router | ............................. 22 | // +-------+ +----+ | | : UDPProxy : 23 | // : | | +----+ +---------+ +---------+ +--------+ 24 | // : | |--->o--|:Net|-->o-| vnet. |-->o-| net. |--->-| :Real | 25 | // : | | +----+ | UDPConn | | UDPConn | | Server | 26 | // : | | : +---------+ +---------+ +--------+ 27 | // : | | ............................: 28 | // : +--------+ : 29 | // ............................................... 30 | type UDPProxy struct { 31 | // The router bind to. 32 | router *Router 33 | 34 | // Each vnet source, bind to a real socket to server. 35 | // key is real server addr, which is net.Addr 36 | // value is *aUDPProxyWorker 37 | workers sync.Map 38 | 39 | // For each endpoint, we never know when to start and stop proxy, 40 | // so we stop the endpoint when timeout. 41 | timeout time.Duration 42 | 43 | // For utest, to mock the target real server. 44 | // Optional, use the address of received client packet. 45 | mockRealServerAddr *net.UDPAddr 46 | } 47 | 48 | // NewProxy create a proxy, the router for this proxy belongs/bind to. If need to proxy for 49 | // please create a new proxy for each router. For all addresses we proxy, we will create a 50 | // vnet.Net in this router and proxy all packets. 51 | func NewProxy(router *Router) (*UDPProxy, error) { 52 | v := &UDPProxy{router: router, timeout: 2 * time.Minute} 53 | 54 | return v, nil 55 | } 56 | 57 | // Close the proxy, stop all workers. 58 | func (v *UDPProxy) Close() error { 59 | v.workers.Range(func(_, value any) bool { 60 | _ = value.(*aUDPProxyWorker).Close() //nolint:forcetypeassert 61 | 62 | return true 63 | }) 64 | 65 | return nil 66 | } 67 | 68 | // Proxy starts a worker for server, ignore if already started. 69 | func (v *UDPProxy) Proxy(client *Net, server *net.UDPAddr) error { 70 | // Note that even if the worker exists, it's also ok to create a same worker, 71 | // because the router will use the last one, and the real server will see a address 72 | // change event after we switch to the next worker. 73 | if _, ok := v.workers.Load(server.String()); ok { 74 | // nolint:godox // TODO: Need to restart the stopped worker? 75 | return nil 76 | } 77 | 78 | // Not exists, create a new one. 79 | worker := &aUDPProxyWorker{ 80 | router: v.router, mockRealServerAddr: v.mockRealServerAddr, 81 | } 82 | 83 | // Create context for cleanup. 84 | var ctx context.Context 85 | ctx, worker.ctxDisposeCancel = context.WithCancel(context.Background()) 86 | 87 | v.workers.Store(server.String(), worker) 88 | 89 | return worker.Proxy(ctx, client, server) 90 | } 91 | 92 | // A proxy worker for a specified proxy server. 93 | type aUDPProxyWorker struct { 94 | router *Router 95 | mockRealServerAddr *net.UDPAddr 96 | 97 | // Each vnet source, bind to a real socket to server. 98 | // key is vnet client addr, which is net.Addr 99 | // value is *net.UDPConn 100 | endpoints sync.Map 101 | 102 | // For cleanup. 103 | ctxDisposeCancel context.CancelFunc 104 | wg sync.WaitGroup 105 | } 106 | 107 | func (v *aUDPProxyWorker) Close() error { 108 | // Notify all goroutines to dispose. 109 | v.ctxDisposeCancel() 110 | 111 | // Wait for all goroutines quit. 112 | v.wg.Wait() 113 | 114 | return nil 115 | } 116 | 117 | func (v *aUDPProxyWorker) Proxy(ctx context.Context, _ *Net, serverAddr *net.UDPAddr) error { // nolint:gocognit,cyclop 118 | // Create vnet for real server by serverAddr. 119 | nw, err := NewNet(&NetConfig{ 120 | StaticIP: serverAddr.IP.String(), 121 | }) 122 | if err != nil { 123 | return err 124 | } 125 | 126 | if err = v.router.AddNet(nw); err != nil { 127 | return err 128 | } 129 | 130 | // We must create a "same" vnet.UDPConn as the net.UDPConn, 131 | // which has the same ip:port, to copy packets between them. 132 | vnetSocket, err := nw.ListenUDP("udp4", serverAddr) 133 | if err != nil { 134 | return err 135 | } 136 | 137 | // User stop proxy, we should close the socket. 138 | go func() { 139 | <-ctx.Done() 140 | _ = vnetSocket.Close() 141 | }() 142 | 143 | // Got new vnet client, start a new endpoint. 144 | findEndpointBy := func(addr net.Addr) (*net.UDPConn, error) { 145 | // Exists binding. 146 | if value, ok := v.endpoints.Load(addr.String()); ok { 147 | // Exists endpoint, reuse it. 148 | return value.(*net.UDPConn), nil //nolint:forcetypeassert 149 | } 150 | 151 | // The real server we proxy to, for utest to mock it. 152 | realAddr := serverAddr 153 | if v.mockRealServerAddr != nil { 154 | realAddr = v.mockRealServerAddr 155 | } 156 | 157 | // Got new vnet client, create new endpoint. 158 | realSocket, err := net.DialUDP("udp4", nil, realAddr) 159 | if err != nil { 160 | return nil, err 161 | } 162 | 163 | // User stop proxy, we should close the socket. 164 | go func() { 165 | <-ctx.Done() 166 | _ = realSocket.Close() 167 | }() 168 | 169 | // Bind address. 170 | v.endpoints.Store(addr.String(), realSocket) 171 | 172 | // Got packet from real serverAddr, we should proxy it to vnet. 173 | v.wg.Add(1) 174 | go func(vnetClientAddr net.Addr) { 175 | defer v.wg.Done() 176 | 177 | buf := make([]byte, 1500) 178 | for { 179 | n, _, err := realSocket.ReadFrom(buf) 180 | if err != nil { 181 | return 182 | } 183 | 184 | if n <= 0 { 185 | continue // Drop packet 186 | } 187 | 188 | if _, err := vnetSocket.WriteTo(buf[:n], vnetClientAddr); err != nil { 189 | return 190 | } 191 | } 192 | }(addr) 193 | 194 | return realSocket, nil 195 | } 196 | 197 | // Start a proxy goroutine. 198 | v.wg.Add(1) 199 | go func() { 200 | defer v.wg.Done() 201 | 202 | buf := make([]byte, 1500) 203 | 204 | for { 205 | n, addr, err := vnetSocket.ReadFrom(buf) 206 | if err != nil { 207 | return 208 | } 209 | 210 | if n <= 0 || addr == nil { 211 | continue // Drop packet 212 | } 213 | 214 | realSocket, err := findEndpointBy(addr) 215 | if err != nil { 216 | continue // Drop packet. 217 | } 218 | 219 | if _, err := realSocket.Write(buf[:n]); err != nil { 220 | return 221 | } 222 | } 223 | }() 224 | 225 | return nil 226 | } 227 | -------------------------------------------------------------------------------- /vnet/chunk.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "fmt" 8 | "net" 9 | "strconv" 10 | "strings" 11 | "sync/atomic" 12 | "time" 13 | ) 14 | 15 | type tcpFlag uint8 16 | 17 | const ( 18 | tcpFIN tcpFlag = 0x01 19 | tcpSYN tcpFlag = 0x02 20 | tcpRST tcpFlag = 0x04 21 | tcpPSH tcpFlag = 0x08 22 | tcpACK tcpFlag = 0x10 23 | ) 24 | 25 | func (f tcpFlag) String() string { 26 | var sa []string 27 | if f&tcpFIN != 0 { 28 | sa = append(sa, "FIN") 29 | } 30 | if f&tcpSYN != 0 { 31 | sa = append(sa, "SYN") 32 | } 33 | if f&tcpRST != 0 { 34 | sa = append(sa, "RST") 35 | } 36 | if f&tcpPSH != 0 { 37 | sa = append(sa, "PSH") 38 | } 39 | if f&tcpACK != 0 { 40 | sa = append(sa, "ACK") 41 | } 42 | 43 | return strings.Join(sa, "-") 44 | } 45 | 46 | // Generate a base36-encoded unique tag 47 | // See: https://play.golang.org/p/0ZaAID1q-HN 48 | var assignChunkTag = func() func() string { //nolint:gochecknoglobals 49 | var tagCtr uint64 50 | 51 | return func() string { 52 | n := atomic.AddUint64(&tagCtr, 1) 53 | 54 | return strconv.FormatUint(n, 36) 55 | } 56 | }() 57 | 58 | // Chunk represents a packet passed around in the vnet. 59 | type Chunk interface { 60 | setTimestamp() time.Time // used by router 61 | getTimestamp() time.Time // used by router 62 | getSourceIP() net.IP // used by router 63 | getDestinationIP() net.IP // used by router 64 | setSourceAddr(address string) error // used by nat 65 | setDestinationAddr(address string) error // used by nat 66 | 67 | SourceAddr() net.Addr 68 | DestinationAddr() net.Addr 69 | UserData() []byte 70 | Tag() string 71 | Clone() Chunk 72 | Network() string // returns "udp" or "tcp" 73 | String() string 74 | } 75 | 76 | type chunkIP struct { 77 | timestamp time.Time 78 | sourceIP net.IP 79 | destinationIP net.IP 80 | tag string 81 | duplicate bool 82 | } 83 | 84 | func (c *chunkIP) setTimestamp() time.Time { 85 | c.timestamp = time.Now() 86 | 87 | return c.timestamp 88 | } 89 | 90 | func (c *chunkIP) getTimestamp() time.Time { 91 | return c.timestamp 92 | } 93 | 94 | func (c *chunkIP) getDestinationIP() net.IP { 95 | return c.destinationIP 96 | } 97 | 98 | func (c *chunkIP) getSourceIP() net.IP { 99 | return c.sourceIP 100 | } 101 | 102 | func (c *chunkIP) Tag() string { 103 | return c.tag 104 | } 105 | 106 | func (c *chunkIP) markDuplicate() { 107 | c.duplicate = true 108 | } 109 | 110 | func (c *chunkIP) isDuplicate() bool { 111 | return c.duplicate 112 | } 113 | 114 | type chunkUDP struct { 115 | chunkIP 116 | sourcePort int 117 | destinationPort int 118 | userData []byte 119 | } 120 | 121 | func newChunkUDP(srcAddr, dstAddr *net.UDPAddr) *chunkUDP { 122 | return &chunkUDP{ 123 | chunkIP: chunkIP{ 124 | sourceIP: srcAddr.IP, 125 | destinationIP: dstAddr.IP, 126 | tag: assignChunkTag(), 127 | }, 128 | sourcePort: srcAddr.Port, 129 | destinationPort: dstAddr.Port, 130 | } 131 | } 132 | 133 | func (c *chunkUDP) SourceAddr() net.Addr { 134 | return &net.UDPAddr{ 135 | IP: c.sourceIP, 136 | Port: c.sourcePort, 137 | } 138 | } 139 | 140 | func (c *chunkUDP) DestinationAddr() net.Addr { 141 | return &net.UDPAddr{ 142 | IP: c.destinationIP, 143 | Port: c.destinationPort, 144 | } 145 | } 146 | 147 | func (c *chunkUDP) UserData() []byte { 148 | return c.userData 149 | } 150 | 151 | func (c *chunkUDP) Clone() Chunk { 152 | var userData []byte 153 | if c.userData != nil { 154 | userData = make([]byte, len(c.userData)) 155 | copy(userData, c.userData) 156 | } 157 | 158 | return &chunkUDP{ 159 | chunkIP: chunkIP{ 160 | timestamp: c.timestamp, 161 | sourceIP: c.sourceIP, 162 | destinationIP: c.destinationIP, 163 | tag: c.tag, 164 | }, 165 | sourcePort: c.sourcePort, 166 | destinationPort: c.destinationPort, 167 | userData: userData, 168 | } 169 | } 170 | 171 | func (c *chunkUDP) Network() string { 172 | return udp 173 | } 174 | 175 | func (c *chunkUDP) String() string { 176 | src := c.SourceAddr() 177 | dst := c.DestinationAddr() 178 | 179 | return fmt.Sprintf("%s chunk %s %s => %s", 180 | src.Network(), 181 | c.tag, 182 | src.String(), 183 | dst.String(), 184 | ) 185 | } 186 | 187 | func (c *chunkUDP) setSourceAddr(address string) error { 188 | addr, err := net.ResolveUDPAddr(udp, address) 189 | if err != nil { 190 | return err 191 | } 192 | c.sourceIP = addr.IP 193 | c.sourcePort = addr.Port 194 | 195 | return nil 196 | } 197 | 198 | func (c *chunkUDP) setDestinationAddr(address string) error { 199 | addr, err := net.ResolveUDPAddr(udp, address) 200 | if err != nil { 201 | return err 202 | } 203 | c.destinationIP = addr.IP 204 | c.destinationPort = addr.Port 205 | 206 | return nil 207 | } 208 | 209 | type chunkTCP struct { 210 | chunkIP 211 | sourcePort int 212 | destinationPort int 213 | flags tcpFlag // control bits 214 | userData []byte // only with PSH flag 215 | // seq uint32 // always starts with 0 216 | // ack uint32 // always starts with 0 217 | } 218 | 219 | func newChunkTCP(srcAddr, dstAddr *net.TCPAddr, flags tcpFlag) *chunkTCP { 220 | return &chunkTCP{ 221 | chunkIP: chunkIP{ 222 | sourceIP: srcAddr.IP, 223 | destinationIP: dstAddr.IP, 224 | tag: assignChunkTag(), 225 | }, 226 | sourcePort: srcAddr.Port, 227 | destinationPort: dstAddr.Port, 228 | flags: flags, 229 | } 230 | } 231 | 232 | func (c *chunkTCP) SourceAddr() net.Addr { 233 | return &net.TCPAddr{ 234 | IP: c.sourceIP, 235 | Port: c.sourcePort, 236 | } 237 | } 238 | 239 | func (c *chunkTCP) DestinationAddr() net.Addr { 240 | return &net.TCPAddr{ 241 | IP: c.destinationIP, 242 | Port: c.destinationPort, 243 | } 244 | } 245 | 246 | func (c *chunkTCP) UserData() []byte { 247 | return c.userData 248 | } 249 | 250 | func (c *chunkTCP) Clone() Chunk { 251 | userData := make([]byte, len(c.userData)) 252 | copy(userData, c.userData) 253 | 254 | return &chunkTCP{ 255 | chunkIP: chunkIP{ 256 | timestamp: c.timestamp, 257 | sourceIP: c.sourceIP, 258 | destinationIP: c.destinationIP, 259 | }, 260 | sourcePort: c.sourcePort, 261 | destinationPort: c.destinationPort, 262 | userData: userData, 263 | } 264 | } 265 | 266 | func (c *chunkTCP) Network() string { 267 | return "tcp" 268 | } 269 | 270 | func (c *chunkTCP) String() string { 271 | src := c.SourceAddr() 272 | dst := c.DestinationAddr() 273 | 274 | return fmt.Sprintf("%s %s chunk %s %s => %s", 275 | src.Network(), 276 | c.flags.String(), 277 | c.tag, 278 | src.String(), 279 | dst.String(), 280 | ) 281 | } 282 | 283 | func (c *chunkTCP) setSourceAddr(address string) error { 284 | addr, err := net.ResolveTCPAddr("tcp", address) 285 | if err != nil { 286 | return err 287 | } 288 | c.sourceIP = addr.IP 289 | c.sourcePort = addr.Port 290 | 291 | return nil 292 | } 293 | 294 | func (c *chunkTCP) setDestinationAddr(address string) error { 295 | addr, err := net.ResolveTCPAddr("tcp", address) 296 | if err != nil { 297 | return err 298 | } 299 | c.destinationIP = addr.IP 300 | c.destinationPort = addr.Port 301 | 302 | return nil 303 | } 304 | -------------------------------------------------------------------------------- /connctx/connctx_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package connctx 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "io" 10 | "net" 11 | "testing" 12 | "time" 13 | 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func TestRead(t *testing.T) { 18 | ca, cb := net.Pipe() 19 | defer func() { 20 | _ = ca.Close() 21 | }() 22 | 23 | data := []byte{0x01, 0x02, 0xFF} 24 | chErr := make(chan error) 25 | 26 | go func() { 27 | _, err := cb.Write(data) 28 | chErr <- err 29 | }() 30 | 31 | c := New(ca) 32 | b := make([]byte, 100) 33 | n, err := c.ReadContext(context.Background(), b) 34 | assert.NoError(t, err) 35 | assert.Len(t, data, n) 36 | assert.Equal(t, data, b[:n]) 37 | 38 | err = <-chErr 39 | assert.NoError(t, err) 40 | } 41 | 42 | func TestReadTimeout(t *testing.T) { 43 | ca, _ := net.Pipe() 44 | defer func() { 45 | _ = ca.Close() 46 | }() 47 | 48 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 49 | defer cancel() 50 | 51 | c := New(ca) 52 | b := make([]byte, 100) 53 | n, err := c.ReadContext(ctx, b) 54 | assert.Error(t, err) 55 | assert.Empty(t, n) 56 | } 57 | 58 | func TestReadCancel(t *testing.T) { 59 | ca, _ := net.Pipe() 60 | defer func() { 61 | _ = ca.Close() 62 | }() 63 | 64 | ctx, cancel := context.WithCancel(context.Background()) 65 | go func() { 66 | time.Sleep(10 * time.Millisecond) 67 | cancel() 68 | }() 69 | 70 | c := New(ca) 71 | b := make([]byte, 100) 72 | n, err := c.ReadContext(ctx, b) 73 | assert.Error(t, err) 74 | assert.Empty(t, n) 75 | } 76 | 77 | func TestReadClosed(t *testing.T) { 78 | ca, _ := net.Pipe() 79 | 80 | c := New(ca) 81 | _ = c.Close() 82 | 83 | b := make([]byte, 100) 84 | n, err := c.ReadContext(context.Background(), b) 85 | assert.ErrorIs(t, err, io.EOF) 86 | assert.Empty(t, n) 87 | } 88 | 89 | func TestWrite(t *testing.T) { 90 | ca, cb := net.Pipe() 91 | defer func() { 92 | _ = ca.Close() 93 | }() 94 | 95 | chErr := make(chan error) 96 | chRead := make(chan []byte) 97 | 98 | go func() { 99 | b := make([]byte, 100) 100 | n, err := cb.Read(b) 101 | chErr <- err 102 | chRead <- b[:n] 103 | }() 104 | 105 | c := New(ca) 106 | data := []byte{0x01, 0x02, 0xFF} 107 | n, err := c.WriteContext(context.Background(), data) 108 | assert.NoError(t, err) 109 | assert.Len(t, data, n) 110 | 111 | err = <-chErr 112 | b := <-chRead 113 | assert.Equal(t, b, data) 114 | assert.NoError(t, err) 115 | } 116 | 117 | func TestWriteTimeout(t *testing.T) { 118 | ca, _ := net.Pipe() 119 | defer func() { 120 | _ = ca.Close() 121 | }() 122 | 123 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 124 | defer cancel() 125 | 126 | c := New(ca) 127 | b := make([]byte, 100) 128 | n, err := c.WriteContext(ctx, b) 129 | assert.Error(t, err) 130 | assert.Empty(t, n) 131 | } 132 | 133 | func TestWriteCancel(t *testing.T) { 134 | ca, _ := net.Pipe() 135 | defer func() { 136 | _ = ca.Close() 137 | }() 138 | 139 | ctx, cancel := context.WithCancel(context.Background()) 140 | go func() { 141 | time.Sleep(10 * time.Millisecond) 142 | cancel() 143 | }() 144 | 145 | c := New(ca) 146 | b := make([]byte, 100) 147 | n, err := c.WriteContext(ctx, b) 148 | assert.Error(t, err) 149 | assert.Empty(t, n) 150 | } 151 | 152 | func TestWriteClosed(t *testing.T) { 153 | ca, _ := net.Pipe() 154 | 155 | c := New(ca) 156 | _ = c.Close() 157 | 158 | b := make([]byte, 100) 159 | n, err := c.WriteContext(context.Background(), b) 160 | assert.ErrorIs(t, err, ErrClosing) 161 | assert.Empty(t, n) 162 | } 163 | 164 | // Test for TestLocalAddrAndRemoteAddr. 165 | type stringAddr struct { 166 | network string 167 | addr string 168 | } 169 | 170 | func (a stringAddr) Network() string { return a.network } 171 | func (a stringAddr) String() string { return a.addr } 172 | 173 | type connAddrMock struct{} 174 | 175 | func (*connAddrMock) RemoteAddr() net.Addr { return stringAddr{"remote_net", "remote_addr"} } 176 | func (*connAddrMock) LocalAddr() net.Addr { return stringAddr{"local_net", "local_addr"} } 177 | func (*connAddrMock) Read(_ []byte) (n int, err error) { 178 | panic("unimplemented") //nolint 179 | } 180 | 181 | func (*connAddrMock) Write(_ []byte) (n int, err error) { 182 | panic("unimplemented") //nolint 183 | } 184 | 185 | func (*connAddrMock) Close() error { 186 | panic("unimplemented") //nolint 187 | } 188 | 189 | func (*connAddrMock) SetDeadline(_ time.Time) error { 190 | panic("unimplemented") //nolint 191 | } 192 | 193 | func (*connAddrMock) SetReadDeadline(_ time.Time) error { 194 | panic("unimplemented") //nolint 195 | } 196 | 197 | func (*connAddrMock) SetWriteDeadline(_ time.Time) error { 198 | panic("unimplemented") //nolint 199 | } 200 | 201 | func TestLocalAddrAndRemoteAddr(t *testing.T) { 202 | c := New(&connAddrMock{}) 203 | al := c.LocalAddr() 204 | ar := c.RemoteAddr() 205 | 206 | assert.Equal(t, "local_addr", al.String()) 207 | assert.Equal(t, "remote_addr", ar.String()) 208 | } 209 | 210 | func BenchmarkBase(b *testing.B) { 211 | ca, cb := net.Pipe() 212 | defer func() { 213 | _ = ca.Close() 214 | }() 215 | 216 | data := make([]byte, 4096) 217 | for i := range data { 218 | data[i] = byte(i) 219 | } 220 | buf := make([]byte, len(data)) 221 | 222 | b.SetBytes(int64(len(data))) 223 | b.ResetTimer() 224 | 225 | go func(n int) { 226 | for i := 0; i < n; i++ { 227 | _, _ = cb.Write(data) 228 | } 229 | _ = cb.Close() 230 | }(b.N) 231 | 232 | count := 0 233 | for { 234 | n, err := ca.Read(buf) 235 | if err != nil { 236 | if !errors.Is(err, io.EOF) { 237 | b.Fatal(err) 238 | } 239 | 240 | break 241 | } 242 | if n != len(data) { 243 | b.Errorf("Expected %v, got %v", len(data), n) 244 | } 245 | count++ 246 | } 247 | if count != b.N { 248 | b.Errorf("Expected %v, got %v", b.N, count) 249 | } 250 | } 251 | 252 | func BenchmarkWrite(b *testing.B) { 253 | ca, cb := net.Pipe() 254 | defer func() { 255 | _ = ca.Close() 256 | }() 257 | 258 | data := make([]byte, 4096) 259 | for i := range data { 260 | data[i] = byte(i) 261 | } 262 | buf := make([]byte, len(data)) 263 | 264 | b.SetBytes(int64(len(data))) 265 | b.ResetTimer() 266 | 267 | go func(n int) { 268 | c := New(cb) 269 | for i := 0; i < n; i++ { 270 | _, _ = c.WriteContext(context.Background(), data) 271 | } 272 | _ = cb.Close() 273 | }(b.N) 274 | 275 | count := 0 276 | for { 277 | n, err := ca.Read(buf) 278 | if err != nil { 279 | if !errors.Is(err, io.EOF) { 280 | b.Fatal(err) 281 | } 282 | 283 | break 284 | } 285 | if n != len(data) { 286 | b.Errorf("Expected %v, got %v", len(data), n) 287 | } 288 | count++ 289 | } 290 | if count != b.N { 291 | b.Errorf("Expected %v, got %v", b.N, count) 292 | } 293 | } 294 | 295 | func BenchmarkRead(b *testing.B) { 296 | ca, cb := net.Pipe() 297 | defer func() { 298 | _ = ca.Close() 299 | }() 300 | 301 | data := make([]byte, 4096) 302 | for i := range data { 303 | data[i] = byte(i) 304 | } 305 | buf := make([]byte, len(data)) 306 | 307 | b.SetBytes(int64(len(data))) 308 | b.ResetTimer() 309 | 310 | go func(n int) { 311 | for i := 0; i < n; i++ { 312 | _, _ = cb.Write(data) 313 | } 314 | _ = cb.Close() 315 | }(b.N) 316 | 317 | c := New(ca) 318 | count := 0 319 | for { 320 | n, err := c.ReadContext(context.Background(), buf) 321 | if err != nil { 322 | if !errors.Is(err, io.EOF) { 323 | b.Fatal(err) 324 | } 325 | 326 | break 327 | } 328 | if n != len(data) { 329 | b.Errorf("Expected %v, got %v", len(data), n) 330 | } 331 | count++ 332 | } 333 | if count != b.N { 334 | b.Errorf("Expected %v, got %v", b.N, count) 335 | } 336 | } 337 | -------------------------------------------------------------------------------- /vnet/conn_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "errors" 8 | "net" 9 | "sync/atomic" 10 | "testing" 11 | "time" 12 | 13 | "github.com/pion/logging" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | var errFailedToCovertToChuckUDP = errors.New("failed to convert chunk to chunkUDP") 18 | 19 | type dummyObserver struct { 20 | onWrite func(Chunk) error 21 | onOnClosed func(net.Addr) 22 | } 23 | 24 | func (o *dummyObserver) write(c Chunk) error { 25 | return o.onWrite(c) 26 | } 27 | 28 | func (o *dummyObserver) onClosed(addr net.Addr) { 29 | o.onOnClosed(addr) 30 | } 31 | 32 | func (o *dummyObserver) determineSourceIP(locIP, _ net.IP) net.IP { 33 | return locIP 34 | } 35 | 36 | func TestUDPConn(t *testing.T) { //nolint:cyclop,maintidx 37 | log := logging.NewDefaultLoggerFactory().NewLogger("test") 38 | 39 | t.Run("WriteTo ReadFrom", func(t *testing.T) { 40 | var nClosed int32 41 | var conn *UDPConn 42 | var err error 43 | data := []byte("Hello") 44 | srcAddr := &net.UDPAddr{ 45 | IP: net.ParseIP("127.0.0.1"), 46 | Port: 1234, 47 | } 48 | dstAddr := &net.UDPAddr{ 49 | IP: net.ParseIP("127.0.0.1"), 50 | Port: 5678, 51 | } 52 | 53 | obs := &dummyObserver{ 54 | onWrite: func(c Chunk) error { 55 | uc, ok := c.(*chunkUDP) 56 | if !ok { 57 | return errFailedToCovertToChuckUDP 58 | } 59 | chunk := newChunkUDP(uc.DestinationAddr().(*net.UDPAddr), uc.SourceAddr().(*net.UDPAddr)) //nolint:forcetypeassert 60 | chunk.userData = make([]byte, len(uc.userData)) 61 | copy(chunk.userData, uc.userData) 62 | conn.readCh <- chunk // echo back 63 | 64 | return nil 65 | }, 66 | onOnClosed: func(net.Addr) { 67 | atomic.AddInt32(&nClosed, 1) 68 | }, 69 | } 70 | conn, err = newUDPConn(srcAddr, nil, obs) 71 | assert.NoError(t, err, "should succeed") 72 | 73 | rcvdCh := make(chan struct{}) 74 | doneCh := make(chan struct{}) 75 | 76 | go func() { 77 | buf := make([]byte, 1500) 78 | 79 | for { 80 | n, addr, err2 := conn.ReadFrom(buf) 81 | if err2 != nil { 82 | log.Debug("conn closed. exiting the read loop") 83 | 84 | break 85 | } 86 | log.Debug("read data") 87 | assert.Equal(t, len(data), n, "should match") 88 | assert.Equal(t, string(data), string(data), "should match") 89 | assert.Equal(t, dstAddr.String(), addr.String(), "should match") 90 | rcvdCh <- struct{}{} 91 | } 92 | 93 | close(doneCh) 94 | }() 95 | 96 | var n int 97 | n, err = conn.WriteTo(data, dstAddr) 98 | if !assert.Nil(t, err, "should succeed") { 99 | return 100 | } 101 | assert.Equal(t, len(data), n, "should match") 102 | 103 | loop: 104 | for { 105 | select { 106 | case <-rcvdCh: 107 | log.Debug("closing conn..") 108 | err2 := conn.Close() 109 | assert.Nil(t, err2, "should succeed") 110 | case <-doneCh: 111 | break loop 112 | } 113 | } 114 | 115 | assert.Equal(t, int32(1), atomic.LoadInt32(&nClosed), "should be closed once") 116 | }) 117 | 118 | t.Run("Write Read", func(t *testing.T) { 119 | var nClosed int32 120 | var conn *UDPConn 121 | var err error 122 | data := []byte("Hello") 123 | srcAddr := &net.UDPAddr{ 124 | IP: net.ParseIP("127.0.0.1"), 125 | Port: 1234, 126 | } 127 | dstAddr := &net.UDPAddr{ 128 | IP: net.ParseIP("127.0.0.1"), 129 | Port: 5678, 130 | } 131 | 132 | obs := &dummyObserver{ 133 | onWrite: func(c Chunk) error { 134 | uc, ok := c.(*chunkUDP) 135 | if !ok { 136 | return errFailedToCovertToChuckUDP 137 | } 138 | //nolint:forcetypeassert 139 | chunk := newChunkUDP( 140 | uc.DestinationAddr().(*net.UDPAddr), 141 | uc.SourceAddr().(*net.UDPAddr), 142 | ) 143 | chunk.userData = make([]byte, len(uc.userData)) 144 | copy(chunk.userData, uc.userData) 145 | conn.readCh <- chunk // echo back 146 | 147 | return nil 148 | }, 149 | onOnClosed: func(net.Addr) { 150 | atomic.AddInt32(&nClosed, 1) 151 | }, 152 | } 153 | conn, err = newUDPConn(srcAddr, nil, obs) 154 | assert.NoError(t, err, "should succeed") 155 | conn.remAddr = dstAddr 156 | 157 | rcvdCh := make(chan struct{}) 158 | doneCh := make(chan struct{}) 159 | 160 | go func() { 161 | buf := make([]byte, 1500) 162 | 163 | for { 164 | n, err2 := conn.Read(buf) 165 | if err2 != nil { 166 | log.Debug("conn closed. exiting the read loop") 167 | 168 | break 169 | } 170 | log.Debug("read data") 171 | assert.Equal(t, len(data), n, "should match") 172 | assert.Equal(t, string(data), string(data), "should match") 173 | rcvdCh <- struct{}{} 174 | } 175 | 176 | close(doneCh) 177 | }() 178 | 179 | var n int 180 | n, err = conn.Write(data) 181 | if !assert.Nil(t, err, "should succeed") { 182 | return 183 | } 184 | 185 | assert.Equal(t, len(data), n, "should match") 186 | 187 | loop: 188 | for { 189 | select { 190 | case <-rcvdCh: 191 | log.Debug("closing conn..") 192 | err = conn.Close() 193 | assert.Nil(t, err, "should succeed") 194 | case <-doneCh: 195 | break loop 196 | } 197 | } 198 | 199 | assert.Equal(t, int32(1), atomic.LoadInt32(&nClosed), "should be closed once") 200 | }) 201 | 202 | deadlineTest := func(t *testing.T, readOnly bool) { 203 | t.Helper() 204 | 205 | var nClosed int32 206 | var conn *UDPConn 207 | var err error 208 | srcAddr := &net.UDPAddr{ 209 | IP: net.ParseIP("127.0.0.1"), 210 | Port: 1234, 211 | } 212 | obs := &dummyObserver{ 213 | onOnClosed: func(net.Addr) { 214 | atomic.AddInt32(&nClosed, 1) 215 | }, 216 | } 217 | conn, err = newUDPConn(srcAddr, nil, obs) 218 | assert.NoError(t, err, "should succeed") 219 | 220 | doneCh := make(chan struct{}) 221 | 222 | if readOnly { 223 | err = conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) 224 | } else { 225 | err = conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) 226 | } 227 | assert.Nil(t, err, "should succeed") 228 | 229 | go func() { 230 | buf := make([]byte, 1500) 231 | _, _, err := conn.ReadFrom(buf) 232 | assert.NotNil(t, err, "should return error") 233 | var ne *net.OpError 234 | if errors.As(err, &ne) { 235 | assert.True(t, ne.Timeout(), "should be a timeout") 236 | } else { 237 | assert.True(t, false, "should be an net.OpError") 238 | } 239 | 240 | assert.Nil(t, conn.Close(), "should succeed") 241 | close(doneCh) 242 | }() 243 | 244 | <-doneCh 245 | 246 | assert.Equal(t, int32(1), atomic.LoadInt32(&nClosed), "should be closed once") 247 | } 248 | 249 | t.Run("SetReadDeadline", func(t *testing.T) { 250 | deadlineTest(t, true) 251 | }) 252 | 253 | t.Run("SetDeadline", func(t *testing.T) { 254 | deadlineTest(t, false) 255 | }) 256 | 257 | t.Run("Inbound during close", func(t *testing.T) { 258 | var conn *UDPConn 259 | var err error 260 | srcAddr := &net.UDPAddr{ 261 | IP: net.ParseIP("127.0.0.1"), 262 | Port: 1234, 263 | } 264 | obs := &dummyObserver{ 265 | onOnClosed: func(net.Addr) {}, 266 | } 267 | 268 | for i := 0; i < 1000; i++ { // nolint:staticcheck // (false positive detection) 269 | conn, err = newUDPConn(srcAddr, nil, obs) 270 | assert.NoError(t, err, "should succeed") 271 | 272 | chDone := make(chan struct{}) 273 | go func() { 274 | time.Sleep(20 * time.Millisecond) 275 | assert.NoError(t, conn.Close()) 276 | close(chDone) 277 | }() 278 | tick := time.NewTicker(10 * time.Millisecond) 279 | for { 280 | defer tick.Stop() 281 | select { 282 | case <-chDone: 283 | return 284 | case <-tick.C: 285 | conn.onInboundChunk(nil) 286 | } 287 | } 288 | } 289 | }) 290 | } 291 | -------------------------------------------------------------------------------- /netctx/conn_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package netctx 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "io" 10 | "net" 11 | "testing" 12 | "time" 13 | 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func TestRead(t *testing.T) { 18 | ca, cb := net.Pipe() 19 | defer func() { 20 | _ = ca.Close() 21 | }() 22 | 23 | data := []byte{0x01, 0x02, 0xFF} 24 | chErr := make(chan error) 25 | 26 | go func() { 27 | _, err := cb.Write(data) 28 | chErr <- err 29 | }() 30 | 31 | c := NewConn(ca) 32 | b := make([]byte, 100) 33 | n, err := c.ReadContext(context.Background(), b) 34 | assert.NoError(t, err) 35 | assert.Equal(t, len(data), n) 36 | assert.Equal(t, data, b[:n]) 37 | 38 | assert.NoError(t, <-chErr) 39 | } 40 | 41 | func TestReadTimeout(t *testing.T) { 42 | ca, _ := net.Pipe() 43 | defer func() { 44 | _ = ca.Close() 45 | }() 46 | 47 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 48 | defer cancel() 49 | 50 | c := NewConn(ca) 51 | b := make([]byte, 100) 52 | n, err := c.ReadContext(ctx, b) 53 | assert.Error(t, err) 54 | assert.Empty(t, n) 55 | } 56 | 57 | func TestReadCancel(t *testing.T) { 58 | ca, _ := net.Pipe() 59 | defer func() { 60 | _ = ca.Close() 61 | }() 62 | 63 | ctx, cancel := context.WithCancel(context.Background()) 64 | go func() { 65 | time.Sleep(10 * time.Millisecond) 66 | cancel() 67 | }() 68 | 69 | c := NewConn(ca) 70 | b := make([]byte, 100) 71 | n, err := c.ReadContext(ctx, b) 72 | assert.Error(t, err) 73 | assert.Empty(t, n) 74 | } 75 | 76 | func TestReadClosed(t *testing.T) { 77 | ca, _ := net.Pipe() 78 | 79 | c := NewConn(ca) 80 | _ = c.Close() 81 | 82 | b := make([]byte, 100) 83 | n, err := c.ReadContext(context.Background(), b) 84 | assert.ErrorIs(t, err, net.ErrClosed) 85 | assert.Empty(t, n) 86 | } 87 | 88 | func TestWrite(t *testing.T) { 89 | ca, cb := net.Pipe() 90 | defer func() { 91 | _ = ca.Close() 92 | }() 93 | 94 | chErr := make(chan error) 95 | chRead := make(chan []byte) 96 | 97 | go func() { 98 | b := make([]byte, 100) 99 | n, err := cb.Read(b) 100 | chErr <- err 101 | chRead <- b[:n] 102 | }() 103 | 104 | c := NewConn(ca) 105 | data := []byte{0x01, 0x02, 0xFF} 106 | n, err := c.WriteContext(context.Background(), data) 107 | assert.NoError(t, err) 108 | assert.Len(t, data, n) 109 | 110 | err = <-chErr 111 | b := <-chRead 112 | assert.NoError(t, err) 113 | assert.Equal(t, data, b) 114 | } 115 | 116 | func TestWriteTimeout(t *testing.T) { 117 | ca, _ := net.Pipe() 118 | defer func() { 119 | _ = ca.Close() 120 | }() 121 | 122 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 123 | defer cancel() 124 | 125 | c := NewConn(ca) 126 | b := make([]byte, 100) 127 | n, err := c.WriteContext(ctx, b) 128 | assert.Error(t, err) 129 | assert.Empty(t, n) 130 | } 131 | 132 | func TestWriteCancel(t *testing.T) { 133 | ca, _ := net.Pipe() 134 | defer func() { 135 | _ = ca.Close() 136 | }() 137 | 138 | ctx, cancel := context.WithCancel(context.Background()) 139 | go func() { 140 | time.Sleep(10 * time.Millisecond) 141 | cancel() 142 | }() 143 | 144 | c := NewConn(ca) 145 | b := make([]byte, 100) 146 | n, err := c.WriteContext(ctx, b) 147 | assert.Error(t, err) 148 | assert.Empty(t, n) 149 | } 150 | 151 | func TestWriteClosed(t *testing.T) { 152 | ca, _ := net.Pipe() 153 | 154 | c := NewConn(ca) 155 | _ = c.Close() 156 | 157 | b := make([]byte, 100) 158 | n, err := c.WriteContext(context.Background(), b) 159 | assert.ErrorIs(t, err, ErrClosing) 160 | assert.Empty(t, n) 161 | } 162 | 163 | // Test for TestLocalAddrAndRemoteAddr. 164 | type stringAddr struct { 165 | network string 166 | addr string 167 | } 168 | 169 | func (a stringAddr) Network() string { return a.network } 170 | func (a stringAddr) String() string { return a.addr } 171 | 172 | type connAddrMock struct{} 173 | 174 | func (*connAddrMock) RemoteAddr() net.Addr { return stringAddr{"remote_net", "remote_addr"} } 175 | func (*connAddrMock) LocalAddr() net.Addr { return stringAddr{"local_net", "local_addr"} } 176 | 177 | func (*connAddrMock) Read(_ []byte) (n int, err error) { 178 | panic("unimplemented") //nolint 179 | } 180 | 181 | func (*connAddrMock) Write(_ []byte) (n int, err error) { 182 | panic("unimplemented") //nolint 183 | } 184 | 185 | func (*connAddrMock) Close() error { 186 | panic("unimplemented") //nolint 187 | } 188 | 189 | func (*connAddrMock) SetDeadline(_ time.Time) error { 190 | panic("unimplemented") //nolint 191 | } 192 | 193 | func (*connAddrMock) SetReadDeadline(_ time.Time) error { 194 | panic("unimplemented") //nolint 195 | } 196 | 197 | func (*connAddrMock) SetWriteDeadline(_ time.Time) error { 198 | panic("unimplemented") //nolint 199 | } 200 | 201 | func TestLocalAddrAndRemoteAddr(t *testing.T) { 202 | c := NewConn(&connAddrMock{}) 203 | al := c.LocalAddr() 204 | ar := c.RemoteAddr() 205 | 206 | assert.Equal(t, "local_addr", al.String()) 207 | assert.Equal(t, "remote_addr", ar.String()) 208 | } 209 | 210 | func BenchmarkBase(b *testing.B) { 211 | ca, cb := net.Pipe() 212 | defer func() { 213 | _ = ca.Close() 214 | }() 215 | 216 | data := make([]byte, 4096) 217 | for i := range data { 218 | data[i] = byte(i) 219 | } 220 | buf := make([]byte, len(data)) 221 | 222 | b.SetBytes(int64(len(data))) 223 | b.ResetTimer() 224 | 225 | go func(n int) { 226 | for i := 0; i < n; i++ { 227 | _, _ = cb.Write(data) 228 | } 229 | _ = cb.Close() 230 | }(b.N) 231 | 232 | count := 0 233 | for { 234 | n, err := ca.Read(buf) 235 | if err != nil { 236 | if !errors.Is(err, io.EOF) { 237 | b.Fatal(err) 238 | } 239 | 240 | break 241 | } 242 | if n != len(data) { 243 | b.Errorf("Expected %v, got %v", len(data), n) 244 | } 245 | count++ 246 | } 247 | if count != b.N { 248 | b.Errorf("Expected %v, got %v", b.N, count) 249 | } 250 | } 251 | 252 | func BenchmarkWrite(b *testing.B) { 253 | ca, cb := net.Pipe() 254 | defer func() { 255 | _ = ca.Close() 256 | }() 257 | 258 | data := make([]byte, 4096) 259 | for i := range data { 260 | data[i] = byte(i) 261 | } 262 | buf := make([]byte, len(data)) 263 | 264 | b.SetBytes(int64(len(data))) 265 | b.ResetTimer() 266 | 267 | go func(n int) { 268 | c := NewConn(cb) 269 | for i := 0; i < n; i++ { 270 | _, _ = c.WriteContext(context.Background(), data) 271 | } 272 | _ = cb.Close() 273 | }(b.N) 274 | 275 | count := 0 276 | for { 277 | n, err := ca.Read(buf) 278 | if err != nil { 279 | if !errors.Is(err, io.EOF) { 280 | b.Fatal(err) 281 | } 282 | 283 | break 284 | } 285 | if n != len(data) { 286 | b.Errorf("Expected %v, got %v", len(data), n) 287 | } 288 | count++ 289 | } 290 | if count != b.N { 291 | b.Errorf("Expected %v, got %v", b.N, count) 292 | } 293 | } 294 | 295 | func BenchmarkRead(b *testing.B) { 296 | ca, cb := net.Pipe() 297 | defer func() { 298 | _ = ca.Close() 299 | }() 300 | 301 | data := make([]byte, 4096) 302 | for i := range data { 303 | data[i] = byte(i) 304 | } 305 | buf := make([]byte, len(data)) 306 | 307 | b.SetBytes(int64(len(data))) 308 | b.ResetTimer() 309 | 310 | go func(n int) { 311 | for i := 0; i < n; i++ { 312 | _, _ = cb.Write(data) 313 | } 314 | _ = cb.Close() 315 | }(b.N) 316 | 317 | c := NewConn(ca) 318 | count := 0 319 | for { 320 | n, err := c.ReadContext(context.Background(), buf) 321 | if err != nil { 322 | if !errors.Is(err, io.EOF) { 323 | b.Fatal(err) 324 | } 325 | 326 | break 327 | } 328 | if n != len(data) { 329 | b.Errorf("Expected %v, got %v", len(data), n) 330 | } 331 | count++ 332 | } 333 | if count != b.N { 334 | b.Errorf("Expected %v, got %v", b.N, count) 335 | } 336 | } 337 | -------------------------------------------------------------------------------- /udp/batchconn_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | //go:build linux 5 | 6 | package udp 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "net" 12 | "sync" 13 | "sync/atomic" 14 | "testing" 15 | "time" 16 | 17 | "github.com/pion/transport/v3/test" 18 | "github.com/stretchr/testify/assert" 19 | ) 20 | 21 | func TestBatchConn_WriteBatchInterval(t *testing.T) { 22 | report := test.CheckRoutines(t) 23 | defer report() 24 | 25 | lc := ListenConfig{ 26 | Batch: BatchIOConfig{ 27 | Enable: true, 28 | ReadBatchSize: 10, 29 | WriteBatchSize: 3, 30 | WriteBatchInterval: 5 * time.Millisecond, 31 | }, 32 | ReadBufferSize: 64 * 1024, 33 | WriteBufferSize: 64 * 1024, 34 | } 35 | 36 | laddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 15678} 37 | listener, err := lc.Listen("udp", laddr) 38 | assert.NoError(t, err) 39 | 40 | var serverConnWg sync.WaitGroup 41 | serverConnWg.Add(1) 42 | go func() { //nolint:dupl 43 | var exit int32 44 | defer func() { 45 | defer serverConnWg.Done() 46 | atomic.StoreInt32(&exit, 1) 47 | }() 48 | for { 49 | buf := make([]byte, 1400) 50 | conn, lerr := listener.Accept() 51 | if errors.Is(lerr, ErrClosedListener) { 52 | break 53 | } 54 | assert.NoError(t, lerr) 55 | serverConnWg.Add(1) 56 | go func() { 57 | defer func() { 58 | _ = conn.Close() 59 | serverConnWg.Done() 60 | }() 61 | for atomic.LoadInt32(&exit) != 1 { 62 | _ = conn.SetReadDeadline(time.Now().Add(time.Second)) 63 | n, rerr := conn.Read(buf) 64 | if rerr != nil { 65 | assert.ErrorContains(t, rerr, "timeout") 66 | } else { 67 | _, rerr = conn.Write(buf[:n]) 68 | assert.NoError(t, rerr) 69 | } 70 | } 71 | }() 72 | } 73 | }() 74 | 75 | raddr, _ := listener.Addr().(*net.UDPAddr) 76 | 77 | // test flush by WriteBatchInterval expired 78 | readBuf := make([]byte, 1400) 79 | cli, err := net.DialUDP("udp", nil, raddr) 80 | assert.NoError(t, err) 81 | flushStr := "flushbytimer" 82 | _, err = cli.Write([]byte("flushbytimer")) 83 | assert.NoError(t, err) 84 | n, err := cli.Read(readBuf) 85 | assert.NoError(t, err) 86 | assert.Equal(t, flushStr, string(readBuf[:n])) 87 | 88 | _ = listener.Close() 89 | serverConnWg.Wait() 90 | } 91 | 92 | func TestBatchConn_WriteBatchSize(t *testing.T) { //nolint:cyclop 93 | report := test.CheckRoutines(t) 94 | defer report() 95 | 96 | lc := ListenConfig{ 97 | Batch: BatchIOConfig{ 98 | Enable: true, 99 | ReadBatchSize: 10, 100 | WriteBatchSize: 9, 101 | WriteBatchInterval: time.Minute, 102 | }, 103 | ReadBufferSize: 64 * 1024, 104 | WriteBufferSize: 64 * 1024, 105 | } 106 | 107 | laddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 15678} 108 | listener, err := lc.Listen("udp", laddr) 109 | assert.NoError(t, err) 110 | 111 | var serverConnWg sync.WaitGroup 112 | serverConnWg.Add(1) 113 | go func() { //nolint:dupl 114 | var exit int32 115 | defer func() { 116 | defer serverConnWg.Done() 117 | atomic.StoreInt32(&exit, 1) 118 | }() 119 | for { 120 | buf := make([]byte, 1400) 121 | conn, lerr := listener.Accept() 122 | if errors.Is(lerr, ErrClosedListener) { 123 | break 124 | } 125 | assert.NoError(t, lerr) 126 | serverConnWg.Add(1) 127 | go func() { 128 | defer func() { 129 | _ = conn.Close() 130 | serverConnWg.Done() 131 | }() 132 | for atomic.LoadInt32(&exit) != 1 { 133 | _ = conn.SetReadDeadline(time.Now().Add(time.Second)) 134 | n, rerr := conn.Read(buf) 135 | if rerr != nil { 136 | assert.ErrorContains(t, rerr, "timeout") 137 | } else { 138 | _, rerr = conn.Write(buf[:n]) 139 | assert.NoError(t, rerr) 140 | } 141 | } 142 | }() 143 | } 144 | }() 145 | 146 | raddr, _ := listener.Addr().(*net.UDPAddr) 147 | 148 | // three clients writing three packets each, 149 | // server is batching 9 packets at a time, echoing back packets from client, 150 | // do two batches of writes from each client, two packets first cycle and one packet in second cycle 151 | // - should not be able to read any packets after first write as the server is batching packets 152 | // - should be able to read three packets from each client as server would have flushed nine packets 153 | // to ensure that write batch interval does not kick in, setting the write batch interval to a large value (1m) 154 | cc := 3 155 | var clients [3]*net.UDPConn 156 | wgs := sync.WaitGroup{} 157 | 158 | // first cycle, write two packets from each client, 159 | // should not be able to read any packets 160 | wgs.Add(cc) 161 | for i := 0; i < cc; i++ { 162 | sendStr := fmt.Sprintf("hello %d", i) 163 | 164 | idx := i 165 | go func() { 166 | defer wgs.Done() 167 | 168 | client, err := net.DialUDP("udp", nil, raddr) 169 | assert.NoError(t, err) 170 | clients[idx] = client 171 | 172 | for j := 0; j < 2; j++ { 173 | _, err = client.Write([]byte(sendStr)) 174 | assert.NoError(t, err) 175 | } 176 | 177 | err = client.SetReadDeadline(time.Now().Add(time.Second)) 178 | assert.NoError(t, err) 179 | 180 | buf := make([]byte, 1400) 181 | n, err := client.Read(buf) 182 | assert.Zero(t, n, "unexpected packet from client: %d", idx) 183 | assert.ErrorContains(t, err, "timeout", "expected timeout from client: %d", idx) 184 | }() 185 | } 186 | wgs.Wait() 187 | 188 | // second cycle, write two packets from each client, 189 | // should be able to read three packets on average per client 190 | // (ordering is not guaranteed due to goroutine scheduling, so just check for a total of 9 packets) 191 | wgs.Add(cc * 3) 192 | for i := 0; i < cc; i++ { 193 | sendStr := fmt.Sprintf("hello %d", i) 194 | 195 | idx := i 196 | go func() { 197 | for j := 0; j < 2; j++ { 198 | _, err := clients[idx].Write([]byte(sendStr)) 199 | assert.NoError(t, err) 200 | } 201 | 202 | buf := make([]byte, 1400) 203 | for j := 0; ; j++ { 204 | err := clients[idx].SetReadDeadline(time.Now().Add(time.Second)) 205 | assert.NoError(t, err) 206 | 207 | n, err := clients[idx].Read(buf) 208 | if err == nil { 209 | assert.Equal(t, sendStr, string(buf[:n]), "mismatch in first read, client: %d, packet: %d", idx, j) 210 | wgs.Done() 211 | } else { 212 | break 213 | } 214 | } 215 | }() 216 | } 217 | wgs.Wait() 218 | 219 | // should not be able to read any packets as the next batch is not ready yet 220 | wgs.Add(cc) 221 | for i := 0; i < cc; i++ { 222 | idx := i 223 | go func() { 224 | defer wgs.Done() 225 | 226 | err := clients[idx].SetReadDeadline(time.Now().Add(time.Second)) 227 | assert.NoError(t, err) 228 | 229 | buf := make([]byte, 1400) 230 | n, err := clients[idx].Read(buf) 231 | assert.Zero(t, n, "unexpected packet from client: %d", idx) 232 | assert.ErrorContains(t, err, "timeout", "expected timeout from client: %d", idx) 233 | }() 234 | } 235 | wgs.Wait() 236 | 237 | // third cycle, write two packets from each client, 238 | // should be able to read three packets on average per client 239 | wgs.Add(cc * 3) 240 | for i := 0; i < cc; i++ { 241 | sendStr := fmt.Sprintf("hello %d", i) 242 | 243 | idx := i 244 | go func() { 245 | defer func() { _ = clients[idx].Close() }() 246 | 247 | for j := 0; j < 2; j++ { 248 | _, err := clients[idx].Write([]byte(sendStr)) 249 | assert.NoError(t, err) 250 | } 251 | 252 | buf := make([]byte, 1400) 253 | for j := 0; ; j++ { 254 | err := clients[idx].SetReadDeadline(time.Now().Add(time.Second)) 255 | assert.NoError(t, err) 256 | 257 | n, err := clients[idx].Read(buf) 258 | if err == nil { 259 | assert.Equal(t, sendStr, string(buf[:n]), "mismatch in second read, client: %d, packet: %d", idx, j) 260 | wgs.Done() 261 | } else { 262 | break 263 | } 264 | } 265 | }() 266 | } 267 | wgs.Wait() 268 | 269 | _ = listener.Close() 270 | serverConnWg.Wait() 271 | } 272 | -------------------------------------------------------------------------------- /udp/batchconn.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package udp 5 | 6 | import ( 7 | "io" 8 | "net" 9 | "runtime" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | 14 | "golang.org/x/net/ipv4" 15 | "golang.org/x/net/ipv6" 16 | ) 17 | 18 | // BatchWriter represents conn can write messages in batch. 19 | type BatchWriter interface { 20 | WriteBatch(ms []ipv4.Message, flags int) (int, error) 21 | } 22 | 23 | // BatchReader represents conn can read messages in batch. 24 | type BatchReader interface { 25 | ReadBatch(msg []ipv4.Message, flags int) (int, error) 26 | } 27 | 28 | // BatchPacketConn represents conn can read/write messages in batch. 29 | type BatchPacketConn interface { 30 | BatchWriter 31 | BatchReader 32 | io.Closer 33 | } 34 | 35 | // --------------------------------- 36 | 37 | type messageBatch struct { 38 | size int 39 | batchConn BatchPacketConn 40 | 41 | mu sync.Mutex 42 | messages []ipv4.Message 43 | writePos int 44 | } 45 | 46 | func newMessageBatch(size int, batchConn BatchPacketConn) *messageBatch { 47 | m := &messageBatch{ 48 | size: size, 49 | batchConn: batchConn, 50 | } 51 | m.init() 52 | 53 | return m 54 | } 55 | 56 | func (m *messageBatch) init() { 57 | m.messages = make([]ipv4.Message, m.size) 58 | for i := range m.messages { 59 | m.messages[i].Buffers = [][]byte{make([]byte, sendMTU)} 60 | } 61 | } 62 | 63 | func (m *messageBatch) isFull() bool { 64 | return m.writePos == m.size 65 | } 66 | 67 | func (m *messageBatch) EnqueueMessage(buf []byte, raddr net.Addr) (int, bool) { 68 | m.mu.Lock() 69 | defer m.mu.Unlock() 70 | 71 | if len(buf) == 0 || m.isFull() { 72 | return 0, m.isFull() 73 | } 74 | 75 | msg := &m.messages[m.writePos] 76 | // reset buffers 77 | msg.Buffers = msg.Buffers[:1] 78 | msg.Buffers[0] = msg.Buffers[0][:cap(msg.Buffers[0])] 79 | 80 | if raddr != nil { 81 | msg.Addr = raddr 82 | } 83 | if n := copy(msg.Buffers[0], buf); n < len(buf) { 84 | extraBuffer := make([]byte, len(buf)-n) 85 | copy(extraBuffer, buf[n:]) 86 | msg.Buffers = append(msg.Buffers, extraBuffer) 87 | } else { 88 | msg.Buffers[0] = msg.Buffers[0][:n] 89 | } 90 | m.writePos++ 91 | 92 | return len(buf), m.isFull() 93 | } 94 | 95 | func (m *messageBatch) Flush() { 96 | m.mu.Lock() 97 | defer m.mu.Unlock() 98 | 99 | var txN int 100 | for txN < m.writePos { 101 | n, err := m.batchConn.WriteBatch(m.messages[txN:m.writePos], 0) 102 | if err != nil { 103 | break 104 | } 105 | txN += n 106 | } 107 | 108 | m.writePos = 0 109 | } 110 | 111 | // --------------------------------- 112 | 113 | type pingPong struct { 114 | mu sync.Mutex 115 | batches [2]*messageBatch 116 | writeBatchIdx int 117 | readBatchIdx int 118 | flushPending bool 119 | 120 | writeReady chan struct{} 121 | flushCycleDone chan struct{} 122 | flusherDone chan struct{} 123 | 124 | closed int32 125 | } 126 | 127 | func newPingPong(size int, interval time.Duration, batchConn BatchPacketConn) *pingPong { 128 | p := &pingPong{ 129 | writeReady: make(chan struct{}), 130 | flushCycleDone: make(chan struct{}), 131 | flusherDone: make(chan struct{}), 132 | } 133 | for i := 0; i < len(p.batches); i++ { 134 | p.batches[i] = newMessageBatch(size, batchConn) 135 | } 136 | 137 | go p.flusher(interval) 138 | 139 | return p 140 | } 141 | 142 | func (p *pingPong) Close() { 143 | atomic.StoreInt32(&p.closed, 1) 144 | 145 | select { 146 | case p.writeReady <- struct{}{}: 147 | default: 148 | } 149 | 150 | <-p.flusherDone 151 | } 152 | 153 | func (p *pingPong) EnqueueMessage(buf []byte, raddr net.Addr) int { 154 | p.mu.Lock() 155 | var ( 156 | writeBatch *messageBatch 157 | n int 158 | isFull bool 159 | ) 160 | for { 161 | if writeBatch = p.getWriteBatch(); writeBatch != nil { 162 | n, isFull = writeBatch.EnqueueMessage(buf, raddr) 163 | if n == len(buf) { 164 | break 165 | } 166 | } 167 | 168 | p.mu.Unlock() 169 | select { 170 | case <-p.flushCycleDone: 171 | case <-time.After(100 * time.Microsecond): 172 | } 173 | 174 | if atomic.LoadInt32(&p.closed) == 1 { 175 | return 0 176 | } 177 | 178 | p.mu.Lock() 179 | } 180 | p.mu.Unlock() 181 | 182 | // enqueuing given message fills up the write batch, queue up a flush 183 | if isFull { 184 | select { 185 | case p.writeReady <- struct{}{}: 186 | default: 187 | } 188 | } 189 | 190 | return n 191 | } 192 | 193 | func (p *pingPong) getWriteBatch() *messageBatch { 194 | if p.writeBatchIdx == p.readBatchIdx && p.flushPending { 195 | return nil 196 | } 197 | 198 | return p.batches[p.writeBatchIdx] 199 | } 200 | 201 | func (p *pingPong) updateWriteBatchAndGetReadBatch() *messageBatch { 202 | p.mu.Lock() 203 | defer p.mu.Unlock() 204 | 205 | if p.writeBatchIdx != p.readBatchIdx || p.flushPending { 206 | return nil 207 | } 208 | 209 | p.writeBatchIdx ^= 1 210 | p.flushPending = true 211 | 212 | return p.batches[p.readBatchIdx] 213 | } 214 | 215 | func (p *pingPong) updateReadBatch() { 216 | p.mu.Lock() 217 | defer p.mu.Unlock() 218 | 219 | if !p.flushPending { 220 | return 221 | } 222 | 223 | p.readBatchIdx ^= 1 224 | p.flushPending = false 225 | 226 | select { 227 | case p.flushCycleDone <- struct{}{}: 228 | default: 229 | } 230 | } 231 | 232 | func (p *pingPong) flusher(interval time.Duration) { 233 | defer close(p.flusherDone) 234 | 235 | writeTicker := time.NewTicker(interval / 2) 236 | defer writeTicker.Stop() 237 | 238 | lastFlushAt := time.Now().Add(-interval) 239 | for atomic.LoadInt32(&p.closed) != 1 { 240 | select { 241 | case <-writeTicker.C: 242 | if time.Since(lastFlushAt) < interval { 243 | continue 244 | } 245 | case <-p.writeReady: 246 | } 247 | 248 | readBatch := p.updateWriteBatchAndGetReadBatch() 249 | if readBatch == nil { 250 | continue 251 | } 252 | 253 | readBatch.Flush() 254 | p.updateReadBatch() 255 | 256 | lastFlushAt = time.Now() 257 | } 258 | } 259 | 260 | // ------------------------------ 261 | 262 | // BatchConn uses ipv4/v6.NewPacketConn to wrap a net.PacketConn to write/read messages in batch, 263 | // only available in linux. In other platform, it will use single Write/Read as same as net.Conn. 264 | type BatchConn struct { 265 | net.PacketConn 266 | 267 | batchConn BatchPacketConn 268 | 269 | // ping-pong the batches to be able to accept new packets while a batch is written to socket 270 | batchPingPong *pingPong 271 | } 272 | 273 | // NewBatchConn creates a *BatchConn from net.PacketConn with batch configs. 274 | func NewBatchConn(conn net.PacketConn, batchWriteSize int, batchWriteInterval time.Duration) *BatchConn { 275 | bc := &BatchConn{ 276 | PacketConn: conn, 277 | } 278 | 279 | // batch write only supports linux 280 | if runtime.GOOS == "linux" { 281 | if pc4 := ipv4.NewPacketConn(conn); pc4 != nil { 282 | bc.batchConn = pc4 283 | } else if pc6 := ipv6.NewPacketConn(conn); pc6 != nil { 284 | bc.batchConn = pc6 285 | } 286 | 287 | bc.batchPingPong = newPingPong(batchWriteSize, batchWriteInterval, bc.batchConn) 288 | } 289 | 290 | return bc 291 | } 292 | 293 | // Close batchConn and the underlying PacketConn. 294 | func (c *BatchConn) Close() error { 295 | if c.batchPingPong != nil { 296 | c.batchPingPong.Close() 297 | } 298 | 299 | if c.batchConn != nil { 300 | return c.batchConn.Close() 301 | } 302 | 303 | return c.PacketConn.Close() 304 | } 305 | 306 | // WriteTo write message to an UDPAddr, addr should be nil if it is a connected socket. 307 | func (c *BatchConn) WriteTo(b []byte, addr net.Addr) (int, error) { 308 | if c.batchConn == nil { 309 | return c.PacketConn.WriteTo(b, addr) 310 | } 311 | 312 | return c.batchPingPong.EnqueueMessage(b, addr), nil 313 | } 314 | 315 | // ReadBatch reads messages in batch, return length of message readed and error. 316 | func (c *BatchConn) ReadBatch(msgs []ipv4.Message, flags int) (int, error) { 317 | if c.batchConn == nil { 318 | n, addr, err := c.PacketConn.ReadFrom(msgs[0].Buffers[0]) 319 | if err == nil { 320 | msgs[0].N = n 321 | msgs[0].Addr = addr 322 | 323 | return 1, nil 324 | } 325 | 326 | return 0, err 327 | } 328 | 329 | return c.batchConn.ReadBatch(msgs, flags) 330 | } 331 | -------------------------------------------------------------------------------- /vnet/conn_map_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package vnet 5 | 6 | import ( 7 | "net" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | type myConnObserver struct{} 14 | 15 | func (obs *myConnObserver) write(Chunk) error { 16 | return nil 17 | } 18 | 19 | func (obs *myConnObserver) onClosed(net.Addr) { 20 | } 21 | 22 | func (obs *myConnObserver) determineSourceIP(net.IP, net.IP) net.IP { 23 | return net.IP{} 24 | } 25 | 26 | func TestUDPConnMap(t *testing.T) { 27 | // log := logging.NewDefaultLoggerFactory().NewLogger("test") 28 | 29 | t.Run("insert an UDPConn and remove it", func(t *testing.T) { 30 | connMap := newUDPConnMap() 31 | 32 | obs := &myConnObserver{} 33 | connIn, err := newUDPConn(&net.UDPAddr{ 34 | IP: net.ParseIP("127.0.0.1"), 35 | Port: 1234, 36 | }, nil, obs) 37 | assert.NoError(t, err, "should succeed") 38 | 39 | err = connMap.insert(connIn) 40 | assert.NoError(t, err, "should succeed") 41 | 42 | connOut, ok := connMap.find(connIn.LocalAddr()) 43 | assert.True(t, ok, "should succeed") 44 | assert.Equal(t, connIn, connOut, "should match") 45 | assert.Equal(t, 1, len(connMap.portMap), "should match") 46 | 47 | err = connMap.delete(connIn.LocalAddr()) 48 | assert.NoError(t, err, "should succeed") 49 | assert.Empty(t, connMap.portMap, "should match") 50 | 51 | err = connMap.delete(connIn.LocalAddr()) 52 | assert.Error(t, err, "should fail") 53 | }) 54 | 55 | t.Run("insert an UDPConn on 0.0.0.0 and remove it", func(t *testing.T) { 56 | connMap := newUDPConnMap() 57 | 58 | obs := &myConnObserver{} 59 | connIn, err := newUDPConn(&net.UDPAddr{ 60 | IP: net.ParseIP("0.0.0.0"), 61 | Port: 1234, 62 | }, nil, obs) 63 | assert.NoError(t, err, "should succeed") 64 | 65 | err = connMap.insert(connIn) 66 | assert.NoError(t, err, "should succeed") 67 | 68 | connOut, ok := connMap.find(connIn.LocalAddr()) 69 | assert.True(t, ok, "should succeed") 70 | assert.Equal(t, connIn, connOut, "should match") 71 | assert.Equal(t, 1, len(connMap.portMap), "should match") 72 | 73 | err = connMap.delete(connIn.LocalAddr()) 74 | assert.NoError(t, err, "should succeed") 75 | 76 | err = connMap.delete(connIn.LocalAddr()) 77 | assert.Error(t, err, "should fail") 78 | }) 79 | 80 | t.Run("find UDPConn on 0.0.0.0 by specified IP", func(t *testing.T) { 81 | connMap := newUDPConnMap() 82 | 83 | obs := &myConnObserver{} 84 | connIn, err := newUDPConn(&net.UDPAddr{ 85 | IP: net.ParseIP("0.0.0.0"), 86 | Port: 1234, 87 | }, nil, obs) 88 | assert.NoError(t, err, "should succeed") 89 | 90 | err = connMap.insert(connIn) 91 | assert.NoError(t, err, "should succeed") 92 | 93 | connOut, ok := connMap.find(&net.UDPAddr{ 94 | IP: net.ParseIP("192.168.0.1"), 95 | Port: 1234, 96 | }) 97 | assert.True(t, ok, "should succeed") 98 | assert.Equal(t, connIn, connOut, "should match") 99 | assert.Equal(t, 1, len(connMap.portMap), "should match") 100 | }) 101 | 102 | t.Run("insert many IPs with the same port", func(t *testing.T) { 103 | connMap := newUDPConnMap() 104 | 105 | obs := &myConnObserver{} 106 | connIn1, err := newUDPConn(&net.UDPAddr{ 107 | IP: net.ParseIP("10.1.2.1"), 108 | Port: 5678, 109 | }, nil, obs) 110 | assert.NoError(t, err, "should succeed") 111 | err = connMap.insert(connIn1) 112 | assert.NoError(t, err, "should succeed") 113 | 114 | connIn2, err := newUDPConn(&net.UDPAddr{ 115 | IP: net.ParseIP("10.1.2.2"), 116 | Port: 5678, 117 | }, nil, obs) 118 | assert.NoError(t, err, "should succeed") 119 | err = connMap.insert(connIn2) 120 | assert.NoError(t, err, "should succeed") 121 | 122 | connOut1, ok := connMap.find(&net.UDPAddr{ 123 | IP: net.ParseIP("10.1.2.1"), 124 | Port: 5678, 125 | }) 126 | assert.True(t, ok, "should succeed") 127 | assert.Equal(t, connIn1, connOut1, "should match") 128 | 129 | connOut2, ok := connMap.find(&net.UDPAddr{ 130 | IP: net.ParseIP("10.1.2.2"), 131 | Port: 5678, 132 | }) 133 | assert.True(t, ok, "should succeed") 134 | assert.Equal(t, connIn2, connOut2, "should match") 135 | 136 | assert.Equal(t, 1, len(connMap.portMap), "should match") 137 | }) 138 | 139 | t.Run("already in-use when inserting 0.0.0.0", func(t *testing.T) { 140 | connMap := newUDPConnMap() 141 | 142 | obs := &myConnObserver{} 143 | connIn1, err := newUDPConn(&net.UDPAddr{ 144 | IP: net.ParseIP("10.1.2.1"), 145 | Port: 5678, 146 | }, nil, obs) 147 | assert.NoError(t, err, "should succeed") 148 | err = connMap.insert(connIn1) 149 | assert.NoError(t, err, "should succeed") 150 | 151 | connIn2, err := newUDPConn(&net.UDPAddr{ 152 | IP: net.ParseIP("0.0.0.0"), 153 | Port: 5678, 154 | }, nil, obs) 155 | assert.NoError(t, err, "should succeed") 156 | 157 | err = connMap.insert(connIn2) 158 | assert.Error(t, err, "should fail") 159 | }) 160 | 161 | t.Run("already in-use when inserting a specified IP", func(t *testing.T) { 162 | connMap := newUDPConnMap() 163 | 164 | obs := &myConnObserver{} 165 | connIn1, err := newUDPConn(&net.UDPAddr{ 166 | IP: net.ParseIP("0.0.0.0"), 167 | Port: 5678, 168 | }, nil, obs) 169 | assert.NoError(t, err, "should succeed") 170 | err = connMap.insert(connIn1) 171 | assert.NoError(t, err, "should succeed") 172 | 173 | connIn2, err := newUDPConn(&net.UDPAddr{ 174 | IP: net.ParseIP("192.168.0.1"), 175 | Port: 5678, 176 | }, nil, obs) 177 | assert.NoError(t, err, "should succeed") 178 | 179 | err = connMap.insert(connIn2) 180 | assert.Error(t, err, "should fail") 181 | }) 182 | 183 | t.Run("already in-use when inserting the same specified IP", func(t *testing.T) { 184 | connMap := newUDPConnMap() 185 | 186 | obs := &myConnObserver{} 187 | connIn1, err := newUDPConn(&net.UDPAddr{ 188 | IP: net.ParseIP("192.168.0.1"), 189 | Port: 5678, 190 | }, nil, obs) 191 | assert.NoError(t, err, "should succeed") 192 | err = connMap.insert(connIn1) 193 | assert.NoError(t, err, "should succeed") 194 | 195 | connIn2, err := newUDPConn(&net.UDPAddr{ 196 | IP: net.ParseIP("192.168.0.1"), 197 | Port: 5678, 198 | }, nil, obs) 199 | assert.NoError(t, err, "should succeed") 200 | 201 | err = connMap.insert(connIn2) 202 | assert.Error(t, err, "should fail") 203 | }) 204 | 205 | t.Run("find failure 1", func(t *testing.T) { 206 | connMap := newUDPConnMap() 207 | 208 | obs := &myConnObserver{} 209 | connIn, err := newUDPConn(&net.UDPAddr{ 210 | IP: net.ParseIP("192.168.0.1"), 211 | Port: 5678, 212 | }, nil, obs) 213 | assert.NoError(t, err, "should succeed") 214 | err = connMap.insert(connIn) 215 | assert.NoError(t, err, "should succeed") 216 | 217 | _, ok := connMap.find(&net.UDPAddr{ 218 | IP: net.ParseIP("192.168.0.2"), 219 | Port: 5678, 220 | }) 221 | assert.False(t, ok, "should fail") 222 | }) 223 | 224 | t.Run("find failure 2", func(t *testing.T) { 225 | connMap := newUDPConnMap() 226 | 227 | obs := &myConnObserver{} 228 | connIn, err := newUDPConn(&net.UDPAddr{ 229 | IP: net.ParseIP("192.168.0.1"), 230 | Port: 5678, 231 | }, nil, obs) 232 | assert.NoError(t, err, "should succeed") 233 | err = connMap.insert(connIn) 234 | assert.NoError(t, err, "should succeed") 235 | 236 | _, ok := connMap.find(&net.UDPAddr{ 237 | IP: net.ParseIP("192.168.0.1"), 238 | Port: 1234, 239 | }) 240 | assert.False(t, ok, "should fail") 241 | }) 242 | 243 | t.Run("insert two UDPConns on the same port, then remove them", func(t *testing.T) { 244 | connMap := newUDPConnMap() 245 | 246 | obs := &myConnObserver{} 247 | 248 | connIn1, err := newUDPConn(&net.UDPAddr{ 249 | IP: net.ParseIP("192.168.0.1"), 250 | Port: 5678, 251 | }, nil, obs) 252 | assert.NoError(t, err, "should succeed") 253 | err = connMap.insert(connIn1) 254 | assert.NoError(t, err, "should succeed") 255 | 256 | connIn2, err := newUDPConn(&net.UDPAddr{ 257 | IP: net.ParseIP("192.168.0.2"), 258 | Port: 5678, 259 | }, nil, obs) 260 | assert.NoError(t, err, "should succeed") 261 | err = connMap.insert(connIn2) 262 | assert.NoError(t, err, "should succeed") 263 | 264 | err = connMap.delete(connIn1.LocalAddr()) 265 | assert.NoError(t, err, "should succeed") 266 | 267 | err = connMap.delete(connIn2.LocalAddr()) 268 | assert.NoError(t, err, "should succeed") 269 | }) 270 | } 271 | -------------------------------------------------------------------------------- /netctx/packetconn_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-FileCopyrightText: 2023 The Pion community 2 | // SPDX-License-Identifier: MIT 3 | 4 | package netctx 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "io" 10 | "net" 11 | "testing" 12 | "time" 13 | 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | var _ net.PacketConn = wrapConn{} 18 | 19 | type wrapConn struct { 20 | c net.Conn 21 | } 22 | 23 | func (w wrapConn) ReadFrom(p []byte) (int, net.Addr, error) { 24 | n, err := w.c.Read(p) 25 | 26 | return n, nil, err 27 | } 28 | 29 | func (w wrapConn) WriteTo(p []byte, _ net.Addr) (n int, err error) { 30 | return w.c.Write(p) 31 | } 32 | 33 | func (w wrapConn) Close() error { 34 | return w.c.Close() 35 | } 36 | 37 | func (w wrapConn) LocalAddr() net.Addr { 38 | return w.c.LocalAddr() 39 | } 40 | 41 | func (w wrapConn) RemoteAddr() net.Addr { 42 | return w.c.RemoteAddr() 43 | } 44 | 45 | func (w wrapConn) SetDeadline(t time.Time) error { 46 | return w.c.SetDeadline(t) 47 | } 48 | 49 | func (w wrapConn) SetReadDeadline(t time.Time) error { 50 | return w.c.SetReadDeadline(t) 51 | } 52 | 53 | func (w wrapConn) SetWriteDeadline(t time.Time) error { 54 | return w.c.SetWriteDeadline(t) 55 | } 56 | 57 | func pipe() (net.PacketConn, net.PacketConn) { 58 | a, b := net.Pipe() 59 | 60 | return wrapConn{a}, wrapConn{b} 61 | } 62 | 63 | func TestReadFrom(t *testing.T) { 64 | ca, cb := pipe() 65 | defer func() { 66 | _ = ca.Close() 67 | }() 68 | 69 | data := []byte{0x01, 0x02, 0xFF} 70 | chErr := make(chan error) 71 | 72 | go func() { 73 | _, err := cb.WriteTo(data, nil) 74 | chErr <- err 75 | }() 76 | 77 | c := NewPacketConn(ca) 78 | b := make([]byte, 100) 79 | n, _, err := c.ReadFromContext(context.Background(), b) 80 | assert.NoError(t, err) 81 | assert.Len(t, data, n, "Wrong data length") 82 | assert.Equal(t, data, b[:n]) 83 | assert.NoError(t, <-chErr) 84 | } 85 | 86 | func TestReadFromTimeout(t *testing.T) { 87 | ca, _ := pipe() 88 | defer func() { 89 | _ = ca.Close() 90 | }() 91 | 92 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 93 | defer cancel() 94 | 95 | c := NewPacketConn(ca) 96 | b := make([]byte, 100) 97 | n, _, err := c.ReadFromContext(ctx, b) 98 | assert.Error(t, err) 99 | assert.Empty(t, n, "Wrong data length") 100 | } 101 | 102 | func TestReadFromCancel(t *testing.T) { 103 | ca, _ := pipe() 104 | defer func() { 105 | _ = ca.Close() 106 | }() 107 | 108 | ctx, cancel := context.WithCancel(context.Background()) 109 | go func() { 110 | time.Sleep(10 * time.Millisecond) 111 | cancel() 112 | }() 113 | 114 | c := NewPacketConn(ca) 115 | b := make([]byte, 100) 116 | n, _, err := c.ReadFromContext(ctx, b) 117 | assert.Error(t, err) 118 | assert.Empty(t, n, "Wrong data length") 119 | } 120 | 121 | func TestReadFromClosed(t *testing.T) { 122 | ca, _ := pipe() 123 | 124 | c := NewPacketConn(ca) 125 | _ = c.Close() 126 | 127 | b := make([]byte, 100) 128 | n, _, err := c.ReadFromContext(context.Background(), b) 129 | assert.ErrorIs(t, err, net.ErrClosed) 130 | assert.Empty(t, n, "Wrong data length") 131 | } 132 | 133 | func TestWriteTo(t *testing.T) { 134 | ca, cb := pipe() 135 | defer func() { 136 | _ = ca.Close() 137 | }() 138 | 139 | chErr := make(chan error) 140 | chRead := make(chan []byte) 141 | 142 | go func() { 143 | b := make([]byte, 100) 144 | n, _, err := cb.ReadFrom(b) 145 | chErr <- err 146 | chRead <- b[:n] 147 | }() 148 | 149 | c := NewPacketConn(ca) 150 | data := []byte{0x01, 0x02, 0xFF} 151 | n, err := c.WriteToContext(context.Background(), data, nil) 152 | assert.NoError(t, err) 153 | assert.Len(t, data, n, "Wrong data length") 154 | 155 | err = <-chErr 156 | b := <-chRead 157 | assert.NoError(t, err) 158 | assert.Equal(t, data, b) 159 | } 160 | 161 | func TestWriteToTimeout(t *testing.T) { 162 | ca, _ := pipe() 163 | defer func() { 164 | _ = ca.Close() 165 | }() 166 | 167 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 168 | defer cancel() 169 | 170 | c := NewPacketConn(ca) 171 | b := make([]byte, 100) 172 | n, err := c.WriteToContext(ctx, b, nil) 173 | assert.Error(t, err) 174 | assert.Empty(t, n, "Wrong data length") 175 | } 176 | 177 | func TestWriteToCancel(t *testing.T) { 178 | ca, _ := pipe() 179 | defer func() { 180 | _ = ca.Close() 181 | }() 182 | 183 | ctx, cancel := context.WithCancel(context.Background()) 184 | go func() { 185 | time.Sleep(10 * time.Millisecond) 186 | cancel() 187 | }() 188 | 189 | c := NewPacketConn(ca) 190 | b := make([]byte, 100) 191 | n, err := c.WriteToContext(ctx, b, nil) 192 | assert.Error(t, err) 193 | assert.Empty(t, n, "Wrong data length") 194 | } 195 | 196 | func TestWriteToClosed(t *testing.T) { 197 | ca, _ := pipe() 198 | 199 | c := NewPacketConn(ca) 200 | _ = c.Close() 201 | 202 | b := make([]byte, 100) 203 | n, err := c.WriteToContext(context.Background(), b, nil) 204 | assert.ErrorIs(t, err, ErrClosing) 205 | assert.Empty(t, n, "Wrong data length") 206 | } 207 | 208 | type packetConnAddrMock struct{} 209 | 210 | func (*packetConnAddrMock) LocalAddr() net.Addr { return stringAddr{"local_net", "local_addr"} } 211 | func (*packetConnAddrMock) ReadFrom([]byte) (int, net.Addr, error) { panic("unimplemented") } //nolint:forbidigo 212 | func (*packetConnAddrMock) WriteTo([]byte, net.Addr) (int, error) { panic("unimplemented") } //nolint:forbidigo 213 | func (*packetConnAddrMock) Close() error { panic("unimplemented") } //nolint:forbidigo 214 | func (*packetConnAddrMock) SetDeadline(_ time.Time) error { panic("unimplemented") } //nolint:forbidigo 215 | func (*packetConnAddrMock) SetReadDeadline(_ time.Time) error { panic("unimplemented") } //nolint:forbidigo 216 | func (*packetConnAddrMock) SetWriteDeadline(_ time.Time) error { panic("unimplemented") } //nolint:forbidigo 217 | 218 | func TestPacketConnLocalAddrAndRemoteAddr(t *testing.T) { 219 | c := NewPacketConn(&packetConnAddrMock{}) 220 | al := c.LocalAddr() 221 | 222 | assert.Equal(t, "local_addr", al.String()) 223 | } 224 | 225 | func BenchmarkPacketConnBase(b *testing.B) { 226 | ca, cb := pipe() 227 | defer func() { 228 | _ = ca.Close() 229 | }() 230 | 231 | data := make([]byte, 4096) 232 | for i := range data { 233 | data[i] = byte(i) 234 | } 235 | buf := make([]byte, len(data)) 236 | 237 | b.SetBytes(int64(len(data))) 238 | b.ResetTimer() 239 | 240 | go func(n int) { 241 | for i := 0; i < n; i++ { 242 | _, _ = cb.WriteTo(data, nil) 243 | } 244 | _ = cb.Close() 245 | }(b.N) 246 | 247 | count := 0 248 | for { 249 | n, _, err := ca.ReadFrom(buf) 250 | if err != nil { 251 | if !errors.Is(err, io.EOF) { 252 | b.Fatal(err) 253 | } 254 | 255 | break 256 | } 257 | if n != len(data) { 258 | b.Errorf("Expected %v, got %v", len(data), n) 259 | } 260 | count++ 261 | } 262 | if count != b.N { 263 | b.Errorf("Expected %v, got %v", b.N, count) 264 | } 265 | } 266 | 267 | func BenchmarkWriteTo(b *testing.B) { 268 | ca, cb := pipe() 269 | defer func() { 270 | _ = ca.Close() 271 | }() 272 | 273 | data := make([]byte, 4096) 274 | for i := range data { 275 | data[i] = byte(i) 276 | } 277 | buf := make([]byte, len(data)) 278 | 279 | b.SetBytes(int64(len(data))) 280 | b.ResetTimer() 281 | 282 | go func(n int) { 283 | c := NewPacketConn(cb) 284 | for i := 0; i < n; i++ { 285 | _, _ = c.WriteToContext(context.Background(), data, nil) 286 | } 287 | _ = cb.Close() 288 | }(b.N) 289 | 290 | count := 0 291 | for { 292 | n, _, err := ca.ReadFrom(buf) 293 | if err != nil { 294 | if !errors.Is(err, io.EOF) { 295 | b.Fatal(err) 296 | } 297 | 298 | break 299 | } 300 | if n != len(data) { 301 | b.Errorf("Expected %v, got %v", len(data), n) 302 | } 303 | count++ 304 | } 305 | if count != b.N { 306 | b.Errorf("Expected %v, got %v", b.N, count) 307 | } 308 | } 309 | 310 | func BenchmarkReadFrom(b *testing.B) { 311 | ca, cb := pipe() 312 | defer func() { 313 | _ = ca.Close() 314 | }() 315 | 316 | data := make([]byte, 4096) 317 | for i := range data { 318 | data[i] = byte(i) 319 | } 320 | buf := make([]byte, len(data)) 321 | 322 | b.SetBytes(int64(len(data))) 323 | b.ResetTimer() 324 | 325 | go func(n int) { 326 | for i := 0; i < n; i++ { 327 | _, _ = cb.WriteTo(data, nil) 328 | } 329 | _ = cb.Close() 330 | }(b.N) 331 | 332 | c := NewPacketConn(ca) 333 | count := 0 334 | for { 335 | n, _, err := c.ReadFromContext(context.Background(), buf) 336 | if err != nil { 337 | if !errors.Is(err, io.EOF) { 338 | b.Fatal(err) 339 | } 340 | 341 | break 342 | } 343 | if n != len(data) { 344 | b.Errorf("Expected %v, got %v", len(data), n) 345 | } 346 | count++ 347 | } 348 | if count != b.N { 349 | b.Errorf("Expected %v, got %v", b.N, count) 350 | } 351 | } 352 | --------------------------------------------------------------------------------