├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md ├── dependabot.yml └── workflows │ ├── container-image.yaml │ └── go-build.yaml ├── .gitignore ├── LICENSE ├── README.md ├── app ├── app.go └── wg.go ├── cmd ├── warp-plus │ ├── main.go │ ├── rootcmd.go │ └── versioncmd.go └── warp-scan │ └── main.go ├── example_config.json ├── go.mod ├── go.sum ├── ipscanner ├── engine │ ├── engine.go │ └── queue.go ├── iterator │ └── iterator.go ├── ping │ ├── ping.go │ └── warp.go ├── scanner.go └── statute │ ├── ping.go │ ├── queue.go │ └── statute.go ├── iputils └── iputils.go ├── proxy ├── README.md ├── example │ ├── customHandler │ │ └── main.go │ ├── minimal │ │ └── main.go │ └── udpClient │ │ └── main.go └── pkg │ ├── http │ ├── common.go │ └── server.go │ ├── mixed │ ├── handlers.go │ └── proxy.go │ ├── socks4 │ ├── common.go │ └── server.go │ ├── socks5 │ ├── common.go │ └── server.go │ └── statute │ ├── statute.go │ └── tunnel.go ├── psiphon └── p.go ├── termux.sh ├── warp ├── account.go ├── api.go ├── endpoint.go ├── key.go └── tls.go ├── wireguard ├── LICENSE.md ├── conn │ ├── bind_std.go │ ├── bind_std_test.go │ ├── bind_windows.go │ ├── bindtest │ │ └── bindtest.go │ ├── boundif_android.go │ ├── conn.go │ ├── conn_test.go │ ├── controlfns.go │ ├── controlfns_linux.go │ ├── controlfns_unix.go │ ├── controlfns_windows.go │ ├── default.go │ ├── errors_default.go │ ├── errors_linux.go │ ├── features_default.go │ ├── features_linux.go │ ├── gso_default.go │ ├── gso_linux.go │ ├── mark_default.go │ ├── mark_unix.go │ ├── sticky_default.go │ ├── sticky_linux.go │ ├── sticky_linux_test.go │ └── winrio │ │ └── rio_windows.go ├── device │ ├── allowedips.go │ ├── allowedips_rand_test.go │ ├── allowedips_test.go │ ├── bind_test.go │ ├── channels.go │ ├── constants.go │ ├── cookie.go │ ├── cookie_test.go │ ├── device.go │ ├── device_test.go │ ├── devicestate_string.go │ ├── endpoint_test.go │ ├── indextable.go │ ├── ip.go │ ├── kdf_test.go │ ├── keypair.go │ ├── logger.go │ ├── mobilequirks.go │ ├── noise-helpers.go │ ├── noise-protocol.go │ ├── noise-types.go │ ├── noise_test.go │ ├── peer.go │ ├── pools.go │ ├── pools_test.go │ ├── queueconstants_android.go │ ├── queueconstants_default.go │ ├── queueconstants_ios.go │ ├── queueconstants_windows.go │ ├── race_disabled_test.go │ ├── race_enabled_test.go │ ├── receive.go │ ├── send.go │ ├── sticky_default.go │ ├── sticky_linux.go │ ├── timers.go │ ├── tun.go │ └── uapi.go ├── ipc │ ├── namedpipe │ │ ├── file.go │ │ ├── namedpipe.go │ │ └── namedpipe_test.go │ ├── uapi_bsd.go │ ├── uapi_linux.go │ ├── uapi_unix.go │ ├── uapi_wasm.go │ └── uapi_windows.go ├── ratelimiter │ ├── ratelimiter.go │ └── ratelimiter_test.go ├── replay │ ├── replay.go │ └── replay_test.go ├── rwcancel │ ├── rwcancel.go │ └── rwcancel_stub.go ├── tai64n │ ├── tai64n.go │ └── tai64n_test.go └── tun │ ├── alignment_windows_test.go │ ├── checksum.go │ ├── checksum_test.go │ ├── errors.go │ ├── netstack │ ├── examples │ │ ├── http_client.go │ │ ├── http_server.go │ │ └── ping_client.go │ └── tun.go │ ├── offload_linux.go │ ├── offload_linux_test.go │ ├── operateonfd.go │ ├── tun.go │ ├── tun_darwin.go │ ├── tun_freebsd.go │ ├── tun_linux.go │ ├── tun_openbsd.go │ └── tuntest │ └── tuntest.go └── wiresocks ├── config.go ├── config_test.go ├── proxy.go ├── scanner.go └── udpfw.go /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Version number** 11 | The version of the application that has the bug. 12 | 13 | **Describe the bug** 14 | A clear and concise description of what the bug is. 15 | 16 | **To Reproduce** 17 | Steps to reproduce the behavior: 18 | 1. Go to '...' 19 | 2. Click on '....' 20 | 3. Scroll down to '....' 21 | 4. See error 22 | 23 | **Expected behavior** 24 | A clear and concise description of what you expected to happen. 25 | 26 | **Screenshots** 27 | If applicable, add screenshots to help explain your problem. 28 | 29 | **Desktop (please complete the following information):** 30 | - OS: [e.g. iOS] 31 | 32 | **Smartphone (please complete the following information):** 33 | - Device: [e.g. iPhone6] 34 | - OS: [e.g. iOS8.1] 35 | 36 | **Additional context** 37 | Add any other context about the problem here. 38 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /.github/workflows/container-image.yaml: -------------------------------------------------------------------------------- 1 | name: Container Image 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - main 8 | tags: 9 | - v** 10 | release: 11 | types: [published] 12 | workflow_dispatch: 13 | 14 | permissions: 15 | contents: read 16 | packages: write 17 | 18 | jobs: 19 | build-publish: 20 | name: Build and publish container image 21 | runs-on: ubuntu-latest 22 | env: 23 | CGO_ENABLED: 0 24 | PLATFORMS: linux/arm/v7,linux/arm64,linux/amd64 25 | steps: 26 | - name: Checkout codebase 27 | uses: actions/checkout@v4 28 | 29 | - name: Set up Go 30 | uses: actions/setup-go@v5 31 | with: 32 | go-version: '1.24' 33 | check-latest: true 34 | 35 | - name: Setup `ko` 36 | # The latest (@v0.6) version of this workflow has bug if there are uppercase letters in repo name 37 | uses: ko-build/setup-ko@main 38 | 39 | - name: Extract metadata 40 | uses: docker/metadata-action@v5 41 | id: meta 42 | with: 43 | # The images doesn't required, as only tags needed for the ko build step. 44 | images: "" 45 | tags: | 46 | type=ref,event=branch 47 | type=semver,pattern={{version}} 48 | type=semver,pattern={{major}}.{{minor}} 49 | type=semver,pattern={{major}} 50 | type=sha 51 | type=sha,format=long 52 | 53 | - name: Build and push image 54 | env: 55 | TAGS: ${{ steps.meta.outputs.tags }} 56 | GOFLAGS: "-ldflags=-checklinkname=0" 57 | run: ko build ./cmd/warp-plus --platform "${PLATFORMS}" --bare --tags $(echo $TAGS | tr ' ' ',') 58 | -------------------------------------------------------------------------------- /.github/workflows/go-build.yaml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: [published] 7 | push: 8 | 9 | jobs: 10 | build: 11 | permissions: 12 | contents: write 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | include: 17 | - goos: darwin 18 | goarch: amd64 19 | - goos: darwin 20 | goarch: arm64 21 | 22 | - goos: android 23 | goarch: arm64 24 | 25 | - goos: linux 26 | goarch: amd64 27 | - goos: linux 28 | goarch: arm64 29 | - goos: linux 30 | goarch: arm 31 | goarm: 7 32 | - goos: linux 33 | goarch: riscv64 34 | - goos: linux 35 | goarch: mips64 36 | - goos: linux 37 | goarch: mips64le 38 | - goos: linux 39 | goarch: mips 40 | - goos: linux 41 | goarch: mipsle 42 | - goos: linux 43 | goarch: mips64 44 | gomips: softfloat 45 | - goos: linux 46 | goarch: mips64le 47 | gomips: softfloat 48 | - goos: linux 49 | goarch: mips 50 | gomips: softfloat 51 | - goos: linux 52 | goarch: mipsle 53 | gomips: softfloat 54 | 55 | - goos: windows 56 | goarch: amd64 57 | - goos: windows 58 | goarch: arm64 59 | - goos: windows 60 | goarch: 386 61 | 62 | runs-on: ubuntu-latest 63 | env: 64 | GOOS: ${{ matrix.goos }} 65 | GOARCH: ${{ matrix.goarch }} 66 | GOARM: ${{ matrix.goarm }} 67 | GOMIPS: ${{ matrix.gomips }} 68 | CGO_ENABLED: 0 69 | steps: 70 | - name: Checkout codebase 71 | uses: actions/checkout@v4 72 | 73 | - name: Show workflow information 74 | run: | 75 | export _NAME=$GOOS-$GOARCH$GOARM$GOMIPS 76 | echo "GOOS: $GOOS, GOARCH: $GOARCH, GOARM: $GOARM, GOMIPS: $GOMIPS, RELEASE_NAME: $_NAME" 77 | echo "ASSET_NAME=$_NAME" >> $GITHUB_ENV 78 | echo "REF=${GITHUB_SHA::6}" >> $GITHUB_ENV 79 | 80 | - name: Set up Go 81 | uses: actions/setup-go@v5 82 | with: 83 | go-version: '1.24' 84 | check-latest: true 85 | 86 | - name: Build warp-plus 87 | run: | 88 | go build -v -o warp-plus_${{ env.ASSET_NAME }}/ -trimpath -ldflags "-s -w -buildid= -checklinkname=0 -X main.version=${{ github.ref }}" ./cmd/warp-plus 89 | go build -v -o warp-scan_${{ env.ASSET_NAME }}/ -trimpath -ldflags "-s -w -buildid= -checklinkname=0 -X main.version=${{ github.ref }}" ./cmd/warp-scan 90 | 91 | - name: Copy README.md & LICENSE 92 | run: | 93 | cp ${GITHUB_WORKSPACE}/README.md ./warp-plus_${{ env.ASSET_NAME }}/README.md 94 | cp ${GITHUB_WORKSPACE}/LICENSE ./warp-plus_${{ env.ASSET_NAME }}/LICENSE 95 | 96 | - name: Create ZIP archive 97 | shell: bash 98 | run: | 99 | pushd ./warp-plus_${{ env.ASSET_NAME }} || exit 1 100 | touch -mt $(date +%Y01010000) * 101 | zip -9vr ../warp-plus_${{ env.ASSET_NAME }}.zip . 102 | popd || exit 1 103 | FILE=./warp-plus_${{ env.ASSET_NAME }}.zip 104 | DGST=$FILE.dgst 105 | for METHOD in {"md5","sha256","sha512"} 106 | do 107 | openssl dgst -$METHOD $FILE | sed 's/([^)]*)//g' >>$DGST 108 | done 109 | 110 | - name: Upload warp-plus files to Artifacts 111 | uses: actions/upload-artifact@v4 112 | with: 113 | name: warp-plus_${{ env.ASSET_NAME }}_${{ env.REF }} 114 | path: | 115 | ./warp-plus_${{ env.ASSET_NAME }}/* 116 | 117 | - name: Upload warp-scan files to Artifacts 118 | uses: actions/upload-artifact@v4 119 | with: 120 | name: warp-scan_${{ env.ASSET_NAME }}_${{ env.REF }} 121 | path: | 122 | ./warp-scan_${{ env.ASSET_NAME }}/* 123 | 124 | - name: Upload binaries to release 125 | uses: svenstaro/upload-release-action@v2 126 | if: github.event_name == 'release' 127 | with: 128 | repo_token: ${{ secrets.GITHUB_TOKEN }} 129 | file: ./warp-plus_${{ env.ASSET_NAME }}.zip* 130 | tag: ${{ github.ref }} 131 | file_glob: true 132 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /warp-plus 2 | /warp-plus.exe 3 | /warp-scan 4 | /warp-scan.exe 5 | .idea 6 | stuff/ 7 | .DS_Store 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy of 2 | this software and associated documentation files (the "Software"), to deal in 3 | the Software without restriction, including without limitation the rights to 4 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 5 | of the Software, and to permit persons to whom the Software is furnished to do 6 | so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all 9 | copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Warp-Plus 2 | 3 | Warp-Plus is an open-source implementation of Cloudflare's Warp, enhanced with Psiphon integration for circumventing censorship. This project aims to provide a robust and cross-platform VPN solution that can use psiphon on top of warp and warp-in-warp for changing the user virtual nat location. 4 | 5 | ## Features 6 | 7 | - **Warp Integration**: Leverages Cloudflare's Warp to provide a fast and secure VPN service. 8 | - **Psiphon Chaining**: Integrates with Psiphon for censorship circumvention, allowing seamless access to the internet in restrictive environments. 9 | - **Warp in Warp Chaining**: Chaning two instances of warp together to bypass location restrictions. 10 | - **SOCKS5 Proxy Support**: Includes a SOCKS5 proxy for secure and private browsing. 11 | 12 | ## Getting Started 13 | 14 | ### Prerequisites 15 | 16 | - [Download the latest version from the releases page](https://github.com/bepass-org/warp-plus/releases) 17 | - Basic understanding of VPN and proxy configurations 18 | 19 | ### Usage 20 | 21 | ``` 22 | NAME 23 | warp-plus 24 | 25 | FLAGS 26 | -4 only use IPv4 for random warp endpoint 27 | -6 only use IPv6 for random warp endpoint 28 | -v, --verbose enable verbose logging 29 | -b, --bind STRING socks bind address (default: 127.0.0.1:8086) 30 | -e, --endpoint STRING warp endpoint 31 | -k, --key STRING warp key 32 | --dns STRING DNS address (default: 1.1.1.1) 33 | --gool enable gool mode (warp in warp) 34 | --cfon enable psiphon mode (must provide country as well) 35 | --country STRING psiphon country code (valid values: [AT AU BE BG CA CH CZ DE DK EE ES FI FR GB HR HU IE IN IT JP LV NL NO PL PT RO RS SE SG SK US]) (default: AT) 36 | --scan enable warp scanning 37 | --rtt DURATION scanner rtt limit (default: 1s) 38 | --cache-dir STRING directory to store generated profiles 39 | --fwmark UINT set linux firewall mark for tun mode (requires sudo/root/CAP_NET_ADMIN) (default: 0) 40 | --reserved STRING override wireguard reserved value (format: '1,2,3') 41 | --wgconf STRING path to a normal wireguard config 42 | --test-url STRING connectivity test url (default: http://connectivity.cloudflareclient.com/cdn-cgi/trace) 43 | -c, --config STRING path to config file 44 | --version displays version number 45 | ``` 46 | 47 | ### Country Codes for Psiphon 48 | 49 | - Austria (AT) 50 | - Australia (AU) 51 | - Belgium (BE) 52 | - Bulgaria (BG) 53 | - Canada (CA) 54 | - Switzerland (CH) 55 | - Czech Republic (CZ) 56 | - Germany (DE) 57 | - Denmark (DK) 58 | - Estonia (EE) 59 | - Spain (ES) 60 | - Finland (FI) 61 | - France (FR) 62 | - United Kingdom (GB) 63 | - Croatia (HR) 64 | - Hungary (HU) 65 | - Ireland (IE) 66 | - India (IN) 67 | - Italy (IT) 68 | - Japan (JP) 69 | - Latvia (LV) 70 | - Netherlands (NL) 71 | - Norway (NO) 72 | - Poland (PL) 73 | - Portugal (PT) 74 | - Romania (RO) 75 | - Serbia (RS) 76 | - Sweden (SE) 77 | - Singapore (SG) 78 | - Slovakia (SK) 79 | - United States (US) 80 | ![0](https://raw.githubusercontent.com/Ptechgithub/configs/main/media/line.gif) 81 | ### Termux 82 | 83 | ``` 84 | bash <(curl -fsSL https://raw.githubusercontent.com/bepass-org/warp-plus/master/termux.sh) 85 | ``` 86 | ![1](https://github.com/Ptechgithub/configs/blob/main/media/18.jpg?raw=true) 87 | 88 | - اگه حس کردی کانکت نمیشه یا خطا میده دستور `rm -rf .cache/warp-plus` رو بزن و مجدد warp رو وارد کن. 89 | - بعد از نصب برای اجرای مجدد فقط کافیه که `warp` یا `usef` یا `./warp` یا `warp-plus`را وارد کنید. همش یکیه هیچ فرقی ندارد. 90 | - اگر با 1 نصب نشد و خطا گرفتید ابتدا یک بار 3 را بزنید تا `Uninstall` شود سپس عدد 2 رو انتخاب کنید یعنی Arm. 91 | - برای نمایش راهنما ` warp -h` را وارد کنید. 92 | - ای پی و پورت `127.0.0.1:8086`پروتکل socks 93 | - در روش تبدیل اکانت warp به warp plus (گزینه 6) مقدار ID را وارد میکنید. پس از اجرای warp دو اکانت برای شما ساخته شده که پس از انتخاب گزینه 6 خودش مقدار ID هر دو اکانت را پیدا میکند و شما باید هر بار یکی را انتخاب کنید و یا میتوانید با انتخاب manual مقدار ID دیگری را وارد کنید (مثلا برای خود برنامه ی 1.1.1.1 یا جای دیگر) با این کار هر 20 ثانیه 1 GB به اکانت شما اضافه میشود. و اکانت شما از حالت رایگان به پلاس تبدیل میشود. 94 | - برای تغییر لوکیشن با استفاده از سایفون از طریق منو یا به صورت دستی (برای مثال به USA از دستور زیر استفاده کنید) 95 | - `warp --cfon --country US` 96 | - برای اسکن ای پی سالم وارپ از دستور `warp --scan` استفاده کنید. 97 | - برای ترکیب (chain) دو کانفیگ برای تغییر لوکیشن از دستور `warp --gool` استفاده کنید. 98 | 99 | ## Acknowledgements 100 | 101 | - Cloudflare Warp 102 | - Psiphon 103 | - All contributors and supporters of this project 104 | -------------------------------------------------------------------------------- /app/wg.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "fmt" 8 | "log/slog" 9 | "net/http" 10 | "strings" 11 | "time" 12 | 13 | "github.com/bepass-org/warp-plus/wireguard/conn" 14 | "github.com/bepass-org/warp-plus/wireguard/device" 15 | wgtun "github.com/bepass-org/warp-plus/wireguard/tun" 16 | "github.com/bepass-org/warp-plus/wireguard/tun/netstack" 17 | "github.com/bepass-org/warp-plus/wiresocks" 18 | ) 19 | 20 | func usermodeTunTest(ctx context.Context, l *slog.Logger, tnet *netstack.Net, url string) error { 21 | ctx, cancel := context.WithDeadline(ctx, time.Now().Add(5*time.Second)) 22 | defer cancel() 23 | 24 | for { 25 | select { 26 | case <-ctx.Done(): 27 | return ctx.Err() 28 | default: 29 | } 30 | 31 | client := http.Client{Transport: &http.Transport{ 32 | DialContext: tnet.DialContext, 33 | ResponseHeaderTimeout: 5 * time.Second, 34 | }} 35 | resp, err := client.Head(url) 36 | if err != nil { 37 | l.Error("connection test failed") 38 | continue 39 | } 40 | if resp.StatusCode != http.StatusOK { 41 | l.Error("connection test failed") 42 | continue 43 | } 44 | 45 | l.Info("connection test successful") 46 | break 47 | } 48 | 49 | return nil 50 | } 51 | 52 | func waitHandshake(ctx context.Context, l *slog.Logger, dev *device.Device) error { 53 | lastHandshakeSecs := "0" 54 | for { 55 | select { 56 | case <-ctx.Done(): 57 | return ctx.Err() 58 | default: 59 | } 60 | 61 | get, err := dev.IpcGet() 62 | if err != nil { 63 | continue 64 | } 65 | scanner := bufio.NewScanner(strings.NewReader(get)) 66 | for scanner.Scan() { 67 | line := scanner.Text() 68 | if line == "" { 69 | break 70 | } 71 | 72 | key, value, ok := strings.Cut(line, "=") 73 | if !ok { 74 | continue 75 | } 76 | 77 | if key == "last_handshake_time_sec" { 78 | lastHandshakeSecs = value 79 | break 80 | } 81 | } 82 | if lastHandshakeSecs != "0" { 83 | l.Debug("handshake complete") 84 | break 85 | } 86 | 87 | l.Debug("waiting on handshake") 88 | time.Sleep(1 * time.Second) 89 | } 90 | 91 | return nil 92 | } 93 | 94 | func establishWireguard(l *slog.Logger, conf *wiresocks.Configuration, tunDev wgtun.Device, fwmark uint32, t string) error { 95 | // create the IPC message to establish the wireguard conn 96 | var request bytes.Buffer 97 | 98 | request.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey)) 99 | if fwmark != 0 { 100 | request.WriteString(fmt.Sprintf("fwmark=%d\n", fwmark)) 101 | } 102 | 103 | for _, peer := range conf.Peers { 104 | request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey)) 105 | request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive)) 106 | request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey)) 107 | request.WriteString(fmt.Sprintf("endpoint=%s\n", peer.Endpoint)) 108 | request.WriteString(fmt.Sprintf("trick=%s\n", t)) 109 | request.WriteString(fmt.Sprintf("reserved=%d,%d,%d\n", peer.Reserved[0], peer.Reserved[1], peer.Reserved[2])) 110 | 111 | for _, cidr := range peer.AllowedIPs { 112 | request.WriteString(fmt.Sprintf("allowed_ip=%s\n", cidr)) 113 | } 114 | } 115 | 116 | dev := device.NewDevice( 117 | tunDev, 118 | conn.NewDefaultBind(), 119 | device.NewSLogger(l.With("subsystem", "wireguard-go")), 120 | ) 121 | 122 | if err := dev.IpcSet(request.String()); err != nil { 123 | return err 124 | } 125 | 126 | if err := dev.Up(); err != nil { 127 | return err 128 | } 129 | 130 | ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(15*time.Second)) 131 | defer cancel() 132 | if err := waitHandshake(ctx, l, dev); err != nil { 133 | dev.BindClose() 134 | dev.Close() 135 | return err 136 | } 137 | 138 | return nil 139 | } 140 | -------------------------------------------------------------------------------- /cmd/warp-plus/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log/slog" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | 12 | "github.com/peterbourgon/ff/v4" 13 | "github.com/peterbourgon/ff/v4/ffhelp" 14 | "github.com/peterbourgon/ff/v4/ffjson" 15 | ) 16 | 17 | const appName = "warp-plus" 18 | 19 | func main() { 20 | args := os.Args[1:] 21 | ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) 22 | rootCmd := newRootCmd() 23 | versionCmd(rootCmd) 24 | err := rootCmd.command.Parse( 25 | args, 26 | ff.WithConfigFileFlag("config"), 27 | ff.WithConfigFileParser(ffjson.Parse), 28 | ) 29 | 30 | switch { 31 | case errors.Is(err, ff.ErrHelp): 32 | fmt.Fprintf(os.Stderr, "%s\n", ffhelp.Command(rootCmd.command)) 33 | os.Exit(0) 34 | case err != nil: 35 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 36 | os.Exit(1) 37 | } 38 | 39 | if err := rootCmd.command.Run(ctx); err != nil { 40 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 41 | os.Exit(1) 42 | } 43 | } 44 | 45 | func fatal(l *slog.Logger, err error) { 46 | l.Error(err.Error()) 47 | os.Exit(1) 48 | } 49 | -------------------------------------------------------------------------------- /cmd/warp-plus/versioncmd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/carlmjohnson/versioninfo" 9 | "github.com/peterbourgon/ff/v4" 10 | ) 11 | 12 | var version string = "" 13 | 14 | func versionCmd(rootConfig *rootConfig) { 15 | command := &ff.Command{ 16 | Name: "version", 17 | ShortHelp: "displays version", 18 | Exec: func(ctx context.Context, args []string) error { 19 | if version == "" { 20 | version = versioninfo.Short() 21 | } 22 | fmt.Fprintf(os.Stderr, "%s\n", version) 23 | return nil 24 | }, 25 | } 26 | rootConfig.command.Subcommands = append(rootConfig.command.Subcommands, command) 27 | } 28 | -------------------------------------------------------------------------------- /cmd/warp-scan/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log/slog" 8 | "os" 9 | "time" 10 | 11 | "github.com/bepass-org/warp-plus/ipscanner" 12 | "github.com/bepass-org/warp-plus/warp" 13 | "github.com/carlmjohnson/versioninfo" 14 | "github.com/fatih/color" 15 | "github.com/peterbourgon/ff/v4" 16 | "github.com/peterbourgon/ff/v4/ffhelp" 17 | "github.com/rodaine/table" 18 | ) 19 | 20 | const appName = "warp-scan" 21 | 22 | var version string = "" 23 | 24 | func main() { 25 | fs := ff.NewFlagSet(appName) 26 | var ( 27 | v4 = fs.BoolShort('4', "only use IPv4 for random warp endpoint") 28 | v6 = fs.BoolShort('6', "only use IPv6 for random warp endpoint") 29 | rtt = fs.DurationLong("rtt", 1000*time.Millisecond, "scanner rtt limit") 30 | verFlag = fs.BoolLong("version", "displays version number") 31 | ) 32 | 33 | err := ff.Parse(fs, os.Args[1:]) 34 | switch { 35 | case errors.Is(err, ff.ErrHelp): 36 | fmt.Fprintf(os.Stderr, "%s\n", ffhelp.Flags(fs)) 37 | os.Exit(0) 38 | case err != nil: 39 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 40 | os.Exit(1) 41 | } 42 | 43 | if *verFlag { 44 | if version == "" { 45 | version = versioninfo.Short() 46 | } 47 | fmt.Fprintf(os.Stderr, "%s\n", version) 48 | os.Exit(0) 49 | } 50 | 51 | // Essentially doing XNOR to make sure that if they are both false 52 | // or both true, just set them both true. 53 | if *v4 == *v6 { 54 | *v4, *v6 = true, true 55 | } 56 | 57 | // new scanner 58 | scanner := ipscanner.NewScanner( 59 | ipscanner.WithLogger(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))), 60 | ipscanner.WithWarpPrivateKey("yGXeX7gMyUIZmK5QIgC7+XX5USUSskQvBYiQ6LdkiXI="), 61 | ipscanner.WithWarpPeerPublicKey("bmXOC+F1FxEMF9dyiK2H5/1SUtzH0JuVo51h2wPfgyo="), 62 | ipscanner.WithUseIPv4(*v4), 63 | ipscanner.WithUseIPv6(*v6), 64 | ipscanner.WithMaxDesirableRTT(*rtt), 65 | ipscanner.WithCidrList(warp.WarpPrefixes()), 66 | ipscanner.WithIPQueueSize(0xffff), 67 | ) 68 | 69 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 70 | defer cancel() 71 | 72 | scanner.Run(ctx) 73 | <-ctx.Done() 74 | 75 | ipList := scanner.GetAvailableIPs() 76 | 77 | headerFmt := color.New(color.FgGreen, color.Underline).SprintfFunc() 78 | columnFmt := color.New(color.FgYellow).SprintfFunc() 79 | 80 | tbl := table.New("Address", "RTT (ping)", "Time") 81 | tbl.WithHeaderFormatter(headerFmt).WithFirstColumnFormatter(columnFmt) 82 | 83 | for _, info := range ipList { 84 | tbl.AddRow(info.AddrPort, info.RTT, info.CreatedAt.Format(time.DateTime)) 85 | } 86 | 87 | tbl.Print() 88 | } 89 | -------------------------------------------------------------------------------- /example_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "verbose": false, 3 | "bind": "127.0.0.1:8086", 4 | "endpoint": "", 5 | "key": "", 6 | "dns": "1.1.1.1", 7 | "gool": false, 8 | "cfon": false, 9 | "country": "DE", 10 | "scan": true, 11 | "rtt": "1000ms", 12 | "cache-dir": "", 13 | "fwmark": "0x1375", 14 | "wgconf": "", 15 | "reserved": "", 16 | "test-url": "", 17 | "4": true, 18 | "6": true 19 | } 20 | -------------------------------------------------------------------------------- /ipscanner/engine/engine.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "log/slog" 7 | "net/netip" 8 | 9 | "github.com/bepass-org/warp-plus/ipscanner/iterator" 10 | "github.com/bepass-org/warp-plus/ipscanner/ping" 11 | "github.com/bepass-org/warp-plus/ipscanner/statute" 12 | ) 13 | 14 | type Engine struct { 15 | generator *iterator.IpGenerator 16 | ipQueue *IPQueue 17 | ping func(context.Context, netip.Addr) (statute.IPInfo, error) 18 | log *slog.Logger 19 | } 20 | 21 | func NewScannerEngine(opts *statute.ScannerOptions) *Engine { 22 | queue := NewIPQueue(opts) 23 | 24 | p := ping.Ping{ 25 | Options: opts, 26 | } 27 | return &Engine{ 28 | ipQueue: queue, 29 | ping: p.DoPing, 30 | generator: iterator.NewIterator(opts), 31 | log: opts.Logger, 32 | } 33 | } 34 | 35 | func (e *Engine) GetAvailableIPs(desc bool) []statute.IPInfo { 36 | if e.ipQueue != nil { 37 | return e.ipQueue.AvailableIPs(desc) 38 | } 39 | return nil 40 | } 41 | 42 | func (e *Engine) Run(ctx context.Context) { 43 | e.ipQueue.Init() 44 | 45 | select { 46 | case <-ctx.Done(): 47 | return 48 | case <-e.ipQueue.available: 49 | e.log.Debug("Started new scanning round") 50 | batch, err := e.generator.NextBatch() 51 | if err != nil { 52 | e.log.Error("Error while generating IP: %v", err) 53 | return 54 | } 55 | for _, ip := range batch { 56 | select { 57 | case <-ctx.Done(): 58 | return 59 | default: 60 | ipInfo, err := e.ping(ctx, ip) 61 | if err != nil { 62 | if !errors.Is(err, context.Canceled) { 63 | e.log.Error("ping error", "addr", ip, "error", err) 64 | } 65 | continue 66 | } 67 | e.log.Debug("ping success", "addr", ipInfo.AddrPort, "rtt", ipInfo.RTT) 68 | e.ipQueue.Enqueue(ipInfo) 69 | } 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /ipscanner/engine/queue.go: -------------------------------------------------------------------------------- 1 | package engine 2 | 3 | import ( 4 | "log/slog" 5 | "sort" 6 | "sync" 7 | "time" 8 | 9 | "github.com/bepass-org/warp-plus/ipscanner/statute" 10 | ) 11 | 12 | type IPQueue struct { 13 | queue []statute.IPInfo 14 | maxQueueSize int 15 | mu sync.Mutex 16 | available chan struct{} 17 | maxTTL time.Duration 18 | rttThreshold time.Duration 19 | inIdealMode bool 20 | log *slog.Logger 21 | reserved statute.IPInfQueue 22 | } 23 | 24 | func NewIPQueue(opts *statute.ScannerOptions) *IPQueue { 25 | var reserved statute.IPInfQueue 26 | return &IPQueue{ 27 | queue: make([]statute.IPInfo, 0), 28 | maxQueueSize: opts.IPQueueSize, 29 | maxTTL: opts.IPQueueTTL, 30 | rttThreshold: opts.MaxDesirableRTT, 31 | available: make(chan struct{}, opts.IPQueueSize), 32 | log: opts.Logger, 33 | reserved: reserved, 34 | } 35 | } 36 | 37 | func (q *IPQueue) Enqueue(info statute.IPInfo) bool { 38 | q.mu.Lock() 39 | defer q.mu.Unlock() 40 | 41 | defer func() { 42 | q.log.Debug("queue change", "len", len(q.queue)) 43 | for _, ipInfo := range q.queue { 44 | q.log.Debug( 45 | "queue change", 46 | "created", ipInfo.CreatedAt, 47 | "addr", ipInfo.AddrPort, 48 | "rtt", ipInfo.RTT, 49 | ) 50 | } 51 | }() 52 | 53 | q.log.Debug("Enqueue: Sorting queue by RTT") 54 | sort.Slice(q.queue, func(i, j int) bool { 55 | return q.queue[i].RTT < q.queue[j].RTT 56 | }) 57 | 58 | if len(q.queue) == 0 { 59 | q.log.Debug("Enqueue: empty queue adding first available item") 60 | q.queue = append(q.queue, info) 61 | return false 62 | } 63 | 64 | if info.RTT <= q.rttThreshold { 65 | q.log.Debug("Enqueue: the new item's RTT is less than at least one of the members.") 66 | if len(q.queue) >= q.maxQueueSize && info.RTT < q.queue[len(q.queue)-1].RTT { 67 | q.log.Debug("Enqueue: the queue is full, remove the item with the highest RTT.") 68 | q.queue = q.queue[:len(q.queue)-1] 69 | } else if len(q.queue) < q.maxQueueSize { 70 | q.log.Debug("Enqueue: Insert the new item in a sorted position.") 71 | index := sort.Search(len(q.queue), func(i int) bool { return q.queue[i].RTT > info.RTT }) 72 | q.queue = append(q.queue[:index], append([]statute.IPInfo{info}, q.queue[index:]...)...) 73 | } else { 74 | q.log.Debug("Enqueue: The Queue is full but we keep the new item in the reserved queue.") 75 | q.reserved.Enqueue(info) 76 | } 77 | } 78 | 79 | q.log.Debug("Enqueue: Checking if any member has a higher RTT than the threshold.") 80 | for _, member := range q.queue { 81 | if member.RTT > q.rttThreshold { 82 | return false // If any member has a higher RTT than the threshold, return false. 83 | } 84 | } 85 | 86 | q.log.Debug("Enqueue: All members have an RTT lower than the threshold.") 87 | if len(q.queue) < q.maxQueueSize { 88 | // the queue isn't full dont wait 89 | return false 90 | } 91 | 92 | q.inIdealMode = true 93 | // ok wait for expiration signal 94 | q.log.Debug("Enqueue: All members have an RTT lower than the threshold. Waiting for expiration signal.") 95 | return true 96 | } 97 | 98 | func (q *IPQueue) Dequeue() (statute.IPInfo, bool) { 99 | defer func() { 100 | q.log.Debug("queue change", "len", len(q.queue)) 101 | for _, ipInfo := range q.queue { 102 | q.log.Debug( 103 | "queue change", 104 | "created", ipInfo.CreatedAt, 105 | "addr", ipInfo.AddrPort, 106 | "rtt", ipInfo.RTT, 107 | ) 108 | } 109 | }() 110 | q.mu.Lock() 111 | defer q.mu.Unlock() 112 | 113 | if len(q.queue) == 0 { 114 | return statute.IPInfo{}, false 115 | } 116 | 117 | info := q.queue[len(q.queue)-1] 118 | q.queue = q.queue[0 : len(q.queue)-1] 119 | 120 | q.available <- struct{}{} 121 | 122 | return info, true 123 | } 124 | 125 | func (q *IPQueue) Init() { 126 | q.mu.Lock() 127 | defer q.mu.Unlock() 128 | 129 | if !q.inIdealMode { 130 | q.available <- struct{}{} 131 | return 132 | } 133 | } 134 | 135 | func (q *IPQueue) Expire() { 136 | q.mu.Lock() 137 | defer q.mu.Unlock() 138 | 139 | q.log.Debug("Expire: In ideal mode") 140 | defer func() { 141 | q.log.Debug("queue change", "len", len(q.queue)) 142 | for _, ipInfo := range q.queue { 143 | q.log.Debug( 144 | "queue change", 145 | "created", ipInfo.CreatedAt, 146 | "addr", ipInfo.AddrPort, 147 | "rtt", ipInfo.RTT, 148 | ) 149 | } 150 | }() 151 | 152 | shouldStartNewScan := false 153 | resQ := make([]statute.IPInfo, 0) 154 | for i := 0; i < len(q.queue); i++ { 155 | if time.Since(q.queue[i].CreatedAt) > q.maxTTL { 156 | q.log.Debug("Expire: Removing expired item from queue") 157 | shouldStartNewScan = true 158 | } else { 159 | resQ = append(resQ, q.queue[i]) 160 | } 161 | } 162 | q.queue = resQ 163 | q.log.Debug("Expire: Adding reserved items to queue") 164 | for i := 0; i < q.maxQueueSize && i < q.reserved.Size(); i++ { 165 | q.queue = append(q.queue, q.reserved.Dequeue()) 166 | } 167 | if shouldStartNewScan { 168 | q.available <- struct{}{} 169 | } 170 | } 171 | 172 | func (q *IPQueue) AvailableIPs(desc bool) []statute.IPInfo { 173 | q.mu.Lock() 174 | defer q.mu.Unlock() 175 | 176 | // Create a separate slice for sorting 177 | sortedQueue := make([]statute.IPInfo, len(q.queue)) 178 | copy(sortedQueue, q.queue) 179 | 180 | // Sort by RTT ascending/descending 181 | sort.Slice(sortedQueue, func(i, j int) bool { 182 | if desc { 183 | return sortedQueue[i].RTT > sortedQueue[j].RTT 184 | } 185 | return sortedQueue[i].RTT < sortedQueue[j].RTT 186 | }) 187 | 188 | return sortedQueue 189 | } 190 | -------------------------------------------------------------------------------- /ipscanner/ping/ping.go: -------------------------------------------------------------------------------- 1 | package ping 2 | 3 | import ( 4 | "context" 5 | "net/netip" 6 | 7 | "github.com/bepass-org/warp-plus/ipscanner/statute" 8 | ) 9 | 10 | type Ping struct { 11 | Options *statute.ScannerOptions 12 | } 13 | 14 | // DoPing performs a ping on the given IP address. 15 | func (p *Ping) DoPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { 16 | res, err := p.calc(ctx, NewWarpPing(ip, p.Options)) 17 | if err != nil { 18 | return statute.IPInfo{}, err 19 | } 20 | 21 | return res, nil 22 | } 23 | 24 | func (p *Ping) calc(ctx context.Context, tp statute.IPing) (statute.IPInfo, error) { 25 | pr := tp.PingContext(ctx) 26 | err := pr.Error() 27 | if err != nil { 28 | return statute.IPInfo{}, err 29 | } 30 | return pr.Result(), nil 31 | } 32 | -------------------------------------------------------------------------------- /ipscanner/scanner.go: -------------------------------------------------------------------------------- 1 | package ipscanner 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "net/netip" 7 | "time" 8 | 9 | "github.com/bepass-org/warp-plus/ipscanner/engine" 10 | "github.com/bepass-org/warp-plus/ipscanner/statute" 11 | ) 12 | 13 | type IPScanner struct { 14 | options statute.ScannerOptions 15 | log *slog.Logger 16 | engine *engine.Engine 17 | } 18 | 19 | func NewScanner(options ...Option) *IPScanner { 20 | p := &IPScanner{ 21 | options: statute.ScannerOptions{ 22 | UseIPv4: true, 23 | UseIPv6: true, 24 | CidrList: statute.DefaultCFRanges(), 25 | Logger: slog.Default(), 26 | WarpPresharedKey: "", 27 | WarpPeerPublicKey: "", 28 | WarpPrivateKey: "", 29 | IPQueueSize: 8, 30 | MaxDesirableRTT: 400 * time.Millisecond, 31 | IPQueueTTL: 30 * time.Second, 32 | }, 33 | log: slog.Default(), 34 | } 35 | 36 | for _, option := range options { 37 | option(p) 38 | } 39 | 40 | return p 41 | } 42 | 43 | type Option func(*IPScanner) 44 | 45 | func WithUseIPv4(useIPv4 bool) Option { 46 | return func(i *IPScanner) { 47 | i.options.UseIPv4 = useIPv4 48 | } 49 | } 50 | 51 | func WithUseIPv6(useIPv6 bool) Option { 52 | return func(i *IPScanner) { 53 | i.options.UseIPv6 = useIPv6 54 | } 55 | } 56 | 57 | func WithLogger(logger *slog.Logger) Option { 58 | return func(i *IPScanner) { 59 | i.log = logger 60 | i.options.Logger = logger 61 | } 62 | } 63 | 64 | func WithCidrList(cidrList []netip.Prefix) Option { 65 | return func(i *IPScanner) { 66 | i.options.CidrList = cidrList 67 | } 68 | } 69 | 70 | func WithIPQueueSize(size int) Option { 71 | return func(i *IPScanner) { 72 | i.options.IPQueueSize = size 73 | } 74 | } 75 | 76 | func WithMaxDesirableRTT(threshold time.Duration) Option { 77 | return func(i *IPScanner) { 78 | i.options.MaxDesirableRTT = threshold 79 | } 80 | } 81 | 82 | func WithIPQueueTTL(ttl time.Duration) Option { 83 | return func(i *IPScanner) { 84 | i.options.IPQueueTTL = ttl 85 | } 86 | } 87 | 88 | func WithWarpPrivateKey(privateKey string) Option { 89 | return func(i *IPScanner) { 90 | i.options.WarpPrivateKey = privateKey 91 | } 92 | } 93 | 94 | func WithWarpPeerPublicKey(peerPublicKey string) Option { 95 | return func(i *IPScanner) { 96 | i.options.WarpPeerPublicKey = peerPublicKey 97 | } 98 | } 99 | 100 | func WithWarpPreSharedKey(presharedKey string) Option { 101 | return func(i *IPScanner) { 102 | i.options.WarpPresharedKey = presharedKey 103 | } 104 | } 105 | 106 | // run engine and in case of new event call onChange callback also if it gets canceled with context 107 | // cancel all operations 108 | 109 | func (i *IPScanner) Run(ctx context.Context) { 110 | if !i.options.UseIPv4 && !i.options.UseIPv6 { 111 | i.log.Error("Fatal: both IPv4 and IPv6 are disabled, nothing to do") 112 | return 113 | } 114 | i.engine = engine.NewScannerEngine(&i.options) 115 | go i.engine.Run(ctx) 116 | } 117 | 118 | func (i *IPScanner) GetAvailableIPs() []statute.IPInfo { 119 | if i.engine != nil { 120 | return i.engine.GetAvailableIPs(false) 121 | } 122 | return nil 123 | } 124 | 125 | type IPInfo = statute.IPInfo 126 | -------------------------------------------------------------------------------- /ipscanner/statute/ping.go: -------------------------------------------------------------------------------- 1 | package statute 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | type IPingResult interface { 9 | Result() IPInfo 10 | Error() error 11 | fmt.Stringer 12 | } 13 | 14 | type IPing interface { 15 | Ping() IPingResult 16 | PingContext(context.Context) IPingResult 17 | } 18 | -------------------------------------------------------------------------------- /ipscanner/statute/queue.go: -------------------------------------------------------------------------------- 1 | package statute 2 | 3 | import ( 4 | "sort" 5 | "time" 6 | ) 7 | 8 | type IPInfQueue struct { 9 | items []IPInfo 10 | } 11 | 12 | // Enqueue adds an item and then sorts the queue. 13 | func (q *IPInfQueue) Enqueue(item IPInfo) { 14 | q.items = append(q.items, item) 15 | sort.Slice(q.items, func(i, j int) bool { 16 | return q.items[i].RTT < q.items[j].RTT 17 | }) 18 | } 19 | 20 | // Dequeue removes and returns the item with the lowest RTT. 21 | func (q *IPInfQueue) Dequeue() IPInfo { 22 | if len(q.items) == 0 { 23 | return IPInfo{} // Returning an empty IPInfo when the queue is empty. 24 | } 25 | item := q.items[0] 26 | q.items = q.items[1:] 27 | item.CreatedAt = time.Now() 28 | return item 29 | } 30 | 31 | // Size returns the number of items in the queue. 32 | func (q *IPInfQueue) Size() int { 33 | return len(q.items) 34 | } 35 | -------------------------------------------------------------------------------- /ipscanner/statute/statute.go: -------------------------------------------------------------------------------- 1 | package statute 2 | 3 | import ( 4 | "log/slog" 5 | "net/netip" 6 | "time" 7 | ) 8 | 9 | type IPInfo struct { 10 | AddrPort netip.AddrPort 11 | RTT time.Duration 12 | CreatedAt time.Time 13 | } 14 | 15 | type ScannerOptions struct { 16 | UseIPv4 bool 17 | UseIPv6 bool 18 | CidrList []netip.Prefix // CIDR ranges to scan 19 | Logger *slog.Logger 20 | WarpPrivateKey string 21 | WarpPeerPublicKey string 22 | WarpPresharedKey string 23 | IPQueueSize int 24 | IPQueueTTL time.Duration 25 | MaxDesirableRTT time.Duration 26 | } 27 | 28 | func DefaultCFRanges() []netip.Prefix { 29 | return []netip.Prefix{ 30 | netip.MustParsePrefix("103.21.244.0/22"), 31 | netip.MustParsePrefix("103.22.200.0/22"), 32 | netip.MustParsePrefix("103.31.4.0/22"), 33 | netip.MustParsePrefix("104.16.0.0/12"), 34 | netip.MustParsePrefix("108.162.192.0/18"), 35 | netip.MustParsePrefix("131.0.72.0/22"), 36 | netip.MustParsePrefix("141.101.64.0/18"), 37 | netip.MustParsePrefix("162.158.0.0/15"), 38 | netip.MustParsePrefix("172.64.0.0/13"), 39 | netip.MustParsePrefix("173.245.48.0/20"), 40 | netip.MustParsePrefix("188.114.96.0/20"), 41 | netip.MustParsePrefix("190.93.240.0/20"), 42 | netip.MustParsePrefix("197.234.240.0/22"), 43 | netip.MustParsePrefix("198.41.128.0/17"), 44 | netip.MustParsePrefix("2400:cb00::/32"), 45 | netip.MustParsePrefix("2405:8100::/32"), 46 | netip.MustParsePrefix("2405:b500::/32"), 47 | netip.MustParsePrefix("2606:4700::/32"), 48 | netip.MustParsePrefix("2803:f800::/32"), 49 | netip.MustParsePrefix("2c0f:f248::/32"), 50 | netip.MustParsePrefix("2a06:98c0::/29"), 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /iputils/iputils.go: -------------------------------------------------------------------------------- 1 | package iputils 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "math/big" 8 | "math/rand" 9 | "net" 10 | "net/netip" 11 | "strconv" 12 | "time" 13 | ) 14 | 15 | // RandomIPFromPrefix returns a random IP from the provided CIDR prefix. 16 | // Supports IPv4 and IPv6. Does not support mapped inputs. 17 | func RandomIPFromPrefix(cidr netip.Prefix) (netip.Addr, error) { 18 | startingAddress := cidr.Masked().Addr() 19 | if startingAddress.Is4In6() { 20 | return netip.Addr{}, errors.New("mapped v4 addresses not supported") 21 | } 22 | 23 | prefixLen := cidr.Bits() 24 | if prefixLen == -1 { 25 | return netip.Addr{}, fmt.Errorf("invalid cidr: %s", cidr) 26 | } 27 | 28 | // Initialise rand number generator 29 | rng := rand.New(rand.NewSource(time.Now().UnixNano())) 30 | 31 | // Find the bit length of the Host portion of the provided CIDR 32 | // prefix 33 | hostLen := big.NewInt(int64(startingAddress.BitLen() - prefixLen)) 34 | 35 | // Find the max value for our random number 36 | max := new(big.Int).Exp(big.NewInt(2), hostLen, nil) 37 | 38 | // Generate the random number 39 | randInt := new(big.Int).Rand(rng, max) 40 | 41 | // Get the first address in the CIDR prefix in 16-bytes form 42 | startingAddress16 := startingAddress.As16() 43 | 44 | // Convert the first address into a decimal number 45 | startingAddressInt := new(big.Int).SetBytes(startingAddress16[:]) 46 | 47 | // Add the random number to the decimal form of the starting address 48 | // to get a random address in the desired range 49 | randomAddressInt := new(big.Int).Add(startingAddressInt, randInt) 50 | 51 | // Convert the random address from decimal form back into netip.Addr 52 | randomAddress, ok := netip.AddrFromSlice(randomAddressInt.FillBytes(make([]byte, 16))) 53 | if !ok { 54 | return netip.Addr{}, fmt.Errorf("failed to generate random IP from CIDR: %s", cidr) 55 | } 56 | 57 | // Unmap any mapped v4 addresses before return 58 | return randomAddress.Unmap(), nil 59 | } 60 | 61 | func ParseResolveAddressPort(hostname string, includev6 bool, dnsServer string) (netip.AddrPort, error) { 62 | // Attempt to split the hostname into a host and port 63 | host, port, err := net.SplitHostPort(hostname) 64 | if err != nil { 65 | return netip.AddrPort{}, fmt.Errorf("can't parse provided hostname into host and port: %w", err) 66 | } 67 | 68 | // Convert the string port to a uint16 69 | portInt, err := strconv.Atoi(port) 70 | if err != nil { 71 | return netip.AddrPort{}, fmt.Errorf("error parsing port: %w", err) 72 | } 73 | 74 | if portInt < 1 || portInt > 65535 { 75 | return netip.AddrPort{}, fmt.Errorf("port number %d is out of range", portInt) 76 | } 77 | 78 | // Attempt to parse the host into an IP. Return on success. 79 | addr, err := netip.ParseAddr(host) 80 | if err == nil { 81 | return netip.AddrPortFrom(addr.Unmap(), uint16(portInt)), nil 82 | } 83 | 84 | // Use Go's built-in DNS resolver 85 | resolver := &net.Resolver{ 86 | PreferGo: true, 87 | Dial: func(ctx context.Context, network, address string) (net.Conn, error) { 88 | return net.Dial("udp", net.JoinHostPort(dnsServer, "53")) 89 | }, 90 | } 91 | 92 | // If the host wasn't an IP, perform a lookup 93 | ips, err := resolver.LookupIP(context.Background(), "ip", host) 94 | if err != nil { 95 | return netip.AddrPort{}, fmt.Errorf("hostname lookup failed: %w", err) 96 | } 97 | 98 | for _, ip := range ips { 99 | // Take the first IP and then return it 100 | addr, ok := netip.AddrFromSlice(ip) 101 | if !ok { 102 | continue 103 | } 104 | 105 | if addr.Unmap().Is4() { 106 | return netip.AddrPortFrom(addr.Unmap(), uint16(portInt)), nil 107 | } else if includev6 { 108 | return netip.AddrPortFrom(addr.Unmap(), uint16(portInt)), nil 109 | } 110 | } 111 | 112 | return netip.AddrPort{}, errors.New("no valid IP addresses found") 113 | } 114 | -------------------------------------------------------------------------------- /proxy/README.md: -------------------------------------------------------------------------------- 1 | # Table of Contents 2 | - [Introduction](#introduction) 3 | - [Features](#features) 4 | - [Installation](#installation) 5 | - [Examples](#examples) 6 | - [Minimal](#minimal) 7 | - [Customized](#customized) 8 | 9 | 10 | ## Introduction 11 | The proxy module simplifies connection handling and offers a generic way to work with both HTTP and SOCKS connections, 12 | making it a powerful tool for managing network traffic. 13 | 14 | 15 | ## Features 16 | The Inbound Proxy project offers the following features: 17 | 18 | - Full support for `HTTP`, `SOCKS5`, `SOCKS5h`, `SOCKS4` and `SOCKS4a` protocols. 19 | - Handling of `HTTP` and `HTTPS-connect` proxy requests. 20 | - Full support for both `IPv4` and `IPv6`. 21 | - Able to handle both `TCP` and `UDP` traffic. 22 | 23 | ## Installation 24 | 25 | ```bash 26 | go get github.com/bepass-org/proxy 27 | ``` 28 | 29 | ### Examples 30 | 31 | #### Minimal 32 | 33 | ```go 34 | package main 35 | 36 | import ( 37 | "github.com/bepass-org/proxy/pkg/mixed" 38 | ) 39 | 40 | func main() { 41 | proxy := mixed.NewProxy() 42 | _ = proxy.ListenAndServe() 43 | } 44 | ``` 45 | 46 | #### Customized 47 | 48 | ```go 49 | package main 50 | 51 | import ( 52 | "github.com/bepass-org/proxy/pkg/mixed" 53 | ) 54 | 55 | func main() { 56 | proxy := mixed.NewProxy( 57 | mixed.WithBindAddress("0.0.0.0:8080"), 58 | ) 59 | _ = proxy.ListenAndServe() 60 | } 61 | 62 | ``` 63 | 64 | There are other examples provided in the [example](https://github.com/bepass-org/proxy/tree/main/example) directory 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /proxy/example/customHandler/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "net" 8 | 9 | "github.com/bepass-org/warp-plus/proxy/pkg/mixed" 10 | "github.com/bepass-org/warp-plus/proxy/pkg/statute" 11 | ) 12 | 13 | func main() { 14 | proxy := mixed.NewProxy( 15 | mixed.WithBindAddress("127.0.0.1:1080"), 16 | mixed.WithUserHandler(generalHandler), 17 | ) 18 | _ = proxy.ListenAndServe() 19 | } 20 | 21 | func generalHandler(req *statute.ProxyRequest) error { 22 | fmt.Println("handling request to", req.Destination) 23 | conn, err := net.Dial(req.Network, req.Destination) 24 | if err != nil { 25 | return err 26 | } 27 | go func() { 28 | _, err := io.Copy(conn, req.Conn) 29 | if err != nil { 30 | log.Println(err) 31 | } 32 | }() 33 | _, err = io.Copy(req.Conn, conn) 34 | return err 35 | } 36 | -------------------------------------------------------------------------------- /proxy/example/minimal/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/bepass-org/warp-plus/proxy/pkg/mixed" 5 | ) 6 | 7 | func main() { 8 | proxy := mixed.NewProxy() 9 | _ = proxy.ListenAndServe() 10 | } 11 | -------------------------------------------------------------------------------- /proxy/example/udpClient/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "io" 7 | "net" 8 | "strconv" 9 | ) 10 | 11 | func main() { 12 | proxyAddr := "127.0.0.1:1080" 13 | targetAddr := ":4444" 14 | 15 | // Connect to SOCKS5 proxy 16 | conn, err := net.Dial("tcp", proxyAddr) 17 | if err != nil { 18 | panic(err) 19 | } 20 | defer conn.Close() 21 | 22 | // Send greeting to SOCKS5 proxy 23 | conn.Write([]byte{0x05, 0x01, 0x00}) 24 | 25 | // Read greeting response 26 | response := make([]byte, 2) 27 | io.ReadFull(conn, response) 28 | 29 | // Send UDP ASSOCIATE request 30 | conn.Write([]byte{0x05, 0x03, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) 31 | 32 | // Read UDP ASSOCIATE response 33 | response = make([]byte, 10) 34 | io.ReadFull(conn, response) 35 | 36 | // Extract the bind address and port 37 | bindIP := net.IP(response[4:8]) 38 | bindPort := binary.BigEndian.Uint16(response[8:10]) 39 | 40 | // Print the bind address 41 | fmt.Printf("Bind address: %s:%d\n", bindIP, bindPort) 42 | 43 | // Create UDP connection 44 | udpConn, err := net.Dial("udp", fmt.Sprintf("%s:%d", bindIP, bindPort)) 45 | if err != nil { 46 | panic(err) 47 | } 48 | defer udpConn.Close() 49 | 50 | // Extract target IP and port 51 | dstIP, dstPortStr, _ := net.SplitHostPort(targetAddr) 52 | dstPort, _ := strconv.Atoi(dstPortStr) 53 | 54 | // Construct the UDP packet with the target address and message 55 | packet := make([]byte, 0) 56 | packet = append(packet, 0x00, 0x00, 0x00) // RSV and FRAG 57 | packet = append(packet, 0x01) // ATYP for IPv4 58 | packet = append(packet, net.ParseIP(dstIP).To4()...) 59 | packet = append(packet, byte(dstPort>>8), byte(dstPort&0xFF)) 60 | packet = append(packet, []byte("Hello, UDP through SOCKS5!")...) 61 | 62 | // Send the UDP packet 63 | udpConn.Write(packet) 64 | 65 | // Read the response 66 | buffer := make([]byte, 1024) 67 | n, _ := udpConn.Read(buffer) 68 | fmt.Println("Received:", string(buffer[10:n])) 69 | } 70 | -------------------------------------------------------------------------------- /proxy/pkg/http/common.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | "sync" 9 | ) 10 | 11 | // copyBuffer is a helper function to copy data between two net.Conn objects. 12 | // func copyBuffer(dst, src net.Conn, buf []byte) (int64, error) { 13 | // return io.CopyBuffer(dst, src, buf) 14 | // } 15 | 16 | type responseWriter struct { 17 | conn net.Conn 18 | headers http.Header 19 | status int 20 | written bool 21 | } 22 | 23 | func NewHTTPResponseWriter(conn net.Conn) http.ResponseWriter { 24 | return &responseWriter{ 25 | conn: conn, 26 | headers: http.Header{}, 27 | status: http.StatusOK, 28 | } 29 | } 30 | 31 | func (rw *responseWriter) Header() http.Header { 32 | return rw.headers 33 | } 34 | 35 | func (rw *responseWriter) WriteHeader(statusCode int) { 36 | if rw.written { 37 | return 38 | } 39 | rw.status = statusCode 40 | rw.written = true 41 | 42 | statusText := http.StatusText(statusCode) 43 | if statusText == "" { 44 | statusText = fmt.Sprintf("status code %d", statusCode) 45 | } 46 | _, _ = fmt.Fprintf(rw.conn, "HTTP/1.1 %d %s\r\n", statusCode, statusText) 47 | _ = rw.headers.Write(rw.conn) 48 | _, _ = rw.conn.Write([]byte("\r\n")) 49 | } 50 | 51 | func (rw *responseWriter) Write(data []byte) (int, error) { 52 | if !rw.written { 53 | rw.WriteHeader(http.StatusOK) 54 | } 55 | return rw.conn.Write(data) 56 | } 57 | 58 | type customConn struct { 59 | net.Conn 60 | req *http.Request 61 | initialData []byte 62 | once sync.Once 63 | } 64 | 65 | func (c *customConn) Read(p []byte) (n int, err error) { 66 | c.once.Do(func() { 67 | buf := &bytes.Buffer{} 68 | err = c.req.Write(buf) 69 | if err != nil { 70 | n = 0 71 | return 72 | } 73 | c.initialData = buf.Bytes() 74 | }) 75 | 76 | if len(c.initialData) > 0 { 77 | copy(p, c.initialData) 78 | n = len(p) 79 | c.initialData = nil 80 | return 81 | } 82 | 83 | return c.Conn.Read(p) 84 | } 85 | -------------------------------------------------------------------------------- /proxy/pkg/http/server.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "io" 7 | "log/slog" 8 | "net" 9 | "net/http" 10 | "strconv" 11 | 12 | "github.com/bepass-org/warp-plus/proxy/pkg/statute" 13 | ) 14 | 15 | type Server struct { 16 | // bind is the address to listen on 17 | Bind string 18 | 19 | Listener net.Listener 20 | 21 | // ProxyDial specifies the optional proxyDial function for 22 | // establishing the transport connection. 23 | ProxyDial statute.ProxyDialFunc 24 | // UserConnectHandle gives the user control to handle the TCP CONNECT requests 25 | UserConnectHandle statute.UserConnectHandler 26 | // Logger error log 27 | Logger *slog.Logger 28 | // Context is default context 29 | Context context.Context 30 | // BytesPool getting and returning temporary bytes for use by io.CopyBuffer 31 | BytesPool statute.BytesPool 32 | } 33 | 34 | func NewServer(options ...ServerOption) *Server { 35 | s := &Server{ 36 | Bind: statute.DefaultBindAddress, 37 | ProxyDial: statute.DefaultProxyDial(), 38 | Logger: slog.Default(), 39 | Context: statute.DefaultContext(), 40 | } 41 | 42 | for _, option := range options { 43 | option(s) 44 | } 45 | 46 | return s 47 | } 48 | 49 | type ServerOption func(*Server) 50 | 51 | func (s *Server) ListenAndServe() error { 52 | // Create a new listener 53 | if s.Listener == nil { 54 | ln, err := net.Listen("tcp", s.Bind) 55 | if err != nil { 56 | return err // Return error if binding was unsuccessful 57 | } 58 | s.Listener = ln 59 | } 60 | 61 | s.Bind = s.Listener.Addr().(*net.TCPAddr).String() 62 | 63 | // ensure listener will be closed 64 | defer func() { 65 | _ = s.Listener.Close() 66 | }() 67 | 68 | // Create a cancelable context based on s.Context 69 | ctx, cancel := context.WithCancel(s.Context) 70 | defer cancel() // Ensure resources are cleaned up 71 | 72 | // Start to accept connections and serve them 73 | for { 74 | select { 75 | case <-ctx.Done(): 76 | return ctx.Err() 77 | default: 78 | conn, err := s.Listener.Accept() 79 | if err != nil { 80 | s.Logger.Error(err.Error()) 81 | continue 82 | } 83 | 84 | // Start a new goroutine to handle each connection 85 | // This way, the server can handle multiple connections concurrently 86 | go func() { 87 | err := s.ServeConn(conn) 88 | if err != nil { 89 | s.Logger.Error(err.Error()) // Log errors from ServeConn 90 | } 91 | }() 92 | } 93 | } 94 | } 95 | 96 | func WithLogger(logger *slog.Logger) ServerOption { 97 | return func(s *Server) { 98 | s.Logger = logger 99 | } 100 | } 101 | 102 | func WithBind(bindAddress string) ServerOption { 103 | return func(s *Server) { 104 | s.Bind = bindAddress 105 | } 106 | } 107 | 108 | func WithConnectHandle(handler statute.UserConnectHandler) ServerOption { 109 | return func(s *Server) { 110 | s.UserConnectHandle = handler 111 | } 112 | } 113 | 114 | func WithProxyDial(proxyDial statute.ProxyDialFunc) ServerOption { 115 | return func(s *Server) { 116 | s.ProxyDial = proxyDial 117 | } 118 | } 119 | 120 | func WithContext(ctx context.Context) ServerOption { 121 | return func(s *Server) { 122 | s.Context = ctx 123 | } 124 | } 125 | 126 | func WithBytesPool(bytesPool statute.BytesPool) ServerOption { 127 | return func(s *Server) { 128 | s.BytesPool = bytesPool 129 | } 130 | } 131 | 132 | func (s *Server) ServeConn(conn net.Conn) error { 133 | reader := bufio.NewReader(conn) 134 | req, err := http.ReadRequest(reader) 135 | if err != nil { 136 | return err 137 | } 138 | 139 | return s.handleHTTP(conn, req, req.Method == http.MethodConnect) 140 | } 141 | 142 | func (s *Server) handleHTTP(conn net.Conn, req *http.Request, isConnectMethod bool) error { 143 | if s.UserConnectHandle == nil { 144 | return s.embedHandleHTTP(conn, req, isConnectMethod) 145 | } 146 | 147 | if isConnectMethod { 148 | _, err := conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) 149 | if err != nil { 150 | return err 151 | } 152 | } else { 153 | cConn := &customConn{ 154 | Conn: conn, 155 | req: req, 156 | } 157 | conn = cConn 158 | } 159 | 160 | targetAddr := req.URL.Host 161 | host, portStr, err := net.SplitHostPort(targetAddr) 162 | if err != nil { 163 | host = targetAddr 164 | if req.URL.Scheme == "https" || isConnectMethod { 165 | portStr = "443" 166 | } else { 167 | portStr = "80" 168 | } 169 | targetAddr = net.JoinHostPort(host, portStr) 170 | } 171 | 172 | portInt, err := strconv.Atoi(portStr) 173 | if err != nil { 174 | return err // Handle the error if the port string is not a valid integer. 175 | } 176 | port := int32(portInt) 177 | 178 | proxyReq := &statute.ProxyRequest{ 179 | Conn: conn, 180 | Reader: io.Reader(conn), 181 | Writer: io.Writer(conn), 182 | Network: "tcp", 183 | Destination: targetAddr, 184 | DestHost: host, 185 | DestPort: port, 186 | } 187 | 188 | return s.UserConnectHandle(proxyReq) 189 | } 190 | 191 | func (s *Server) embedHandleHTTP(conn net.Conn, req *http.Request, isConnectMethod bool) error { 192 | defer func() { 193 | _ = conn.Close() 194 | }() 195 | 196 | host, portStr, err := net.SplitHostPort(req.URL.Host) 197 | if err != nil { 198 | host = req.URL.Host 199 | if req.URL.Scheme == "https" || isConnectMethod { 200 | portStr = "443" 201 | } else { 202 | portStr = "80" 203 | } 204 | } 205 | targetAddr := net.JoinHostPort(host, portStr) 206 | 207 | target, err := s.ProxyDial(s.Context, "tcp", targetAddr) 208 | if err != nil { 209 | http.Error( 210 | NewHTTPResponseWriter(conn), 211 | err.Error(), 212 | http.StatusServiceUnavailable, 213 | ) 214 | return err 215 | } 216 | defer func() { 217 | _ = target.Close() 218 | }() 219 | 220 | if isConnectMethod { 221 | _, err = conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) 222 | if err != nil { 223 | return err 224 | } 225 | } else { 226 | err = req.Write(target) 227 | if err != nil { 228 | return err 229 | } 230 | } 231 | 232 | var buf1, buf2 []byte 233 | if s.BytesPool != nil { 234 | buf1 = s.BytesPool.Get() 235 | buf2 = s.BytesPool.Get() 236 | defer func() { 237 | s.BytesPool.Put(buf1) 238 | s.BytesPool.Put(buf2) 239 | }() 240 | } else { 241 | buf1 = make([]byte, 32*1024) 242 | buf2 = make([]byte, 32*1024) 243 | } 244 | return statute.Tunnel(s.Context, target, conn, buf1, buf2) 245 | } 246 | -------------------------------------------------------------------------------- /proxy/pkg/mixed/handlers.go: -------------------------------------------------------------------------------- 1 | package mixed 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "net" 7 | 8 | "github.com/bepass-org/warp-plus/proxy/pkg/statute" 9 | ) 10 | 11 | func WithBindAddress(binAddress string) Option { 12 | return func(p *Proxy) { 13 | p.bind = binAddress 14 | p.socks5Proxy.Bind = binAddress 15 | p.socks4Proxy.Bind = binAddress 16 | p.httpProxy.Bind = binAddress 17 | } 18 | } 19 | 20 | func WithListener(ln net.Listener) Option { 21 | return func(p *Proxy) { 22 | p.listener = ln 23 | p.socks5Proxy.Listener = ln 24 | p.socks4Proxy.Listener = ln 25 | p.httpProxy.Listener = ln 26 | } 27 | } 28 | 29 | func WithLogger(logger *slog.Logger) Option { 30 | return func(p *Proxy) { 31 | p.logger = logger 32 | p.socks5Proxy.Logger = logger 33 | p.socks4Proxy.Logger = logger 34 | p.httpProxy.Logger = logger 35 | } 36 | } 37 | 38 | func WithUserHandler(handler userHandler) Option { 39 | return func(p *Proxy) { 40 | p.userHandler = handler 41 | p.socks5Proxy.UserConnectHandle = statute.UserConnectHandler(handler) 42 | p.socks5Proxy.UserAssociateHandle = statute.UserAssociateHandler(handler) 43 | p.socks4Proxy.UserConnectHandle = statute.UserConnectHandler(handler) 44 | p.httpProxy.UserConnectHandle = statute.UserConnectHandler(handler) 45 | } 46 | } 47 | 48 | func WithUserTCPHandler(handler userHandler) Option { 49 | return func(p *Proxy) { 50 | p.userTCPHandler = handler 51 | p.socks5Proxy.UserConnectHandle = statute.UserConnectHandler(handler) 52 | p.socks4Proxy.UserConnectHandle = statute.UserConnectHandler(handler) 53 | p.httpProxy.UserConnectHandle = statute.UserConnectHandler(handler) 54 | } 55 | } 56 | 57 | func WithUserUDPHandler(handler userHandler) Option { 58 | return func(p *Proxy) { 59 | p.userUDPHandler = handler 60 | p.socks5Proxy.UserAssociateHandle = statute.UserAssociateHandler(handler) 61 | } 62 | } 63 | 64 | func WithUserDialFunc(proxyDial statute.ProxyDialFunc) Option { 65 | return func(p *Proxy) { 66 | p.userDialFunc = proxyDial 67 | p.socks5Proxy.ProxyDial = proxyDial 68 | p.socks4Proxy.ProxyDial = proxyDial 69 | p.httpProxy.ProxyDial = proxyDial 70 | } 71 | } 72 | 73 | func WithUserListenPacketFunc(proxyListenPacket statute.ProxyListenPacket) Option { 74 | return func(p *Proxy) { 75 | p.socks5Proxy.ProxyListenPacket = proxyListenPacket 76 | } 77 | } 78 | 79 | func WithUserForwardAddressFunc(packetForwardAddress statute.PacketForwardAddress) Option { 80 | return func(p *Proxy) { 81 | p.socks5Proxy.PacketForwardAddress = packetForwardAddress 82 | } 83 | } 84 | 85 | func WithContext(ctx context.Context) Option { 86 | return func(p *Proxy) { 87 | p.ctx = ctx 88 | p.socks5Proxy.Context = ctx 89 | p.socks4Proxy.Context = ctx 90 | p.httpProxy.Context = ctx 91 | } 92 | } 93 | 94 | func WithBytesPool(bytesPool statute.BytesPool) Option { 95 | return func(p *Proxy) { 96 | p.socks5Proxy.BytesPool = bytesPool 97 | p.socks4Proxy.BytesPool = bytesPool 98 | p.httpProxy.BytesPool = bytesPool 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /proxy/pkg/mixed/proxy.go: -------------------------------------------------------------------------------- 1 | package mixed 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "log/slog" 7 | "net" 8 | 9 | "github.com/bepass-org/warp-plus/proxy/pkg/http" 10 | "github.com/bepass-org/warp-plus/proxy/pkg/socks4" 11 | "github.com/bepass-org/warp-plus/proxy/pkg/socks5" 12 | "github.com/bepass-org/warp-plus/proxy/pkg/statute" 13 | ) 14 | 15 | type userHandler func(request *statute.ProxyRequest) error 16 | 17 | type Proxy struct { 18 | // bind is the address to listen on 19 | bind string 20 | 21 | listener net.Listener 22 | 23 | // socks5Proxy is a socks5 server with tcp and udp support 24 | socks5Proxy *socks5.Server 25 | // socks4Proxy is a socks4 server with tcp support 26 | socks4Proxy *socks4.Server 27 | // httpProxy is a http proxy server with http and http-connect support 28 | httpProxy *http.Server 29 | // userConnectHandle is a user handler for tcp and udp requests(its general handler) 30 | userHandler userHandler 31 | // if user doesnt set userHandler, it can specify userTCPHandler for manual handling of tcp requests 32 | userTCPHandler userHandler 33 | // if user doesnt set userHandler, it can specify userUDPHandler for manual handling of udp requests 34 | userUDPHandler userHandler 35 | // overwrite dial functions of http, socks4, socks5 36 | userDialFunc statute.ProxyDialFunc 37 | // logger error log 38 | logger *slog.Logger 39 | // ctx is default context 40 | ctx context.Context 41 | } 42 | 43 | func NewProxy(options ...Option) *Proxy { 44 | p := &Proxy{ 45 | bind: statute.DefaultBindAddress, 46 | socks5Proxy: socks5.NewServer(), 47 | socks4Proxy: socks4.NewServer(), 48 | httpProxy: http.NewServer(), 49 | userDialFunc: statute.DefaultProxyDial(), 50 | logger: slog.Default(), 51 | ctx: statute.DefaultContext(), 52 | } 53 | 54 | for _, option := range options { 55 | option(p) 56 | } 57 | 58 | return p 59 | } 60 | 61 | type Option func(*Proxy) 62 | 63 | // SwitchConn wraps a net.Conn and a bufio.Reader 64 | type SwitchConn struct { 65 | net.Conn 66 | *bufio.Reader 67 | } 68 | 69 | // NewSwitchConn creates a new SwitchConn 70 | func NewSwitchConn(conn net.Conn) *SwitchConn { 71 | return &SwitchConn{ 72 | Conn: conn, 73 | Reader: bufio.NewReaderSize(conn, 2048), 74 | } 75 | } 76 | 77 | // Read reads data into p, first from the bufio.Reader, then from the net.Conn 78 | func (c *SwitchConn) Read(p []byte) (n int, err error) { 79 | return c.Reader.Read(p) 80 | } 81 | 82 | func (p *Proxy) ListenAndServe() error { 83 | // Create a new listener 84 | if p.listener == nil { 85 | ln, err := net.Listen("tcp", p.bind) 86 | if err != nil { 87 | return err // Return error if binding was unsuccessful 88 | } 89 | p.listener = ln 90 | } 91 | 92 | p.bind = p.listener.Addr().(*net.TCPAddr).String() 93 | 94 | // ensure listener will be closed 95 | defer func() { 96 | _ = p.listener.Close() 97 | }() 98 | 99 | // Create a cancelable context based on p.Context 100 | ctx, cancel := context.WithCancel(p.ctx) 101 | defer cancel() // Ensure resources are cleaned up 102 | 103 | // Start to accept connections and serve them 104 | for { 105 | select { 106 | case <-ctx.Done(): 107 | return ctx.Err() 108 | default: 109 | conn, err := p.listener.Accept() 110 | if err != nil { 111 | p.logger.Error(err.Error()) 112 | continue 113 | } 114 | 115 | // Start a new goroutine to handle each connection 116 | // This way, the server can handle multiple connections concurrently 117 | go func() { 118 | defer conn.Close() 119 | err := p.handleConnection(conn) 120 | if err != nil { 121 | p.logger.Error(err.Error()) // Log errors from ServeConn 122 | } 123 | }() 124 | } 125 | } 126 | } 127 | 128 | func (p *Proxy) handleConnection(conn net.Conn) error { 129 | // Create a SwitchConn 130 | switchConn := NewSwitchConn(conn) 131 | 132 | // Peek one byte to determine the protocol 133 | buf, err := switchConn.Peek(1) 134 | if err != nil { 135 | return err 136 | } 137 | 138 | switch buf[0] { 139 | case 5: 140 | err = p.socks5Proxy.ServeConn(switchConn) 141 | case 4: 142 | err = p.socks4Proxy.ServeConn(switchConn) 143 | default: 144 | err = p.httpProxy.ServeConn(switchConn) 145 | } 146 | 147 | return err 148 | } 149 | -------------------------------------------------------------------------------- /proxy/pkg/socks4/common.go: -------------------------------------------------------------------------------- 1 | package socks4 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "io" 7 | "net" 8 | "strconv" 9 | ) 10 | 11 | var ( 12 | isSocks4a = []byte{0, 0, 0, 1} 13 | isNone = []byte{0, 0, 0, 0} 14 | ) 15 | 16 | const ( 17 | socks4Version = 0x04 18 | ) 19 | 20 | const ( 21 | ConnectCommand Command = 0x01 22 | ) 23 | 24 | // Command is a SOCKS Command. 25 | type Command byte 26 | 27 | func (cmd Command) String() string { 28 | switch cmd { 29 | case ConnectCommand: 30 | return "socks connect" 31 | default: 32 | return "socks " + strconv.Itoa(int(cmd)) 33 | } 34 | } 35 | 36 | const ( 37 | grantedReply reply = 0x5a 38 | rejectedReply reply = 0x5b 39 | noIdentdReply reply = 0x5c 40 | invalidUserReply reply = 0x5d 41 | ) 42 | 43 | // reply is a SOCKS Command reply code. 44 | type reply byte 45 | 46 | func (code reply) String() string { 47 | switch code { 48 | case grantedReply: 49 | return "request granted" 50 | case rejectedReply: 51 | return "request rejected or failed" 52 | case noIdentdReply: 53 | return "request rejected becasue SOCKS server cannot connect to identd on the client" 54 | case invalidUserReply: 55 | return "request rejected because the client program and identd report different user-ids" 56 | default: 57 | return "unknown code: " + strconv.Itoa(int(code)) 58 | } 59 | } 60 | 61 | // address is a SOCKS-specific address. 62 | // Either Name or IP is used exclusively. 63 | type address struct { 64 | Name string // fully-qualified domain name 65 | IP net.IP 66 | Port int 67 | } 68 | 69 | func (a *address) Network() string { return "socks4" } 70 | 71 | func (a *address) String() string { 72 | if a == nil { 73 | return "" 74 | } 75 | return a.Address() 76 | } 77 | 78 | // Address returns a string suitable to dial; prefer returning IP-based 79 | // address, fallback to Name 80 | func (a address) Address() string { 81 | port := strconv.Itoa(a.Port) 82 | if a.Name != "" { 83 | return net.JoinHostPort(a.Name, port) 84 | } 85 | return net.JoinHostPort(a.IP.String(), port) 86 | } 87 | 88 | type AddrAnfUser struct { 89 | address 90 | Username string 91 | } 92 | 93 | func readBytes(r io.Reader) ([]byte, error) { 94 | buf := []byte{} 95 | var data [1]byte 96 | for { 97 | _, err := r.Read(data[:]) 98 | if err != nil { 99 | return nil, err 100 | } 101 | if data[0] == 0 { 102 | return buf, nil 103 | } 104 | buf = append(buf, data[0]) 105 | } 106 | } 107 | 108 | func readByte(r io.Reader) (byte, error) { 109 | var buf [1]byte 110 | _, err := r.Read(buf[:]) 111 | if err != nil { 112 | return 0, err 113 | } 114 | return buf[0], nil 115 | } 116 | 117 | func readAddrAndUser(r io.Reader) (*AddrAnfUser, error) { 118 | address := &AddrAnfUser{} 119 | var port [2]byte 120 | if _, err := io.ReadFull(r, port[:]); err != nil { 121 | return nil, err 122 | } 123 | address.Port = int(binary.BigEndian.Uint16(port[:])) 124 | ip := make(net.IP, net.IPv4len) 125 | if _, err := io.ReadFull(r, ip); err != nil { 126 | return nil, err 127 | } 128 | socks4a := bytes.Equal(ip, isSocks4a) 129 | 130 | username, err := readBytes(r) 131 | if err != nil { 132 | return nil, err 133 | } 134 | address.Username = string(username) 135 | if socks4a { 136 | hostname, err := readBytes(r) 137 | if err != nil { 138 | return nil, err 139 | } 140 | address.Name = string(hostname) 141 | } else { 142 | address.IP = ip 143 | } 144 | return address, nil 145 | } 146 | 147 | func writeAddr(w io.Writer, addr *address) error { 148 | var ip net.IP 149 | var port uint16 150 | if addr != nil { 151 | ip = addr.IP.To4() 152 | port = uint16(addr.Port) 153 | } 154 | var p [2]byte 155 | binary.BigEndian.PutUint16(p[:], port) 156 | _, err := w.Write(p[:]) 157 | if err != nil { 158 | return err 159 | } 160 | 161 | if ip == nil { 162 | _, err = w.Write(isNone) 163 | } else { 164 | _, err = w.Write(ip) 165 | } 166 | return err 167 | } 168 | -------------------------------------------------------------------------------- /proxy/pkg/statute/statute.go: -------------------------------------------------------------------------------- 1 | package statute 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net" 8 | ) 9 | 10 | type Logger interface { 11 | Debug(v ...interface{}) 12 | Error(v ...interface{}) 13 | } 14 | 15 | type DefaultLogger struct{} 16 | 17 | func (l DefaultLogger) Debug(v ...interface{}) { 18 | fmt.Println(v...) 19 | } 20 | 21 | func (l DefaultLogger) Error(v ...interface{}) { 22 | fmt.Println(v...) 23 | } 24 | 25 | type ProxyRequest struct { 26 | Conn net.Conn 27 | Reader io.Reader 28 | Writer io.Writer 29 | Network string 30 | Destination string 31 | DestHost string 32 | DestPort int32 33 | } 34 | 35 | // UserConnectHandler is used for socks5, socks4 and http 36 | type UserConnectHandler func(request *ProxyRequest) error 37 | 38 | // UserAssociateHandler is used for socks5 39 | type UserAssociateHandler func(request *ProxyRequest) error 40 | 41 | // ProxyDialFunc is used for socks5, socks4 and http 42 | type ProxyDialFunc func(ctx context.Context, network string, address string) (net.Conn, error) 43 | 44 | // DefaultProxyDial for ProxyDialFunc type 45 | func DefaultProxyDial() ProxyDialFunc { 46 | var dialer net.Dialer 47 | return dialer.DialContext 48 | } 49 | 50 | // ProxyListenPacket specifies the optional proxyListenPacket function for 51 | // establishing the transport connection. 52 | type ProxyListenPacket func(ctx context.Context, network string, address string) (net.PacketConn, error) 53 | 54 | // DefaultProxyListenPacket for ProxyListenPacket type 55 | func DefaultProxyListenPacket() ProxyListenPacket { 56 | var listener net.ListenConfig 57 | return listener.ListenPacket 58 | } 59 | 60 | // PacketForwardAddress specifies the packet forwarding address 61 | type PacketForwardAddress func(ctx context.Context, destinationAddr string, 62 | packet net.PacketConn, conn net.Conn) (net.IP, int, error) 63 | 64 | // BytesPool is an interface for getting and returning temporary 65 | // bytes for use by io.CopyBuffer. 66 | type BytesPool interface { 67 | Get() []byte 68 | Put([]byte) 69 | } 70 | 71 | // DefaultContext for context.Context type 72 | func DefaultContext() context.Context { 73 | return context.Background() 74 | } 75 | 76 | const DefaultBindAddress = "127.0.0.1:1080" 77 | -------------------------------------------------------------------------------- /proxy/pkg/statute/tunnel.go: -------------------------------------------------------------------------------- 1 | package statute 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net" 7 | "os" 8 | "reflect" 9 | "runtime" 10 | "strings" 11 | ) 12 | 13 | // isClosedConnError reports whether err is an error from use of a closed 14 | // network connection. 15 | func isClosedConnError(err error) bool { 16 | if err == nil { 17 | return false 18 | } 19 | 20 | str := err.Error() 21 | if strings.Contains(str, "use of closed network connection") { 22 | return true 23 | } 24 | 25 | if runtime.GOOS == "windows" { 26 | if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { 27 | if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { 28 | const WSAECONNABORTED = 10053 29 | const WSAECONNRESET = 10054 30 | if n := errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED { 31 | return true 32 | } 33 | } 34 | } 35 | } 36 | return false 37 | } 38 | 39 | func errno(v error) uintptr { 40 | if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr { 41 | return uintptr(rv.Uint()) 42 | } 43 | return 0 44 | } 45 | 46 | // Tunnel create tunnels for two io.ReadWriteCloser 47 | func Tunnel(ctx context.Context, c1, c2 io.ReadWriteCloser, buf1, buf2 []byte) error { 48 | ctx, cancel := context.WithCancel(ctx) 49 | var errs tunnelErr 50 | go func() { 51 | _, errs[0] = io.CopyBuffer(c1, c2, buf1) 52 | cancel() 53 | }() 54 | go func() { 55 | _, errs[1] = io.CopyBuffer(c2, c1, buf2) 56 | cancel() 57 | }() 58 | <-ctx.Done() 59 | errs[2] = c1.Close() 60 | errs[3] = c2.Close() 61 | errs[4] = ctx.Err() 62 | if errs[4] == context.Canceled { 63 | errs[4] = nil 64 | } 65 | return errs.FirstError() 66 | } 67 | 68 | type tunnelErr [5]error 69 | 70 | func (t tunnelErr) FirstError() error { 71 | for _, err := range t { 72 | if err != nil { 73 | if isClosedConnError(err) { 74 | return nil 75 | } 76 | return err 77 | } 78 | } 79 | return nil 80 | } 81 | -------------------------------------------------------------------------------- /psiphon/p.go: -------------------------------------------------------------------------------- 1 | package psiphon 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log/slog" 10 | "net/netip" 11 | "path/filepath" 12 | 13 | "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon" 14 | ) 15 | 16 | var Countries = []string{ 17 | "AT", 18 | "AU", 19 | "BE", 20 | "BG", 21 | "CA", 22 | "CH", 23 | "CZ", 24 | "DE", 25 | "DK", 26 | "EE", 27 | "ES", 28 | "FI", 29 | "FR", 30 | "GB", 31 | "HR", 32 | "HU", 33 | "IE", 34 | "IN", 35 | "IT", 36 | "JP", 37 | "LV", 38 | "NL", 39 | "NO", 40 | "PL", 41 | "PT", 42 | "RO", 43 | "RS", 44 | "SE", 45 | "SG", 46 | "SK", 47 | "US", 48 | } 49 | 50 | // NoticeEvent represents the notices emitted by tunnel core. It will be passed to 51 | // noticeReceiver, if supplied. 52 | // NOTE: Ordinary users of this library should never need this. 53 | type NoticeEvent struct { 54 | Data map[string]interface{} `json:"data"` 55 | Type string `json:"noticeType"` 56 | Timestamp string `json:"timestamp"` 57 | } 58 | 59 | func StartTunnel(ctx context.Context, l *slog.Logger, config *psiphon.Config) error { 60 | controllerCtx, cancel := context.WithCancel(ctx) 61 | // config.Commit must be called before calling config.SetParameters 62 | // or attempting to connect. 63 | if err := config.Commit(true); err != nil { 64 | return errors.New("config.Commit failed") 65 | } 66 | 67 | // Will receive a value when the tunnel has successfully connected. 68 | connected := make(chan struct{}) 69 | // Will receive a value if an error occurs during the connection sequence. 70 | errored := make(chan error) 71 | 72 | // Set up notice handling 73 | psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver( 74 | func(notice []byte) { 75 | var event NoticeEvent 76 | if err := json.Unmarshal(notice, &event); err != nil { 77 | return 78 | } 79 | 80 | go func(event NoticeEvent) { 81 | l.Debug("psiphon core notice", "type", event.Type, "data", event.Data) 82 | switch event.Type { 83 | case "EstablishTunnelTimeout": 84 | select { 85 | case errored <- errors.New("clientlib: tunnel establishment timeout"): 86 | default: 87 | } 88 | case "Tunnels": 89 | if event.Data["count"].(float64) > 0 { 90 | select { 91 | case connected <- struct{}{}: 92 | default: 93 | } 94 | } 95 | } 96 | }(event) 97 | })) 98 | 99 | if err := psiphon.OpenDataStore(config); err != nil { 100 | return errors.New("failed to open data store") 101 | } 102 | 103 | if err := psiphon.ImportEmbeddedServerEntries(controllerCtx, config, "", ""); err != nil { 104 | return err 105 | } 106 | 107 | // Create the Psiphon controller 108 | controller, err := psiphon.NewController(config) 109 | if err != nil { 110 | return errors.New("psiphon.NewController failed") 111 | } 112 | 113 | // Begin tunnel connection 114 | go func() { 115 | // Start the tunnel. Only returns on error (or internal timeout). 116 | controller.Run(controllerCtx) 117 | 118 | select { 119 | case errored <- errors.New("controller.Run exited unexpectedly"): 120 | default: 121 | } 122 | }() 123 | 124 | // Wait for an active tunnel or error 125 | select { 126 | case <-connected: 127 | return nil 128 | case err := <-errored: 129 | cancel() 130 | psiphon.CloseDataStore() 131 | psiphon.SetNoticeWriter(io.Discard) 132 | return err 133 | } 134 | } 135 | 136 | func RunPsiphon(ctx context.Context, l *slog.Logger, wgBind netip.AddrPort, dir string, localSocksAddr netip.AddrPort, country string) error { 137 | host := "" 138 | if !netip.MustParsePrefix("127.0.0.0/8").Contains(localSocksAddr.Addr()) { 139 | host = "any" 140 | } 141 | 142 | timeout := 60 143 | config := psiphon.Config{ 144 | EgressRegion: country, 145 | ListenInterface: host, 146 | LocalSocksProxyPort: int(localSocksAddr.Port()), 147 | UpstreamProxyURL: fmt.Sprintf("socks5://%s", wgBind), 148 | DisableLocalHTTPProxy: true, 149 | PropagationChannelId: "FFFFFFFFFFFFFFFF", 150 | RemoteServerListDownloadFilename: "remote_server_list", 151 | RemoteServerListSignaturePublicKey: "MIICIDANBgkqhkiG9w0BAQEFAAOCAg0AMIICCAKCAgEAt7Ls+/39r+T6zNW7GiVpJfzq/xvL9SBH5rIFnk0RXYEYavax3WS6HOD35eTAqn8AniOwiH+DOkvgSKF2caqk/y1dfq47Pdymtwzp9ikpB1C5OfAysXzBiwVJlCdajBKvBZDerV1cMvRzCKvKwRmvDmHgphQQ7WfXIGbRbmmk6opMBh3roE42KcotLFtqp0RRwLtcBRNtCdsrVsjiI1Lqz/lH+T61sGjSjQ3CHMuZYSQJZo/KrvzgQXpkaCTdbObxHqb6/+i1qaVOfEsvjoiyzTxJADvSytVtcTjijhPEV6XskJVHE1Zgl+7rATr/pDQkw6DPCNBS1+Y6fy7GstZALQXwEDN/qhQI9kWkHijT8ns+i1vGg00Mk/6J75arLhqcodWsdeG/M/moWgqQAnlZAGVtJI1OgeF5fsPpXu4kctOfuZlGjVZXQNW34aOzm8r8S0eVZitPlbhcPiR4gT/aSMz/wd8lZlzZYsje/Jr8u/YtlwjjreZrGRmG8KMOzukV3lLmMppXFMvl4bxv6YFEmIuTsOhbLTwFgh7KYNjodLj/LsqRVfwz31PgWQFTEPICV7GCvgVlPRxnofqKSjgTWI4mxDhBpVcATvaoBl1L/6WLbFvBsoAUBItWwctO2xalKxF5szhGm8lccoc5MZr8kfE0uxMgsxz4er68iCID+rsCAQM=", 152 | RemoteServerListUrl: "https://s3.amazonaws.com//psiphon/web/mjr4-p23r-puwl/server_list_compressed", 153 | SponsorId: "FFFFFFFFFFFFFFFF", 154 | NetworkID: "test", 155 | ClientPlatform: "Android_4.0.4_com.example.exampleClientLibraryApp", 156 | AllowDefaultDNSResolverWithBindToDevice: true, 157 | EstablishTunnelTimeoutSeconds: &timeout, 158 | DataRootDirectory: dir, 159 | MigrateDataStoreDirectory: dir, 160 | MigrateObfuscatedServerListDownloadDirectory: dir, 161 | MigrateRemoteServerListDownloadFilename: filepath.Join(dir, "server_list_compressed"), 162 | } 163 | 164 | l.Info("starting handshake") 165 | if err := StartTunnel(ctx, l, &config); err != nil { 166 | return fmt.Errorf("Unable to start psiphon: %w", err) 167 | } 168 | l.Info("psiphon started successfully") 169 | return nil 170 | } 171 | -------------------------------------------------------------------------------- /warp/account.go: -------------------------------------------------------------------------------- 1 | package warp 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "log/slog" 7 | "os" 8 | "path/filepath" 9 | ) 10 | 11 | var identityFile = "wgcf-identity.json" 12 | 13 | func saveIdentity(a Identity, path string) error { 14 | file, err := os.Create(filepath.Join(path, identityFile)) 15 | if err != nil { 16 | return err 17 | } 18 | 19 | encoder := json.NewEncoder(file) 20 | encoder.SetIndent("", " ") 21 | err = encoder.Encode(a) 22 | if err != nil { 23 | return err 24 | } 25 | 26 | return file.Close() 27 | } 28 | 29 | func LoadOrCreateIdentity(l *slog.Logger, path, license string) (*Identity, error) { 30 | l = l.With("subsystem", "warp/account") 31 | 32 | warpAPI := NewWarpAPI(l) 33 | 34 | i, err := LoadIdentity(path) 35 | if err != nil { 36 | l.Info("failed to load identity", "path", path, "error", err) 37 | if err := os.RemoveAll(path); err != nil { 38 | return nil, err 39 | } 40 | 41 | if err := os.MkdirAll(path, os.ModePerm); err != nil { 42 | return nil, err 43 | } 44 | 45 | i, err = CreateIdentity(l, warpAPI, license) 46 | if err != nil { 47 | return nil, err 48 | } 49 | 50 | if err = saveIdentity(i, path); err != nil { 51 | return nil, err 52 | } 53 | } 54 | 55 | if license != "" && i.Account.License != license { 56 | l.Info("updating account license key") 57 | _, err := warpAPI.UpdateAccount(i.Token, i.ID, license) 58 | if err != nil { 59 | return nil, err 60 | } 61 | 62 | iAcc, err := warpAPI.GetAccount(i.Token, i.ID) 63 | if err != nil { 64 | return nil, err 65 | } 66 | i.Account = iAcc 67 | 68 | if err = saveIdentity(i, path); err != nil { 69 | return nil, err 70 | } 71 | } 72 | 73 | l.Info("successfully loaded warp identity") 74 | return &i, nil 75 | } 76 | 77 | func LoadIdentity(path string) (Identity, error) { 78 | identityPath := filepath.Join(path, identityFile) 79 | _, err := os.Stat(identityPath) 80 | if err != nil { 81 | return Identity{}, err 82 | } 83 | 84 | fileBytes, err := os.ReadFile(identityPath) 85 | if err != nil { 86 | return Identity{}, err 87 | } 88 | 89 | i := &Identity{} 90 | err = json.Unmarshal(fileBytes, i) 91 | if err != nil { 92 | return Identity{}, err 93 | } 94 | 95 | if len(i.Config.Peers) < 1 { 96 | return Identity{}, errors.New("identity contains 0 peers") 97 | } 98 | 99 | return *i, nil 100 | } 101 | 102 | func CreateIdentity(l *slog.Logger, warpAPI *WarpAPI, license string) (Identity, error) { 103 | priv, err := GeneratePrivateKey() 104 | if err != nil { 105 | return Identity{}, err 106 | } 107 | 108 | privateKey, publicKey := priv.String(), priv.PublicKey().String() 109 | 110 | l.Info("creating new identity") 111 | i, err := warpAPI.Register(publicKey) 112 | if err != nil { 113 | return Identity{}, err 114 | } 115 | 116 | if license != "" { 117 | l.Info("updating account license key") 118 | _, err := warpAPI.UpdateAccount(i.Token, i.ID, license) 119 | if err != nil { 120 | return Identity{}, err 121 | } 122 | 123 | ac, err := warpAPI.GetAccount(i.Token, i.ID) 124 | if err != nil { 125 | return Identity{}, err 126 | } 127 | i.Account = ac 128 | } 129 | 130 | i.PrivateKey = privateKey 131 | 132 | return i, nil 133 | } 134 | -------------------------------------------------------------------------------- /warp/endpoint.go: -------------------------------------------------------------------------------- 1 | package warp 2 | 3 | import ( 4 | "math/rand" 5 | "net/netip" 6 | "time" 7 | 8 | "github.com/bepass-org/warp-plus/iputils" 9 | ) 10 | 11 | func WarpPrefixes() []netip.Prefix { 12 | return []netip.Prefix{ 13 | netip.MustParsePrefix("162.159.192.0/24"), 14 | netip.MustParsePrefix("162.159.195.0/24"), 15 | netip.MustParsePrefix("188.114.96.0/24"), 16 | netip.MustParsePrefix("188.114.97.0/24"), 17 | netip.MustParsePrefix("188.114.98.0/24"), 18 | netip.MustParsePrefix("188.114.99.0/24"), 19 | netip.MustParsePrefix("2606:4700:d0::/64"), 20 | netip.MustParsePrefix("2606:4700:d1::/64"), 21 | } 22 | } 23 | 24 | func RandomWarpPrefix(v4, v6 bool) netip.Prefix { 25 | if !v4 && !v6 { 26 | panic("Must choose a IP version for RandomWarpPrefix") 27 | } 28 | 29 | cidrs := WarpPrefixes() 30 | rng := rand.New(rand.NewSource(time.Now().UnixNano())) 31 | for { 32 | cidr := cidrs[rng.Intn(len(cidrs))] 33 | 34 | if v4 && cidr.Addr().Is4() { 35 | return cidr 36 | } 37 | 38 | if v6 && cidr.Addr().Is6() { 39 | return cidr 40 | } 41 | } 42 | } 43 | 44 | func WarpPorts() []uint16 { 45 | return []uint16{ 46 | 500, 47 | 854, 48 | 859, 49 | 864, 50 | 878, 51 | 880, 52 | 890, 53 | 891, 54 | 894, 55 | 903, 56 | 908, 57 | 928, 58 | 934, 59 | 939, 60 | 942, 61 | 943, 62 | 945, 63 | 946, 64 | 955, 65 | 968, 66 | 987, 67 | 988, 68 | 1002, 69 | 1010, 70 | 1014, 71 | 1018, 72 | 1070, 73 | 1074, 74 | 1180, 75 | 1387, 76 | 1701, 77 | 1843, 78 | 2371, 79 | 2408, 80 | 2506, 81 | 3138, 82 | 3476, 83 | 3581, 84 | 3854, 85 | 4177, 86 | 4198, 87 | 4233, 88 | 4500, 89 | 5279, 90 | 5956, 91 | 7103, 92 | 7152, 93 | 7156, 94 | 7281, 95 | 7559, 96 | 8319, 97 | 8742, 98 | 8854, 99 | 8886, 100 | } 101 | } 102 | 103 | func RandomWarpPort() uint16 { 104 | ports := WarpPorts() 105 | rng := rand.New(rand.NewSource(time.Now().UnixNano())) 106 | return ports[rng.Intn(len(ports))] 107 | } 108 | 109 | func RandomWarpEndpoint(v4, v6 bool) (netip.AddrPort, error) { 110 | randomIP, err := iputils.RandomIPFromPrefix(RandomWarpPrefix(v4, v6)) 111 | if err != nil { 112 | return netip.AddrPort{}, err 113 | } 114 | 115 | return netip.AddrPortFrom(randomIP, RandomWarpPort()), nil 116 | } 117 | -------------------------------------------------------------------------------- /warp/key.go: -------------------------------------------------------------------------------- 1 | package warp 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/base64" 6 | "fmt" 7 | 8 | "golang.org/x/crypto/curve25519" 9 | ) 10 | 11 | // KeyLen is the expected key length for a WireGuard key. 12 | const KeyLen = 32 // wgh.KeyLen 13 | 14 | // A Key is a public, private, or pre-shared secret key. The Key constructor 15 | // functions in this package can be used to create Keys suitable for each of 16 | // these applications. 17 | type Key [KeyLen]byte 18 | 19 | // GenerateKey generates a Key suitable for use as a pre-shared secret key from 20 | // a cryptographically safe source. 21 | // 22 | // The output Key should not be used as a private key; use GeneratePrivateKey 23 | // instead. 24 | func GenerateKey() (Key, error) { 25 | b := make([]byte, KeyLen) 26 | if _, err := rand.Read(b); err != nil { 27 | return Key{}, fmt.Errorf("wgtypes: failed to read random bytes: %w", err) 28 | } 29 | 30 | return NewKey(b) 31 | } 32 | 33 | // GeneratePrivateKey generates a Key suitable for use as a private key from a 34 | // cryptographically safe source. 35 | func GeneratePrivateKey() (Key, error) { 36 | key, err := GenerateKey() 37 | if err != nil { 38 | return Key{}, err 39 | } 40 | 41 | // Modify random bytes using algorithm described at: 42 | // https://cr.yp.to/ecdh.html. 43 | key[0] &= 248 44 | key[31] &= 127 45 | key[31] |= 64 46 | 47 | return key, nil 48 | } 49 | 50 | // NewKey creates a Key from an existing byte slice. The byte slice must be 51 | // exactly 32 bytes in length. 52 | func NewKey(b []byte) (Key, error) { 53 | if len(b) != KeyLen { 54 | return Key{}, fmt.Errorf("wgtypes: incorrect key size: %d", len(b)) 55 | } 56 | 57 | var k Key 58 | copy(k[:], b) 59 | 60 | return k, nil 61 | } 62 | 63 | // PublicKey computes a public key from the private key k. 64 | // 65 | // PublicKey should only be called when k is a private key. 66 | func (k Key) PublicKey() Key { 67 | var ( 68 | pub [KeyLen]byte 69 | priv = [KeyLen]byte(k) 70 | ) 71 | 72 | // ScalarBaseMult uses the correct base value per https://cr.yp.to/ecdh.html, 73 | // so no need to specify it. 74 | curve25519.ScalarBaseMult(&pub, &priv) 75 | 76 | return Key(pub) 77 | } 78 | 79 | // String returns the base64-encoded string representation of a Key. 80 | // 81 | // ParseKey can be used to produce a new Key from this string. 82 | func (k Key) String() string { 83 | return base64.StdEncoding.EncodeToString(k[:]) 84 | } 85 | -------------------------------------------------------------------------------- /wireguard/LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /wireguard/conn/bindtest/bindtest.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package bindtest 7 | 8 | import ( 9 | "fmt" 10 | "math/rand" 11 | "net" 12 | "net/netip" 13 | "os" 14 | 15 | "github.com/bepass-org/warp-plus/wireguard/conn" 16 | ) 17 | 18 | type ChannelBind struct { 19 | rx4, tx4 *chan []byte 20 | rx6, tx6 *chan []byte 21 | closeSignal chan bool 22 | source4, source6 ChannelEndpoint 23 | target4, target6 ChannelEndpoint 24 | } 25 | 26 | type ChannelEndpoint uint16 27 | 28 | var ( 29 | _ conn.Bind = (*ChannelBind)(nil) 30 | _ conn.Endpoint = (*ChannelEndpoint)(nil) 31 | ) 32 | 33 | func NewChannelBinds() [2]conn.Bind { 34 | arx4 := make(chan []byte, 8192) 35 | brx4 := make(chan []byte, 8192) 36 | arx6 := make(chan []byte, 8192) 37 | brx6 := make(chan []byte, 8192) 38 | var binds [2]ChannelBind 39 | binds[0].rx4 = &arx4 40 | binds[0].tx4 = &brx4 41 | binds[1].rx4 = &brx4 42 | binds[1].tx4 = &arx4 43 | binds[0].rx6 = &arx6 44 | binds[0].tx6 = &brx6 45 | binds[1].rx6 = &brx6 46 | binds[1].tx6 = &arx6 47 | binds[0].target4 = ChannelEndpoint(1) 48 | binds[1].target4 = ChannelEndpoint(2) 49 | binds[0].target6 = ChannelEndpoint(3) 50 | binds[1].target6 = ChannelEndpoint(4) 51 | binds[0].source4 = binds[1].target4 52 | binds[0].source6 = binds[1].target6 53 | binds[1].source4 = binds[0].target4 54 | binds[1].source6 = binds[0].target6 55 | return [2]conn.Bind{&binds[0], &binds[1]} 56 | } 57 | 58 | func (c ChannelEndpoint) ClearSrc() {} 59 | 60 | func (c ChannelEndpoint) SrcToString() string { return "" } 61 | 62 | func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } 63 | 64 | func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } 65 | 66 | func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } 67 | 68 | func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } 69 | 70 | func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 71 | c.closeSignal = make(chan bool) 72 | fns = append(fns, c.makeReceiveFunc(*c.rx4)) 73 | fns = append(fns, c.makeReceiveFunc(*c.rx6)) 74 | if rand.Uint32()&1 == 0 { 75 | return fns, uint16(c.source4), nil 76 | } else { 77 | return fns, uint16(c.source6), nil 78 | } 79 | } 80 | 81 | func (c *ChannelBind) Close() error { 82 | if c.closeSignal != nil { 83 | select { 84 | case <-c.closeSignal: 85 | default: 86 | close(c.closeSignal) 87 | } 88 | } 89 | return nil 90 | } 91 | 92 | func (c *ChannelBind) BatchSize() int { return 1 } 93 | 94 | func (c *ChannelBind) SetMark(mark uint32) error { return nil } 95 | 96 | func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { 97 | return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { 98 | select { 99 | case <-c.closeSignal: 100 | return 0, net.ErrClosed 101 | case rx := <-ch: 102 | copied := copy(bufs[0], rx) 103 | sizes[0] = copied 104 | eps[0] = c.target6 105 | return 1, nil 106 | } 107 | } 108 | } 109 | 110 | func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { 111 | for _, b := range bufs { 112 | select { 113 | case <-c.closeSignal: 114 | return net.ErrClosed 115 | default: 116 | bc := make([]byte, len(b)) 117 | copy(bc, b) 118 | if ep.(ChannelEndpoint) == c.target4 { 119 | *c.tx4 <- bc 120 | } else if ep.(ChannelEndpoint) == c.target6 { 121 | *c.tx6 <- bc 122 | } else { 123 | return os.ErrInvalid 124 | } 125 | } 126 | } 127 | return nil 128 | } 129 | 130 | func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { 131 | addr, err := netip.ParseAddrPort(s) 132 | if err != nil { 133 | return nil, err 134 | } 135 | return ChannelEndpoint(addr.Port()), nil 136 | } 137 | -------------------------------------------------------------------------------- /wireguard/conn/boundif_android.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { 9 | sysconn, err := s.ipv4.SyscallConn() 10 | if err != nil { 11 | return -1, err 12 | } 13 | err = sysconn.Control(func(f uintptr) { 14 | fd = int(f) 15 | }) 16 | if err != nil { 17 | return -1, err 18 | } 19 | return 20 | } 21 | 22 | func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { 23 | sysconn, err := s.ipv6.SyscallConn() 24 | if err != nil { 25 | return -1, err 26 | } 27 | err = sysconn.Control(func(f uintptr) { 28 | fd = int(f) 29 | }) 30 | if err != nil { 31 | return -1, err 32 | } 33 | return 34 | } 35 | -------------------------------------------------------------------------------- /wireguard/conn/conn.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | // Package conn implements WireGuard's network connections. 7 | package conn 8 | 9 | import ( 10 | "errors" 11 | "fmt" 12 | "net/netip" 13 | "reflect" 14 | "runtime" 15 | "strings" 16 | ) 17 | 18 | const ( 19 | IdealBatchSize = 128 // maximum number of packets handled per read and write 20 | ) 21 | 22 | // A ReceiveFunc receives at least one packet from the network and writes them 23 | // into packets. On a successful read it returns the number of elements of 24 | // sizes, packets, and endpoints that should be evaluated. Some elements of 25 | // sizes may be zero, and callers should ignore them. Callers must pass a sizes 26 | // and eps slice with a length greater than or equal to the length of packets. 27 | // These lengths must not exceed the length of the associated Bind.BatchSize(). 28 | type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) 29 | 30 | // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. 31 | // 32 | // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, 33 | // depending on the platform-specific implementation. 34 | type Bind interface { 35 | // Open puts the Bind into a listening state on a given port and reports the actual 36 | // port that it bound to. Passing zero results in a random selection. 37 | // fns is the set of functions that will be called to receive packets. 38 | Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) 39 | 40 | // Close closes the Bind listener. 41 | // All fns returned by Open must return net.ErrClosed after a call to Close. 42 | Close() error 43 | 44 | // SetMark sets the mark for each packet sent through this Bind. 45 | // This mark is passed to the kernel as the socket option SO_MARK. 46 | SetMark(mark uint32) error 47 | 48 | // Send writes one or more packets in bufs to address ep. The length of 49 | // bufs must not exceed BatchSize(). 50 | Send(bufs [][]byte, ep Endpoint) error 51 | 52 | // ParseEndpoint creates a new endpoint from a string. 53 | ParseEndpoint(s string) (Endpoint, error) 54 | 55 | // BatchSize is the number of buffers expected to be passed to 56 | // the ReceiveFuncs, and the maximum expected to be passed to SendBatch. 57 | BatchSize() int 58 | } 59 | 60 | // BindSocketToInterface is implemented by Bind objects that support being 61 | // tied to a single network interface. Used by wireguard-windows. 62 | type BindSocketToInterface interface { 63 | BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error 64 | BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error 65 | } 66 | 67 | // PeekLookAtSocketFd is implemented by Bind objects that support having their 68 | // file descriptor peeked at. Used by wireguard-android. 69 | type PeekLookAtSocketFd interface { 70 | PeekLookAtSocketFd4() (fd int, err error) 71 | PeekLookAtSocketFd6() (fd int, err error) 72 | } 73 | 74 | // An Endpoint maintains the source/destination caching for a peer. 75 | // 76 | // dst: the remote address of a peer ("endpoint" in uapi terminology) 77 | // src: the local address from which datagrams originate going to the peer 78 | type Endpoint interface { 79 | ClearSrc() // clears the source address 80 | SrcToString() string // returns the local source address (ip:port) 81 | DstToString() string // returns the destination address (ip:port) 82 | DstToBytes() []byte // used for mac2 cookie calculations 83 | DstIP() netip.Addr 84 | SrcIP() netip.Addr 85 | } 86 | 87 | var ( 88 | ErrBindAlreadyOpen = errors.New("bind is already open") 89 | ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") 90 | ) 91 | 92 | func (fn ReceiveFunc) PrettyName() string { 93 | name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() 94 | // 0. cheese/taco.beansIPv6.func12.func21218-fm 95 | name = strings.TrimSuffix(name, "-fm") 96 | // 1. cheese/taco.beansIPv6.func12.func21218 97 | if idx := strings.LastIndexByte(name, '/'); idx != -1 { 98 | name = name[idx+1:] 99 | // 2. taco.beansIPv6.func12.func21218 100 | } 101 | for { 102 | var idx int 103 | for idx = len(name) - 1; idx >= 0; idx-- { 104 | if name[idx] < '0' || name[idx] > '9' { 105 | break 106 | } 107 | } 108 | if idx == len(name)-1 { 109 | break 110 | } 111 | const dotFunc = ".func" 112 | if !strings.HasSuffix(name[:idx+1], dotFunc) { 113 | break 114 | } 115 | name = name[:idx+1-len(dotFunc)] 116 | // 3. taco.beansIPv6.func12 117 | // 4. taco.beansIPv6 118 | } 119 | if idx := strings.LastIndexByte(name, '.'); idx != -1 { 120 | name = name[idx+1:] 121 | // 5. beansIPv6 122 | } 123 | if name == "" { 124 | return fmt.Sprintf("%p", fn) 125 | } 126 | if strings.HasSuffix(name, "IPv4") { 127 | return "v4" 128 | } 129 | if strings.HasSuffix(name, "IPv6") { 130 | return "v6" 131 | } 132 | return name 133 | } 134 | -------------------------------------------------------------------------------- /wireguard/conn/conn_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "testing" 10 | ) 11 | 12 | func TestPrettyName(t *testing.T) { 13 | var ( 14 | recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } 15 | ) 16 | 17 | const want = "TestPrettyName" 18 | 19 | t.Run("ReceiveFunc.PrettyName", func(t *testing.T) { 20 | if got := recvFunc.PrettyName(); got != want { 21 | t.Errorf("PrettyName() = %v, want %v", got, want) 22 | } 23 | }) 24 | } 25 | -------------------------------------------------------------------------------- /wireguard/conn/controlfns.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "net" 10 | "syscall" 11 | ) 12 | 13 | // UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is 14 | // the max supported by a default configuration of macOS. Some platforms will 15 | // silently clamp the value to other maximums, such as linux clamping to 16 | // net.core.{r,w}mem_max (see _linux.go for additional implementation that works 17 | // around this limitation) 18 | const socketBufferSize = 7 << 20 19 | 20 | // controlFn is the callback function signature from net.ListenConfig.Control. 21 | // It is used to apply platform specific configuration to the socket prior to 22 | // bind. 23 | type controlFn func(network, address string, c syscall.RawConn) error 24 | 25 | // controlFns is a list of functions that are called from the listen config 26 | // that can apply socket options. 27 | var controlFns = []controlFn{} 28 | 29 | // listenConfig returns a net.ListenConfig that applies the controlFns to the 30 | // socket prior to bind. This is used to apply socket buffer sizing and packet 31 | // information OOB configuration for sticky sockets. 32 | func listenConfig() *net.ListenConfig { 33 | return &net.ListenConfig{ 34 | Control: func(network, address string, c syscall.RawConn) error { 35 | for _, fn := range controlFns { 36 | if err := fn(network, address, c); err != nil { 37 | return err 38 | } 39 | } 40 | return nil 41 | }, 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /wireguard/conn/controlfns_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "fmt" 10 | "runtime" 11 | "syscall" 12 | 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | func init() { 17 | controlFns = append(controlFns, 18 | 19 | // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by 20 | // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to 21 | // fail silently - the result of failure is lower performance on very fast 22 | // links or high latency links. 23 | func(network, address string, c syscall.RawConn) error { 24 | return c.Control(func(fd uintptr) { 25 | // Set up to *mem_max 26 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) 27 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) 28 | // Set beyond *mem_max if CAP_NET_ADMIN 29 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize) 30 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize) 31 | }) 32 | }, 33 | 34 | // Enable receiving of the packet information (IP_PKTINFO for IPv4, 35 | // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support. 36 | func(network, address string, c syscall.RawConn) error { 37 | var err error 38 | switch network { 39 | case "udp4": 40 | if runtime.GOOS != "android" { 41 | c.Control(func(fd uintptr) { 42 | err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) 43 | }) 44 | } 45 | case "udp6": 46 | c.Control(func(fd uintptr) { 47 | if runtime.GOOS != "android" { 48 | err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) 49 | if err != nil { 50 | return 51 | } 52 | } 53 | err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) 54 | }) 55 | default: 56 | err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL) 57 | } 58 | return err 59 | }, 60 | 61 | // Attempt to enable UDP_GRO 62 | func(network, address string, c syscall.RawConn) error { 63 | c.Control(func(fd uintptr) { 64 | _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) 65 | }) 66 | return nil 67 | }, 68 | ) 69 | } 70 | -------------------------------------------------------------------------------- /wireguard/conn/controlfns_unix.go: -------------------------------------------------------------------------------- 1 | //go:build !windows && !linux && !wasm 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "syscall" 12 | 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | func init() { 17 | controlFns = append(controlFns, 18 | func(network, address string, c syscall.RawConn) error { 19 | return c.Control(func(fd uintptr) { 20 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) 21 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) 22 | }) 23 | }, 24 | 25 | func(network, address string, c syscall.RawConn) error { 26 | var err error 27 | if network == "udp6" { 28 | c.Control(func(fd uintptr) { 29 | err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) 30 | }) 31 | } 32 | return err 33 | }, 34 | ) 35 | } 36 | -------------------------------------------------------------------------------- /wireguard/conn/controlfns_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "syscall" 10 | 11 | "golang.org/x/sys/windows" 12 | ) 13 | 14 | func init() { 15 | controlFns = append(controlFns, 16 | func(network, address string, c syscall.RawConn) error { 17 | return c.Control(func(fd uintptr) { 18 | _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize) 19 | _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize) 20 | }) 21 | }, 22 | ) 23 | } 24 | -------------------------------------------------------------------------------- /wireguard/conn/default.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func NewDefaultBind() Bind { return NewStdNetBind() } 11 | -------------------------------------------------------------------------------- /wireguard/conn/errors_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func errShouldDisableUDPGSO(err error) bool { 11 | return false 12 | } 13 | -------------------------------------------------------------------------------- /wireguard/conn/errors_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "errors" 10 | "os" 11 | 12 | "golang.org/x/sys/unix" 13 | ) 14 | 15 | func errShouldDisableUDPGSO(err error) bool { 16 | var serr *os.SyscallError 17 | if errors.As(err, &serr) { 18 | // EIO is returned by udp_send_skb() if the device driver does not have 19 | // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. 20 | // See: 21 | // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 22 | // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 23 | return serr.Err == unix.EIO 24 | } 25 | return false 26 | } 27 | -------------------------------------------------------------------------------- /wireguard/conn/features_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | // +build !linux 3 | 4 | /* SPDX-License-Identifier: MIT 5 | * 6 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 7 | */ 8 | 9 | package conn 10 | 11 | import "net" 12 | 13 | func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { 14 | return 15 | } 16 | -------------------------------------------------------------------------------- /wireguard/conn/features_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "net" 10 | 11 | "golang.org/x/sys/unix" 12 | ) 13 | 14 | func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { 15 | rc, err := conn.SyscallConn() 16 | if err != nil { 17 | return 18 | } 19 | err = rc.Control(func(fd uintptr) { 20 | _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) 21 | txOffload = errSyscall == nil 22 | opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) 23 | rxOffload = errSyscall == nil && opt == 1 24 | }) 25 | if err != nil { 26 | return false, false 27 | } 28 | return txOffload, rxOffload 29 | } 30 | -------------------------------------------------------------------------------- /wireguard/conn/gso_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | // getGSOSize parses control for UDP_GRO and if found returns its GSO size data. 11 | func getGSOSize(control []byte) (int, error) { 12 | return 0, nil 13 | } 14 | 15 | // setGSOSize sets a UDP_SEGMENT in control based on gsoSize. 16 | func setGSOSize(control *[]byte, gsoSize uint16) { 17 | } 18 | 19 | // gsoControlSize returns the recommended buffer size for pooling sticky and UDP 20 | // offloading control data. 21 | const gsoControlSize = 0 22 | -------------------------------------------------------------------------------- /wireguard/conn/gso_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "fmt" 12 | "unsafe" 13 | 14 | "golang.org/x/sys/unix" 15 | ) 16 | 17 | const ( 18 | sizeOfGSOData = 2 19 | ) 20 | 21 | // getGSOSize parses control for UDP_GRO and if found returns its GSO size data. 22 | func getGSOSize(control []byte) (int, error) { 23 | var ( 24 | hdr unix.Cmsghdr 25 | data []byte 26 | rem = control 27 | err error 28 | ) 29 | 30 | for len(rem) > unix.SizeofCmsghdr { 31 | hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) 32 | if err != nil { 33 | return 0, fmt.Errorf("error parsing socket control message: %w", err) 34 | } 35 | if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { 36 | var gso uint16 37 | copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) 38 | return int(gso), nil 39 | } 40 | } 41 | return 0, nil 42 | } 43 | 44 | // setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing 45 | // data in control untouched. 46 | func setGSOSize(control *[]byte, gsoSize uint16) { 47 | existingLen := len(*control) 48 | avail := cap(*control) - existingLen 49 | space := unix.CmsgSpace(sizeOfGSOData) 50 | if avail < space { 51 | return 52 | } 53 | *control = (*control)[:cap(*control)] 54 | gsoControl := (*control)[existingLen:] 55 | hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) 56 | hdr.Level = unix.SOL_UDP 57 | hdr.Type = unix.UDP_SEGMENT 58 | hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) 59 | copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) 60 | *control = (*control)[:existingLen+space] 61 | } 62 | 63 | // gsoControlSize returns the recommended buffer size for pooling UDP 64 | // offloading control data. 65 | var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) 66 | -------------------------------------------------------------------------------- /wireguard/conn/mark_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux && !openbsd && !freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func (s *StdNetBind) SetMark(mark uint32) error { 11 | return nil 12 | } 13 | -------------------------------------------------------------------------------- /wireguard/conn/mark_unix.go: -------------------------------------------------------------------------------- 1 | //go:build linux || openbsd || freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "runtime" 12 | 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | var fwmarkIoctl int 17 | 18 | func init() { 19 | switch runtime.GOOS { 20 | case "linux", "android": 21 | fwmarkIoctl = 36 /* unix.SO_MARK */ 22 | case "freebsd": 23 | fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ 24 | case "openbsd": 25 | fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ 26 | } 27 | } 28 | 29 | func (s *StdNetBind) SetMark(mark uint32) error { 30 | var operr error 31 | if fwmarkIoctl == 0 { 32 | return nil 33 | } 34 | if s.ipv4 != nil { 35 | fd, err := s.ipv4.SyscallConn() 36 | if err != nil { 37 | return err 38 | } 39 | err = fd.Control(func(fd uintptr) { 40 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) 41 | }) 42 | if err == nil { 43 | err = operr 44 | } 45 | if err != nil { 46 | return err 47 | } 48 | } 49 | if s.ipv6 != nil { 50 | fd, err := s.ipv6.SyscallConn() 51 | if err != nil { 52 | return err 53 | } 54 | err = fd.Control(func(fd uintptr) { 55 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) 56 | }) 57 | if err == nil { 58 | err = operr 59 | } 60 | if err != nil { 61 | return err 62 | } 63 | } 64 | return nil 65 | } 66 | -------------------------------------------------------------------------------- /wireguard/conn/sticky_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux || android 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import "net/netip" 11 | 12 | func (e *StdNetEndpoint) SrcIP() netip.Addr { 13 | return netip.Addr{} 14 | } 15 | 16 | func (e *StdNetEndpoint) SrcIfidx() int32 { 17 | return 0 18 | } 19 | 20 | func (e *StdNetEndpoint) SrcToString() string { 21 | return "" 22 | } 23 | 24 | // TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets 25 | // {get,set}srcControl feature set, but use alternatively named flags and need 26 | // ports and require testing. 27 | 28 | // getSrcFromControl parses the control for PKTINFO and if found updates ep with 29 | // the source information found. 30 | func getSrcFromControl(control []byte, ep *StdNetEndpoint) { 31 | } 32 | 33 | // setSrcControl parses the control for PKTINFO and if found updates ep with 34 | // the source information found. 35 | func setSrcControl(control *[]byte, ep *StdNetEndpoint) { 36 | } 37 | 38 | // stickyControlSize returns the recommended buffer size for pooling sticky 39 | // offloading control data. 40 | const stickyControlSize = 0 41 | 42 | const StdNetSupportsStickySockets = false 43 | -------------------------------------------------------------------------------- /wireguard/conn/sticky_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux && !android 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "net/netip" 12 | "unsafe" 13 | 14 | "golang.org/x/sys/unix" 15 | ) 16 | 17 | func (e *StdNetEndpoint) SrcIP() netip.Addr { 18 | switch len(e.src) { 19 | case unix.CmsgSpace(unix.SizeofInet4Pktinfo): 20 | info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) 21 | return netip.AddrFrom4(info.Spec_dst) 22 | case unix.CmsgSpace(unix.SizeofInet6Pktinfo): 23 | info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) 24 | // TODO: set zone. in order to do so we need to check if the address is 25 | // link local, and if it is perform a syscall to turn the ifindex into a 26 | // zone string because netip uses string zones. 27 | return netip.AddrFrom16(info.Addr) 28 | } 29 | return netip.Addr{} 30 | } 31 | 32 | func (e *StdNetEndpoint) SrcIfidx() int32 { 33 | switch len(e.src) { 34 | case unix.CmsgSpace(unix.SizeofInet4Pktinfo): 35 | info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) 36 | return info.Ifindex 37 | case unix.CmsgSpace(unix.SizeofInet6Pktinfo): 38 | info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) 39 | return int32(info.Ifindex) 40 | } 41 | return 0 42 | } 43 | 44 | func (e *StdNetEndpoint) SrcToString() string { 45 | return e.SrcIP().String() 46 | } 47 | 48 | // getSrcFromControl parses the control for PKTINFO and if found updates ep with 49 | // the source information found. 50 | func getSrcFromControl(control []byte, ep *StdNetEndpoint) { 51 | ep.ClearSrc() 52 | 53 | var ( 54 | hdr unix.Cmsghdr 55 | data []byte 56 | rem []byte = control 57 | err error 58 | ) 59 | 60 | for len(rem) > unix.SizeofCmsghdr { 61 | hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) 62 | if err != nil { 63 | return 64 | } 65 | 66 | if hdr.Level == unix.IPPROTO_IP && 67 | hdr.Type == unix.IP_PKTINFO { 68 | 69 | if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { 70 | ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) 71 | } 72 | ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] 73 | 74 | hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) 75 | copy(ep.src, hdrBuf) 76 | copy(ep.src[unix.CmsgLen(0):], data) 77 | return 78 | } 79 | 80 | if hdr.Level == unix.IPPROTO_IPV6 && 81 | hdr.Type == unix.IPV6_PKTINFO { 82 | 83 | if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { 84 | ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) 85 | } 86 | 87 | ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] 88 | 89 | hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) 90 | copy(ep.src, hdrBuf) 91 | copy(ep.src[unix.CmsgLen(0):], data) 92 | return 93 | } 94 | } 95 | } 96 | 97 | // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address 98 | // and source ifindex found in ep. control's len will be set to 0 in the event 99 | // that ep is a default value. 100 | func setSrcControl(control *[]byte, ep *StdNetEndpoint) { 101 | if cap(*control) < len(ep.src) { 102 | return 103 | } 104 | *control = (*control)[:0] 105 | *control = append(*control, ep.src...) 106 | } 107 | 108 | // stickyControlSize returns the recommended buffer size for pooling sticky 109 | // offloading control data. 110 | var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) 111 | 112 | const StdNetSupportsStickySockets = true 113 | -------------------------------------------------------------------------------- /wireguard/device/allowedips_rand_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net" 11 | "net/netip" 12 | "sort" 13 | "testing" 14 | ) 15 | 16 | const ( 17 | NumberOfPeers = 100 18 | NumberOfPeerRemovals = 4 19 | NumberOfAddresses = 250 20 | NumberOfTests = 10000 21 | ) 22 | 23 | type SlowNode struct { 24 | peer *Peer 25 | cidr uint8 26 | bits []byte 27 | } 28 | 29 | type SlowRouter []*SlowNode 30 | 31 | func (r SlowRouter) Len() int { 32 | return len(r) 33 | } 34 | 35 | func (r SlowRouter) Less(i, j int) bool { 36 | return r[i].cidr > r[j].cidr 37 | } 38 | 39 | func (r SlowRouter) Swap(i, j int) { 40 | r[i], r[j] = r[j], r[i] 41 | } 42 | 43 | func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { 44 | for _, t := range r { 45 | if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { 46 | t.peer = peer 47 | t.bits = addr 48 | return r 49 | } 50 | } 51 | r = append(r, &SlowNode{ 52 | cidr: cidr, 53 | bits: addr, 54 | peer: peer, 55 | }) 56 | sort.Sort(r) 57 | return r 58 | } 59 | 60 | func (r SlowRouter) Lookup(addr []byte) *Peer { 61 | for _, t := range r { 62 | common := commonBits(t.bits, addr) 63 | if common >= t.cidr { 64 | return t.peer 65 | } 66 | } 67 | return nil 68 | } 69 | 70 | func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { 71 | n := 0 72 | for _, x := range r { 73 | if x.peer != peer { 74 | r[n] = x 75 | n++ 76 | } 77 | } 78 | return r[:n] 79 | } 80 | 81 | func TestTrieRandom(t *testing.T) { 82 | var slow4, slow6 SlowRouter 83 | var peers []*Peer 84 | var allowedIPs AllowedIPs 85 | 86 | rand.Seed(1) 87 | 88 | for n := 0; n < NumberOfPeers; n++ { 89 | peers = append(peers, &Peer{}) 90 | } 91 | 92 | for n := 0; n < NumberOfAddresses; n++ { 93 | var addr4 [4]byte 94 | rand.Read(addr4[:]) 95 | cidr := uint8(rand.Intn(32) + 1) 96 | index := rand.Intn(NumberOfPeers) 97 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) 98 | slow4 = slow4.Insert(addr4[:], cidr, peers[index]) 99 | 100 | var addr6 [16]byte 101 | rand.Read(addr6[:]) 102 | cidr = uint8(rand.Intn(128) + 1) 103 | index = rand.Intn(NumberOfPeers) 104 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) 105 | slow6 = slow6.Insert(addr6[:], cidr, peers[index]) 106 | } 107 | 108 | var p int 109 | for p = 0; ; p++ { 110 | for n := 0; n < NumberOfTests; n++ { 111 | var addr4 [4]byte 112 | rand.Read(addr4[:]) 113 | peer1 := slow4.Lookup(addr4[:]) 114 | peer2 := allowedIPs.Lookup(addr4[:]) 115 | if peer1 != peer2 { 116 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) 117 | } 118 | 119 | var addr6 [16]byte 120 | rand.Read(addr6[:]) 121 | peer1 = slow6.Lookup(addr6[:]) 122 | peer2 = allowedIPs.Lookup(addr6[:]) 123 | if peer1 != peer2 { 124 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) 125 | } 126 | } 127 | if p >= len(peers) || p >= NumberOfPeerRemovals { 128 | break 129 | } 130 | allowedIPs.RemoveByPeer(peers[p]) 131 | slow4 = slow4.RemoveByPeer(peers[p]) 132 | slow6 = slow6.RemoveByPeer(peers[p]) 133 | } 134 | for ; p < len(peers); p++ { 135 | allowedIPs.RemoveByPeer(peers[p]) 136 | } 137 | 138 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 139 | t.Error("Failed to remove all nodes from trie by peer") 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /wireguard/device/bind_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "errors" 10 | 11 | "github.com/bepass-org/warp-plus/wireguard/conn" 12 | ) 13 | 14 | type DummyDatagram struct { 15 | msg []byte 16 | endpoint conn.Endpoint 17 | } 18 | 19 | type DummyBind struct { 20 | in6 chan DummyDatagram 21 | in4 chan DummyDatagram 22 | closed bool 23 | } 24 | 25 | func (b *DummyBind) SetMark(v uint32) error { 26 | return nil 27 | } 28 | 29 | func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) { 30 | datagram, ok := <-b.in6 31 | if !ok { 32 | return 0, nil, errors.New("closed") 33 | } 34 | copy(buf, datagram.msg) 35 | return len(datagram.msg), datagram.endpoint, nil 36 | } 37 | 38 | func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) { 39 | datagram, ok := <-b.in4 40 | if !ok { 41 | return 0, nil, errors.New("closed") 42 | } 43 | copy(buf, datagram.msg) 44 | return len(datagram.msg), datagram.endpoint, nil 45 | } 46 | 47 | func (b *DummyBind) Close() error { 48 | close(b.in6) 49 | close(b.in4) 50 | b.closed = true 51 | return nil 52 | } 53 | 54 | func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error { 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /wireguard/device/channels.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "runtime" 10 | "sync" 11 | ) 12 | 13 | // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. 14 | // An outboundQueue is ref-counted using its wg field. 15 | // An outboundQueue created with newOutboundQueue has one reference. 16 | // Every additional writer must call wg.Add(1). 17 | // Every completed writer must call wg.Done(). 18 | // When no further writers will be added, 19 | // call wg.Done to remove the initial reference. 20 | // When the refcount hits 0, the queue's channel is closed. 21 | type outboundQueue struct { 22 | c chan *QueueOutboundElementsContainer 23 | wg sync.WaitGroup 24 | } 25 | 26 | func newOutboundQueue() *outboundQueue { 27 | q := &outboundQueue{ 28 | c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), 29 | } 30 | q.wg.Add(1) 31 | go func() { 32 | q.wg.Wait() 33 | close(q.c) 34 | }() 35 | return q 36 | } 37 | 38 | // A inboundQueue is similar to an outboundQueue; see those docs. 39 | type inboundQueue struct { 40 | c chan *QueueInboundElementsContainer 41 | wg sync.WaitGroup 42 | } 43 | 44 | func newInboundQueue() *inboundQueue { 45 | q := &inboundQueue{ 46 | c: make(chan *QueueInboundElementsContainer, QueueInboundSize), 47 | } 48 | q.wg.Add(1) 49 | go func() { 50 | q.wg.Wait() 51 | close(q.c) 52 | }() 53 | return q 54 | } 55 | 56 | // A handshakeQueue is similar to an outboundQueue; see those docs. 57 | type handshakeQueue struct { 58 | c chan QueueHandshakeElement 59 | wg sync.WaitGroup 60 | } 61 | 62 | func newHandshakeQueue() *handshakeQueue { 63 | q := &handshakeQueue{ 64 | c: make(chan QueueHandshakeElement, QueueHandshakeSize), 65 | } 66 | q.wg.Add(1) 67 | go func() { 68 | q.wg.Wait() 69 | close(q.c) 70 | }() 71 | return q 72 | } 73 | 74 | type autodrainingInboundQueue struct { 75 | c chan *QueueInboundElementsContainer 76 | } 77 | 78 | // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. 79 | // It is useful in cases in which is it hard to manage the lifetime of the channel. 80 | // The returned channel must not be closed. Senders should signal shutdown using 81 | // some other means, such as sending a sentinel nil values. 82 | func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { 83 | q := &autodrainingInboundQueue{ 84 | c: make(chan *QueueInboundElementsContainer, QueueInboundSize), 85 | } 86 | runtime.SetFinalizer(q, device.flushInboundQueue) 87 | return q 88 | } 89 | 90 | func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { 91 | for { 92 | select { 93 | case elemsContainer := <-q.c: 94 | elemsContainer.Lock() 95 | for _, elem := range elemsContainer.elems { 96 | device.PutMessageBuffer(elem.buffer) 97 | device.PutInboundElement(elem) 98 | } 99 | device.PutInboundElementsContainer(elemsContainer) 100 | default: 101 | return 102 | } 103 | } 104 | } 105 | 106 | type autodrainingOutboundQueue struct { 107 | c chan *QueueOutboundElementsContainer 108 | } 109 | 110 | // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. 111 | // It is useful in cases in which is it hard to manage the lifetime of the channel. 112 | // The returned channel must not be closed. Senders should signal shutdown using 113 | // some other means, such as sending a sentinel nil values. 114 | // All sends to the channel must be best-effort, because there may be no receivers. 115 | func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { 116 | q := &autodrainingOutboundQueue{ 117 | c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), 118 | } 119 | runtime.SetFinalizer(q, device.flushOutboundQueue) 120 | return q 121 | } 122 | 123 | func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { 124 | for { 125 | select { 126 | case elemsContainer := <-q.c: 127 | elemsContainer.Lock() 128 | for _, elem := range elemsContainer.elems { 129 | device.PutMessageBuffer(elem.buffer) 130 | device.PutOutboundElement(elem) 131 | } 132 | device.PutOutboundElementsContainer(elemsContainer) 133 | default: 134 | return 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /wireguard/device/constants.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "time" 10 | ) 11 | 12 | /* Specification constants */ 13 | 14 | const ( 15 | RekeyAfterMessages = (1 << 60) 16 | RejectAfterMessages = (1 << 64) - (1 << 13) - 1 17 | RekeyAfterTime = time.Second * 120 18 | RekeyAttemptTime = time.Second * 90 19 | RekeyTimeout = time.Second * 10 20 | MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */ 21 | RekeyTimeoutJitterMaxMs = 334 22 | RejectAfterTime = time.Second * 180 23 | KeepaliveTimeout = time.Second * 5 24 | CookieRefreshTime = time.Second * 120 25 | HandshakeInitationRate = time.Second / 50 26 | PaddingMultiple = 16 27 | ) 28 | 29 | const ( 30 | MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) 31 | MaxMessageSize = MaxSegmentSize // maximum size of transport message 32 | MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content 33 | ) 34 | 35 | /* Implementation constants */ 36 | 37 | const ( 38 | UnderLoadAfterTime = time.Second // how long does the device remain under load after detected 39 | MaxPeers = 1 << 16 // maximum number of configured peers 40 | ) 41 | -------------------------------------------------------------------------------- /wireguard/device/cookie.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "sync" 12 | "time" 13 | 14 | "golang.org/x/crypto/blake2s" 15 | "golang.org/x/crypto/chacha20poly1305" 16 | ) 17 | 18 | type CookieChecker struct { 19 | sync.RWMutex 20 | mac1 struct { 21 | key [blake2s.Size]byte 22 | } 23 | mac2 struct { 24 | secret [blake2s.Size]byte 25 | secretSet time.Time 26 | encryptionKey [chacha20poly1305.KeySize]byte 27 | } 28 | } 29 | 30 | type CookieGenerator struct { 31 | sync.RWMutex 32 | mac1 struct { 33 | key [blake2s.Size]byte 34 | } 35 | mac2 struct { 36 | cookie [blake2s.Size128]byte 37 | cookieSet time.Time 38 | hasLastMAC1 bool 39 | lastMAC1 [blake2s.Size128]byte 40 | encryptionKey [chacha20poly1305.KeySize]byte 41 | } 42 | } 43 | 44 | func (st *CookieChecker) Init(pk NoisePublicKey) { 45 | st.Lock() 46 | defer st.Unlock() 47 | 48 | // mac1 state 49 | 50 | func() { 51 | hash, _ := blake2s.New256(nil) 52 | hash.Write([]byte(WGLabelMAC1)) 53 | hash.Write(pk[:]) 54 | hash.Sum(st.mac1.key[:0]) 55 | }() 56 | 57 | // mac2 state 58 | 59 | func() { 60 | hash, _ := blake2s.New256(nil) 61 | hash.Write([]byte(WGLabelCookie)) 62 | hash.Write(pk[:]) 63 | hash.Sum(st.mac2.encryptionKey[:0]) 64 | }() 65 | 66 | st.mac2.secretSet = time.Time{} 67 | } 68 | 69 | func (st *CookieChecker) CheckMAC1(msg []byte) bool { 70 | st.RLock() 71 | defer st.RUnlock() 72 | 73 | size := len(msg) 74 | smac2 := size - blake2s.Size128 75 | smac1 := smac2 - blake2s.Size128 76 | 77 | var mac1 [blake2s.Size128]byte 78 | 79 | mac, _ := blake2s.New128(st.mac1.key[:]) 80 | mac.Write(msg[:smac1]) 81 | mac.Sum(mac1[:0]) 82 | 83 | return hmac.Equal(mac1[:], msg[smac1:smac2]) 84 | } 85 | 86 | func (st *CookieChecker) CheckMAC2(msg, src []byte) bool { 87 | st.RLock() 88 | defer st.RUnlock() 89 | 90 | if time.Since(st.mac2.secretSet) > CookieRefreshTime { 91 | return false 92 | } 93 | 94 | // derive cookie key 95 | 96 | var cookie [blake2s.Size128]byte 97 | func() { 98 | mac, _ := blake2s.New128(st.mac2.secret[:]) 99 | mac.Write(src) 100 | mac.Sum(cookie[:0]) 101 | }() 102 | 103 | // calculate mac of packet (including mac1) 104 | 105 | smac2 := len(msg) - blake2s.Size128 106 | 107 | var mac2 [blake2s.Size128]byte 108 | func() { 109 | mac, _ := blake2s.New128(cookie[:]) 110 | mac.Write(msg[:smac2]) 111 | mac.Sum(mac2[:0]) 112 | }() 113 | 114 | return hmac.Equal(mac2[:], msg[smac2:]) 115 | } 116 | 117 | func (st *CookieChecker) CreateReply( 118 | msg []byte, 119 | recv uint32, 120 | src []byte, 121 | ) (*MessageCookieReply, error) { 122 | st.RLock() 123 | 124 | // refresh cookie secret 125 | 126 | if time.Since(st.mac2.secretSet) > CookieRefreshTime { 127 | st.RUnlock() 128 | st.Lock() 129 | _, err := rand.Read(st.mac2.secret[:]) 130 | if err != nil { 131 | st.Unlock() 132 | return nil, err 133 | } 134 | st.mac2.secretSet = time.Now() 135 | st.Unlock() 136 | st.RLock() 137 | } 138 | 139 | // derive cookie 140 | 141 | var cookie [blake2s.Size128]byte 142 | func() { 143 | mac, _ := blake2s.New128(st.mac2.secret[:]) 144 | mac.Write(src) 145 | mac.Sum(cookie[:0]) 146 | }() 147 | 148 | // encrypt cookie 149 | 150 | size := len(msg) 151 | 152 | smac2 := size - blake2s.Size128 153 | smac1 := smac2 - blake2s.Size128 154 | 155 | reply := new(MessageCookieReply) 156 | reply.Type = MessageCookieReplyType 157 | reply.Receiver = recv 158 | 159 | _, err := rand.Read(reply.Nonce[:]) 160 | if err != nil { 161 | st.RUnlock() 162 | return nil, err 163 | } 164 | 165 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) 166 | xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2]) 167 | 168 | st.RUnlock() 169 | 170 | return reply, nil 171 | } 172 | 173 | func (st *CookieGenerator) Init(pk NoisePublicKey) { 174 | st.Lock() 175 | defer st.Unlock() 176 | 177 | func() { 178 | hash, _ := blake2s.New256(nil) 179 | hash.Write([]byte(WGLabelMAC1)) 180 | hash.Write(pk[:]) 181 | hash.Sum(st.mac1.key[:0]) 182 | }() 183 | 184 | func() { 185 | hash, _ := blake2s.New256(nil) 186 | hash.Write([]byte(WGLabelCookie)) 187 | hash.Write(pk[:]) 188 | hash.Sum(st.mac2.encryptionKey[:0]) 189 | }() 190 | 191 | st.mac2.cookieSet = time.Time{} 192 | } 193 | 194 | func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { 195 | st.Lock() 196 | defer st.Unlock() 197 | 198 | if !st.mac2.hasLastMAC1 { 199 | return false 200 | } 201 | 202 | var cookie [blake2s.Size128]byte 203 | 204 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) 205 | _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) 206 | if err != nil { 207 | return false 208 | } 209 | 210 | st.mac2.cookieSet = time.Now() 211 | st.mac2.cookie = cookie 212 | return true 213 | } 214 | 215 | func (st *CookieGenerator) AddMacs(msg []byte) { 216 | size := len(msg) 217 | 218 | smac2 := size - blake2s.Size128 219 | smac1 := smac2 - blake2s.Size128 220 | 221 | mac1 := msg[smac1:smac2] 222 | mac2 := msg[smac2:] 223 | 224 | st.Lock() 225 | defer st.Unlock() 226 | 227 | // set mac1 228 | 229 | func() { 230 | mac, _ := blake2s.New128(st.mac1.key[:]) 231 | mac.Write(msg[:smac1]) 232 | mac.Sum(mac1[:0]) 233 | }() 234 | copy(st.mac2.lastMAC1[:], mac1) 235 | st.mac2.hasLastMAC1 = true 236 | 237 | // set mac2 238 | 239 | if time.Since(st.mac2.cookieSet) > CookieRefreshTime { 240 | return 241 | } 242 | 243 | func() { 244 | mac, _ := blake2s.New128(st.mac2.cookie[:]) 245 | mac.Write(msg[:smac2]) 246 | mac.Sum(mac2[:0]) 247 | }() 248 | } 249 | -------------------------------------------------------------------------------- /wireguard/device/devicestate_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT. 2 | 3 | package device 4 | 5 | import "strconv" 6 | 7 | const _deviceState_name = "DownUpClosed" 8 | 9 | var _deviceState_index = [...]uint8{0, 4, 6, 12} 10 | 11 | func (i deviceState) String() string { 12 | if i >= deviceState(len(_deviceState_index)-1) { 13 | return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")" 14 | } 15 | return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]] 16 | } 17 | -------------------------------------------------------------------------------- /wireguard/device/endpoint_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net/netip" 11 | ) 12 | 13 | type DummyEndpoint struct { 14 | src, dst netip.Addr 15 | } 16 | 17 | func CreateDummyEndpoint() (*DummyEndpoint, error) { 18 | var src, dst [16]byte 19 | if _, err := rand.Read(src[:]); err != nil { 20 | return nil, err 21 | } 22 | _, err := rand.Read(dst[:]) 23 | return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err 24 | } 25 | 26 | func (e *DummyEndpoint) ClearSrc() {} 27 | 28 | func (e *DummyEndpoint) SrcToString() string { 29 | return netip.AddrPortFrom(e.SrcIP(), 1000).String() 30 | } 31 | 32 | func (e *DummyEndpoint) DstToString() string { 33 | return netip.AddrPortFrom(e.DstIP(), 1000).String() 34 | } 35 | 36 | func (e *DummyEndpoint) DstToBytes() []byte { 37 | out := e.DstIP().AsSlice() 38 | out = append(out, byte(1000&0xff)) 39 | out = append(out, byte((1000>>8)&0xff)) 40 | return out 41 | } 42 | 43 | func (e *DummyEndpoint) DstIP() netip.Addr { 44 | return e.dst 45 | } 46 | 47 | func (e *DummyEndpoint) SrcIP() netip.Addr { 48 | return e.src 49 | } 50 | -------------------------------------------------------------------------------- /wireguard/device/indextable.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/rand" 10 | "encoding/binary" 11 | "sync" 12 | ) 13 | 14 | type IndexTableEntry struct { 15 | peer *Peer 16 | handshake *Handshake 17 | keypair *Keypair 18 | } 19 | 20 | type IndexTable struct { 21 | sync.RWMutex 22 | table map[uint32]IndexTableEntry 23 | } 24 | 25 | func randUint32() (uint32, error) { 26 | var integer [4]byte 27 | _, err := rand.Read(integer[:]) 28 | // Arbitrary endianness; both are intrinsified by the Go compiler. 29 | return binary.LittleEndian.Uint32(integer[:]), err 30 | } 31 | 32 | func (table *IndexTable) Init() { 33 | table.Lock() 34 | defer table.Unlock() 35 | table.table = make(map[uint32]IndexTableEntry) 36 | } 37 | 38 | func (table *IndexTable) Delete(index uint32) { 39 | table.Lock() 40 | defer table.Unlock() 41 | delete(table.table, index) 42 | } 43 | 44 | func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { 45 | table.Lock() 46 | defer table.Unlock() 47 | entry, ok := table.table[index] 48 | if !ok { 49 | return 50 | } 51 | table.table[index] = IndexTableEntry{ 52 | peer: entry.peer, 53 | keypair: keypair, 54 | handshake: nil, 55 | } 56 | } 57 | 58 | func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) { 59 | for { 60 | // generate random index 61 | 62 | index, err := randUint32() 63 | if err != nil { 64 | return index, err 65 | } 66 | 67 | // check if index used 68 | 69 | table.RLock() 70 | _, ok := table.table[index] 71 | table.RUnlock() 72 | if ok { 73 | continue 74 | } 75 | 76 | // check again while locked 77 | 78 | table.Lock() 79 | _, found := table.table[index] 80 | if found { 81 | table.Unlock() 82 | continue 83 | } 84 | table.table[index] = IndexTableEntry{ 85 | peer: peer, 86 | handshake: handshake, 87 | keypair: nil, 88 | } 89 | table.Unlock() 90 | return index, nil 91 | } 92 | } 93 | 94 | func (table *IndexTable) Lookup(id uint32) IndexTableEntry { 95 | table.RLock() 96 | defer table.RUnlock() 97 | return table.table[id] 98 | } 99 | -------------------------------------------------------------------------------- /wireguard/device/ip.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "net" 10 | ) 11 | 12 | const ( 13 | IPv4offsetTotalLength = 2 14 | IPv4offsetSrc = 12 15 | IPv4offsetDst = IPv4offsetSrc + net.IPv4len 16 | ) 17 | 18 | const ( 19 | IPv6offsetPayloadLength = 4 20 | IPv6offsetSrc = 8 21 | IPv6offsetDst = IPv6offsetSrc + net.IPv6len 22 | ) 23 | -------------------------------------------------------------------------------- /wireguard/device/kdf_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "encoding/hex" 10 | "testing" 11 | 12 | "golang.org/x/crypto/blake2s" 13 | ) 14 | 15 | type KDFTest struct { 16 | key string 17 | input string 18 | t0 string 19 | t1 string 20 | t2 string 21 | } 22 | 23 | func assertEquals(t *testing.T, a, b string) { 24 | if a != b { 25 | t.Fatal("expected", a, "=", b) 26 | } 27 | } 28 | 29 | func TestKDF(t *testing.T) { 30 | tests := []KDFTest{ 31 | { 32 | key: "746573742d6b6579", 33 | input: "746573742d696e707574", 34 | t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", 35 | t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", 36 | t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", 37 | }, 38 | { 39 | key: "776972656775617264", 40 | input: "776972656775617264", 41 | t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", 42 | t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", 43 | t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", 44 | }, 45 | { 46 | key: "", 47 | input: "", 48 | t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", 49 | t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", 50 | t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", 51 | }, 52 | } 53 | 54 | var t0, t1, t2 [blake2s.Size]byte 55 | 56 | for _, test := range tests { 57 | key, _ := hex.DecodeString(test.key) 58 | input, _ := hex.DecodeString(test.input) 59 | KDF3(&t0, &t1, &t2, key, input) 60 | t0s := hex.EncodeToString(t0[:]) 61 | t1s := hex.EncodeToString(t1[:]) 62 | t2s := hex.EncodeToString(t2[:]) 63 | assertEquals(t, t0s, test.t0) 64 | assertEquals(t, t1s, test.t1) 65 | assertEquals(t, t2s, test.t2) 66 | } 67 | 68 | for _, test := range tests { 69 | key, _ := hex.DecodeString(test.key) 70 | input, _ := hex.DecodeString(test.input) 71 | KDF2(&t0, &t1, key, input) 72 | t0s := hex.EncodeToString(t0[:]) 73 | t1s := hex.EncodeToString(t1[:]) 74 | assertEquals(t, t0s, test.t0) 75 | assertEquals(t, t1s, test.t1) 76 | } 77 | 78 | for _, test := range tests { 79 | key, _ := hex.DecodeString(test.key) 80 | input, _ := hex.DecodeString(test.input) 81 | KDF1(&t0, key, input) 82 | t0s := hex.EncodeToString(t0[:]) 83 | assertEquals(t, t0s, test.t0) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /wireguard/device/keypair.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/cipher" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | 14 | "github.com/bepass-org/warp-plus/wireguard/replay" 15 | ) 16 | 17 | /* Due to limitations in Go and /x/crypto there is currently 18 | * no way to ensure that key material is securely ereased in memory. 19 | * 20 | * Since this may harm the forward secrecy property, 21 | * we plan to resolve this issue; whenever Go allows us to do so. 22 | */ 23 | 24 | type Keypair struct { 25 | sendNonce atomic.Uint64 26 | send cipher.AEAD 27 | receive cipher.AEAD 28 | replayFilter replay.Filter 29 | isInitiator bool 30 | created time.Time 31 | localIndex uint32 32 | remoteIndex uint32 33 | } 34 | 35 | type Keypairs struct { 36 | sync.RWMutex 37 | current *Keypair 38 | previous *Keypair 39 | next atomic.Pointer[Keypair] 40 | } 41 | 42 | func (kp *Keypairs) Current() *Keypair { 43 | kp.RLock() 44 | defer kp.RUnlock() 45 | return kp.current 46 | } 47 | 48 | func (device *Device) DeleteKeypair(key *Keypair) { 49 | if key != nil { 50 | device.indexTable.Delete(key.localIndex) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /wireguard/device/logger.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "fmt" 10 | "log" 11 | "log/slog" 12 | "os" 13 | ) 14 | 15 | // A Logger provides logging for a Device. 16 | // The functions are Printf-style functions. 17 | // They must be safe for concurrent use. 18 | // They do not require a trailing newline in the format. 19 | // If nil, that level of logging will be silent. 20 | type Logger struct { 21 | Verbosef func(format string, args ...any) 22 | Errorf func(format string, args ...any) 23 | } 24 | 25 | // Log levels for use with NewLogger. 26 | const ( 27 | LogLevelSilent = iota 28 | LogLevelError 29 | LogLevelVerbose 30 | ) 31 | 32 | // Function for use in Logger for discarding logged lines. 33 | func DiscardLogf(format string, args ...any) {} 34 | 35 | // NewLogger constructs a Logger that writes to stdout. 36 | // It logs at the specified log level and above. 37 | // It decorates log lines with the log level, date, time, and prepend. 38 | func NewLogger(level int, prepend string) *Logger { 39 | logger := &Logger{DiscardLogf, DiscardLogf} 40 | logf := func(prefix string) func(string, ...any) { 41 | return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf 42 | } 43 | if level >= LogLevelVerbose { 44 | logger.Verbosef = logf("DEBUG") 45 | } 46 | if level >= LogLevelError { 47 | logger.Errorf = logf("ERROR") 48 | } 49 | return logger 50 | } 51 | 52 | func NewSLogger(l *slog.Logger) *Logger { 53 | return &Logger{ 54 | Verbosef: func(format string, v ...any) { 55 | l.Debug(fmt.Sprintf(format, v...)) 56 | }, 57 | Errorf: func(format string, v ...any) { 58 | l.Error(fmt.Sprintf(format, v...)) 59 | }, 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /wireguard/device/mobilequirks.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | // DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created, 9 | // though it will try to deal with it, and race maybe, if called after. 10 | func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { 11 | device.net.brokenRoaming = true 12 | device.peers.RLock() 13 | for _, peer := range device.peers.keyMap { 14 | peer.endpoint.Lock() 15 | peer.endpoint.disableRoaming = peer.endpoint.val != nil 16 | peer.endpoint.Unlock() 17 | } 18 | device.peers.RUnlock() 19 | } 20 | -------------------------------------------------------------------------------- /wireguard/device/noise-helpers.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "crypto/subtle" 12 | "errors" 13 | "hash" 14 | 15 | "golang.org/x/crypto/blake2s" 16 | "golang.org/x/crypto/curve25519" 17 | ) 18 | 19 | /* KDF related functions. 20 | * HMAC-based Key Derivation Function (HKDF) 21 | * https://tools.ietf.org/html/rfc5869 22 | */ 23 | 24 | func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { 25 | mac := hmac.New(func() hash.Hash { 26 | h, _ := blake2s.New256(nil) 27 | return h 28 | }, key) 29 | mac.Write(in0) 30 | mac.Sum(sum[:0]) 31 | } 32 | 33 | func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { 34 | mac := hmac.New(func() hash.Hash { 35 | h, _ := blake2s.New256(nil) 36 | return h 37 | }, key) 38 | mac.Write(in0) 39 | mac.Write(in1) 40 | mac.Sum(sum[:0]) 41 | } 42 | 43 | func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { 44 | HMAC1(t0, key, input) 45 | HMAC1(t0, t0[:], []byte{0x1}) 46 | } 47 | 48 | func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { 49 | var prk [blake2s.Size]byte 50 | HMAC1(&prk, key, input) 51 | HMAC1(t0, prk[:], []byte{0x1}) 52 | HMAC2(t1, prk[:], t0[:], []byte{0x2}) 53 | setZero(prk[:]) 54 | } 55 | 56 | func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { 57 | var prk [blake2s.Size]byte 58 | HMAC1(&prk, key, input) 59 | HMAC1(t0, prk[:], []byte{0x1}) 60 | HMAC2(t1, prk[:], t0[:], []byte{0x2}) 61 | HMAC2(t2, prk[:], t1[:], []byte{0x3}) 62 | setZero(prk[:]) 63 | } 64 | 65 | func isZero(val []byte) bool { 66 | acc := 1 67 | for _, b := range val { 68 | acc &= subtle.ConstantTimeByteEq(b, 0) 69 | } 70 | return acc == 1 71 | } 72 | 73 | /* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ 74 | func setZero(arr []byte) { 75 | for i := range arr { 76 | arr[i] = 0 77 | } 78 | } 79 | 80 | func (sk *NoisePrivateKey) clamp() { 81 | sk[0] &= 248 82 | sk[31] = (sk[31] & 127) | 64 83 | } 84 | 85 | func newPrivateKey() (sk NoisePrivateKey, err error) { 86 | _, err = rand.Read(sk[:]) 87 | sk.clamp() 88 | return 89 | } 90 | 91 | func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { 92 | apk := (*[NoisePublicKeySize]byte)(&pk) 93 | ask := (*[NoisePrivateKeySize]byte)(sk) 94 | curve25519.ScalarBaseMult(apk, ask) 95 | return 96 | } 97 | 98 | var errInvalidPublicKey = errors.New("invalid public key") 99 | 100 | func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) { 101 | apk := (*[NoisePublicKeySize]byte)(&pk) 102 | ask := (*[NoisePrivateKeySize]byte)(sk) 103 | curve25519.ScalarMult(&ss, ask, apk) 104 | if isZero(ss[:]) { 105 | return ss, errInvalidPublicKey 106 | } 107 | return ss, nil 108 | } 109 | -------------------------------------------------------------------------------- /wireguard/device/noise-types.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/subtle" 10 | "encoding/hex" 11 | "errors" 12 | ) 13 | 14 | const ( 15 | NoisePublicKeySize = 32 16 | NoisePrivateKeySize = 32 17 | NoisePresharedKeySize = 32 18 | ) 19 | 20 | type ( 21 | NoisePublicKey [NoisePublicKeySize]byte 22 | NoisePrivateKey [NoisePrivateKeySize]byte 23 | NoisePresharedKey [NoisePresharedKeySize]byte 24 | NoiseNonce uint64 // padded to 12-bytes 25 | ) 26 | 27 | func loadExactHex(dst []byte, src string) error { 28 | slice, err := hex.DecodeString(src) 29 | if err != nil { 30 | return err 31 | } 32 | if len(slice) != len(dst) { 33 | return errors.New("hex string does not fit the slice") 34 | } 35 | copy(dst, slice) 36 | return nil 37 | } 38 | 39 | func (key NoisePrivateKey) IsZero() bool { 40 | var zero NoisePrivateKey 41 | return key.Equals(zero) 42 | } 43 | 44 | func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { 45 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 46 | } 47 | 48 | func (key *NoisePrivateKey) FromHex(src string) (err error) { 49 | err = loadExactHex(key[:], src) 50 | key.clamp() 51 | return 52 | } 53 | 54 | func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { 55 | err = loadExactHex(key[:], src) 56 | if key.IsZero() { 57 | return 58 | } 59 | key.clamp() 60 | return 61 | } 62 | 63 | func (key *NoisePublicKey) FromHex(src string) error { 64 | return loadExactHex(key[:], src) 65 | } 66 | 67 | func (key NoisePublicKey) IsZero() bool { 68 | var zero NoisePublicKey 69 | return key.Equals(zero) 70 | } 71 | 72 | func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { 73 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 74 | } 75 | 76 | func (key *NoisePresharedKey) FromHex(src string) error { 77 | return loadExactHex(key[:], src) 78 | } 79 | -------------------------------------------------------------------------------- /wireguard/device/noise_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "bytes" 10 | "encoding/binary" 11 | "testing" 12 | 13 | "github.com/bepass-org/warp-plus/wireguard/conn" 14 | "github.com/bepass-org/warp-plus/wireguard/tun/tuntest" 15 | ) 16 | 17 | func TestCurveWrappers(t *testing.T) { 18 | sk1, err := newPrivateKey() 19 | assertNil(t, err) 20 | 21 | sk2, err := newPrivateKey() 22 | assertNil(t, err) 23 | 24 | pk1 := sk1.publicKey() 25 | pk2 := sk2.publicKey() 26 | 27 | ss1, err1 := sk1.sharedSecret(pk2) 28 | ss2, err2 := sk2.sharedSecret(pk1) 29 | 30 | if ss1 != ss2 || err1 != nil || err2 != nil { 31 | t.Fatal("Failed to compute shared secet") 32 | } 33 | } 34 | 35 | func randDevice(t *testing.T) *Device { 36 | sk, err := newPrivateKey() 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | tun := tuntest.NewChannelTUN() 41 | logger := NewLogger(LogLevelError, "") 42 | device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) 43 | device.SetPrivateKey(sk) 44 | return device 45 | } 46 | 47 | func assertNil(t *testing.T, err error) { 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | } 52 | 53 | func assertEqual(t *testing.T, a, b []byte) { 54 | if !bytes.Equal(a, b) { 55 | t.Fatal(a, "!=", b) 56 | } 57 | } 58 | 59 | func TestNoiseHandshake(t *testing.T) { 60 | dev1 := randDevice(t) 61 | dev2 := randDevice(t) 62 | 63 | defer dev1.Close() 64 | defer dev2.Close() 65 | 66 | peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) 67 | if err != nil { 68 | t.Fatal(err) 69 | } 70 | peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | peer1.Start() 75 | peer2.Start() 76 | 77 | assertEqual( 78 | t, 79 | peer1.handshake.precomputedStaticStatic[:], 80 | peer2.handshake.precomputedStaticStatic[:], 81 | ) 82 | 83 | /* simulate handshake */ 84 | 85 | // initiation message 86 | 87 | t.Log("exchange initiation message") 88 | 89 | msg1, err := dev1.CreateMessageInitiation(peer2) 90 | assertNil(t, err) 91 | 92 | packet := make([]byte, 0, 256) 93 | writer := bytes.NewBuffer(packet) 94 | err = binary.Write(writer, binary.LittleEndian, msg1) 95 | assertNil(t, err) 96 | peer := dev2.ConsumeMessageInitiation(msg1) 97 | if peer == nil { 98 | t.Fatal("handshake failed at initiation message") 99 | } 100 | 101 | assertEqual( 102 | t, 103 | peer1.handshake.chainKey[:], 104 | peer2.handshake.chainKey[:], 105 | ) 106 | 107 | assertEqual( 108 | t, 109 | peer1.handshake.hash[:], 110 | peer2.handshake.hash[:], 111 | ) 112 | 113 | // response message 114 | 115 | t.Log("exchange response message") 116 | 117 | msg2, err := dev2.CreateMessageResponse(peer1) 118 | assertNil(t, err) 119 | 120 | peer = dev1.ConsumeMessageResponse(msg2) 121 | if peer == nil { 122 | t.Fatal("handshake failed at response message") 123 | } 124 | 125 | assertEqual( 126 | t, 127 | peer1.handshake.chainKey[:], 128 | peer2.handshake.chainKey[:], 129 | ) 130 | 131 | assertEqual( 132 | t, 133 | peer1.handshake.hash[:], 134 | peer2.handshake.hash[:], 135 | ) 136 | 137 | // key pairs 138 | 139 | t.Log("deriving keys") 140 | 141 | err = peer1.BeginSymmetricSession() 142 | if err != nil { 143 | t.Fatal("failed to derive keypair for peer 1", err) 144 | } 145 | 146 | err = peer2.BeginSymmetricSession() 147 | if err != nil { 148 | t.Fatal("failed to derive keypair for peer 2", err) 149 | } 150 | 151 | key1 := peer1.keypairs.next.Load() 152 | key2 := peer2.keypairs.current 153 | 154 | // encrypting / decryption test 155 | 156 | t.Log("test key pairs") 157 | 158 | func() { 159 | testMsg := []byte("wireguard test message 1") 160 | var err error 161 | var out []byte 162 | var nonce [12]byte 163 | out = key1.send.Seal(out, nonce[:], testMsg, nil) 164 | out, err = key2.receive.Open(out[:0], nonce[:], out, nil) 165 | assertNil(t, err) 166 | assertEqual(t, out, testMsg) 167 | }() 168 | 169 | func() { 170 | testMsg := []byte("wireguard test message 2") 171 | var err error 172 | var out []byte 173 | var nonce [12]byte 174 | out = key2.send.Seal(out, nonce[:], testMsg, nil) 175 | out, err = key1.receive.Open(out[:0], nonce[:], out, nil) 176 | assertNil(t, err) 177 | assertEqual(t, out, testMsg) 178 | }() 179 | } 180 | -------------------------------------------------------------------------------- /wireguard/device/pools.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "sync" 10 | "sync/atomic" 11 | ) 12 | 13 | type WaitPool struct { 14 | pool sync.Pool 15 | cond sync.Cond 16 | lock sync.Mutex 17 | count atomic.Uint32 18 | max uint32 19 | } 20 | 21 | func NewWaitPool(max uint32, new func() any) *WaitPool { 22 | p := &WaitPool{pool: sync.Pool{New: new}, max: max} 23 | p.cond = sync.Cond{L: &p.lock} 24 | return p 25 | } 26 | 27 | func (p *WaitPool) Get() any { 28 | if p.max != 0 { 29 | p.lock.Lock() 30 | for p.count.Load() >= p.max { 31 | p.cond.Wait() 32 | } 33 | p.count.Add(1) 34 | p.lock.Unlock() 35 | } 36 | return p.pool.Get() 37 | } 38 | 39 | func (p *WaitPool) Put(x any) { 40 | p.pool.Put(x) 41 | if p.max == 0 { 42 | return 43 | } 44 | p.count.Add(^uint32(0)) 45 | p.cond.Signal() 46 | } 47 | 48 | func (device *Device) PopulatePools() { 49 | device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { 50 | s := make([]*QueueInboundElement, 0, device.BatchSize()) 51 | return &QueueInboundElementsContainer{elems: s} 52 | }) 53 | device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { 54 | s := make([]*QueueOutboundElement, 0, device.BatchSize()) 55 | return &QueueOutboundElementsContainer{elems: s} 56 | }) 57 | device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { 58 | return new([MaxMessageSize]byte) 59 | }) 60 | device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 61 | return new(QueueInboundElement) 62 | }) 63 | device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 64 | return new(QueueOutboundElement) 65 | }) 66 | } 67 | 68 | func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { 69 | c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) 70 | c.Mutex = sync.Mutex{} 71 | return c 72 | } 73 | 74 | func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { 75 | for i := range c.elems { 76 | c.elems[i] = nil 77 | } 78 | c.elems = c.elems[:0] 79 | device.pool.inboundElementsContainer.Put(c) 80 | } 81 | 82 | func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { 83 | c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) 84 | c.Mutex = sync.Mutex{} 85 | return c 86 | } 87 | 88 | func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { 89 | for i := range c.elems { 90 | c.elems[i] = nil 91 | } 92 | c.elems = c.elems[:0] 93 | device.pool.outboundElementsContainer.Put(c) 94 | } 95 | 96 | func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { 97 | return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) 98 | } 99 | 100 | func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { 101 | device.pool.messageBuffers.Put(msg) 102 | } 103 | 104 | func (device *Device) GetInboundElement() *QueueInboundElement { 105 | return device.pool.inboundElements.Get().(*QueueInboundElement) 106 | } 107 | 108 | func (device *Device) PutInboundElement(elem *QueueInboundElement) { 109 | elem.clearPointers() 110 | device.pool.inboundElements.Put(elem) 111 | } 112 | 113 | func (device *Device) GetOutboundElement() *QueueOutboundElement { 114 | return device.pool.outboundElements.Get().(*QueueOutboundElement) 115 | } 116 | 117 | func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { 118 | elem.clearPointers() 119 | device.pool.outboundElements.Put(elem) 120 | } 121 | -------------------------------------------------------------------------------- /wireguard/device/pools_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "runtime" 11 | "sync" 12 | "sync/atomic" 13 | "testing" 14 | "time" 15 | ) 16 | 17 | func TestWaitPool(t *testing.T) { 18 | t.Skip("Currently disabled") 19 | var wg sync.WaitGroup 20 | var trials atomic.Int32 21 | startTrials := int32(100000) 22 | if raceEnabled { 23 | // This test can be very slow with -race. 24 | startTrials /= 10 25 | } 26 | trials.Store(startTrials) 27 | workers := runtime.NumCPU() + 2 28 | if workers-4 <= 0 { 29 | t.Skip("Not enough cores") 30 | } 31 | p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) 32 | wg.Add(workers) 33 | var max atomic.Uint32 34 | updateMax := func() { 35 | count := p.count.Load() 36 | if count > p.max { 37 | t.Errorf("count (%d) > max (%d)", count, p.max) 38 | } 39 | for { 40 | old := max.Load() 41 | if count <= old { 42 | break 43 | } 44 | if max.CompareAndSwap(old, count) { 45 | break 46 | } 47 | } 48 | } 49 | for i := 0; i < workers; i++ { 50 | go func() { 51 | defer wg.Done() 52 | for trials.Add(-1) > 0 { 53 | updateMax() 54 | x := p.Get() 55 | updateMax() 56 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 57 | updateMax() 58 | p.Put(x) 59 | updateMax() 60 | } 61 | }() 62 | } 63 | wg.Wait() 64 | if max.Load() != p.max { 65 | t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) 66 | } 67 | } 68 | 69 | func BenchmarkWaitPool(b *testing.B) { 70 | var wg sync.WaitGroup 71 | var trials atomic.Int32 72 | trials.Store(int32(b.N)) 73 | workers := runtime.NumCPU() + 2 74 | if workers-4 <= 0 { 75 | b.Skip("Not enough cores") 76 | } 77 | p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) 78 | wg.Add(workers) 79 | b.ResetTimer() 80 | for i := 0; i < workers; i++ { 81 | go func() { 82 | defer wg.Done() 83 | for trials.Add(-1) > 0 { 84 | x := p.Get() 85 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 86 | p.Put(x) 87 | } 88 | }() 89 | } 90 | wg.Wait() 91 | } 92 | 93 | func BenchmarkWaitPoolEmpty(b *testing.B) { 94 | var wg sync.WaitGroup 95 | var trials atomic.Int32 96 | trials.Store(int32(b.N)) 97 | workers := runtime.NumCPU() + 2 98 | if workers-4 <= 0 { 99 | b.Skip("Not enough cores") 100 | } 101 | p := NewWaitPool(0, func() any { return make([]byte, 16) }) 102 | wg.Add(workers) 103 | b.ResetTimer() 104 | for i := 0; i < workers; i++ { 105 | go func() { 106 | defer wg.Done() 107 | for trials.Add(-1) > 0 { 108 | x := p.Get() 109 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 110 | p.Put(x) 111 | } 112 | }() 113 | } 114 | wg.Wait() 115 | } 116 | 117 | func BenchmarkSyncPool(b *testing.B) { 118 | var wg sync.WaitGroup 119 | var trials atomic.Int32 120 | trials.Store(int32(b.N)) 121 | workers := runtime.NumCPU() + 2 122 | if workers-4 <= 0 { 123 | b.Skip("Not enough cores") 124 | } 125 | p := sync.Pool{New: func() any { return make([]byte, 16) }} 126 | wg.Add(workers) 127 | b.ResetTimer() 128 | for i := 0; i < workers; i++ { 129 | go func() { 130 | defer wg.Done() 131 | for trials.Add(-1) > 0 { 132 | x := p.Get() 133 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 134 | p.Put(x) 135 | } 136 | }() 137 | } 138 | wg.Wait() 139 | } 140 | -------------------------------------------------------------------------------- /wireguard/device/queueconstants_android.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import "github.com/bepass-org/warp-plus/wireguard/conn" 9 | 10 | /* Reduce memory consumption for Android */ 11 | 12 | const ( 13 | QueueStagedSize = conn.IdealBatchSize 14 | QueueOutboundSize = 1024 15 | QueueInboundSize = 1024 16 | QueueHandshakeSize = 1024 17 | MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram 18 | PreallocatedBuffersPerPool = 4096 19 | ) 20 | -------------------------------------------------------------------------------- /wireguard/device/queueconstants_default.go: -------------------------------------------------------------------------------- 1 | //go:build !android && !ios && !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | import "github.com/bepass-org/warp-plus/wireguard/conn" 11 | 12 | const ( 13 | QueueStagedSize = conn.IdealBatchSize 14 | QueueOutboundSize = 1024 15 | QueueInboundSize = 1024 16 | QueueHandshakeSize = 1024 17 | MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram 18 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth 19 | ) 20 | -------------------------------------------------------------------------------- /wireguard/device/queueconstants_ios.go: -------------------------------------------------------------------------------- 1 | //go:build ios 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | // Fit within memory limits for iOS's Network Extension API, which has stricter requirements. 11 | // These are vars instead of consts, because heavier network extensions might want to reduce 12 | // them further. 13 | var ( 14 | QueueStagedSize = 128 15 | QueueOutboundSize = 1024 16 | QueueInboundSize = 1024 17 | QueueHandshakeSize = 1024 18 | PreallocatedBuffersPerPool uint32 = 1024 19 | ) 20 | 21 | const MaxSegmentSize = 1700 22 | -------------------------------------------------------------------------------- /wireguard/device/queueconstants_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | const ( 9 | QueueStagedSize = 128 10 | QueueOutboundSize = 1024 11 | QueueInboundSize = 1024 12 | QueueHandshakeSize = 1024 13 | MaxSegmentSize = 2048 - 32 // largest possible UDP datagram 14 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth 15 | ) 16 | -------------------------------------------------------------------------------- /wireguard/device/race_disabled_test.go: -------------------------------------------------------------------------------- 1 | //go:build !race 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const raceEnabled = false 11 | -------------------------------------------------------------------------------- /wireguard/device/race_enabled_test.go: -------------------------------------------------------------------------------- 1 | //go:build race 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const raceEnabled = true 11 | -------------------------------------------------------------------------------- /wireguard/device/sticky_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | package device 4 | 5 | import ( 6 | "github.com/bepass-org/warp-plus/wireguard/conn" 7 | "github.com/bepass-org/warp-plus/wireguard/rwcancel" 8 | ) 9 | 10 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 11 | return nil, nil 12 | } 13 | -------------------------------------------------------------------------------- /wireguard/device/tun.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "fmt" 10 | 11 | "github.com/bepass-org/warp-plus/wireguard/tun" 12 | ) 13 | 14 | const DefaultMTU = 1420 15 | 16 | func (device *Device) RoutineTUNEventReader() { 17 | device.log.Verbosef("Routine: event worker - started") 18 | 19 | for event := range device.tun.device.Events() { 20 | if event&tun.EventMTUUpdate != 0 { 21 | mtu, err := device.tun.device.MTU() 22 | if err != nil { 23 | device.log.Errorf("Failed to load updated MTU of device: %v", err) 24 | continue 25 | } 26 | if mtu < 0 { 27 | device.log.Errorf("MTU not updated to negative value: %v", mtu) 28 | continue 29 | } 30 | var tooLarge string 31 | if mtu > MaxContentSize { 32 | tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) 33 | mtu = MaxContentSize 34 | } 35 | old := device.tun.mtu.Swap(int32(mtu)) 36 | if int(old) != mtu { 37 | device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) 38 | } 39 | } 40 | 41 | if event&tun.EventUp != 0 { 42 | device.log.Verbosef("Interface up requested") 43 | device.Up() 44 | } 45 | 46 | if event&tun.EventDown != 0 { 47 | device.log.Verbosef("Interface down requested") 48 | device.Down() 49 | } 50 | } 51 | 52 | device.log.Verbosef("Routine: event worker - stopped") 53 | } 54 | -------------------------------------------------------------------------------- /wireguard/ipc/uapi_bsd.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || freebsd || openbsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package ipc 9 | 10 | import ( 11 | "errors" 12 | "net" 13 | "os" 14 | "unsafe" 15 | 16 | "golang.org/x/sys/unix" 17 | ) 18 | 19 | type UAPIListener struct { 20 | listener net.Listener // unix socket listener 21 | connNew chan net.Conn 22 | connErr chan error 23 | kqueueFd int 24 | keventFd int 25 | } 26 | 27 | func (l *UAPIListener) Accept() (net.Conn, error) { 28 | for { 29 | select { 30 | case conn := <-l.connNew: 31 | return conn, nil 32 | 33 | case err := <-l.connErr: 34 | return nil, err 35 | } 36 | } 37 | } 38 | 39 | func (l *UAPIListener) Close() error { 40 | err1 := unix.Close(l.kqueueFd) 41 | err2 := unix.Close(l.keventFd) 42 | err3 := l.listener.Close() 43 | if err1 != nil { 44 | return err1 45 | } 46 | if err2 != nil { 47 | return err2 48 | } 49 | return err3 50 | } 51 | 52 | func (l *UAPIListener) Addr() net.Addr { 53 | return l.listener.Addr() 54 | } 55 | 56 | func UAPIListen(name string, file *os.File) (net.Listener, error) { 57 | // wrap file in listener 58 | 59 | listener, err := net.FileListener(file) 60 | if err != nil { 61 | return nil, err 62 | } 63 | 64 | uapi := &UAPIListener{ 65 | listener: listener, 66 | connNew: make(chan net.Conn, 1), 67 | connErr: make(chan error, 1), 68 | } 69 | 70 | if unixListener, ok := listener.(*net.UnixListener); ok { 71 | unixListener.SetUnlinkOnClose(true) 72 | } 73 | 74 | socketPath := sockPath(name) 75 | 76 | // watch for deletion of socket 77 | 78 | uapi.kqueueFd, err = unix.Kqueue() 79 | if err != nil { 80 | return nil, err 81 | } 82 | uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0) 83 | if err != nil { 84 | unix.Close(uapi.kqueueFd) 85 | return nil, err 86 | } 87 | 88 | go func(l *UAPIListener) { 89 | event := unix.Kevent_t{ 90 | Filter: unix.EVFILT_VNODE, 91 | Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT, 92 | Fflags: unix.NOTE_WRITE, 93 | } 94 | // Allow this assignment to work with both the 32-bit and 64-bit version 95 | // of the above struct. If you know another way, please submit a patch. 96 | *(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd) 97 | events := make([]unix.Kevent_t, 1) 98 | n := 1 99 | var kerr error 100 | for { 101 | // start with lstat to avoid race condition 102 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) { 103 | l.connErr <- err 104 | return 105 | } 106 | if (kerr != nil || n != 1) && kerr != unix.EINTR { 107 | if kerr != nil { 108 | l.connErr <- kerr 109 | } else { 110 | l.connErr <- errors.New("kqueue returned empty") 111 | } 112 | return 113 | } 114 | n, kerr = unix.Kevent(uapi.kqueueFd, []unix.Kevent_t{event}, events, nil) 115 | } 116 | }(uapi) 117 | 118 | // watch for new connections 119 | 120 | go func(l *UAPIListener) { 121 | for { 122 | conn, err := l.listener.Accept() 123 | if err != nil { 124 | l.connErr <- err 125 | break 126 | } 127 | l.connNew <- conn 128 | } 129 | }(uapi) 130 | 131 | return uapi, nil 132 | } 133 | -------------------------------------------------------------------------------- /wireguard/ipc/uapi_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | import ( 9 | "net" 10 | "os" 11 | 12 | "github.com/bepass-org/warp-plus/wireguard/rwcancel" 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | type UAPIListener struct { 17 | listener net.Listener // unix socket listener 18 | connNew chan net.Conn 19 | connErr chan error 20 | inotifyFd int 21 | inotifyRWCancel *rwcancel.RWCancel 22 | } 23 | 24 | func (l *UAPIListener) Accept() (net.Conn, error) { 25 | for { 26 | select { 27 | case conn := <-l.connNew: 28 | return conn, nil 29 | 30 | case err := <-l.connErr: 31 | return nil, err 32 | } 33 | } 34 | } 35 | 36 | func (l *UAPIListener) Close() error { 37 | err1 := unix.Close(l.inotifyFd) 38 | err2 := l.inotifyRWCancel.Cancel() 39 | err3 := l.listener.Close() 40 | if err1 != nil { 41 | return err1 42 | } 43 | if err2 != nil { 44 | return err2 45 | } 46 | return err3 47 | } 48 | 49 | func (l *UAPIListener) Addr() net.Addr { 50 | return l.listener.Addr() 51 | } 52 | 53 | func UAPIListen(name string, file *os.File) (net.Listener, error) { 54 | // wrap file in listener 55 | 56 | listener, err := net.FileListener(file) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | if unixListener, ok := listener.(*net.UnixListener); ok { 62 | unixListener.SetUnlinkOnClose(true) 63 | } 64 | 65 | uapi := &UAPIListener{ 66 | listener: listener, 67 | connNew: make(chan net.Conn, 1), 68 | connErr: make(chan error, 1), 69 | } 70 | 71 | // watch for deletion of socket 72 | 73 | socketPath := sockPath(name) 74 | 75 | uapi.inotifyFd, err = unix.InotifyInit() 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | _, err = unix.InotifyAddWatch( 81 | uapi.inotifyFd, 82 | socketPath, 83 | unix.IN_ATTRIB| 84 | unix.IN_DELETE| 85 | unix.IN_DELETE_SELF, 86 | ) 87 | 88 | if err != nil { 89 | return nil, err 90 | } 91 | 92 | uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd) 93 | if err != nil { 94 | unix.Close(uapi.inotifyFd) 95 | return nil, err 96 | } 97 | 98 | go func(l *UAPIListener) { 99 | var buf [0]byte 100 | for { 101 | defer uapi.inotifyRWCancel.Close() 102 | // start with lstat to avoid race condition 103 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) { 104 | l.connErr <- err 105 | return 106 | } 107 | _, err := uapi.inotifyRWCancel.Read(buf[:]) 108 | if err != nil { 109 | l.connErr <- err 110 | return 111 | } 112 | } 113 | }(uapi) 114 | 115 | // watch for new connections 116 | 117 | go func(l *UAPIListener) { 118 | for { 119 | conn, err := l.listener.Accept() 120 | if err != nil { 121 | l.connErr <- err 122 | break 123 | } 124 | l.connNew <- conn 125 | } 126 | }(uapi) 127 | 128 | return uapi, nil 129 | } 130 | -------------------------------------------------------------------------------- /wireguard/ipc/uapi_unix.go: -------------------------------------------------------------------------------- 1 | //go:build linux || darwin || freebsd || openbsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package ipc 9 | 10 | import ( 11 | "errors" 12 | "fmt" 13 | "net" 14 | "os" 15 | 16 | "golang.org/x/sys/unix" 17 | ) 18 | 19 | const ( 20 | IpcErrorIO = -int64(unix.EIO) 21 | IpcErrorProtocol = -int64(unix.EPROTO) 22 | IpcErrorInvalid = -int64(unix.EINVAL) 23 | IpcErrorPortInUse = -int64(unix.EADDRINUSE) 24 | IpcErrorUnknown = -55 // ENOANO 25 | ) 26 | 27 | // socketDirectory is variable because it is modified by a linker 28 | // flag in wireguard-android. 29 | var socketDirectory = "/var/run/wireguard" 30 | 31 | func sockPath(iface string) string { 32 | return fmt.Sprintf("%s/%s.sock", socketDirectory, iface) 33 | } 34 | 35 | func UAPIOpen(name string) (*os.File, error) { 36 | if err := os.MkdirAll(socketDirectory, 0o755); err != nil { 37 | return nil, err 38 | } 39 | 40 | socketPath := sockPath(name) 41 | addr, err := net.ResolveUnixAddr("unix", socketPath) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | oldUmask := unix.Umask(0o077) 47 | defer unix.Umask(oldUmask) 48 | 49 | listener, err := net.ListenUnix("unix", addr) 50 | if err == nil { 51 | return listener.File() 52 | } 53 | 54 | // Test socket, if not in use cleanup and try again. 55 | if _, err := net.Dial("unix", socketPath); err == nil { 56 | return nil, errors.New("unix socket in use") 57 | } 58 | if err := os.Remove(socketPath); err != nil { 59 | return nil, err 60 | } 61 | listener, err = net.ListenUnix("unix", addr) 62 | if err != nil { 63 | return nil, err 64 | } 65 | return listener.File() 66 | } 67 | -------------------------------------------------------------------------------- /wireguard/ipc/uapi_wasm.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | // Made up sentinel error codes for {js,wasip1}/wasm. 9 | const ( 10 | IpcErrorIO = 1 11 | IpcErrorInvalid = 2 12 | IpcErrorPortInUse = 3 13 | IpcErrorUnknown = 4 14 | IpcErrorProtocol = 5 15 | ) 16 | -------------------------------------------------------------------------------- /wireguard/ipc/uapi_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | import ( 9 | "net" 10 | 11 | "github.com/bepass-org/warp-plus/wireguard/ipc/namedpipe" 12 | "golang.org/x/sys/windows" 13 | ) 14 | 15 | // TODO: replace these with actual standard windows error numbers from the win package 16 | const ( 17 | IpcErrorIO = -int64(5) 18 | IpcErrorProtocol = -int64(71) 19 | IpcErrorInvalid = -int64(22) 20 | IpcErrorPortInUse = -int64(98) 21 | IpcErrorUnknown = -int64(55) 22 | ) 23 | 24 | type UAPIListener struct { 25 | listener net.Listener // unix socket listener 26 | connNew chan net.Conn 27 | connErr chan error 28 | kqueueFd int 29 | keventFd int 30 | } 31 | 32 | func (l *UAPIListener) Accept() (net.Conn, error) { 33 | for { 34 | select { 35 | case conn := <-l.connNew: 36 | return conn, nil 37 | 38 | case err := <-l.connErr: 39 | return nil, err 40 | } 41 | } 42 | } 43 | 44 | func (l *UAPIListener) Close() error { 45 | return l.listener.Close() 46 | } 47 | 48 | func (l *UAPIListener) Addr() net.Addr { 49 | return l.listener.Addr() 50 | } 51 | 52 | var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR 53 | 54 | func init() { 55 | var err error 56 | UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)") 57 | if err != nil { 58 | panic(err) 59 | } 60 | } 61 | 62 | func UAPIListen(name string) (net.Listener, error) { 63 | listener, err := (&namedpipe.ListenConfig{ 64 | SecurityDescriptor: UAPISecurityDescriptor, 65 | }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | uapi := &UAPIListener{ 71 | listener: listener, 72 | connNew: make(chan net.Conn, 1), 73 | connErr: make(chan error, 1), 74 | } 75 | 76 | go func(l *UAPIListener) { 77 | for { 78 | conn, err := l.listener.Accept() 79 | if err != nil { 80 | l.connErr <- err 81 | break 82 | } 83 | l.connNew <- conn 84 | } 85 | }(uapi) 86 | 87 | return uapi, nil 88 | } 89 | -------------------------------------------------------------------------------- /wireguard/ratelimiter/ratelimiter.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ratelimiter 7 | 8 | import ( 9 | "net/netip" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | const ( 15 | packetsPerSecond = 20 16 | packetsBurstable = 5 17 | garbageCollectTime = time.Second 18 | packetCost = 1000000000 / packetsPerSecond 19 | maxTokens = packetCost * packetsBurstable 20 | ) 21 | 22 | type RatelimiterEntry struct { 23 | mu sync.Mutex 24 | lastTime time.Time 25 | tokens int64 26 | } 27 | 28 | type Ratelimiter struct { 29 | mu sync.RWMutex 30 | timeNow func() time.Time 31 | 32 | stopReset chan struct{} // send to reset, close to stop 33 | table map[netip.Addr]*RatelimiterEntry 34 | } 35 | 36 | func (rate *Ratelimiter) Close() { 37 | rate.mu.Lock() 38 | defer rate.mu.Unlock() 39 | 40 | if rate.stopReset != nil { 41 | close(rate.stopReset) 42 | } 43 | } 44 | 45 | func (rate *Ratelimiter) Init() { 46 | rate.mu.Lock() 47 | defer rate.mu.Unlock() 48 | 49 | if rate.timeNow == nil { 50 | rate.timeNow = time.Now 51 | } 52 | 53 | // stop any ongoing garbage collection routine 54 | if rate.stopReset != nil { 55 | close(rate.stopReset) 56 | } 57 | 58 | rate.stopReset = make(chan struct{}) 59 | rate.table = make(map[netip.Addr]*RatelimiterEntry) 60 | 61 | stopReset := rate.stopReset // store in case Init is called again. 62 | 63 | // Start garbage collection routine. 64 | go func() { 65 | ticker := time.NewTicker(time.Second) 66 | ticker.Stop() 67 | for { 68 | select { 69 | case _, ok := <-stopReset: 70 | ticker.Stop() 71 | if !ok { 72 | return 73 | } 74 | ticker = time.NewTicker(time.Second) 75 | case <-ticker.C: 76 | if rate.cleanup() { 77 | ticker.Stop() 78 | } 79 | } 80 | } 81 | }() 82 | } 83 | 84 | func (rate *Ratelimiter) cleanup() (empty bool) { 85 | rate.mu.Lock() 86 | defer rate.mu.Unlock() 87 | 88 | for key, entry := range rate.table { 89 | entry.mu.Lock() 90 | if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 91 | delete(rate.table, key) 92 | } 93 | entry.mu.Unlock() 94 | } 95 | 96 | return len(rate.table) == 0 97 | } 98 | 99 | func (rate *Ratelimiter) Allow(ip netip.Addr) bool { 100 | var entry *RatelimiterEntry 101 | // lookup entry 102 | rate.mu.RLock() 103 | entry = rate.table[ip] 104 | rate.mu.RUnlock() 105 | 106 | // make new entry if not found 107 | if entry == nil { 108 | entry = new(RatelimiterEntry) 109 | entry.tokens = maxTokens - packetCost 110 | entry.lastTime = rate.timeNow() 111 | rate.mu.Lock() 112 | rate.table[ip] = entry 113 | if len(rate.table) == 1 { 114 | rate.stopReset <- struct{}{} 115 | } 116 | rate.mu.Unlock() 117 | return true 118 | } 119 | 120 | // add tokens to entry 121 | entry.mu.Lock() 122 | now := rate.timeNow() 123 | entry.tokens += now.Sub(entry.lastTime).Nanoseconds() 124 | entry.lastTime = now 125 | if entry.tokens > maxTokens { 126 | entry.tokens = maxTokens 127 | } 128 | 129 | // subtract cost of packet 130 | if entry.tokens > packetCost { 131 | entry.tokens -= packetCost 132 | entry.mu.Unlock() 133 | return true 134 | } 135 | entry.mu.Unlock() 136 | return false 137 | } 138 | -------------------------------------------------------------------------------- /wireguard/ratelimiter/ratelimiter_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ratelimiter 7 | 8 | import ( 9 | "net/netip" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | type result struct { 15 | allowed bool 16 | text string 17 | wait time.Duration 18 | } 19 | 20 | func TestRatelimiter(t *testing.T) { 21 | var rate Ratelimiter 22 | var expectedResults []result 23 | 24 | nano := func(nano int64) time.Duration { 25 | return time.Nanosecond * time.Duration(nano) 26 | } 27 | 28 | add := func(res result) { 29 | expectedResults = append( 30 | expectedResults, 31 | res, 32 | ) 33 | } 34 | 35 | for i := 0; i < packetsBurstable; i++ { 36 | add(result{ 37 | allowed: true, 38 | text: "initial burst", 39 | }) 40 | } 41 | 42 | add(result{ 43 | allowed: false, 44 | text: "after burst", 45 | }) 46 | 47 | add(result{ 48 | allowed: true, 49 | wait: nano(time.Second.Nanoseconds() / packetsPerSecond), 50 | text: "filling tokens for single packet", 51 | }) 52 | 53 | add(result{ 54 | allowed: false, 55 | text: "not having refilled enough", 56 | }) 57 | 58 | add(result{ 59 | allowed: true, 60 | wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)), 61 | text: "filling tokens for two packet burst", 62 | }) 63 | 64 | add(result{ 65 | allowed: true, 66 | text: "second packet in 2 packet burst", 67 | }) 68 | 69 | add(result{ 70 | allowed: false, 71 | text: "packet following 2 packet burst", 72 | }) 73 | 74 | ips := []netip.Addr{ 75 | netip.MustParseAddr("127.0.0.1"), 76 | netip.MustParseAddr("192.168.1.1"), 77 | netip.MustParseAddr("172.167.2.3"), 78 | netip.MustParseAddr("97.231.252.215"), 79 | netip.MustParseAddr("248.97.91.167"), 80 | netip.MustParseAddr("188.208.233.47"), 81 | netip.MustParseAddr("104.2.183.179"), 82 | netip.MustParseAddr("72.129.46.120"), 83 | netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), 84 | netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), 85 | netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), 86 | netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), 87 | netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), 88 | netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), 89 | } 90 | 91 | now := time.Now() 92 | rate.timeNow = func() time.Time { 93 | return now 94 | } 95 | defer func() { 96 | // Lock to avoid data race with cleanup goroutine from Init. 97 | rate.mu.Lock() 98 | defer rate.mu.Unlock() 99 | 100 | rate.timeNow = time.Now 101 | }() 102 | timeSleep := func(d time.Duration) { 103 | now = now.Add(d + 1) 104 | rate.cleanup() 105 | } 106 | 107 | rate.Init() 108 | defer rate.Close() 109 | 110 | for i, res := range expectedResults { 111 | timeSleep(res.wait) 112 | for _, ip := range ips { 113 | allowed := rate.Allow(ip) 114 | if allowed != res.allowed { 115 | t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed) 116 | } 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /wireguard/replay/replay.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. 7 | package replay 8 | 9 | type block uint64 10 | 11 | const ( 12 | blockBitLog = 6 // 1<<6 == 64 bits 13 | blockBits = 1 << blockBitLog // must be power of 2 14 | ringBlocks = 1 << 7 // must be power of 2 15 | windowSize = (ringBlocks - 1) * blockBits 16 | blockMask = ringBlocks - 1 17 | bitMask = blockBits - 1 18 | ) 19 | 20 | // A Filter rejects replayed messages by checking if message counter value is 21 | // within a sliding window of previously received messages. 22 | // The zero value for Filter is an empty filter ready to use. 23 | // Filters are unsafe for concurrent use. 24 | type Filter struct { 25 | last uint64 26 | ring [ringBlocks]block 27 | } 28 | 29 | // Reset resets the filter to empty state. 30 | func (f *Filter) Reset() { 31 | f.last = 0 32 | f.ring[0] = 0 33 | } 34 | 35 | // ValidateCounter checks if the counter should be accepted. 36 | // Overlimit counters (>= limit) are always rejected. 37 | func (f *Filter) ValidateCounter(counter, limit uint64) bool { 38 | if counter >= limit { 39 | return false 40 | } 41 | indexBlock := counter >> blockBitLog 42 | if counter > f.last { // move window forward 43 | current := f.last >> blockBitLog 44 | diff := indexBlock - current 45 | if diff > ringBlocks { 46 | diff = ringBlocks // cap diff to clear the whole ring 47 | } 48 | for i := current + 1; i <= current+diff; i++ { 49 | f.ring[i&blockMask] = 0 50 | } 51 | f.last = counter 52 | } else if f.last-counter > windowSize { // behind current window 53 | return false 54 | } 55 | // check and set bit 56 | indexBlock &= blockMask 57 | indexBit := counter & bitMask 58 | old := f.ring[indexBlock] 59 | new := old | 1< 0; i-- { 91 | T(i, true) 92 | } 93 | 94 | t.Log("Bulk test 4") 95 | filter.Reset() 96 | testNumber = 0 97 | for i := uint64(windowSize + 2); i > 1; i-- { 98 | T(i, true) 99 | } 100 | T(0, false) 101 | 102 | t.Log("Bulk test 5") 103 | filter.Reset() 104 | testNumber = 0 105 | for i := uint64(windowSize); i > 0; i-- { 106 | T(i, true) 107 | } 108 | T(windowSize+1, true) 109 | T(0, false) 110 | 111 | t.Log("Bulk test 6") 112 | filter.Reset() 113 | testNumber = 0 114 | for i := uint64(windowSize); i > 0; i-- { 115 | T(i, true) 116 | } 117 | T(0, true) 118 | T(windowSize+1, true) 119 | } 120 | -------------------------------------------------------------------------------- /wireguard/rwcancel/rwcancel.go: -------------------------------------------------------------------------------- 1 | //go:build !windows && !wasm 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | // Package rwcancel implements cancelable read/write operations on 9 | // a file descriptor. 10 | package rwcancel 11 | 12 | import ( 13 | "errors" 14 | "os" 15 | "syscall" 16 | 17 | "golang.org/x/sys/unix" 18 | ) 19 | 20 | type RWCancel struct { 21 | fd int 22 | closingReader *os.File 23 | closingWriter *os.File 24 | } 25 | 26 | func NewRWCancel(fd int) (*RWCancel, error) { 27 | err := unix.SetNonblock(fd, true) 28 | if err != nil { 29 | return nil, err 30 | } 31 | rwcancel := RWCancel{fd: fd} 32 | 33 | rwcancel.closingReader, rwcancel.closingWriter, err = os.Pipe() 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | return &rwcancel, nil 39 | } 40 | 41 | func RetryAfterError(err error) bool { 42 | return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) 43 | } 44 | 45 | func (rw *RWCancel) ReadyRead() bool { 46 | closeFd := int32(rw.closingReader.Fd()) 47 | 48 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}} 49 | var err error 50 | for { 51 | _, err = unix.Poll(pollFds, -1) 52 | if err == nil || !RetryAfterError(err) { 53 | break 54 | } 55 | } 56 | if err != nil { 57 | return false 58 | } 59 | if pollFds[1].Revents != 0 { 60 | return false 61 | } 62 | return pollFds[0].Revents != 0 63 | } 64 | 65 | func (rw *RWCancel) ReadyWrite() bool { 66 | closeFd := int32(rw.closingReader.Fd()) 67 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}} 68 | var err error 69 | for { 70 | _, err = unix.Poll(pollFds, -1) 71 | if err == nil || !RetryAfterError(err) { 72 | break 73 | } 74 | } 75 | if err != nil { 76 | return false 77 | } 78 | 79 | if pollFds[1].Revents != 0 { 80 | return false 81 | } 82 | return pollFds[0].Revents != 0 83 | } 84 | 85 | func (rw *RWCancel) Read(p []byte) (n int, err error) { 86 | for { 87 | n, err := unix.Read(rw.fd, p) 88 | if err == nil || !RetryAfterError(err) { 89 | return n, err 90 | } 91 | if !rw.ReadyRead() { 92 | return 0, os.ErrClosed 93 | } 94 | } 95 | } 96 | 97 | func (rw *RWCancel) Write(p []byte) (n int, err error) { 98 | for { 99 | n, err := unix.Write(rw.fd, p) 100 | if err == nil || !RetryAfterError(err) { 101 | return n, err 102 | } 103 | if !rw.ReadyWrite() { 104 | return 0, os.ErrClosed 105 | } 106 | } 107 | } 108 | 109 | func (rw *RWCancel) Cancel() (err error) { 110 | _, err = rw.closingWriter.Write([]byte{0}) 111 | return 112 | } 113 | 114 | func (rw *RWCancel) Close() { 115 | rw.closingReader.Close() 116 | rw.closingWriter.Close() 117 | } 118 | -------------------------------------------------------------------------------- /wireguard/rwcancel/rwcancel_stub.go: -------------------------------------------------------------------------------- 1 | //go:build windows || wasm 2 | 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rwcancel 6 | 7 | type RWCancel struct{} 8 | 9 | func (*RWCancel) Cancel() {} 10 | -------------------------------------------------------------------------------- /wireguard/tai64n/tai64n.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tai64n 7 | 8 | import ( 9 | "bytes" 10 | "encoding/binary" 11 | "time" 12 | ) 13 | 14 | const ( 15 | TimestampSize = 12 16 | base = uint64(0x400000000000000a) 17 | whitenerMask = uint32(0x1000000 - 1) 18 | ) 19 | 20 | type Timestamp [TimestampSize]byte 21 | 22 | func stamp(t time.Time) Timestamp { 23 | var tai64n Timestamp 24 | secs := base + uint64(t.Unix()) 25 | nano := uint32(t.Nanosecond()) &^ whitenerMask 26 | binary.BigEndian.PutUint64(tai64n[:], secs) 27 | binary.BigEndian.PutUint32(tai64n[8:], nano) 28 | return tai64n 29 | } 30 | 31 | func Now() Timestamp { 32 | return stamp(time.Now()) 33 | } 34 | 35 | func (t1 Timestamp) After(t2 Timestamp) bool { 36 | return bytes.Compare(t1[:], t2[:]) > 0 37 | } 38 | 39 | func (t Timestamp) String() string { 40 | return time.Unix(int64(binary.BigEndian.Uint64(t[:8])-base), int64(binary.BigEndian.Uint32(t[8:12]))).String() 41 | } 42 | -------------------------------------------------------------------------------- /wireguard/tai64n/tai64n_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tai64n 7 | 8 | import ( 9 | "testing" 10 | "time" 11 | ) 12 | 13 | // Test that timestamps are monotonic as required by Wireguard and that 14 | // nanosecond-level information is whitened to prevent side channel attacks. 15 | func TestMonotonic(t *testing.T) { 16 | startTime := time.Unix(0, 123456789) // a nontrivial bit pattern 17 | // Whitening should reduce timestamp granularity 18 | // to more than 10 but fewer than 20 milliseconds. 19 | tests := []struct { 20 | name string 21 | t1, t2 time.Time 22 | wantAfter bool 23 | }{ 24 | {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false}, 25 | {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false}, 26 | {"after_1_ms", startTime, startTime.Add(time.Millisecond), false}, 27 | {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false}, 28 | {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true}, 29 | } 30 | 31 | for _, tt := range tests { 32 | t.Run(tt.name, func(t *testing.T) { 33 | ts1, ts2 := stamp(tt.t1), stamp(tt.t2) 34 | got := ts2.After(ts1) 35 | if got != tt.wantAfter { 36 | t.Errorf("after = %v; want %v", got, tt.wantAfter) 37 | } 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /wireguard/tun/alignment_windows_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "reflect" 10 | "testing" 11 | "unsafe" 12 | ) 13 | 14 | func checkAlignment(t *testing.T, name string, offset uintptr) { 15 | t.Helper() 16 | if offset%8 != 0 { 17 | t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) 18 | } 19 | } 20 | 21 | // TestRateJugglerAlignment checks that atomically-accessed fields are 22 | // aligned to 64-bit boundaries, as required by the atomic package. 23 | // 24 | // Unfortunately, violating this rule on 32-bit platforms results in a 25 | // hard segfault at runtime. 26 | func TestRateJugglerAlignment(t *testing.T) { 27 | var r rateJuggler 28 | 29 | typ := reflect.TypeOf(&r).Elem() 30 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 31 | for i := 0; i < typ.NumField(); i++ { 32 | field := typ.Field(i) 33 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 34 | field.Name, 35 | field.Offset, 36 | field.Type.Size(), 37 | field.Type.Align(), 38 | ) 39 | } 40 | 41 | checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current)) 42 | checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount)) 43 | checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime)) 44 | } 45 | 46 | // TestNativeTunAlignment checks that atomically-accessed fields are 47 | // aligned to 64-bit boundaries, as required by the atomic package. 48 | // 49 | // Unfortunately, violating this rule on 32-bit platforms results in a 50 | // hard segfault at runtime. 51 | func TestNativeTunAlignment(t *testing.T) { 52 | var tun NativeTun 53 | 54 | typ := reflect.TypeOf(&tun).Elem() 55 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 56 | for i := 0; i < typ.NumField(); i++ { 57 | field := typ.Field(i) 58 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 59 | field.Name, 60 | field.Offset, 61 | field.Type.Size(), 62 | field.Type.Align(), 63 | ) 64 | } 65 | 66 | checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate)) 67 | } 68 | -------------------------------------------------------------------------------- /wireguard/tun/checksum.go: -------------------------------------------------------------------------------- 1 | package tun 2 | 3 | import "encoding/binary" 4 | 5 | // TODO: Explore SIMD and/or other assembly optimizations. 6 | // TODO: Test native endian loads. See RFC 1071 section 2 part B. 7 | func checksumNoFold(b []byte, initial uint64) uint64 { 8 | ac := initial 9 | 10 | for len(b) >= 128 { 11 | ac += uint64(binary.BigEndian.Uint32(b[:4])) 12 | ac += uint64(binary.BigEndian.Uint32(b[4:8])) 13 | ac += uint64(binary.BigEndian.Uint32(b[8:12])) 14 | ac += uint64(binary.BigEndian.Uint32(b[12:16])) 15 | ac += uint64(binary.BigEndian.Uint32(b[16:20])) 16 | ac += uint64(binary.BigEndian.Uint32(b[20:24])) 17 | ac += uint64(binary.BigEndian.Uint32(b[24:28])) 18 | ac += uint64(binary.BigEndian.Uint32(b[28:32])) 19 | ac += uint64(binary.BigEndian.Uint32(b[32:36])) 20 | ac += uint64(binary.BigEndian.Uint32(b[36:40])) 21 | ac += uint64(binary.BigEndian.Uint32(b[40:44])) 22 | ac += uint64(binary.BigEndian.Uint32(b[44:48])) 23 | ac += uint64(binary.BigEndian.Uint32(b[48:52])) 24 | ac += uint64(binary.BigEndian.Uint32(b[52:56])) 25 | ac += uint64(binary.BigEndian.Uint32(b[56:60])) 26 | ac += uint64(binary.BigEndian.Uint32(b[60:64])) 27 | ac += uint64(binary.BigEndian.Uint32(b[64:68])) 28 | ac += uint64(binary.BigEndian.Uint32(b[68:72])) 29 | ac += uint64(binary.BigEndian.Uint32(b[72:76])) 30 | ac += uint64(binary.BigEndian.Uint32(b[76:80])) 31 | ac += uint64(binary.BigEndian.Uint32(b[80:84])) 32 | ac += uint64(binary.BigEndian.Uint32(b[84:88])) 33 | ac += uint64(binary.BigEndian.Uint32(b[88:92])) 34 | ac += uint64(binary.BigEndian.Uint32(b[92:96])) 35 | ac += uint64(binary.BigEndian.Uint32(b[96:100])) 36 | ac += uint64(binary.BigEndian.Uint32(b[100:104])) 37 | ac += uint64(binary.BigEndian.Uint32(b[104:108])) 38 | ac += uint64(binary.BigEndian.Uint32(b[108:112])) 39 | ac += uint64(binary.BigEndian.Uint32(b[112:116])) 40 | ac += uint64(binary.BigEndian.Uint32(b[116:120])) 41 | ac += uint64(binary.BigEndian.Uint32(b[120:124])) 42 | ac += uint64(binary.BigEndian.Uint32(b[124:128])) 43 | b = b[128:] 44 | } 45 | if len(b) >= 64 { 46 | ac += uint64(binary.BigEndian.Uint32(b[:4])) 47 | ac += uint64(binary.BigEndian.Uint32(b[4:8])) 48 | ac += uint64(binary.BigEndian.Uint32(b[8:12])) 49 | ac += uint64(binary.BigEndian.Uint32(b[12:16])) 50 | ac += uint64(binary.BigEndian.Uint32(b[16:20])) 51 | ac += uint64(binary.BigEndian.Uint32(b[20:24])) 52 | ac += uint64(binary.BigEndian.Uint32(b[24:28])) 53 | ac += uint64(binary.BigEndian.Uint32(b[28:32])) 54 | ac += uint64(binary.BigEndian.Uint32(b[32:36])) 55 | ac += uint64(binary.BigEndian.Uint32(b[36:40])) 56 | ac += uint64(binary.BigEndian.Uint32(b[40:44])) 57 | ac += uint64(binary.BigEndian.Uint32(b[44:48])) 58 | ac += uint64(binary.BigEndian.Uint32(b[48:52])) 59 | ac += uint64(binary.BigEndian.Uint32(b[52:56])) 60 | ac += uint64(binary.BigEndian.Uint32(b[56:60])) 61 | ac += uint64(binary.BigEndian.Uint32(b[60:64])) 62 | b = b[64:] 63 | } 64 | if len(b) >= 32 { 65 | ac += uint64(binary.BigEndian.Uint32(b[:4])) 66 | ac += uint64(binary.BigEndian.Uint32(b[4:8])) 67 | ac += uint64(binary.BigEndian.Uint32(b[8:12])) 68 | ac += uint64(binary.BigEndian.Uint32(b[12:16])) 69 | ac += uint64(binary.BigEndian.Uint32(b[16:20])) 70 | ac += uint64(binary.BigEndian.Uint32(b[20:24])) 71 | ac += uint64(binary.BigEndian.Uint32(b[24:28])) 72 | ac += uint64(binary.BigEndian.Uint32(b[28:32])) 73 | b = b[32:] 74 | } 75 | if len(b) >= 16 { 76 | ac += uint64(binary.BigEndian.Uint32(b[:4])) 77 | ac += uint64(binary.BigEndian.Uint32(b[4:8])) 78 | ac += uint64(binary.BigEndian.Uint32(b[8:12])) 79 | ac += uint64(binary.BigEndian.Uint32(b[12:16])) 80 | b = b[16:] 81 | } 82 | if len(b) >= 8 { 83 | ac += uint64(binary.BigEndian.Uint32(b[:4])) 84 | ac += uint64(binary.BigEndian.Uint32(b[4:8])) 85 | b = b[8:] 86 | } 87 | if len(b) >= 4 { 88 | ac += uint64(binary.BigEndian.Uint32(b)) 89 | b = b[4:] 90 | } 91 | if len(b) >= 2 { 92 | ac += uint64(binary.BigEndian.Uint16(b)) 93 | b = b[2:] 94 | } 95 | if len(b) == 1 { 96 | ac += uint64(b[0]) << 8 97 | } 98 | 99 | return ac 100 | } 101 | 102 | func checksum(b []byte, initial uint64) uint16 { 103 | ac := checksumNoFold(b, initial) 104 | ac = (ac >> 16) + (ac & 0xffff) 105 | ac = (ac >> 16) + (ac & 0xffff) 106 | ac = (ac >> 16) + (ac & 0xffff) 107 | ac = (ac >> 16) + (ac & 0xffff) 108 | return uint16(ac) 109 | } 110 | 111 | func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { 112 | sum := checksumNoFold(srcAddr, 0) 113 | sum = checksumNoFold(dstAddr, sum) 114 | sum = checksumNoFold([]byte{0, protocol}, sum) 115 | tmp := make([]byte, 2) 116 | binary.BigEndian.PutUint16(tmp, totalLen) 117 | return checksumNoFold(tmp, sum) 118 | } 119 | -------------------------------------------------------------------------------- /wireguard/tun/checksum_test.go: -------------------------------------------------------------------------------- 1 | package tun 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | ) 8 | 9 | func BenchmarkChecksum(b *testing.B) { 10 | lengths := []int{ 11 | 64, 12 | 128, 13 | 256, 14 | 512, 15 | 1024, 16 | 1500, 17 | 2048, 18 | 4096, 19 | 8192, 20 | 9000, 21 | 9001, 22 | } 23 | 24 | for _, length := range lengths { 25 | b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { 26 | buf := make([]byte, length) 27 | rng := rand.New(rand.NewSource(1)) 28 | rng.Read(buf) 29 | b.ResetTimer() 30 | for i := 0; i < b.N; i++ { 31 | checksum(buf, 0) 32 | } 33 | }) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /wireguard/tun/errors.go: -------------------------------------------------------------------------------- 1 | package tun 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | var ( 8 | // ErrTooManySegments is returned by Device.Read() when segmentation 9 | // overflows the length of supplied buffers. This error should not cause 10 | // reads to cease. 11 | ErrTooManySegments = errors.New("too many segments") 12 | ) 13 | -------------------------------------------------------------------------------- /wireguard/tun/netstack/examples/http_client.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "io" 12 | "log" 13 | "net/http" 14 | "net/netip" 15 | 16 | "github.com/bepass-org/warp-plus/wireguard/conn" 17 | "github.com/bepass-org/warp-plus/wireguard/device" 18 | "github.com/bepass-org/warp-plus/wireguard/tun/netstack" 19 | ) 20 | 21 | func main() { 22 | tun, tnet, err := netstack.CreateNetTUN( 23 | []netip.Addr{netip.MustParseAddr("192.168.4.28")}, 24 | []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 25 | 1420) 26 | if err != nil { 27 | log.Panic(err) 28 | } 29 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 30 | err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379 31 | public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28 32 | allowed_ip=0.0.0.0/0 33 | endpoint=127.0.0.1:58120 34 | `) 35 | err = dev.Up() 36 | if err != nil { 37 | log.Panic(err) 38 | } 39 | 40 | client := http.Client{ 41 | Transport: &http.Transport{ 42 | DialContext: tnet.DialContext, 43 | }, 44 | } 45 | resp, err := client.Get("http://192.168.4.29/") 46 | if err != nil { 47 | log.Panic(err) 48 | } 49 | body, err := io.ReadAll(resp.Body) 50 | if err != nil { 51 | log.Panic(err) 52 | } 53 | log.Println(string(body)) 54 | } 55 | -------------------------------------------------------------------------------- /wireguard/tun/netstack/examples/http_server.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "io" 12 | "log" 13 | "net" 14 | "net/http" 15 | "net/netip" 16 | 17 | "github.com/bepass-org/warp-plus/wireguard/conn" 18 | "github.com/bepass-org/warp-plus/wireguard/device" 19 | "github.com/bepass-org/warp-plus/wireguard/tun/netstack" 20 | ) 21 | 22 | func main() { 23 | tun, tnet, err := netstack.CreateNetTUN( 24 | []netip.Addr{netip.MustParseAddr("192.168.4.29")}, 25 | []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, 26 | 1420, 27 | ) 28 | if err != nil { 29 | log.Panic(err) 30 | } 31 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 32 | dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641 33 | listen_port=58120 34 | public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c 35 | allowed_ip=192.168.4.28/32 36 | persistent_keepalive_interval=25 37 | `) 38 | dev.Up() 39 | listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80}) 40 | if err != nil { 41 | log.Panicln(err) 42 | } 43 | http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { 44 | log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent()) 45 | io.WriteString(writer, "Hello from userspace TCP!") 46 | }) 47 | err = http.Serve(listener, nil) 48 | if err != nil { 49 | log.Panicln(err) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /wireguard/tun/netstack/examples/ping_client.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "bytes" 12 | "log" 13 | "math/rand" 14 | "net/netip" 15 | "time" 16 | 17 | "golang.org/x/net/icmp" 18 | "golang.org/x/net/ipv4" 19 | 20 | "github.com/bepass-org/warp-plus/wireguard/conn" 21 | "github.com/bepass-org/warp-plus/wireguard/device" 22 | "github.com/bepass-org/warp-plus/wireguard/tun/netstack" 23 | ) 24 | 25 | func main() { 26 | tun, tnet, err := netstack.CreateNetTUN( 27 | []netip.Addr{netip.MustParseAddr("192.168.4.29")}, 28 | []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 29 | 1420) 30 | if err != nil { 31 | log.Panic(err) 32 | } 33 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 34 | dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f 35 | public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b 36 | endpoint=163.172.161.0:12912 37 | allowed_ip=0.0.0.0/0 38 | `) 39 | err = dev.Up() 40 | if err != nil { 41 | log.Panic(err) 42 | } 43 | 44 | socket, err := tnet.Dial("ping4", "zx2c4.com") 45 | if err != nil { 46 | log.Panic(err) 47 | } 48 | requestPing := icmp.Echo{ 49 | Seq: rand.Intn(1 << 16), 50 | Data: []byte("gopher burrow"), 51 | } 52 | icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) 53 | socket.SetReadDeadline(time.Now().Add(time.Second * 10)) 54 | start := time.Now() 55 | _, err = socket.Write(icmpBytes) 56 | if err != nil { 57 | log.Panic(err) 58 | } 59 | n, err := socket.Read(icmpBytes[:]) 60 | if err != nil { 61 | log.Panic(err) 62 | } 63 | replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) 64 | if err != nil { 65 | log.Panic(err) 66 | } 67 | replyPing, ok := replyPacket.Body.(*icmp.Echo) 68 | if !ok { 69 | log.Panicf("invalid reply type: %v", replyPacket) 70 | } 71 | if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { 72 | log.Panicf("invalid ping reply: %v", replyPing) 73 | } 74 | log.Printf("Ping latency: %v", time.Since(start)) 75 | } 76 | -------------------------------------------------------------------------------- /wireguard/tun/operateonfd.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package tun 9 | 10 | import ( 11 | "fmt" 12 | ) 13 | 14 | func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) { 15 | sysconn, err := tun.tunFile.SyscallConn() 16 | if err != nil { 17 | tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error()) 18 | return 19 | } 20 | err = sysconn.Control(fn) 21 | if err != nil { 22 | tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error()) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /wireguard/tun/tun.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "os" 10 | ) 11 | 12 | type Event int 13 | 14 | const ( 15 | EventUp = 1 << iota 16 | EventDown 17 | EventMTUUpdate 18 | ) 19 | 20 | type Device interface { 21 | // File returns the file descriptor of the device. 22 | File() *os.File 23 | 24 | // Read one or more packets from the Device (without any additional headers). 25 | // On a successful read it returns the number of packets read, and sets 26 | // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). 27 | // A nonzero offset can be used to instruct the Device on where to begin 28 | // reading into each element of the bufs slice. 29 | Read(bufs [][]byte, sizes []int, offset int) (n int, err error) 30 | 31 | // Write one or more packets to the device (without any additional headers). 32 | // On a successful write it returns the number of packets written. A nonzero 33 | // offset can be used to instruct the Device on where to begin writing from 34 | // each packet contained within the bufs slice. 35 | Write(bufs [][]byte, offset int) (int, error) 36 | 37 | // MTU returns the MTU of the Device. 38 | MTU() (int, error) 39 | 40 | // Name returns the current name of the Device. 41 | Name() (string, error) 42 | 43 | // Events returns a channel of type Event, which is fed Device events. 44 | Events() <-chan Event 45 | 46 | // Close stops the Device and closes the Event channel. 47 | Close() error 48 | 49 | // BatchSize returns the preferred/max number of packets that can be read or 50 | // written in a single read/write call. BatchSize must not change over the 51 | // lifetime of a Device. 52 | BatchSize() int 53 | } 54 | -------------------------------------------------------------------------------- /wireguard/tun/tuntest/tuntest.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tuntest 7 | 8 | import ( 9 | "encoding/binary" 10 | "io" 11 | "net/netip" 12 | "os" 13 | 14 | "github.com/bepass-org/warp-plus/wireguard/tun" 15 | ) 16 | 17 | func Ping(dst, src netip.Addr) []byte { 18 | localPort := uint16(1337) 19 | seq := uint16(0) 20 | 21 | payload := make([]byte, 4) 22 | binary.BigEndian.PutUint16(payload[0:], localPort) 23 | binary.BigEndian.PutUint16(payload[2:], seq) 24 | 25 | return genICMPv4(payload, dst, src) 26 | } 27 | 28 | // Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. 29 | func checksum(buf []byte, initial uint16) uint16 { 30 | v := uint32(initial) 31 | for i := 0; i < len(buf)-1; i += 2 { 32 | v += uint32(binary.BigEndian.Uint16(buf[i:])) 33 | } 34 | if len(buf)%2 == 1 { 35 | v += uint32(buf[len(buf)-1]) << 8 36 | } 37 | for v > 0xffff { 38 | v = (v >> 16) + (v & 0xffff) 39 | } 40 | return ^uint16(v) 41 | } 42 | 43 | func genICMPv4(payload []byte, dst, src netip.Addr) []byte { 44 | const ( 45 | icmpv4ProtocolNumber = 1 46 | icmpv4Echo = 8 47 | icmpv4ChecksumOffset = 2 48 | icmpv4Size = 8 49 | ipv4Size = 20 50 | ipv4TotalLenOffset = 2 51 | ipv4ChecksumOffset = 10 52 | ttl = 65 53 | headerSize = ipv4Size + icmpv4Size 54 | ) 55 | 56 | pkt := make([]byte, headerSize+len(payload)) 57 | 58 | ip := pkt[0:ipv4Size] 59 | icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size] 60 | 61 | // https://tools.ietf.org/html/rfc792 62 | icmpv4[0] = icmpv4Echo // type 63 | icmpv4[1] = 0 // code 64 | chksum := ^checksum(icmpv4, checksum(payload, 0)) 65 | binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) 66 | 67 | // https://tools.ietf.org/html/rfc760 section 3.1 68 | length := uint16(len(pkt)) 69 | ip[0] = (4 << 4) | (ipv4Size / 4) 70 | binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) 71 | ip[8] = ttl 72 | ip[9] = icmpv4ProtocolNumber 73 | copy(ip[12:], src.AsSlice()) 74 | copy(ip[16:], dst.AsSlice()) 75 | chksum = ^checksum(ip[:], 0) 76 | binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) 77 | 78 | copy(pkt[headerSize:], payload) 79 | return pkt 80 | } 81 | 82 | type ChannelTUN struct { 83 | Inbound chan []byte // incoming packets, closed on TUN close 84 | Outbound chan []byte // outbound packets, blocks forever on TUN close 85 | 86 | closed chan struct{} 87 | events chan tun.Event 88 | tun chTun 89 | } 90 | 91 | func NewChannelTUN() *ChannelTUN { 92 | c := &ChannelTUN{ 93 | Inbound: make(chan []byte), 94 | Outbound: make(chan []byte), 95 | closed: make(chan struct{}), 96 | events: make(chan tun.Event, 1), 97 | } 98 | c.tun.c = c 99 | c.events <- tun.EventUp 100 | return c 101 | } 102 | 103 | func (c *ChannelTUN) TUN() tun.Device { 104 | return &c.tun 105 | } 106 | 107 | type chTun struct { 108 | c *ChannelTUN 109 | } 110 | 111 | func (t *chTun) File() *os.File { return nil } 112 | 113 | func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { 114 | select { 115 | case <-t.c.closed: 116 | return 0, os.ErrClosed 117 | case msg := <-t.c.Outbound: 118 | n := copy(packets[0][offset:], msg) 119 | sizes[0] = n 120 | return 1, nil 121 | } 122 | } 123 | 124 | // Write is called by the wireguard device to deliver a packet for routing. 125 | func (t *chTun) Write(packets [][]byte, offset int) (int, error) { 126 | if offset == -1 { 127 | close(t.c.closed) 128 | close(t.c.events) 129 | return 0, io.EOF 130 | } 131 | for i, data := range packets { 132 | msg := make([]byte, len(data)-offset) 133 | copy(msg, data[offset:]) 134 | select { 135 | case <-t.c.closed: 136 | return i, os.ErrClosed 137 | case t.c.Inbound <- msg: 138 | } 139 | } 140 | return len(packets), nil 141 | } 142 | 143 | func (t *chTun) BatchSize() int { 144 | return 1 145 | } 146 | 147 | const DefaultMTU = 1420 148 | 149 | func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } 150 | func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } 151 | func (t *chTun) Events() <-chan tun.Event { return t.c.events } 152 | func (t *chTun) Close() error { 153 | t.Write(nil, -1) 154 | return nil 155 | } 156 | -------------------------------------------------------------------------------- /wiresocks/config_test.go: -------------------------------------------------------------------------------- 1 | package wiresocks 2 | 3 | import ( 4 | "net/netip" 5 | "testing" 6 | 7 | qt "github.com/frankban/quicktest" 8 | "github.com/go-ini/ini" 9 | "github.com/google/go-cmp/cmp/cmpopts" 10 | ) 11 | 12 | const testConfig = ` 13 | [Interface] 14 | PrivateKey = aK8FWhiV1CtKFbKUPssL13P+Tv+c5owmYcU5PCP6yFw= 15 | DNS = 8.8.8.8 16 | Address = 172.16.0.2/24 17 | Address = 2606:4700:110:8cc0:1ad3:9155:6742:ea8d/128 18 | MTU = 1500 19 | [Peer] 20 | PublicKey = bmXOC+F1FxEMF9dyiK2H5/1SUtzH0JuVo51h2wPfgyo= 21 | AllowedIPs = 0.0.0.0/0 22 | AllowedIPs = ::/0 23 | Endpoint = engage.cloudflareclient.com:2408 24 | PersistentKeepalive = 3 25 | Trick = true 26 | Reserved = 1,2,3 27 | ` 28 | const ( 29 | privateKeyBase64 = "68af055a1895d42b4a15b2943ecb0bd773fe4eff9ce68c2661c5393c23fac85c" 30 | publicKeyBase64 = "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a" 31 | presharedKeyBase64 = "0000000000000000000000000000000000000000000000000000000000000000" 32 | ) 33 | 34 | func TestParseInterface(t *testing.T) { 35 | opts := ini.LoadOptions{ 36 | Insensitive: true, 37 | AllowShadows: true, 38 | AllowNonUniqueSections: true, 39 | } 40 | 41 | cfg, err := ini.LoadSources(opts, []byte(testConfig)) 42 | qt.Assert(t, err, qt.IsNil) 43 | 44 | device, err := ParseInterface(cfg) 45 | qt.Assert(t, err, qt.IsNil) 46 | 47 | want := InterfaceConfig{ 48 | PrivateKey: privateKeyBase64, 49 | Addresses: []netip.Addr{ 50 | netip.MustParseAddr("172.16.0.2"), 51 | netip.MustParseAddr("2606:4700:110:8cc0:1ad3:9155:6742:ea8d"), 52 | }, 53 | DNS: []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 54 | MTU: 1500, 55 | } 56 | qt.Assert(t, device, qt.CmpEquals(cmpopts.EquateComparable(netip.Addr{})), want) 57 | t.Logf("%+v", device) 58 | } 59 | 60 | func TestParsePeers(t *testing.T) { 61 | opts := ini.LoadOptions{ 62 | Insensitive: true, 63 | AllowShadows: true, 64 | AllowNonUniqueSections: true, 65 | } 66 | 67 | cfg, err := ini.LoadSources(opts, []byte(testConfig)) 68 | qt.Assert(t, err, qt.IsNil) 69 | 70 | peers, err := ParsePeers(cfg) 71 | qt.Assert(t, err, qt.IsNil) 72 | 73 | want := []PeerConfig{{ 74 | PublicKey: publicKeyBase64, 75 | PreSharedKey: presharedKeyBase64, 76 | Endpoint: "engage.cloudflareclient.com:2408", 77 | KeepAlive: 3, 78 | AllowedIPs: []netip.Prefix{ 79 | netip.MustParsePrefix("0.0.0.0/0"), 80 | netip.MustParsePrefix("::/0"), 81 | }, 82 | Trick: true, 83 | Reserved: [3]byte{1, 2, 3}, 84 | }} 85 | qt.Assert(t, peers, qt.CmpEquals(cmpopts.EquateComparable(netip.Prefix{})), want) 86 | t.Logf("%+v", peers) 87 | } 88 | -------------------------------------------------------------------------------- /wiresocks/proxy.go: -------------------------------------------------------------------------------- 1 | package wiresocks 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "github.com/sagernet/sing/common/buf" 7 | "io" 8 | "log/slog" 9 | "net" 10 | "net/netip" 11 | "syscall" 12 | "time" 13 | 14 | "github.com/bepass-org/warp-plus/proxy/pkg/mixed" 15 | "github.com/bepass-org/warp-plus/proxy/pkg/statute" 16 | "github.com/bepass-org/warp-plus/wireguard/device" 17 | "github.com/bepass-org/warp-plus/wireguard/tun/netstack" 18 | ) 19 | 20 | // VirtualTun stores a reference to netstack network and DNS configuration 21 | type VirtualTun struct { 22 | Tnet *netstack.Net 23 | Logger *slog.Logger 24 | Dev *device.Device 25 | Ctx context.Context 26 | pool buf.Allocator 27 | //pool bufferpool.BufPool 28 | } 29 | 30 | var BuffSize = 65536 31 | 32 | // StartProxy spawns a socks5 server. 33 | func StartProxy(ctx context.Context, l *slog.Logger, tnet *netstack.Net, bindAddress netip.AddrPort) (netip.AddrPort, error) { 34 | ln, err := net.Listen("tcp", bindAddress.String()) 35 | if err != nil { 36 | return netip.AddrPort{}, err // Return error if binding was unsuccessful 37 | } 38 | 39 | vt := VirtualTun{ 40 | Tnet: tnet, 41 | Logger: l.With("subsystem", "vtun"), 42 | Dev: nil, 43 | Ctx: ctx, 44 | pool: buf.DefaultAllocator, 45 | } 46 | 47 | proxy := mixed.NewProxy( 48 | mixed.WithListener(ln), 49 | mixed.WithLogger(l), 50 | mixed.WithContext(ctx), 51 | mixed.WithUserHandler(func(request *statute.ProxyRequest) error { 52 | return vt.generalHandler(request) 53 | }), 54 | ) 55 | go func() { 56 | _ = proxy.ListenAndServe() 57 | }() 58 | go func() { 59 | <-vt.Ctx.Done() 60 | vt.Stop() 61 | }() 62 | 63 | return ln.Addr().(*net.TCPAddr).AddrPort(), nil 64 | } 65 | 66 | func (vt *VirtualTun) generalHandler(req *statute.ProxyRequest) error { 67 | vt.Logger.Debug("handling connection", "protocol", req.Network, "destination", req.Destination) 68 | conn, err := vt.Tnet.Dial(req.Network, req.Destination) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | timeout := 0 * time.Second 74 | switch req.Network { 75 | case "udp", "udp4", "udp6": 76 | timeout = 15 * time.Second 77 | } 78 | 79 | // Close the connections when this function exits 80 | defer conn.Close() 81 | defer req.Conn.Close() 82 | // Channel to notify when copy operation is done 83 | done := make(chan error, 1) 84 | // Copy data from req.Conn to conn 85 | go func() { 86 | buf1 := vt.pool.Get(BuffSize) 87 | defer func(pool buf.Allocator, buf []byte) { 88 | _ = pool.Put(buf) 89 | }(vt.pool, buf1) 90 | _, err := copyConnTimeout(conn, req.Conn, buf1, timeout) 91 | if errors.Is(err, syscall.ECONNRESET) { 92 | done <- nil 93 | return 94 | } 95 | done <- err 96 | }() 97 | // Copy data from conn to req.Conn 98 | go func() { 99 | buf2 := vt.pool.Get(BuffSize) 100 | defer func(pool buf.Allocator, buf []byte) { 101 | _ = pool.Put(buf) 102 | }(vt.pool, buf2) 103 | _, err := copyConnTimeout(req.Conn, conn, buf2, timeout) 104 | done <- err 105 | }() 106 | // Wait for one of the copy operations to finish 107 | err = <-done 108 | if err != nil { 109 | vt.Logger.Warn(err.Error()) 110 | } 111 | 112 | // Close connections and wait for the other copy operation to finish 113 | <-done 114 | return nil 115 | } 116 | 117 | func (vt *VirtualTun) Stop() { 118 | if vt.Dev != nil { 119 | if err := vt.Dev.Down(); err != nil { 120 | vt.Logger.Warn(err.Error()) 121 | } 122 | } 123 | } 124 | 125 | var errInvalidWrite = errors.New("invalid write result") 126 | 127 | func copyConnTimeout(dst net.Conn, src net.Conn, buf []byte, timeout time.Duration) (written int64, err error) { 128 | if buf != nil && len(buf) == 0 { 129 | panic("empty buffer in CopyBuffer") 130 | } 131 | 132 | for { 133 | deadline := time.Time{} 134 | if timeout != 0 { 135 | deadline = time.Now().Add(timeout) 136 | } 137 | if err := src.SetReadDeadline(deadline); err != nil { 138 | return 0, err 139 | } 140 | 141 | nr, er := src.Read(buf) 142 | if nr > 0 { 143 | nw, ew := dst.Write(buf[0:nr]) 144 | if nw < 0 || nr < nw { 145 | nw = 0 146 | if ew == nil { 147 | ew = errInvalidWrite 148 | } 149 | } 150 | written += int64(nw) 151 | if ew != nil { 152 | err = ew 153 | break 154 | } 155 | if nr != nw { 156 | err = io.ErrShortWrite 157 | break 158 | } 159 | } 160 | if er != nil { 161 | if er != io.EOF { 162 | err = er 163 | } 164 | break 165 | } 166 | } 167 | return written, err 168 | } 169 | -------------------------------------------------------------------------------- /wiresocks/scanner.go: -------------------------------------------------------------------------------- 1 | package wiresocks 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "log/slog" 7 | "time" 8 | 9 | "github.com/bepass-org/warp-plus/ipscanner" 10 | "github.com/bepass-org/warp-plus/warp" 11 | ) 12 | 13 | type ScanOptions struct { 14 | V4 bool 15 | V6 bool 16 | MaxRTT time.Duration 17 | PrivateKey string 18 | PublicKey string 19 | } 20 | 21 | func RunScan(ctx context.Context, l *slog.Logger, opts ScanOptions) (result []ipscanner.IPInfo, err error) { 22 | ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) 23 | defer cancel() 24 | 25 | scanner := ipscanner.NewScanner( 26 | ipscanner.WithLogger(l.With(slog.String("subsystem", "scanner"))), 27 | ipscanner.WithWarpPrivateKey(opts.PrivateKey), 28 | ipscanner.WithWarpPeerPublicKey(opts.PublicKey), 29 | ipscanner.WithUseIPv4(opts.V4), 30 | ipscanner.WithUseIPv6(opts.V6), 31 | ipscanner.WithMaxDesirableRTT(opts.MaxRTT), 32 | ipscanner.WithCidrList(warp.WarpPrefixes()), 33 | ) 34 | 35 | scanner.Run(ctx) 36 | 37 | t := time.NewTicker(1 * time.Second) 38 | defer t.Stop() 39 | 40 | for { 41 | ipList := scanner.GetAvailableIPs() 42 | if len(ipList) > 1 { 43 | for i := 0; i < 2; i++ { 44 | result = append(result, ipList[i]) 45 | } 46 | return result, nil 47 | } 48 | 49 | select { 50 | case <-ctx.Done(): 51 | // Context is done - canceled externally 52 | return nil, errors.New("user canceled the operation") 53 | case <-t.C: 54 | // Prevent the loop from spinning too fast 55 | continue 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /wiresocks/udpfw.go: -------------------------------------------------------------------------------- 1 | package wiresocks 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/netip" 7 | "sync" 8 | 9 | "github.com/bepass-org/warp-plus/wireguard/tun/netstack" 10 | ) 11 | 12 | func NewVtunUDPForwarder(ctx context.Context, localBind netip.AddrPort, dest string, tnet *netstack.Net, mtu int) (netip.AddrPort, error) { 13 | destAddr, err := net.ResolveUDPAddr("udp", dest) 14 | if err != nil { 15 | return netip.AddrPort{}, err 16 | } 17 | 18 | listener, err := net.ListenUDP("udp", net.UDPAddrFromAddrPort(localBind)) 19 | if err != nil { 20 | return netip.AddrPort{}, err 21 | } 22 | 23 | rconn, err := tnet.DialUDP(nil, destAddr) 24 | if err != nil { 25 | return netip.AddrPort{}, err 26 | } 27 | 28 | var clientAddr *net.UDPAddr 29 | var wg sync.WaitGroup 30 | wg.Add(2) 31 | 32 | go func() { 33 | buffer := make([]byte, mtu) 34 | for { 35 | select { 36 | case <-ctx.Done(): 37 | wg.Done() 38 | return 39 | default: 40 | n, cAddr, err := listener.ReadFrom(buffer) 41 | if err != nil { 42 | continue 43 | } 44 | 45 | clientAddr = cAddr.(*net.UDPAddr) 46 | 47 | rconn.WriteTo(buffer[:n], destAddr) 48 | } 49 | } 50 | }() 51 | go func() { 52 | buffer := make([]byte, mtu) 53 | for { 54 | select { 55 | case <-ctx.Done(): 56 | wg.Done() 57 | return 58 | default: 59 | n, _, err := rconn.ReadFrom(buffer) 60 | if err != nil { 61 | continue 62 | } 63 | if clientAddr != nil { 64 | listener.WriteTo(buffer[:n], clientAddr) 65 | } 66 | } 67 | } 68 | }() 69 | go func() { 70 | wg.Wait() 71 | _ = listener.Close() 72 | _ = rconn.Close() 73 | }() 74 | 75 | return listener.LocalAddr().(*net.UDPAddr).AddrPort(), nil 76 | } 77 | --------------------------------------------------------------------------------