├── .github └── workflows │ └── build-if-tag.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── conn ├── bind_std.go ├── bind_std_test.go ├── bind_windows.go ├── bindtest │ └── bindtest.go ├── boundif_android.go ├── conn.go ├── conn_test.go ├── controlfns.go ├── controlfns_linux.go ├── controlfns_unix.go ├── controlfns_windows.go ├── default.go ├── errors_default.go ├── errors_linux.go ├── features_default.go ├── features_linux.go ├── gso_default.go ├── gso_linux.go ├── mark_default.go ├── mark_unix.go ├── sticky_default.go ├── sticky_linux.go ├── sticky_linux_test.go └── winrio │ └── rio_windows.go ├── device ├── allowedips.go ├── allowedips_rand_test.go ├── allowedips_test.go ├── bind_test.go ├── channels.go ├── constants.go ├── cookie.go ├── cookie_test.go ├── device.go ├── device_test.go ├── devicestate_string.go ├── endpoint_test.go ├── indextable.go ├── ip.go ├── junk_creator.go ├── junk_creator_test.go ├── kdf_test.go ├── keypair.go ├── logger.go ├── mobilequirks.go ├── noise-helpers.go ├── noise-protocol.go ├── noise-types.go ├── noise_test.go ├── peer.go ├── pools.go ├── pools_test.go ├── queueconstants_android.go ├── queueconstants_default.go ├── queueconstants_ios.go ├── queueconstants_windows.go ├── race_disabled_test.go ├── race_enabled_test.go ├── receive.go ├── send.go ├── sticky_default.go ├── sticky_linux.go ├── timers.go ├── tun.go └── uapi.go ├── format_test.go ├── go.mod ├── go.sum ├── ipc ├── namedpipe │ ├── file.go │ ├── namedpipe.go │ └── namedpipe_test.go ├── uapi_bsd.go ├── uapi_linux.go ├── uapi_unix.go ├── uapi_wasm.go └── uapi_windows.go ├── main.go ├── main_windows.go ├── ratelimiter ├── ratelimiter.go └── ratelimiter_test.go ├── replay ├── replay.go └── replay_test.go ├── rwcancel ├── rwcancel.go └── rwcancel_stub.go ├── tai64n ├── tai64n.go └── tai64n_test.go ├── tests └── netns.sh ├── tun ├── alignment_windows_test.go ├── checksum.go ├── checksum_test.go ├── errors.go ├── netstack │ ├── examples │ │ ├── http_client.go │ │ ├── http_server.go │ │ └── ping_client.go │ └── tun.go ├── offload_linux.go ├── offload_linux_test.go ├── operateonfd.go ├── tun.go ├── tun_darwin.go ├── tun_freebsd.go ├── tun_linux.go ├── tun_openbsd.go ├── tun_windows.go └── tuntest │ └── tuntest.go └── version.go /.github/workflows/build-if-tag.yml: -------------------------------------------------------------------------------- 1 | name: build-if-tag 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v[0-9]+.[0-9]+.[0-9]+' 7 | 8 | env: 9 | APP: amneziawg-go 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | name: build 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v4 18 | with: 19 | ref: ${{ github.ref_name }} 20 | 21 | - name: Login to Docker Hub 22 | uses: docker/login-action@v3 23 | with: 24 | username: ${{ secrets.DOCKERHUB_USERNAME }} 25 | password: ${{ secrets.DOCKERHUB_TOKEN }} 26 | 27 | - name: Setup metadata 28 | uses: docker/metadata-action@v5 29 | id: metadata 30 | with: 31 | images: amneziavpn/${{ env.APP }} 32 | tags: type=semver,pattern={{version}} 33 | 34 | - name: Set up Docker Buildx 35 | uses: docker/setup-buildx-action@v3 36 | 37 | - name: Build 38 | uses: docker/build-push-action@v5 39 | with: 40 | push: true 41 | tags: ${{ steps.metadata.outputs.tags }} 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | amneziawg-go -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.24 as awg 2 | COPY . /awg 3 | WORKDIR /awg 4 | RUN go mod download && \ 5 | go mod verify && \ 6 | go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin 7 | 8 | FROM alpine:3.19 9 | ARG AWGTOOLS_RELEASE="1.0.20241018" 10 | RUN apk --no-cache add iproute2 iptables bash && \ 11 | cd /usr/bin/ && \ 12 | wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \ 13 | unzip -j alpine-3.19-amneziawg-tools.zip && \ 14 | chmod +x /usr/bin/awg /usr/bin/awg-quick && \ 15 | ln -s /usr/bin/awg /usr/bin/wg && \ 16 | ln -s /usr/bin/awg-quick /usr/bin/wg-quick 17 | COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy of 2 | this software and associated documentation files (the "Software"), to deal in 3 | the Software without restriction, including without limitation the rights to 4 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 5 | of the Software, and to permit persons to whom the Software is furnished to do 6 | so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all 9 | copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PREFIX ?= /usr 2 | DESTDIR ?= 3 | BINDIR ?= $(PREFIX)/bin 4 | export GO111MODULE := on 5 | 6 | all: generate-version-and-build 7 | 8 | MAKEFLAGS += --no-print-directory 9 | 10 | generate-version-and-build: 11 | @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ 12 | tag="$$(git describe --tags --dirty 2>/dev/null)" && \ 13 | ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \ 14 | [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ 15 | echo "$$ver" > version.go && \ 16 | git update-index --assume-unchanged version.go || true 17 | @$(MAKE) amneziawg-go 18 | 19 | amneziawg-go: $(wildcard *.go) $(wildcard */*.go) 20 | go build -v -o "$@" 21 | 22 | install: amneziawg-go 23 | @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go" 24 | 25 | test: 26 | go test ./... 27 | 28 | clean: 29 | rm -f amneziawg-go 30 | 31 | .PHONY: all clean test install generate-version-and-build 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Go Implementation of AmneziaWG 2 | 3 | AmneziaWG is a contemporary version of the WireGuard protocol. It's a fork of WireGuard-Go and offers protection against detection by Deep Packet Inspection (DPI) systems. At the same time, it retains the simplified architecture and high performance of the original. 4 | 5 | The precursor, WireGuard, is known for its efficiency but had issues with detection due to its distinctive packet signatures. 6 | AmneziaWG addresses this problem by employing advanced obfuscation methods, allowing its traffic to blend seamlessly with regular internet traffic. 7 | As a result, AmneziaWG maintains high performance while adding an extra layer of stealth, making it a superb choice for those seeking a fast and discreet VPN connection. 8 | 9 | ## Usage 10 | 11 | Simply run: 12 | 13 | ``` 14 | $ amneziawg-go wg0 15 | ``` 16 | 17 | This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down. 18 | 19 | To run amneziawg-go without forking to the background, pass `-f` or `--foreground`: 20 | 21 | ``` 22 | $ amneziawg-go -f wg0 23 | ``` 24 | When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands. 25 | 26 | To run with more logging you may set the environment variable `LOG_LEVEL=debug`. 27 | 28 | ## Platforms 29 | 30 | ### Linux 31 | 32 | This will run on Linux; you should run amnezia-wg instead of using default linux kernel module. 33 | 34 | ### macOS 35 | 36 | This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. 37 | This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple) 38 | 39 | ### Windows 40 | 41 | This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module. 42 | 43 | 44 | ## Building 45 | 46 | This requires an installation of the latest version of [Go](https://go.dev/). 47 | 48 | ``` 49 | $ git clone https://github.com/amnezia-vpn/amneziawg-go 50 | $ cd amneziawg-go 51 | $ make 52 | ``` 53 | -------------------------------------------------------------------------------- /conn/bind_std_test.go: -------------------------------------------------------------------------------- 1 | package conn 2 | 3 | import ( 4 | "encoding/binary" 5 | "net" 6 | "testing" 7 | 8 | "golang.org/x/net/ipv6" 9 | ) 10 | 11 | func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { 12 | bind := NewStdNetBind().(*StdNetBind) 13 | fns, _, err := bind.Open(0) 14 | if err != nil { 15 | t.Fatal(err) 16 | } 17 | bind.Close() 18 | bufs := make([][]byte, 1) 19 | bufs[0] = make([]byte, 1) 20 | sizes := make([]int, 1) 21 | eps := make([]Endpoint, 1) 22 | for _, fn := range fns { 23 | // The ReceiveFuncs must not access conn-related fields on StdNetBind 24 | // unguarded. Close() nils the conn-related fields resulting in a panic 25 | // if they violate the mutex. 26 | fn(bufs, sizes, eps) 27 | } 28 | } 29 | 30 | func mockSetGSOSize(control *[]byte, gsoSize uint16) { 31 | *control = (*control)[:cap(*control)] 32 | binary.LittleEndian.PutUint16(*control, gsoSize) 33 | } 34 | 35 | func Test_coalesceMessages(t *testing.T) { 36 | cases := []struct { 37 | name string 38 | buffs [][]byte 39 | wantLens []int 40 | wantGSO []int 41 | }{ 42 | { 43 | name: "one message no coalesce", 44 | buffs: [][]byte{ 45 | make([]byte, 1, 1), 46 | }, 47 | wantLens: []int{1}, 48 | wantGSO: []int{0}, 49 | }, 50 | { 51 | name: "two messages equal len coalesce", 52 | buffs: [][]byte{ 53 | make([]byte, 1, 2), 54 | make([]byte, 1, 1), 55 | }, 56 | wantLens: []int{2}, 57 | wantGSO: []int{1}, 58 | }, 59 | { 60 | name: "two messages unequal len coalesce", 61 | buffs: [][]byte{ 62 | make([]byte, 2, 3), 63 | make([]byte, 1, 1), 64 | }, 65 | wantLens: []int{3}, 66 | wantGSO: []int{2}, 67 | }, 68 | { 69 | name: "three messages second unequal len coalesce", 70 | buffs: [][]byte{ 71 | make([]byte, 2, 3), 72 | make([]byte, 1, 1), 73 | make([]byte, 2, 2), 74 | }, 75 | wantLens: []int{3, 2}, 76 | wantGSO: []int{2, 0}, 77 | }, 78 | { 79 | name: "three messages limited cap coalesce", 80 | buffs: [][]byte{ 81 | make([]byte, 2, 4), 82 | make([]byte, 2, 2), 83 | make([]byte, 2, 2), 84 | }, 85 | wantLens: []int{4, 2}, 86 | wantGSO: []int{2, 0}, 87 | }, 88 | } 89 | 90 | for _, tt := range cases { 91 | t.Run(tt.name, func(t *testing.T) { 92 | addr := &net.UDPAddr{ 93 | IP: net.ParseIP("127.0.0.1").To4(), 94 | Port: 1, 95 | } 96 | msgs := make([]ipv6.Message, len(tt.buffs)) 97 | for i := range msgs { 98 | msgs[i].Buffers = make([][]byte, 1) 99 | msgs[i].OOB = make([]byte, 0, 2) 100 | } 101 | got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize) 102 | if got != len(tt.wantLens) { 103 | t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) 104 | } 105 | for i := 0; i < got; i++ { 106 | if msgs[i].Addr != addr { 107 | t.Errorf("msgs[%d].Addr != passed addr", i) 108 | } 109 | gotLen := len(msgs[i].Buffers[0]) 110 | if gotLen != tt.wantLens[i] { 111 | t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) 112 | } 113 | gotGSO, err := mockGetGSOSize(msgs[i].OOB) 114 | if err != nil { 115 | t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) 116 | } 117 | if gotGSO != tt.wantGSO[i] { 118 | t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) 119 | } 120 | } 121 | }) 122 | } 123 | } 124 | 125 | func mockGetGSOSize(control []byte) (int, error) { 126 | if len(control) < 2 { 127 | return 0, nil 128 | } 129 | return int(binary.LittleEndian.Uint16(control)), nil 130 | } 131 | 132 | func Test_splitCoalescedMessages(t *testing.T) { 133 | newMsg := func(n, gso int) ipv6.Message { 134 | msg := ipv6.Message{ 135 | Buffers: [][]byte{make([]byte, 1<<16-1)}, 136 | N: n, 137 | OOB: make([]byte, 2), 138 | } 139 | binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) 140 | if gso > 0 { 141 | msg.NN = 2 142 | } 143 | return msg 144 | } 145 | 146 | cases := []struct { 147 | name string 148 | msgs []ipv6.Message 149 | firstMsgAt int 150 | wantNumEval int 151 | wantMsgLens []int 152 | wantErr bool 153 | }{ 154 | { 155 | name: "second last split last empty", 156 | msgs: []ipv6.Message{ 157 | newMsg(0, 0), 158 | newMsg(0, 0), 159 | newMsg(3, 1), 160 | newMsg(0, 0), 161 | }, 162 | firstMsgAt: 2, 163 | wantNumEval: 3, 164 | wantMsgLens: []int{1, 1, 1, 0}, 165 | wantErr: false, 166 | }, 167 | { 168 | name: "second last no split last empty", 169 | msgs: []ipv6.Message{ 170 | newMsg(0, 0), 171 | newMsg(0, 0), 172 | newMsg(1, 0), 173 | newMsg(0, 0), 174 | }, 175 | firstMsgAt: 2, 176 | wantNumEval: 1, 177 | wantMsgLens: []int{1, 0, 0, 0}, 178 | wantErr: false, 179 | }, 180 | { 181 | name: "second last no split last no split", 182 | msgs: []ipv6.Message{ 183 | newMsg(0, 0), 184 | newMsg(0, 0), 185 | newMsg(1, 0), 186 | newMsg(1, 0), 187 | }, 188 | firstMsgAt: 2, 189 | wantNumEval: 2, 190 | wantMsgLens: []int{1, 1, 0, 0}, 191 | wantErr: false, 192 | }, 193 | { 194 | name: "second last no split last split", 195 | msgs: []ipv6.Message{ 196 | newMsg(0, 0), 197 | newMsg(0, 0), 198 | newMsg(1, 0), 199 | newMsg(3, 1), 200 | }, 201 | firstMsgAt: 2, 202 | wantNumEval: 4, 203 | wantMsgLens: []int{1, 1, 1, 1}, 204 | wantErr: false, 205 | }, 206 | { 207 | name: "second last split last split", 208 | msgs: []ipv6.Message{ 209 | newMsg(0, 0), 210 | newMsg(0, 0), 211 | newMsg(2, 1), 212 | newMsg(2, 1), 213 | }, 214 | firstMsgAt: 2, 215 | wantNumEval: 4, 216 | wantMsgLens: []int{1, 1, 1, 1}, 217 | wantErr: false, 218 | }, 219 | { 220 | name: "second last no split last split overflow", 221 | msgs: []ipv6.Message{ 222 | newMsg(0, 0), 223 | newMsg(0, 0), 224 | newMsg(1, 0), 225 | newMsg(4, 1), 226 | }, 227 | firstMsgAt: 2, 228 | wantNumEval: 4, 229 | wantMsgLens: []int{1, 1, 1, 1}, 230 | wantErr: true, 231 | }, 232 | } 233 | 234 | for _, tt := range cases { 235 | t.Run(tt.name, func(t *testing.T) { 236 | got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) 237 | if err != nil && !tt.wantErr { 238 | t.Fatalf("err: %v", err) 239 | } 240 | if got != tt.wantNumEval { 241 | t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) 242 | } 243 | for i, msg := range tt.msgs { 244 | if msg.N != tt.wantMsgLens[i] { 245 | t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) 246 | } 247 | } 248 | }) 249 | } 250 | } 251 | -------------------------------------------------------------------------------- /conn/bindtest/bindtest.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package bindtest 7 | 8 | import ( 9 | "fmt" 10 | "math/rand" 11 | "net" 12 | "net/netip" 13 | "os" 14 | 15 | "github.com/amnezia-vpn/amneziawg-go/conn" 16 | ) 17 | 18 | type ChannelBind struct { 19 | rx4, tx4 *chan []byte 20 | rx6, tx6 *chan []byte 21 | closeSignal chan bool 22 | source4, source6 ChannelEndpoint 23 | target4, target6 ChannelEndpoint 24 | } 25 | 26 | type ChannelEndpoint uint16 27 | 28 | var ( 29 | _ conn.Bind = (*ChannelBind)(nil) 30 | _ conn.Endpoint = (*ChannelEndpoint)(nil) 31 | ) 32 | 33 | func NewChannelBinds() [2]conn.Bind { 34 | arx4 := make(chan []byte, 8192) 35 | brx4 := make(chan []byte, 8192) 36 | arx6 := make(chan []byte, 8192) 37 | brx6 := make(chan []byte, 8192) 38 | var binds [2]ChannelBind 39 | binds[0].rx4 = &arx4 40 | binds[0].tx4 = &brx4 41 | binds[1].rx4 = &brx4 42 | binds[1].tx4 = &arx4 43 | binds[0].rx6 = &arx6 44 | binds[0].tx6 = &brx6 45 | binds[1].rx6 = &brx6 46 | binds[1].tx6 = &arx6 47 | binds[0].target4 = ChannelEndpoint(1) 48 | binds[1].target4 = ChannelEndpoint(2) 49 | binds[0].target6 = ChannelEndpoint(3) 50 | binds[1].target6 = ChannelEndpoint(4) 51 | binds[0].source4 = binds[1].target4 52 | binds[0].source6 = binds[1].target6 53 | binds[1].source4 = binds[0].target4 54 | binds[1].source6 = binds[0].target6 55 | return [2]conn.Bind{&binds[0], &binds[1]} 56 | } 57 | 58 | func (c ChannelEndpoint) ClearSrc() {} 59 | 60 | func (c ChannelEndpoint) SrcToString() string { return "" } 61 | 62 | func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } 63 | 64 | func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } 65 | 66 | func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } 67 | 68 | func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } 69 | 70 | func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 71 | c.closeSignal = make(chan bool) 72 | fns = append(fns, c.makeReceiveFunc(*c.rx4)) 73 | fns = append(fns, c.makeReceiveFunc(*c.rx6)) 74 | if rand.Uint32()&1 == 0 { 75 | return fns, uint16(c.source4), nil 76 | } else { 77 | return fns, uint16(c.source6), nil 78 | } 79 | } 80 | 81 | func (c *ChannelBind) Close() error { 82 | if c.closeSignal != nil { 83 | select { 84 | case <-c.closeSignal: 85 | default: 86 | close(c.closeSignal) 87 | } 88 | } 89 | return nil 90 | } 91 | 92 | func (c *ChannelBind) BatchSize() int { return 1 } 93 | 94 | func (c *ChannelBind) SetMark(mark uint32) error { return nil } 95 | 96 | func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { 97 | return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { 98 | select { 99 | case <-c.closeSignal: 100 | return 0, net.ErrClosed 101 | case rx := <-ch: 102 | copied := copy(bufs[0], rx) 103 | sizes[0] = copied 104 | eps[0] = c.target6 105 | return 1, nil 106 | } 107 | } 108 | } 109 | 110 | func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { 111 | for _, b := range bufs { 112 | select { 113 | case <-c.closeSignal: 114 | return net.ErrClosed 115 | default: 116 | bc := make([]byte, len(b)) 117 | copy(bc, b) 118 | if ep.(ChannelEndpoint) == c.target4 { 119 | *c.tx4 <- bc 120 | } else if ep.(ChannelEndpoint) == c.target6 { 121 | *c.tx6 <- bc 122 | } else { 123 | return os.ErrInvalid 124 | } 125 | } 126 | } 127 | return nil 128 | } 129 | 130 | func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { 131 | addr, err := netip.ParseAddrPort(s) 132 | if err != nil { 133 | return nil, err 134 | } 135 | return ChannelEndpoint(addr.Port()), nil 136 | } 137 | -------------------------------------------------------------------------------- /conn/boundif_android.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { 9 | sysconn, err := s.ipv4.SyscallConn() 10 | if err != nil { 11 | return -1, err 12 | } 13 | err = sysconn.Control(func(f uintptr) { 14 | fd = int(f) 15 | }) 16 | if err != nil { 17 | return -1, err 18 | } 19 | return 20 | } 21 | 22 | func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { 23 | sysconn, err := s.ipv6.SyscallConn() 24 | if err != nil { 25 | return -1, err 26 | } 27 | err = sysconn.Control(func(f uintptr) { 28 | fd = int(f) 29 | }) 30 | if err != nil { 31 | return -1, err 32 | } 33 | return 34 | } 35 | -------------------------------------------------------------------------------- /conn/conn.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | // Package conn implements WireGuard's network connections. 7 | package conn 8 | 9 | import ( 10 | "errors" 11 | "fmt" 12 | "net/netip" 13 | "reflect" 14 | "runtime" 15 | "strings" 16 | ) 17 | 18 | const ( 19 | IdealBatchSize = 128 // maximum number of packets handled per read and write 20 | ) 21 | 22 | // A ReceiveFunc receives at least one packet from the network and writes them 23 | // into packets. On a successful read it returns the number of elements of 24 | // sizes, packets, and endpoints that should be evaluated. Some elements of 25 | // sizes may be zero, and callers should ignore them. Callers must pass a sizes 26 | // and eps slice with a length greater than or equal to the length of packets. 27 | // These lengths must not exceed the length of the associated Bind.BatchSize(). 28 | type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) 29 | 30 | // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. 31 | // 32 | // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, 33 | // depending on the platform-specific implementation. 34 | type Bind interface { 35 | // Open puts the Bind into a listening state on a given port and reports the actual 36 | // port that it bound to. Passing zero results in a random selection. 37 | // fns is the set of functions that will be called to receive packets. 38 | Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) 39 | 40 | // Close closes the Bind listener. 41 | // All fns returned by Open must return net.ErrClosed after a call to Close. 42 | Close() error 43 | 44 | // SetMark sets the mark for each packet sent through this Bind. 45 | // This mark is passed to the kernel as the socket option SO_MARK. 46 | SetMark(mark uint32) error 47 | 48 | // Send writes one or more packets in bufs to address ep. The length of 49 | // bufs must not exceed BatchSize(). 50 | Send(bufs [][]byte, ep Endpoint) error 51 | 52 | // ParseEndpoint creates a new endpoint from a string. 53 | ParseEndpoint(s string) (Endpoint, error) 54 | 55 | // BatchSize is the number of buffers expected to be passed to 56 | // the ReceiveFuncs, and the maximum expected to be passed to SendBatch. 57 | BatchSize() int 58 | } 59 | 60 | // BindSocketToInterface is implemented by Bind objects that support being 61 | // tied to a single network interface. Used by wireguard-windows. 62 | type BindSocketToInterface interface { 63 | BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error 64 | BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error 65 | } 66 | 67 | // PeekLookAtSocketFd is implemented by Bind objects that support having their 68 | // file descriptor peeked at. Used by wireguard-android. 69 | type PeekLookAtSocketFd interface { 70 | PeekLookAtSocketFd4() (fd int, err error) 71 | PeekLookAtSocketFd6() (fd int, err error) 72 | } 73 | 74 | // An Endpoint maintains the source/destination caching for a peer. 75 | // 76 | // dst: the remote address of a peer ("endpoint" in uapi terminology) 77 | // src: the local address from which datagrams originate going to the peer 78 | type Endpoint interface { 79 | ClearSrc() // clears the source address 80 | SrcToString() string // returns the local source address (ip:port) 81 | DstToString() string // returns the destination address (ip:port) 82 | DstToBytes() []byte // used for mac2 cookie calculations 83 | DstIP() netip.Addr 84 | SrcIP() netip.Addr 85 | } 86 | 87 | var ( 88 | ErrBindAlreadyOpen = errors.New("bind is already open") 89 | ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") 90 | ) 91 | 92 | func (fn ReceiveFunc) PrettyName() string { 93 | name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() 94 | // 0. cheese/taco.beansIPv6.func12.func21218-fm 95 | name = strings.TrimSuffix(name, "-fm") 96 | // 1. cheese/taco.beansIPv6.func12.func21218 97 | if idx := strings.LastIndexByte(name, '/'); idx != -1 { 98 | name = name[idx+1:] 99 | // 2. taco.beansIPv6.func12.func21218 100 | } 101 | for { 102 | var idx int 103 | for idx = len(name) - 1; idx >= 0; idx-- { 104 | if name[idx] < '0' || name[idx] > '9' { 105 | break 106 | } 107 | } 108 | if idx == len(name)-1 { 109 | break 110 | } 111 | const dotFunc = ".func" 112 | if !strings.HasSuffix(name[:idx+1], dotFunc) { 113 | break 114 | } 115 | name = name[:idx+1-len(dotFunc)] 116 | // 3. taco.beansIPv6.func12 117 | // 4. taco.beansIPv6 118 | } 119 | if idx := strings.LastIndexByte(name, '.'); idx != -1 { 120 | name = name[idx+1:] 121 | // 5. beansIPv6 122 | } 123 | if name == "" { 124 | return fmt.Sprintf("%p", fn) 125 | } 126 | if strings.HasSuffix(name, "IPv4") { 127 | return "v4" 128 | } 129 | if strings.HasSuffix(name, "IPv6") { 130 | return "v6" 131 | } 132 | return name 133 | } 134 | -------------------------------------------------------------------------------- /conn/conn_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "testing" 10 | ) 11 | 12 | func TestPrettyName(t *testing.T) { 13 | var ( 14 | recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } 15 | ) 16 | 17 | const want = "TestPrettyName" 18 | 19 | t.Run("ReceiveFunc.PrettyName", func(t *testing.T) { 20 | if got := recvFunc.PrettyName(); got != want { 21 | t.Errorf("PrettyName() = %v, want %v", got, want) 22 | } 23 | }) 24 | } 25 | -------------------------------------------------------------------------------- /conn/controlfns.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "net" 10 | "syscall" 11 | ) 12 | 13 | // UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is 14 | // the max supported by a default configuration of macOS. Some platforms will 15 | // silently clamp the value to other maximums, such as linux clamping to 16 | // net.core.{r,w}mem_max (see _linux.go for additional implementation that works 17 | // around this limitation) 18 | const socketBufferSize = 7 << 20 19 | 20 | // controlFn is the callback function signature from net.ListenConfig.Control. 21 | // It is used to apply platform specific configuration to the socket prior to 22 | // bind. 23 | type controlFn func(network, address string, c syscall.RawConn) error 24 | 25 | // controlFns is a list of functions that are called from the listen config 26 | // that can apply socket options. 27 | var controlFns = []controlFn{} 28 | 29 | // listenConfig returns a net.ListenConfig that applies the controlFns to the 30 | // socket prior to bind. This is used to apply socket buffer sizing and packet 31 | // information OOB configuration for sticky sockets. 32 | func listenConfig() *net.ListenConfig { 33 | return &net.ListenConfig{ 34 | Control: func(network, address string, c syscall.RawConn) error { 35 | for _, fn := range controlFns { 36 | if err := fn(network, address, c); err != nil { 37 | return err 38 | } 39 | } 40 | return nil 41 | }, 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /conn/controlfns_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "fmt" 10 | "runtime" 11 | "syscall" 12 | 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | func init() { 17 | controlFns = append(controlFns, 18 | 19 | // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by 20 | // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to 21 | // fail silently - the result of failure is lower performance on very fast 22 | // links or high latency links. 23 | func(network, address string, c syscall.RawConn) error { 24 | return c.Control(func(fd uintptr) { 25 | // Set up to *mem_max 26 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) 27 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) 28 | // Set beyond *mem_max if CAP_NET_ADMIN 29 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize) 30 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize) 31 | }) 32 | }, 33 | 34 | // Enable receiving of the packet information (IP_PKTINFO for IPv4, 35 | // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support. 36 | func(network, address string, c syscall.RawConn) error { 37 | var err error 38 | switch network { 39 | case "udp4": 40 | if runtime.GOOS != "android" { 41 | c.Control(func(fd uintptr) { 42 | err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) 43 | }) 44 | } 45 | case "udp6": 46 | c.Control(func(fd uintptr) { 47 | if runtime.GOOS != "android" { 48 | err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) 49 | if err != nil { 50 | return 51 | } 52 | } 53 | err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) 54 | }) 55 | default: 56 | err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL) 57 | } 58 | return err 59 | }, 60 | ) 61 | } 62 | -------------------------------------------------------------------------------- /conn/controlfns_unix.go: -------------------------------------------------------------------------------- 1 | //go:build !windows && !linux && !wasm 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "syscall" 12 | 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | func init() { 17 | controlFns = append(controlFns, 18 | func(network, address string, c syscall.RawConn) error { 19 | return c.Control(func(fd uintptr) { 20 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) 21 | _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) 22 | }) 23 | }, 24 | 25 | func(network, address string, c syscall.RawConn) error { 26 | var err error 27 | if network == "udp6" { 28 | c.Control(func(fd uintptr) { 29 | err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) 30 | }) 31 | } 32 | return err 33 | }, 34 | ) 35 | } 36 | -------------------------------------------------------------------------------- /conn/controlfns_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "syscall" 10 | 11 | "golang.org/x/sys/windows" 12 | ) 13 | 14 | func init() { 15 | controlFns = append(controlFns, 16 | func(network, address string, c syscall.RawConn) error { 17 | return c.Control(func(fd uintptr) { 18 | _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize) 19 | _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize) 20 | }) 21 | }, 22 | ) 23 | } 24 | -------------------------------------------------------------------------------- /conn/default.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func NewDefaultBind() Bind { return NewStdNetBind() } 11 | -------------------------------------------------------------------------------- /conn/errors_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func errShouldDisableUDPGSO(err error) bool { 11 | return false 12 | } 13 | -------------------------------------------------------------------------------- /conn/errors_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "errors" 10 | "os" 11 | 12 | "golang.org/x/sys/unix" 13 | ) 14 | 15 | func errShouldDisableUDPGSO(err error) bool { 16 | var serr *os.SyscallError 17 | if errors.As(err, &serr) { 18 | // EIO is returned by udp_send_skb() if the device driver does not have 19 | // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. 20 | // See: 21 | // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 22 | // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 23 | // If gso_size + udp + ip headers > fragment size EINVAL is returned. 24 | // It occurs when the peer mtu + wg headers is greater than path mtu. 25 | return serr.Err == unix.EIO || serr.Err == unix.EINVAL 26 | } 27 | return false 28 | } 29 | -------------------------------------------------------------------------------- /conn/features_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | // +build !linux 3 | 4 | /* SPDX-License-Identifier: MIT 5 | * 6 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 7 | */ 8 | 9 | package conn 10 | 11 | import "net" 12 | 13 | func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { 14 | return 15 | } 16 | -------------------------------------------------------------------------------- /conn/features_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "net" 10 | 11 | "golang.org/x/sys/unix" 12 | ) 13 | 14 | func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { 15 | rc, err := conn.SyscallConn() 16 | if err != nil { 17 | return 18 | } 19 | err = rc.Control(func(fd uintptr) { 20 | _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) 21 | txOffload = errSyscall == nil 22 | // getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android 23 | // use setsockopt workaround 24 | errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) 25 | rxOffload = errSyscall == nil 26 | }) 27 | if err != nil { 28 | return false, false 29 | } 30 | return txOffload, rxOffload 31 | } 32 | -------------------------------------------------------------------------------- /conn/gso_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | // getGSOSize parses control for UDP_GRO and if found returns its GSO size data. 11 | func getGSOSize(control []byte) (int, error) { 12 | return 0, nil 13 | } 14 | 15 | // setGSOSize sets a UDP_SEGMENT in control based on gsoSize. 16 | func setGSOSize(control *[]byte, gsoSize uint16) { 17 | } 18 | 19 | // gsoControlSize returns the recommended buffer size for pooling sticky and UDP 20 | // offloading control data. 21 | const gsoControlSize = 0 22 | -------------------------------------------------------------------------------- /conn/gso_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "fmt" 12 | "unsafe" 13 | 14 | "golang.org/x/sys/unix" 15 | ) 16 | 17 | const ( 18 | sizeOfGSOData = 2 19 | ) 20 | 21 | // getGSOSize parses control for UDP_GRO and if found returns its GSO size data. 22 | func getGSOSize(control []byte) (int, error) { 23 | var ( 24 | hdr unix.Cmsghdr 25 | data []byte 26 | rem = control 27 | err error 28 | ) 29 | 30 | for len(rem) > unix.SizeofCmsghdr { 31 | hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) 32 | if err != nil { 33 | return 0, fmt.Errorf("error parsing socket control message: %w", err) 34 | } 35 | if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { 36 | var gso uint16 37 | copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) 38 | return int(gso), nil 39 | } 40 | } 41 | return 0, nil 42 | } 43 | 44 | // setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing 45 | // data in control untouched. 46 | func setGSOSize(control *[]byte, gsoSize uint16) { 47 | existingLen := len(*control) 48 | avail := cap(*control) - existingLen 49 | space := unix.CmsgSpace(sizeOfGSOData) 50 | if avail < space { 51 | return 52 | } 53 | *control = (*control)[:cap(*control)] 54 | gsoControl := (*control)[existingLen:] 55 | hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) 56 | hdr.Level = unix.SOL_UDP 57 | hdr.Type = unix.UDP_SEGMENT 58 | hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) 59 | copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) 60 | *control = (*control)[:existingLen+space] 61 | } 62 | 63 | // gsoControlSize returns the recommended buffer size for pooling UDP 64 | // offloading control data. 65 | var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) 66 | -------------------------------------------------------------------------------- /conn/mark_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux && !openbsd && !freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func (s *StdNetBind) SetMark(mark uint32) error { 11 | return nil 12 | } 13 | -------------------------------------------------------------------------------- /conn/mark_unix.go: -------------------------------------------------------------------------------- 1 | //go:build linux || openbsd || freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "runtime" 12 | 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | var fwmarkIoctl int 17 | 18 | func init() { 19 | switch runtime.GOOS { 20 | case "linux", "android": 21 | fwmarkIoctl = 36 /* unix.SO_MARK */ 22 | case "freebsd": 23 | fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ 24 | case "openbsd": 25 | fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ 26 | } 27 | } 28 | 29 | func (s *StdNetBind) SetMark(mark uint32) error { 30 | var operr error 31 | if fwmarkIoctl == 0 { 32 | return nil 33 | } 34 | if s.ipv4 != nil { 35 | fd, err := s.ipv4.SyscallConn() 36 | if err != nil { 37 | return err 38 | } 39 | err = fd.Control(func(fd uintptr) { 40 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) 41 | }) 42 | if err == nil { 43 | err = operr 44 | } 45 | if err != nil { 46 | return err 47 | } 48 | } 49 | if s.ipv6 != nil { 50 | fd, err := s.ipv6.SyscallConn() 51 | if err != nil { 52 | return err 53 | } 54 | err = fd.Control(func(fd uintptr) { 55 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) 56 | }) 57 | if err == nil { 58 | err = operr 59 | } 60 | if err != nil { 61 | return err 62 | } 63 | } 64 | return nil 65 | } 66 | -------------------------------------------------------------------------------- /conn/sticky_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux || android 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import "net/netip" 11 | 12 | func (e *StdNetEndpoint) SrcIP() netip.Addr { 13 | return netip.Addr{} 14 | } 15 | 16 | func (e *StdNetEndpoint) SrcIfidx() int32 { 17 | return 0 18 | } 19 | 20 | func (e *StdNetEndpoint) SrcToString() string { 21 | return "" 22 | } 23 | 24 | // TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets 25 | // {get,set}srcControl feature set, but use alternatively named flags and need 26 | // ports and require testing. 27 | 28 | // getSrcFromControl parses the control for PKTINFO and if found updates ep with 29 | // the source information found. 30 | func getSrcFromControl(control []byte, ep *StdNetEndpoint) { 31 | } 32 | 33 | // setSrcControl parses the control for PKTINFO and if found updates ep with 34 | // the source information found. 35 | func setSrcControl(control *[]byte, ep *StdNetEndpoint) { 36 | } 37 | 38 | // stickyControlSize returns the recommended buffer size for pooling sticky 39 | // offloading control data. 40 | const stickyControlSize = 0 41 | 42 | const StdNetSupportsStickySockets = false 43 | -------------------------------------------------------------------------------- /conn/sticky_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux && !android 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "net/netip" 12 | "unsafe" 13 | 14 | "golang.org/x/sys/unix" 15 | ) 16 | 17 | func (e *StdNetEndpoint) SrcIP() netip.Addr { 18 | switch len(e.src) { 19 | case unix.CmsgSpace(unix.SizeofInet4Pktinfo): 20 | info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) 21 | return netip.AddrFrom4(info.Spec_dst) 22 | case unix.CmsgSpace(unix.SizeofInet6Pktinfo): 23 | info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) 24 | // TODO: set zone. in order to do so we need to check if the address is 25 | // link local, and if it is perform a syscall to turn the ifindex into a 26 | // zone string because netip uses string zones. 27 | return netip.AddrFrom16(info.Addr) 28 | } 29 | return netip.Addr{} 30 | } 31 | 32 | func (e *StdNetEndpoint) SrcIfidx() int32 { 33 | switch len(e.src) { 34 | case unix.CmsgSpace(unix.SizeofInet4Pktinfo): 35 | info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) 36 | return info.Ifindex 37 | case unix.CmsgSpace(unix.SizeofInet6Pktinfo): 38 | info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) 39 | return int32(info.Ifindex) 40 | } 41 | return 0 42 | } 43 | 44 | func (e *StdNetEndpoint) SrcToString() string { 45 | return e.SrcIP().String() 46 | } 47 | 48 | // getSrcFromControl parses the control for PKTINFO and if found updates ep with 49 | // the source information found. 50 | func getSrcFromControl(control []byte, ep *StdNetEndpoint) { 51 | ep.ClearSrc() 52 | 53 | var ( 54 | hdr unix.Cmsghdr 55 | data []byte 56 | rem []byte = control 57 | err error 58 | ) 59 | 60 | for len(rem) > unix.SizeofCmsghdr { 61 | hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) 62 | if err != nil { 63 | return 64 | } 65 | 66 | if hdr.Level == unix.IPPROTO_IP && 67 | hdr.Type == unix.IP_PKTINFO { 68 | 69 | if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { 70 | ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) 71 | } 72 | ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] 73 | 74 | hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) 75 | copy(ep.src, hdrBuf) 76 | copy(ep.src[unix.CmsgLen(0):], data) 77 | return 78 | } 79 | 80 | if hdr.Level == unix.IPPROTO_IPV6 && 81 | hdr.Type == unix.IPV6_PKTINFO { 82 | 83 | if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { 84 | ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) 85 | } 86 | 87 | ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] 88 | 89 | hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) 90 | copy(ep.src, hdrBuf) 91 | copy(ep.src[unix.CmsgLen(0):], data) 92 | return 93 | } 94 | } 95 | } 96 | 97 | // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address 98 | // and source ifindex found in ep. control's len will be set to 0 in the event 99 | // that ep is a default value. 100 | func setSrcControl(control *[]byte, ep *StdNetEndpoint) { 101 | if cap(*control) < len(ep.src) { 102 | return 103 | } 104 | *control = (*control)[:0] 105 | *control = append(*control, ep.src...) 106 | } 107 | 108 | // stickyControlSize returns the recommended buffer size for pooling sticky 109 | // offloading control data. 110 | var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) 111 | 112 | const StdNetSupportsStickySockets = true 113 | -------------------------------------------------------------------------------- /device/allowedips.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "container/list" 10 | "encoding/binary" 11 | "errors" 12 | "math/bits" 13 | "net" 14 | "net/netip" 15 | "sync" 16 | "unsafe" 17 | ) 18 | 19 | type parentIndirection struct { 20 | parentBit **trieEntry 21 | parentBitType uint8 22 | } 23 | 24 | type trieEntry struct { 25 | peer *Peer 26 | child [2]*trieEntry 27 | parent parentIndirection 28 | cidr uint8 29 | bitAtByte uint8 30 | bitAtShift uint8 31 | bits []byte 32 | perPeerElem *list.Element 33 | } 34 | 35 | func commonBits(ip1, ip2 []byte) uint8 { 36 | size := len(ip1) 37 | if size == net.IPv4len { 38 | a := binary.BigEndian.Uint32(ip1) 39 | b := binary.BigEndian.Uint32(ip2) 40 | x := a ^ b 41 | return uint8(bits.LeadingZeros32(x)) 42 | } else if size == net.IPv6len { 43 | a := binary.BigEndian.Uint64(ip1) 44 | b := binary.BigEndian.Uint64(ip2) 45 | x := a ^ b 46 | if x != 0 { 47 | return uint8(bits.LeadingZeros64(x)) 48 | } 49 | a = binary.BigEndian.Uint64(ip1[8:]) 50 | b = binary.BigEndian.Uint64(ip2[8:]) 51 | x = a ^ b 52 | return 64 + uint8(bits.LeadingZeros64(x)) 53 | } else { 54 | panic("Wrong size bit string") 55 | } 56 | } 57 | 58 | func (node *trieEntry) addToPeerEntries() { 59 | node.perPeerElem = node.peer.trieEntries.PushBack(node) 60 | } 61 | 62 | func (node *trieEntry) removeFromPeerEntries() { 63 | if node.perPeerElem != nil { 64 | node.peer.trieEntries.Remove(node.perPeerElem) 65 | node.perPeerElem = nil 66 | } 67 | } 68 | 69 | func (node *trieEntry) choose(ip []byte) byte { 70 | return (ip[node.bitAtByte] >> node.bitAtShift) & 1 71 | } 72 | 73 | func (node *trieEntry) maskSelf() { 74 | mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) 75 | for i := 0; i < len(mask); i++ { 76 | node.bits[i] &= mask[i] 77 | } 78 | } 79 | 80 | func (node *trieEntry) zeroizePointers() { 81 | // Make the garbage collector's life slightly easier 82 | node.peer = nil 83 | node.child[0] = nil 84 | node.child[1] = nil 85 | node.parent.parentBit = nil 86 | } 87 | 88 | func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { 89 | for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { 90 | parent = node 91 | if parent.cidr == cidr { 92 | exact = true 93 | return 94 | } 95 | bit := node.choose(ip) 96 | node = node.child[bit] 97 | } 98 | return 99 | } 100 | 101 | func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { 102 | if *trie.parentBit == nil { 103 | node := &trieEntry{ 104 | peer: peer, 105 | parent: trie, 106 | bits: ip, 107 | cidr: cidr, 108 | bitAtByte: cidr / 8, 109 | bitAtShift: 7 - (cidr % 8), 110 | } 111 | node.maskSelf() 112 | node.addToPeerEntries() 113 | *trie.parentBit = node 114 | return 115 | } 116 | node, exact := (*trie.parentBit).nodePlacement(ip, cidr) 117 | if exact { 118 | node.removeFromPeerEntries() 119 | node.peer = peer 120 | node.addToPeerEntries() 121 | return 122 | } 123 | 124 | newNode := &trieEntry{ 125 | peer: peer, 126 | bits: ip, 127 | cidr: cidr, 128 | bitAtByte: cidr / 8, 129 | bitAtShift: 7 - (cidr % 8), 130 | } 131 | newNode.maskSelf() 132 | newNode.addToPeerEntries() 133 | 134 | var down *trieEntry 135 | if node == nil { 136 | down = *trie.parentBit 137 | } else { 138 | bit := node.choose(ip) 139 | down = node.child[bit] 140 | if down == nil { 141 | newNode.parent = parentIndirection{&node.child[bit], bit} 142 | node.child[bit] = newNode 143 | return 144 | } 145 | } 146 | common := commonBits(down.bits, ip) 147 | if common < cidr { 148 | cidr = common 149 | } 150 | parent := node 151 | 152 | if newNode.cidr == cidr { 153 | bit := newNode.choose(down.bits) 154 | down.parent = parentIndirection{&newNode.child[bit], bit} 155 | newNode.child[bit] = down 156 | if parent == nil { 157 | newNode.parent = trie 158 | *trie.parentBit = newNode 159 | } else { 160 | bit := parent.choose(newNode.bits) 161 | newNode.parent = parentIndirection{&parent.child[bit], bit} 162 | parent.child[bit] = newNode 163 | } 164 | return 165 | } 166 | 167 | node = &trieEntry{ 168 | bits: append([]byte{}, newNode.bits...), 169 | cidr: cidr, 170 | bitAtByte: cidr / 8, 171 | bitAtShift: 7 - (cidr % 8), 172 | } 173 | node.maskSelf() 174 | 175 | bit := node.choose(down.bits) 176 | down.parent = parentIndirection{&node.child[bit], bit} 177 | node.child[bit] = down 178 | bit = node.choose(newNode.bits) 179 | newNode.parent = parentIndirection{&node.child[bit], bit} 180 | node.child[bit] = newNode 181 | if parent == nil { 182 | node.parent = trie 183 | *trie.parentBit = node 184 | } else { 185 | bit := parent.choose(node.bits) 186 | node.parent = parentIndirection{&parent.child[bit], bit} 187 | parent.child[bit] = node 188 | } 189 | } 190 | 191 | func (node *trieEntry) lookup(ip []byte) *Peer { 192 | var found *Peer 193 | size := uint8(len(ip)) 194 | for node != nil && commonBits(node.bits, ip) >= node.cidr { 195 | if node.peer != nil { 196 | found = node.peer 197 | } 198 | if node.bitAtByte == size { 199 | break 200 | } 201 | bit := node.choose(ip) 202 | node = node.child[bit] 203 | } 204 | return found 205 | } 206 | 207 | type AllowedIPs struct { 208 | IPv4 *trieEntry 209 | IPv6 *trieEntry 210 | mutex sync.RWMutex 211 | } 212 | 213 | func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { 214 | table.mutex.RLock() 215 | defer table.mutex.RUnlock() 216 | 217 | for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { 218 | node := elem.Value.(*trieEntry) 219 | a, _ := netip.AddrFromSlice(node.bits) 220 | if !cb(netip.PrefixFrom(a, int(node.cidr))) { 221 | return 222 | } 223 | } 224 | } 225 | 226 | func (table *AllowedIPs) RemoveByPeer(peer *Peer) { 227 | table.mutex.Lock() 228 | defer table.mutex.Unlock() 229 | 230 | var next *list.Element 231 | for elem := peer.trieEntries.Front(); elem != nil; elem = next { 232 | next = elem.Next() 233 | node := elem.Value.(*trieEntry) 234 | 235 | node.removeFromPeerEntries() 236 | node.peer = nil 237 | if node.child[0] != nil && node.child[1] != nil { 238 | continue 239 | } 240 | bit := 0 241 | if node.child[0] == nil { 242 | bit = 1 243 | } 244 | child := node.child[bit] 245 | if child != nil { 246 | child.parent = node.parent 247 | } 248 | *node.parent.parentBit = child 249 | if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { 250 | node.zeroizePointers() 251 | continue 252 | } 253 | parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) 254 | if parent.peer != nil { 255 | node.zeroizePointers() 256 | continue 257 | } 258 | child = parent.child[node.parent.parentBitType^1] 259 | if child != nil { 260 | child.parent = parent.parent 261 | } 262 | *parent.parent.parentBit = child 263 | node.zeroizePointers() 264 | parent.zeroizePointers() 265 | } 266 | } 267 | 268 | func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { 269 | table.mutex.Lock() 270 | defer table.mutex.Unlock() 271 | 272 | if prefix.Addr().Is6() { 273 | ip := prefix.Addr().As16() 274 | parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) 275 | } else if prefix.Addr().Is4() { 276 | ip := prefix.Addr().As4() 277 | parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) 278 | } else { 279 | panic(errors.New("inserting unknown address type")) 280 | } 281 | } 282 | 283 | func (table *AllowedIPs) Lookup(ip []byte) *Peer { 284 | table.mutex.RLock() 285 | defer table.mutex.RUnlock() 286 | switch len(ip) { 287 | case net.IPv6len: 288 | return table.IPv6.lookup(ip) 289 | case net.IPv4len: 290 | return table.IPv4.lookup(ip) 291 | default: 292 | panic(errors.New("looking up unknown address type")) 293 | } 294 | } 295 | -------------------------------------------------------------------------------- /device/allowedips_rand_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net" 11 | "net/netip" 12 | "sort" 13 | "testing" 14 | ) 15 | 16 | const ( 17 | NumberOfPeers = 100 18 | NumberOfPeerRemovals = 4 19 | NumberOfAddresses = 250 20 | NumberOfTests = 10000 21 | ) 22 | 23 | type SlowNode struct { 24 | peer *Peer 25 | cidr uint8 26 | bits []byte 27 | } 28 | 29 | type SlowRouter []*SlowNode 30 | 31 | func (r SlowRouter) Len() int { 32 | return len(r) 33 | } 34 | 35 | func (r SlowRouter) Less(i, j int) bool { 36 | return r[i].cidr > r[j].cidr 37 | } 38 | 39 | func (r SlowRouter) Swap(i, j int) { 40 | r[i], r[j] = r[j], r[i] 41 | } 42 | 43 | func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { 44 | for _, t := range r { 45 | if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { 46 | t.peer = peer 47 | t.bits = addr 48 | return r 49 | } 50 | } 51 | r = append(r, &SlowNode{ 52 | cidr: cidr, 53 | bits: addr, 54 | peer: peer, 55 | }) 56 | sort.Sort(r) 57 | return r 58 | } 59 | 60 | func (r SlowRouter) Lookup(addr []byte) *Peer { 61 | for _, t := range r { 62 | common := commonBits(t.bits, addr) 63 | if common >= t.cidr { 64 | return t.peer 65 | } 66 | } 67 | return nil 68 | } 69 | 70 | func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { 71 | n := 0 72 | for _, x := range r { 73 | if x.peer != peer { 74 | r[n] = x 75 | n++ 76 | } 77 | } 78 | return r[:n] 79 | } 80 | 81 | func TestTrieRandom(t *testing.T) { 82 | var slow4, slow6 SlowRouter 83 | var peers []*Peer 84 | var allowedIPs AllowedIPs 85 | 86 | rand.Seed(1) 87 | 88 | for n := 0; n < NumberOfPeers; n++ { 89 | peers = append(peers, &Peer{}) 90 | } 91 | 92 | for n := 0; n < NumberOfAddresses; n++ { 93 | var addr4 [4]byte 94 | rand.Read(addr4[:]) 95 | cidr := uint8(rand.Intn(32) + 1) 96 | index := rand.Intn(NumberOfPeers) 97 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) 98 | slow4 = slow4.Insert(addr4[:], cidr, peers[index]) 99 | 100 | var addr6 [16]byte 101 | rand.Read(addr6[:]) 102 | cidr = uint8(rand.Intn(128) + 1) 103 | index = rand.Intn(NumberOfPeers) 104 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) 105 | slow6 = slow6.Insert(addr6[:], cidr, peers[index]) 106 | } 107 | 108 | var p int 109 | for p = 0; ; p++ { 110 | for n := 0; n < NumberOfTests; n++ { 111 | var addr4 [4]byte 112 | rand.Read(addr4[:]) 113 | peer1 := slow4.Lookup(addr4[:]) 114 | peer2 := allowedIPs.Lookup(addr4[:]) 115 | if peer1 != peer2 { 116 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) 117 | } 118 | 119 | var addr6 [16]byte 120 | rand.Read(addr6[:]) 121 | peer1 = slow6.Lookup(addr6[:]) 122 | peer2 = allowedIPs.Lookup(addr6[:]) 123 | if peer1 != peer2 { 124 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) 125 | } 126 | } 127 | if p >= len(peers) || p >= NumberOfPeerRemovals { 128 | break 129 | } 130 | allowedIPs.RemoveByPeer(peers[p]) 131 | slow4 = slow4.RemoveByPeer(peers[p]) 132 | slow6 = slow6.RemoveByPeer(peers[p]) 133 | } 134 | for ; p < len(peers); p++ { 135 | allowedIPs.RemoveByPeer(peers[p]) 136 | } 137 | 138 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 139 | t.Error("Failed to remove all nodes from trie by peer") 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /device/allowedips_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net" 11 | "net/netip" 12 | "testing" 13 | ) 14 | 15 | type testPairCommonBits struct { 16 | s1 []byte 17 | s2 []byte 18 | match uint8 19 | } 20 | 21 | func TestCommonBits(t *testing.T) { 22 | tests := []testPairCommonBits{ 23 | {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, 24 | {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, 25 | {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, 26 | {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, 27 | {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, 28 | } 29 | 30 | for _, p := range tests { 31 | v := commonBits(p.s1, p.s2) 32 | if v != p.match { 33 | t.Error( 34 | "For slice", p.s1, p.s2, 35 | "expected match", p.match, 36 | ",but got", v, 37 | ) 38 | } 39 | } 40 | } 41 | 42 | func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { 43 | var trie *trieEntry 44 | var peers []*Peer 45 | root := parentIndirection{&trie, 2} 46 | 47 | rand.Seed(1) 48 | 49 | const AddressLength = 4 50 | 51 | for n := 0; n < peerNumber; n++ { 52 | peers = append(peers, &Peer{}) 53 | } 54 | 55 | for n := 0; n < addressNumber; n++ { 56 | var addr [AddressLength]byte 57 | rand.Read(addr[:]) 58 | cidr := uint8(rand.Uint32() % (AddressLength * 8)) 59 | index := rand.Int() % peerNumber 60 | root.insert(addr[:], cidr, peers[index]) 61 | } 62 | 63 | for n := 0; n < b.N; n++ { 64 | var addr [AddressLength]byte 65 | rand.Read(addr[:]) 66 | trie.lookup(addr[:]) 67 | } 68 | } 69 | 70 | func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { 71 | benchmarkTrie(100, 1000, net.IPv4len, b) 72 | } 73 | 74 | func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { 75 | benchmarkTrie(10, 10, net.IPv4len, b) 76 | } 77 | 78 | func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { 79 | benchmarkTrie(100, 1000, net.IPv6len, b) 80 | } 81 | 82 | func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { 83 | benchmarkTrie(10, 10, net.IPv6len, b) 84 | } 85 | 86 | /* Test ported from kernel implementation: 87 | * selftest/allowedips.h 88 | */ 89 | func TestTrieIPv4(t *testing.T) { 90 | a := &Peer{} 91 | b := &Peer{} 92 | c := &Peer{} 93 | d := &Peer{} 94 | e := &Peer{} 95 | g := &Peer{} 96 | h := &Peer{} 97 | 98 | var allowedIPs AllowedIPs 99 | 100 | insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { 101 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) 102 | } 103 | 104 | assertEQ := func(peer *Peer, a, b, c, d byte) { 105 | p := allowedIPs.Lookup([]byte{a, b, c, d}) 106 | if p != peer { 107 | t.Error("Assert EQ failed") 108 | } 109 | } 110 | 111 | assertNEQ := func(peer *Peer, a, b, c, d byte) { 112 | p := allowedIPs.Lookup([]byte{a, b, c, d}) 113 | if p == peer { 114 | t.Error("Assert NEQ failed") 115 | } 116 | } 117 | 118 | insert(a, 192, 168, 4, 0, 24) 119 | insert(b, 192, 168, 4, 4, 32) 120 | insert(c, 192, 168, 0, 0, 16) 121 | insert(d, 192, 95, 5, 64, 27) 122 | insert(c, 192, 95, 5, 65, 27) 123 | insert(e, 0, 0, 0, 0, 0) 124 | insert(g, 64, 15, 112, 0, 20) 125 | insert(h, 64, 15, 123, 211, 25) 126 | insert(a, 10, 0, 0, 0, 25) 127 | insert(b, 10, 0, 0, 128, 25) 128 | insert(a, 10, 1, 0, 0, 30) 129 | insert(b, 10, 1, 0, 4, 30) 130 | insert(c, 10, 1, 0, 8, 29) 131 | insert(d, 10, 1, 0, 16, 29) 132 | 133 | assertEQ(a, 192, 168, 4, 20) 134 | assertEQ(a, 192, 168, 4, 0) 135 | assertEQ(b, 192, 168, 4, 4) 136 | assertEQ(c, 192, 168, 200, 182) 137 | assertEQ(c, 192, 95, 5, 68) 138 | assertEQ(e, 192, 95, 5, 96) 139 | assertEQ(g, 64, 15, 116, 26) 140 | assertEQ(g, 64, 15, 127, 3) 141 | 142 | insert(a, 1, 0, 0, 0, 32) 143 | insert(a, 64, 0, 0, 0, 32) 144 | insert(a, 128, 0, 0, 0, 32) 145 | insert(a, 192, 0, 0, 0, 32) 146 | insert(a, 255, 0, 0, 0, 32) 147 | 148 | assertEQ(a, 1, 0, 0, 0) 149 | assertEQ(a, 64, 0, 0, 0) 150 | assertEQ(a, 128, 0, 0, 0) 151 | assertEQ(a, 192, 0, 0, 0) 152 | assertEQ(a, 255, 0, 0, 0) 153 | 154 | allowedIPs.RemoveByPeer(a) 155 | 156 | assertNEQ(a, 1, 0, 0, 0) 157 | assertNEQ(a, 64, 0, 0, 0) 158 | assertNEQ(a, 128, 0, 0, 0) 159 | assertNEQ(a, 192, 0, 0, 0) 160 | assertNEQ(a, 255, 0, 0, 0) 161 | 162 | allowedIPs.RemoveByPeer(a) 163 | allowedIPs.RemoveByPeer(b) 164 | allowedIPs.RemoveByPeer(c) 165 | allowedIPs.RemoveByPeer(d) 166 | allowedIPs.RemoveByPeer(e) 167 | allowedIPs.RemoveByPeer(g) 168 | allowedIPs.RemoveByPeer(h) 169 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 170 | t.Error("Expected removing all the peers to empty trie, but it did not") 171 | } 172 | 173 | insert(a, 192, 168, 0, 0, 16) 174 | insert(a, 192, 168, 0, 0, 24) 175 | 176 | allowedIPs.RemoveByPeer(a) 177 | 178 | assertNEQ(a, 192, 168, 0, 1) 179 | } 180 | 181 | /* Test ported from kernel implementation: 182 | * selftest/allowedips.h 183 | */ 184 | func TestTrieIPv6(t *testing.T) { 185 | a := &Peer{} 186 | b := &Peer{} 187 | c := &Peer{} 188 | d := &Peer{} 189 | e := &Peer{} 190 | f := &Peer{} 191 | g := &Peer{} 192 | h := &Peer{} 193 | 194 | var allowedIPs AllowedIPs 195 | 196 | expand := func(a uint32) []byte { 197 | var out [4]byte 198 | out[0] = byte(a >> 24 & 0xff) 199 | out[1] = byte(a >> 16 & 0xff) 200 | out[2] = byte(a >> 8 & 0xff) 201 | out[3] = byte(a & 0xff) 202 | return out[:] 203 | } 204 | 205 | insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { 206 | var addr []byte 207 | addr = append(addr, expand(a)...) 208 | addr = append(addr, expand(b)...) 209 | addr = append(addr, expand(c)...) 210 | addr = append(addr, expand(d)...) 211 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) 212 | } 213 | 214 | assertEQ := func(peer *Peer, a, b, c, d uint32) { 215 | var addr []byte 216 | addr = append(addr, expand(a)...) 217 | addr = append(addr, expand(b)...) 218 | addr = append(addr, expand(c)...) 219 | addr = append(addr, expand(d)...) 220 | p := allowedIPs.Lookup(addr) 221 | if p != peer { 222 | t.Error("Assert EQ failed") 223 | } 224 | } 225 | 226 | insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) 227 | insert(c, 0x26075300, 0x60006b00, 0, 0, 64) 228 | insert(e, 0, 0, 0, 0, 0) 229 | insert(f, 0, 0, 0, 0, 0) 230 | insert(g, 0x24046800, 0, 0, 0, 32) 231 | insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) 232 | insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) 233 | insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) 234 | insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) 235 | 236 | assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) 237 | assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) 238 | assertEQ(f, 0x26075300, 0x60006b01, 0, 0) 239 | assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) 240 | assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) 241 | assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) 242 | assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) 243 | assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) 244 | assertEQ(h, 0x24046800, 0x40040800, 0, 0) 245 | assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) 246 | assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) 247 | } 248 | -------------------------------------------------------------------------------- /device/bind_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "errors" 10 | 11 | "github.com/amnezia-vpn/amneziawg-go/conn" 12 | ) 13 | 14 | type DummyDatagram struct { 15 | msg []byte 16 | endpoint conn.Endpoint 17 | } 18 | 19 | type DummyBind struct { 20 | in6 chan DummyDatagram 21 | in4 chan DummyDatagram 22 | closed bool 23 | } 24 | 25 | func (b *DummyBind) SetMark(v uint32) error { 26 | return nil 27 | } 28 | 29 | func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) { 30 | datagram, ok := <-b.in6 31 | if !ok { 32 | return 0, nil, errors.New("closed") 33 | } 34 | copy(buf, datagram.msg) 35 | return len(datagram.msg), datagram.endpoint, nil 36 | } 37 | 38 | func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) { 39 | datagram, ok := <-b.in4 40 | if !ok { 41 | return 0, nil, errors.New("closed") 42 | } 43 | copy(buf, datagram.msg) 44 | return len(datagram.msg), datagram.endpoint, nil 45 | } 46 | 47 | func (b *DummyBind) Close() error { 48 | close(b.in6) 49 | close(b.in4) 50 | b.closed = true 51 | return nil 52 | } 53 | 54 | func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error { 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /device/channels.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "runtime" 10 | "sync" 11 | ) 12 | 13 | // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. 14 | // An outboundQueue is ref-counted using its wg field. 15 | // An outboundQueue created with newOutboundQueue has one reference. 16 | // Every additional writer must call wg.Add(1). 17 | // Every completed writer must call wg.Done(). 18 | // When no further writers will be added, 19 | // call wg.Done to remove the initial reference. 20 | // When the refcount hits 0, the queue's channel is closed. 21 | type outboundQueue struct { 22 | c chan *QueueOutboundElementsContainer 23 | wg sync.WaitGroup 24 | } 25 | 26 | func newOutboundQueue() *outboundQueue { 27 | q := &outboundQueue{ 28 | c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), 29 | } 30 | q.wg.Add(1) 31 | go func() { 32 | q.wg.Wait() 33 | close(q.c) 34 | }() 35 | return q 36 | } 37 | 38 | // A inboundQueue is similar to an outboundQueue; see those docs. 39 | type inboundQueue struct { 40 | c chan *QueueInboundElementsContainer 41 | wg sync.WaitGroup 42 | } 43 | 44 | func newInboundQueue() *inboundQueue { 45 | q := &inboundQueue{ 46 | c: make(chan *QueueInboundElementsContainer, QueueInboundSize), 47 | } 48 | q.wg.Add(1) 49 | go func() { 50 | q.wg.Wait() 51 | close(q.c) 52 | }() 53 | return q 54 | } 55 | 56 | // A handshakeQueue is similar to an outboundQueue; see those docs. 57 | type handshakeQueue struct { 58 | c chan QueueHandshakeElement 59 | wg sync.WaitGroup 60 | } 61 | 62 | func newHandshakeQueue() *handshakeQueue { 63 | q := &handshakeQueue{ 64 | c: make(chan QueueHandshakeElement, QueueHandshakeSize), 65 | } 66 | q.wg.Add(1) 67 | go func() { 68 | q.wg.Wait() 69 | close(q.c) 70 | }() 71 | return q 72 | } 73 | 74 | type autodrainingInboundQueue struct { 75 | c chan *QueueInboundElementsContainer 76 | } 77 | 78 | // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. 79 | // It is useful in cases in which is it hard to manage the lifetime of the channel. 80 | // The returned channel must not be closed. Senders should signal shutdown using 81 | // some other means, such as sending a sentinel nil values. 82 | func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { 83 | q := &autodrainingInboundQueue{ 84 | c: make(chan *QueueInboundElementsContainer, QueueInboundSize), 85 | } 86 | runtime.SetFinalizer(q, device.flushInboundQueue) 87 | return q 88 | } 89 | 90 | func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { 91 | for { 92 | select { 93 | case elemsContainer := <-q.c: 94 | elemsContainer.Lock() 95 | for _, elem := range elemsContainer.elems { 96 | device.PutMessageBuffer(elem.buffer) 97 | device.PutInboundElement(elem) 98 | } 99 | device.PutInboundElementsContainer(elemsContainer) 100 | default: 101 | return 102 | } 103 | } 104 | } 105 | 106 | type autodrainingOutboundQueue struct { 107 | c chan *QueueOutboundElementsContainer 108 | } 109 | 110 | // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. 111 | // It is useful in cases in which is it hard to manage the lifetime of the channel. 112 | // The returned channel must not be closed. Senders should signal shutdown using 113 | // some other means, such as sending a sentinel nil values. 114 | // All sends to the channel must be best-effort, because there may be no receivers. 115 | func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { 116 | q := &autodrainingOutboundQueue{ 117 | c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), 118 | } 119 | runtime.SetFinalizer(q, device.flushOutboundQueue) 120 | return q 121 | } 122 | 123 | func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { 124 | for { 125 | select { 126 | case elemsContainer := <-q.c: 127 | elemsContainer.Lock() 128 | for _, elem := range elemsContainer.elems { 129 | device.PutMessageBuffer(elem.buffer) 130 | device.PutOutboundElement(elem) 131 | } 132 | device.PutOutboundElementsContainer(elemsContainer) 133 | default: 134 | return 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /device/constants.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "time" 10 | ) 11 | 12 | /* Specification constants */ 13 | 14 | const ( 15 | RekeyAfterMessages = (1 << 60) 16 | RejectAfterMessages = (1 << 64) - (1 << 13) - 1 17 | RekeyAfterTime = time.Second * 120 18 | RekeyAttemptTime = time.Second * 90 19 | RekeyTimeout = time.Second * 5 20 | MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */ 21 | RekeyTimeoutJitterMaxMs = 334 22 | RejectAfterTime = time.Second * 180 23 | KeepaliveTimeout = time.Second * 10 24 | CookieRefreshTime = time.Second * 120 25 | HandshakeInitationRate = time.Second / 50 26 | PaddingMultiple = 16 27 | ) 28 | 29 | const ( 30 | MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) 31 | MaxMessageSize = MaxSegmentSize // maximum size of transport message 32 | MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content 33 | ) 34 | 35 | /* Implementation constants */ 36 | 37 | const ( 38 | UnderLoadAfterTime = time.Second // how long does the device remain under load after detected 39 | MaxPeers = 1 << 16 // maximum number of configured peers 40 | ) 41 | -------------------------------------------------------------------------------- /device/cookie.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "sync" 12 | "time" 13 | 14 | "golang.org/x/crypto/blake2s" 15 | "golang.org/x/crypto/chacha20poly1305" 16 | ) 17 | 18 | type CookieChecker struct { 19 | sync.RWMutex 20 | mac1 struct { 21 | key [blake2s.Size]byte 22 | } 23 | mac2 struct { 24 | secret [blake2s.Size]byte 25 | secretSet time.Time 26 | encryptionKey [chacha20poly1305.KeySize]byte 27 | } 28 | } 29 | 30 | type CookieGenerator struct { 31 | sync.RWMutex 32 | mac1 struct { 33 | key [blake2s.Size]byte 34 | } 35 | mac2 struct { 36 | cookie [blake2s.Size128]byte 37 | cookieSet time.Time 38 | hasLastMAC1 bool 39 | lastMAC1 [blake2s.Size128]byte 40 | encryptionKey [chacha20poly1305.KeySize]byte 41 | } 42 | } 43 | 44 | func (st *CookieChecker) Init(pk NoisePublicKey) { 45 | st.Lock() 46 | defer st.Unlock() 47 | 48 | // mac1 state 49 | 50 | func() { 51 | hash, _ := blake2s.New256(nil) 52 | hash.Write([]byte(WGLabelMAC1)) 53 | hash.Write(pk[:]) 54 | hash.Sum(st.mac1.key[:0]) 55 | }() 56 | 57 | // mac2 state 58 | 59 | func() { 60 | hash, _ := blake2s.New256(nil) 61 | hash.Write([]byte(WGLabelCookie)) 62 | hash.Write(pk[:]) 63 | hash.Sum(st.mac2.encryptionKey[:0]) 64 | }() 65 | 66 | st.mac2.secretSet = time.Time{} 67 | } 68 | 69 | func (st *CookieChecker) CheckMAC1(msg []byte) bool { 70 | st.RLock() 71 | defer st.RUnlock() 72 | 73 | size := len(msg) 74 | smac2 := size - blake2s.Size128 75 | smac1 := smac2 - blake2s.Size128 76 | 77 | var mac1 [blake2s.Size128]byte 78 | 79 | mac, _ := blake2s.New128(st.mac1.key[:]) 80 | mac.Write(msg[:smac1]) 81 | mac.Sum(mac1[:0]) 82 | 83 | return hmac.Equal(mac1[:], msg[smac1:smac2]) 84 | } 85 | 86 | func (st *CookieChecker) CheckMAC2(msg, src []byte) bool { 87 | st.RLock() 88 | defer st.RUnlock() 89 | 90 | if time.Since(st.mac2.secretSet) > CookieRefreshTime { 91 | return false 92 | } 93 | 94 | // derive cookie key 95 | 96 | var cookie [blake2s.Size128]byte 97 | func() { 98 | mac, _ := blake2s.New128(st.mac2.secret[:]) 99 | mac.Write(src) 100 | mac.Sum(cookie[:0]) 101 | }() 102 | 103 | // calculate mac of packet (including mac1) 104 | 105 | smac2 := len(msg) - blake2s.Size128 106 | 107 | var mac2 [blake2s.Size128]byte 108 | func() { 109 | mac, _ := blake2s.New128(cookie[:]) 110 | mac.Write(msg[:smac2]) 111 | mac.Sum(mac2[:0]) 112 | }() 113 | 114 | return hmac.Equal(mac2[:], msg[smac2:]) 115 | } 116 | 117 | func (st *CookieChecker) CreateReply( 118 | msg []byte, 119 | recv uint32, 120 | src []byte, 121 | ) (*MessageCookieReply, error) { 122 | st.RLock() 123 | 124 | // refresh cookie secret 125 | 126 | if time.Since(st.mac2.secretSet) > CookieRefreshTime { 127 | st.RUnlock() 128 | st.Lock() 129 | _, err := rand.Read(st.mac2.secret[:]) 130 | if err != nil { 131 | st.Unlock() 132 | return nil, err 133 | } 134 | st.mac2.secretSet = time.Now() 135 | st.Unlock() 136 | st.RLock() 137 | } 138 | 139 | // derive cookie 140 | 141 | var cookie [blake2s.Size128]byte 142 | func() { 143 | mac, _ := blake2s.New128(st.mac2.secret[:]) 144 | mac.Write(src) 145 | mac.Sum(cookie[:0]) 146 | }() 147 | 148 | // encrypt cookie 149 | 150 | size := len(msg) 151 | 152 | smac2 := size - blake2s.Size128 153 | smac1 := smac2 - blake2s.Size128 154 | 155 | reply := new(MessageCookieReply) 156 | reply.Type = MessageCookieReplyType 157 | reply.Receiver = recv 158 | 159 | _, err := rand.Read(reply.Nonce[:]) 160 | if err != nil { 161 | st.RUnlock() 162 | return nil, err 163 | } 164 | 165 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) 166 | xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2]) 167 | 168 | st.RUnlock() 169 | 170 | return reply, nil 171 | } 172 | 173 | func (st *CookieGenerator) Init(pk NoisePublicKey) { 174 | st.Lock() 175 | defer st.Unlock() 176 | 177 | func() { 178 | hash, _ := blake2s.New256(nil) 179 | hash.Write([]byte(WGLabelMAC1)) 180 | hash.Write(pk[:]) 181 | hash.Sum(st.mac1.key[:0]) 182 | }() 183 | 184 | func() { 185 | hash, _ := blake2s.New256(nil) 186 | hash.Write([]byte(WGLabelCookie)) 187 | hash.Write(pk[:]) 188 | hash.Sum(st.mac2.encryptionKey[:0]) 189 | }() 190 | 191 | st.mac2.cookieSet = time.Time{} 192 | } 193 | 194 | func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { 195 | st.Lock() 196 | defer st.Unlock() 197 | 198 | if !st.mac2.hasLastMAC1 { 199 | return false 200 | } 201 | 202 | var cookie [blake2s.Size128]byte 203 | 204 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) 205 | _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) 206 | if err != nil { 207 | return false 208 | } 209 | 210 | st.mac2.cookieSet = time.Now() 211 | st.mac2.cookie = cookie 212 | return true 213 | } 214 | 215 | func (st *CookieGenerator) AddMacs(msg []byte) { 216 | size := len(msg) 217 | 218 | smac2 := size - blake2s.Size128 219 | smac1 := smac2 - blake2s.Size128 220 | 221 | mac1 := msg[smac1:smac2] 222 | mac2 := msg[smac2:] 223 | 224 | st.Lock() 225 | defer st.Unlock() 226 | 227 | // set mac1 228 | 229 | func() { 230 | mac, _ := blake2s.New128(st.mac1.key[:]) 231 | mac.Write(msg[:smac1]) 232 | mac.Sum(mac1[:0]) 233 | }() 234 | copy(st.mac2.lastMAC1[:], mac1) 235 | st.mac2.hasLastMAC1 = true 236 | 237 | // set mac2 238 | 239 | if time.Since(st.mac2.cookieSet) > CookieRefreshTime { 240 | return 241 | } 242 | 243 | func() { 244 | mac, _ := blake2s.New128(st.mac2.cookie[:]) 245 | mac.Write(msg[:smac2]) 246 | mac.Sum(mac2[:0]) 247 | }() 248 | } 249 | -------------------------------------------------------------------------------- /device/cookie_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "testing" 10 | ) 11 | 12 | func TestCookieMAC1(t *testing.T) { 13 | // setup generator / checker 14 | 15 | var ( 16 | generator CookieGenerator 17 | checker CookieChecker 18 | ) 19 | 20 | sk, err := newPrivateKey() 21 | if err != nil { 22 | t.Fatal(err) 23 | } 24 | pk := sk.publicKey() 25 | 26 | generator.Init(pk) 27 | checker.Init(pk) 28 | 29 | // check mac1 30 | 31 | src := []byte{192, 168, 13, 37, 10, 10, 10} 32 | 33 | checkMAC1 := func(msg []byte) { 34 | generator.AddMacs(msg) 35 | if !checker.CheckMAC1(msg) { 36 | t.Fatal("MAC1 generation/verification failed") 37 | } 38 | if checker.CheckMAC2(msg, src) { 39 | t.Fatal("MAC2 generation/verification failed") 40 | } 41 | } 42 | 43 | checkMAC1([]byte{ 44 | 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, 45 | 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, 46 | 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, 47 | 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, 48 | 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, 49 | 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, 50 | 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, 51 | 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, 52 | 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, 53 | 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, 54 | 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, 55 | 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, 56 | 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, 57 | 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, 58 | 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, 59 | 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, 60 | }) 61 | 62 | checkMAC1([]byte{ 63 | 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, 64 | 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, 65 | 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, 66 | 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, 67 | 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, 68 | 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, 69 | 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, 70 | 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, 71 | }) 72 | 73 | checkMAC1([]byte{ 74 | 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, 75 | 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, 76 | 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, 77 | 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, 78 | }) 79 | 80 | // exchange cookie reply 81 | 82 | func() { 83 | msg := []byte{ 84 | 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, 85 | 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, 86 | 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, 87 | 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, 88 | 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, 89 | 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, 90 | 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, 91 | 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, 92 | 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, 93 | 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, 94 | 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, 95 | 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, 96 | 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, 97 | 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, 98 | 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, 99 | 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, 100 | } 101 | generator.AddMacs(msg) 102 | reply, err := checker.CreateReply(msg, 1377, src) 103 | if err != nil { 104 | t.Fatal("Failed to create cookie reply:", err) 105 | } 106 | if !generator.ConsumeReply(reply) { 107 | t.Fatal("Failed to consume cookie reply") 108 | } 109 | }() 110 | 111 | // check mac2 112 | 113 | checkMAC2 := func(msg []byte) { 114 | generator.AddMacs(msg) 115 | 116 | if !checker.CheckMAC1(msg) { 117 | t.Fatal("MAC1 generation/verification failed") 118 | } 119 | if !checker.CheckMAC2(msg, src) { 120 | t.Fatal("MAC2 generation/verification failed") 121 | } 122 | 123 | msg[5] ^= 0x20 124 | 125 | if checker.CheckMAC1(msg) { 126 | t.Fatal("MAC1 generation/verification failed") 127 | } 128 | if checker.CheckMAC2(msg, src) { 129 | t.Fatal("MAC2 generation/verification failed") 130 | } 131 | 132 | msg[5] ^= 0x20 133 | 134 | srcBad1 := []byte{192, 168, 13, 37, 40, 1} 135 | if checker.CheckMAC2(msg, srcBad1) { 136 | t.Fatal("MAC2 generation/verification failed") 137 | } 138 | 139 | srcBad2 := []byte{192, 168, 13, 38, 40, 1} 140 | if checker.CheckMAC2(msg, srcBad2) { 141 | t.Fatal("MAC2 generation/verification failed") 142 | } 143 | } 144 | 145 | checkMAC2([]byte{ 146 | 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, 147 | 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, 148 | 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, 149 | 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, 150 | 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, 151 | 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, 152 | 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, 153 | 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, 154 | }) 155 | 156 | checkMAC2([]byte{ 157 | 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, 158 | 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, 159 | 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, 160 | 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, 161 | 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, 162 | 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, 163 | 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, 164 | 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, 165 | 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, 166 | 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, 167 | 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, 168 | 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, 169 | 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, 170 | 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, 171 | 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, 172 | 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, 173 | 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, 174 | 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, 175 | 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, 176 | 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, 177 | 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, 178 | 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, 179 | 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, 180 | 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, 181 | 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, 182 | 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, 183 | 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, 184 | 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, 185 | 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, 186 | 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, 187 | 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, 188 | 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, 189 | }) 190 | } 191 | -------------------------------------------------------------------------------- /device/devicestate_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT. 2 | 3 | package device 4 | 5 | import "strconv" 6 | 7 | const _deviceState_name = "DownUpClosed" 8 | 9 | var _deviceState_index = [...]uint8{0, 4, 6, 12} 10 | 11 | func (i deviceState) String() string { 12 | if i >= deviceState(len(_deviceState_index)-1) { 13 | return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")" 14 | } 15 | return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]] 16 | } 17 | -------------------------------------------------------------------------------- /device/endpoint_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net/netip" 11 | ) 12 | 13 | type DummyEndpoint struct { 14 | src, dst netip.Addr 15 | } 16 | 17 | func CreateDummyEndpoint() (*DummyEndpoint, error) { 18 | var src, dst [16]byte 19 | if _, err := rand.Read(src[:]); err != nil { 20 | return nil, err 21 | } 22 | _, err := rand.Read(dst[:]) 23 | return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err 24 | } 25 | 26 | func (e *DummyEndpoint) ClearSrc() {} 27 | 28 | func (e *DummyEndpoint) SrcToString() string { 29 | return netip.AddrPortFrom(e.SrcIP(), 1000).String() 30 | } 31 | 32 | func (e *DummyEndpoint) DstToString() string { 33 | return netip.AddrPortFrom(e.DstIP(), 1000).String() 34 | } 35 | 36 | func (e *DummyEndpoint) DstToBytes() []byte { 37 | out := e.DstIP().AsSlice() 38 | out = append(out, byte(1000&0xff)) 39 | out = append(out, byte((1000>>8)&0xff)) 40 | return out 41 | } 42 | 43 | func (e *DummyEndpoint) DstIP() netip.Addr { 44 | return e.dst 45 | } 46 | 47 | func (e *DummyEndpoint) SrcIP() netip.Addr { 48 | return e.src 49 | } 50 | -------------------------------------------------------------------------------- /device/indextable.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/rand" 10 | "encoding/binary" 11 | "sync" 12 | ) 13 | 14 | type IndexTableEntry struct { 15 | peer *Peer 16 | handshake *Handshake 17 | keypair *Keypair 18 | } 19 | 20 | type IndexTable struct { 21 | sync.RWMutex 22 | table map[uint32]IndexTableEntry 23 | } 24 | 25 | func randUint32() (uint32, error) { 26 | var integer [4]byte 27 | _, err := rand.Read(integer[:]) 28 | // Arbitrary endianness; both are intrinsified by the Go compiler. 29 | return binary.LittleEndian.Uint32(integer[:]), err 30 | } 31 | 32 | func (table *IndexTable) Init() { 33 | table.Lock() 34 | defer table.Unlock() 35 | table.table = make(map[uint32]IndexTableEntry) 36 | } 37 | 38 | func (table *IndexTable) Delete(index uint32) { 39 | table.Lock() 40 | defer table.Unlock() 41 | delete(table.table, index) 42 | } 43 | 44 | func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { 45 | table.Lock() 46 | defer table.Unlock() 47 | entry, ok := table.table[index] 48 | if !ok { 49 | return 50 | } 51 | table.table[index] = IndexTableEntry{ 52 | peer: entry.peer, 53 | keypair: keypair, 54 | handshake: nil, 55 | } 56 | } 57 | 58 | func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) { 59 | for { 60 | // generate random index 61 | 62 | index, err := randUint32() 63 | if err != nil { 64 | return index, err 65 | } 66 | 67 | // check if index used 68 | 69 | table.RLock() 70 | _, ok := table.table[index] 71 | table.RUnlock() 72 | if ok { 73 | continue 74 | } 75 | 76 | // check again while locked 77 | 78 | table.Lock() 79 | _, found := table.table[index] 80 | if found { 81 | table.Unlock() 82 | continue 83 | } 84 | table.table[index] = IndexTableEntry{ 85 | peer: peer, 86 | handshake: handshake, 87 | keypair: nil, 88 | } 89 | table.Unlock() 90 | return index, nil 91 | } 92 | } 93 | 94 | func (table *IndexTable) Lookup(id uint32) IndexTableEntry { 95 | table.RLock() 96 | defer table.RUnlock() 97 | return table.table[id] 98 | } 99 | -------------------------------------------------------------------------------- /device/ip.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "net" 10 | ) 11 | 12 | const ( 13 | IPv4offsetTotalLength = 2 14 | IPv4offsetSrc = 12 15 | IPv4offsetDst = IPv4offsetSrc + net.IPv4len 16 | ) 17 | 18 | const ( 19 | IPv6offsetPayloadLength = 4 20 | IPv6offsetSrc = 8 21 | IPv6offsetDst = IPv6offsetSrc + net.IPv6len 22 | ) 23 | -------------------------------------------------------------------------------- /device/junk_creator.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "bytes" 5 | crand "crypto/rand" 6 | "fmt" 7 | v2 "math/rand/v2" 8 | ) 9 | 10 | type junkCreator struct { 11 | device *Device 12 | cha8Rand *v2.ChaCha8 13 | } 14 | 15 | func NewJunkCreator(d *Device) (junkCreator, error) { 16 | buf := make([]byte, 32) 17 | _, err := crand.Read(buf) 18 | if err != nil { 19 | return junkCreator{}, err 20 | } 21 | return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil 22 | } 23 | 24 | // Should be called with aSecMux RLocked 25 | func (jc *junkCreator) createJunkPackets() ([][]byte, error) { 26 | if jc.device.aSecCfg.junkPacketCount == 0 { 27 | return nil, nil 28 | } 29 | 30 | junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount) 31 | for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ { 32 | packetSize := jc.randomPacketSize() 33 | junk, err := jc.randomJunkWithSize(packetSize) 34 | if err != nil { 35 | return nil, fmt.Errorf("Failed to create junk packet: %v", err) 36 | } 37 | junks = append(junks, junk) 38 | } 39 | return junks, nil 40 | } 41 | 42 | // Should be called with aSecMux RLocked 43 | func (jc *junkCreator) randomPacketSize() int { 44 | return int( 45 | jc.cha8Rand.Uint64()%uint64( 46 | jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize, 47 | ), 48 | ) + jc.device.aSecCfg.junkPacketMinSize 49 | } 50 | 51 | // Should be called with aSecMux RLocked 52 | func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error { 53 | headerJunk, err := jc.randomJunkWithSize(size) 54 | if err != nil { 55 | return fmt.Errorf("failed to create header junk: %v", err) 56 | } 57 | _, err = writer.Write(headerJunk) 58 | if err != nil { 59 | return fmt.Errorf("failed to write header junk: %v", err) 60 | } 61 | return nil 62 | } 63 | 64 | // Should be called with aSecMux RLocked 65 | func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) { 66 | junk := make([]byte, size) 67 | _, err := jc.cha8Rand.Read(junk) 68 | return junk, err 69 | } 70 | -------------------------------------------------------------------------------- /device/junk_creator_test.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/amnezia-vpn/amneziawg-go/conn/bindtest" 9 | "github.com/amnezia-vpn/amneziawg-go/tun/tuntest" 10 | ) 11 | 12 | func setUpJunkCreator(t *testing.T) (junkCreator, error) { 13 | cfg, _ := genASecurityConfigs(t) 14 | tun := tuntest.NewChannelTUN() 15 | binds := bindtest.NewChannelBinds() 16 | level := LogLevelVerbose 17 | dev := NewDevice( 18 | tun.TUN(), 19 | binds[0], 20 | NewLogger(level, ""), 21 | ) 22 | 23 | if err := dev.IpcSet(cfg[0]); err != nil { 24 | t.Errorf("failed to configure device %v", err) 25 | dev.Close() 26 | return junkCreator{}, err 27 | } 28 | 29 | jc, err := NewJunkCreator(dev) 30 | 31 | if err != nil { 32 | t.Errorf("failed to create junk creator %v", err) 33 | dev.Close() 34 | return junkCreator{}, err 35 | } 36 | 37 | return jc, nil 38 | } 39 | 40 | func Test_junkCreator_createJunkPackets(t *testing.T) { 41 | jc, err := setUpJunkCreator(t) 42 | if err != nil { 43 | return 44 | } 45 | t.Run("", func(t *testing.T) { 46 | got, err := jc.createJunkPackets() 47 | if err != nil { 48 | t.Errorf( 49 | "junkCreator.createJunkPackets() = %v; failed", 50 | err, 51 | ) 52 | return 53 | } 54 | seen := make(map[string]bool) 55 | for _, junk := range got { 56 | key := string(junk) 57 | if seen[key] { 58 | t.Errorf( 59 | "junkCreator.createJunkPackets() = %v, duplicate key: %v", 60 | got, 61 | junk, 62 | ) 63 | return 64 | } 65 | seen[key] = true 66 | } 67 | }) 68 | } 69 | 70 | func Test_junkCreator_randomJunkWithSize(t *testing.T) { 71 | t.Run("", func(t *testing.T) { 72 | jc, err := setUpJunkCreator(t) 73 | if err != nil { 74 | return 75 | } 76 | r1, _ := jc.randomJunkWithSize(10) 77 | r2, _ := jc.randomJunkWithSize(10) 78 | fmt.Printf("%v\n%v\n", r1, r2) 79 | if bytes.Equal(r1, r2) { 80 | t.Errorf("same junks %v", err) 81 | jc.device.Close() 82 | return 83 | } 84 | }) 85 | } 86 | 87 | func Test_junkCreator_randomPacketSize(t *testing.T) { 88 | jc, err := setUpJunkCreator(t) 89 | if err != nil { 90 | return 91 | } 92 | for range [30]struct{}{} { 93 | t.Run("", func(t *testing.T) { 94 | if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got || 95 | got > jc.device.aSecCfg.junkPacketMaxSize { 96 | t.Errorf( 97 | "junkCreator.randomPacketSize() = %v, not between range [%v,%v]", 98 | got, 99 | jc.device.aSecCfg.junkPacketMinSize, 100 | jc.device.aSecCfg.junkPacketMaxSize, 101 | ) 102 | } 103 | }) 104 | } 105 | } 106 | 107 | func Test_junkCreator_appendJunk(t *testing.T) { 108 | jc, err := setUpJunkCreator(t) 109 | if err != nil { 110 | return 111 | } 112 | t.Run("", func(t *testing.T) { 113 | s := "apple" 114 | buffer := bytes.NewBuffer([]byte(s)) 115 | err := jc.appendJunk(buffer, 30) 116 | if err != nil && 117 | buffer.Len() != len(s)+30 { 118 | t.Errorf("appendWithJunk() size don't match") 119 | } 120 | read := make([]byte, 50) 121 | buffer.Read(read) 122 | fmt.Println(string(read)) 123 | }) 124 | } 125 | -------------------------------------------------------------------------------- /device/kdf_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "encoding/hex" 10 | "testing" 11 | 12 | "golang.org/x/crypto/blake2s" 13 | ) 14 | 15 | type KDFTest struct { 16 | key string 17 | input string 18 | t0 string 19 | t1 string 20 | t2 string 21 | } 22 | 23 | func assertEquals(t *testing.T, a, b string) { 24 | if a != b { 25 | t.Fatal("expected", a, "=", b) 26 | } 27 | } 28 | 29 | func TestKDF(t *testing.T) { 30 | tests := []KDFTest{ 31 | { 32 | key: "746573742d6b6579", 33 | input: "746573742d696e707574", 34 | t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", 35 | t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", 36 | t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", 37 | }, 38 | { 39 | key: "776972656775617264", 40 | input: "776972656775617264", 41 | t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", 42 | t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", 43 | t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", 44 | }, 45 | { 46 | key: "", 47 | input: "", 48 | t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", 49 | t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", 50 | t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", 51 | }, 52 | } 53 | 54 | var t0, t1, t2 [blake2s.Size]byte 55 | 56 | for _, test := range tests { 57 | key, _ := hex.DecodeString(test.key) 58 | input, _ := hex.DecodeString(test.input) 59 | KDF3(&t0, &t1, &t2, key, input) 60 | t0s := hex.EncodeToString(t0[:]) 61 | t1s := hex.EncodeToString(t1[:]) 62 | t2s := hex.EncodeToString(t2[:]) 63 | assertEquals(t, t0s, test.t0) 64 | assertEquals(t, t1s, test.t1) 65 | assertEquals(t, t2s, test.t2) 66 | } 67 | 68 | for _, test := range tests { 69 | key, _ := hex.DecodeString(test.key) 70 | input, _ := hex.DecodeString(test.input) 71 | KDF2(&t0, &t1, key, input) 72 | t0s := hex.EncodeToString(t0[:]) 73 | t1s := hex.EncodeToString(t1[:]) 74 | assertEquals(t, t0s, test.t0) 75 | assertEquals(t, t1s, test.t1) 76 | } 77 | 78 | for _, test := range tests { 79 | key, _ := hex.DecodeString(test.key) 80 | input, _ := hex.DecodeString(test.input) 81 | KDF1(&t0, key, input) 82 | t0s := hex.EncodeToString(t0[:]) 83 | assertEquals(t, t0s, test.t0) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /device/keypair.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/cipher" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | 14 | "github.com/amnezia-vpn/amneziawg-go/replay" 15 | ) 16 | 17 | /* Due to limitations in Go and /x/crypto there is currently 18 | * no way to ensure that key material is securely ereased in memory. 19 | * 20 | * Since this may harm the forward secrecy property, 21 | * we plan to resolve this issue; whenever Go allows us to do so. 22 | */ 23 | 24 | type Keypair struct { 25 | sendNonce atomic.Uint64 26 | send cipher.AEAD 27 | receive cipher.AEAD 28 | replayFilter replay.Filter 29 | isInitiator bool 30 | created time.Time 31 | localIndex uint32 32 | remoteIndex uint32 33 | } 34 | 35 | type Keypairs struct { 36 | sync.RWMutex 37 | current *Keypair 38 | previous *Keypair 39 | next atomic.Pointer[Keypair] 40 | } 41 | 42 | func (kp *Keypairs) Current() *Keypair { 43 | kp.RLock() 44 | defer kp.RUnlock() 45 | return kp.current 46 | } 47 | 48 | func (device *Device) DeleteKeypair(key *Keypair) { 49 | if key != nil { 50 | device.indexTable.Delete(key.localIndex) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /device/logger.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "log" 10 | "os" 11 | ) 12 | 13 | // A Logger provides logging for a Device. 14 | // The functions are Printf-style functions. 15 | // They must be safe for concurrent use. 16 | // They do not require a trailing newline in the format. 17 | // If nil, that level of logging will be silent. 18 | type Logger struct { 19 | Verbosef func(format string, args ...any) 20 | Errorf func(format string, args ...any) 21 | } 22 | 23 | // Log levels for use with NewLogger. 24 | const ( 25 | LogLevelSilent = iota 26 | LogLevelError 27 | LogLevelVerbose 28 | ) 29 | 30 | // Function for use in Logger for discarding logged lines. 31 | func DiscardLogf(format string, args ...any) {} 32 | 33 | // NewLogger constructs a Logger that writes to stdout. 34 | // It logs at the specified log level and above. 35 | // It decorates log lines with the log level, date, time, and prepend. 36 | func NewLogger(level int, prepend string) *Logger { 37 | logger := &Logger{DiscardLogf, DiscardLogf} 38 | logf := func(prefix string) func(string, ...any) { 39 | return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf 40 | } 41 | if level >= LogLevelVerbose { 42 | logger.Verbosef = logf("DEBUG") 43 | } 44 | if level >= LogLevelError { 45 | logger.Errorf = logf("ERROR") 46 | } 47 | return logger 48 | } 49 | -------------------------------------------------------------------------------- /device/mobilequirks.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | // DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created, 9 | // though it will try to deal with it, and race maybe, if called after. 10 | func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { 11 | device.net.brokenRoaming = true 12 | device.peers.RLock() 13 | for _, peer := range device.peers.keyMap { 14 | peer.endpoint.Lock() 15 | peer.endpoint.disableRoaming = peer.endpoint.val != nil 16 | peer.endpoint.Unlock() 17 | } 18 | device.peers.RUnlock() 19 | } 20 | -------------------------------------------------------------------------------- /device/noise-helpers.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "crypto/subtle" 12 | "errors" 13 | "hash" 14 | 15 | "golang.org/x/crypto/blake2s" 16 | "golang.org/x/crypto/curve25519" 17 | ) 18 | 19 | /* KDF related functions. 20 | * HMAC-based Key Derivation Function (HKDF) 21 | * https://tools.ietf.org/html/rfc5869 22 | */ 23 | 24 | func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { 25 | mac := hmac.New(func() hash.Hash { 26 | h, _ := blake2s.New256(nil) 27 | return h 28 | }, key) 29 | mac.Write(in0) 30 | mac.Sum(sum[:0]) 31 | } 32 | 33 | func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { 34 | mac := hmac.New(func() hash.Hash { 35 | h, _ := blake2s.New256(nil) 36 | return h 37 | }, key) 38 | mac.Write(in0) 39 | mac.Write(in1) 40 | mac.Sum(sum[:0]) 41 | } 42 | 43 | func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { 44 | HMAC1(t0, key, input) 45 | HMAC1(t0, t0[:], []byte{0x1}) 46 | } 47 | 48 | func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { 49 | var prk [blake2s.Size]byte 50 | HMAC1(&prk, key, input) 51 | HMAC1(t0, prk[:], []byte{0x1}) 52 | HMAC2(t1, prk[:], t0[:], []byte{0x2}) 53 | setZero(prk[:]) 54 | } 55 | 56 | func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { 57 | var prk [blake2s.Size]byte 58 | HMAC1(&prk, key, input) 59 | HMAC1(t0, prk[:], []byte{0x1}) 60 | HMAC2(t1, prk[:], t0[:], []byte{0x2}) 61 | HMAC2(t2, prk[:], t1[:], []byte{0x3}) 62 | setZero(prk[:]) 63 | } 64 | 65 | func isZero(val []byte) bool { 66 | acc := 1 67 | for _, b := range val { 68 | acc &= subtle.ConstantTimeByteEq(b, 0) 69 | } 70 | return acc == 1 71 | } 72 | 73 | /* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ 74 | func setZero(arr []byte) { 75 | for i := range arr { 76 | arr[i] = 0 77 | } 78 | } 79 | 80 | func (sk *NoisePrivateKey) clamp() { 81 | sk[0] &= 248 82 | sk[31] = (sk[31] & 127) | 64 83 | } 84 | 85 | func newPrivateKey() (sk NoisePrivateKey, err error) { 86 | _, err = rand.Read(sk[:]) 87 | sk.clamp() 88 | return 89 | } 90 | 91 | func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { 92 | apk := (*[NoisePublicKeySize]byte)(&pk) 93 | ask := (*[NoisePrivateKeySize]byte)(sk) 94 | curve25519.ScalarBaseMult(apk, ask) 95 | return 96 | } 97 | 98 | var errInvalidPublicKey = errors.New("invalid public key") 99 | 100 | func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) { 101 | apk := (*[NoisePublicKeySize]byte)(&pk) 102 | ask := (*[NoisePrivateKeySize]byte)(sk) 103 | curve25519.ScalarMult(&ss, ask, apk) 104 | if isZero(ss[:]) { 105 | return ss, errInvalidPublicKey 106 | } 107 | return ss, nil 108 | } 109 | -------------------------------------------------------------------------------- /device/noise-types.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/subtle" 10 | "encoding/hex" 11 | "errors" 12 | ) 13 | 14 | const ( 15 | NoisePublicKeySize = 32 16 | NoisePrivateKeySize = 32 17 | NoisePresharedKeySize = 32 18 | ) 19 | 20 | type ( 21 | NoisePublicKey [NoisePublicKeySize]byte 22 | NoisePrivateKey [NoisePrivateKeySize]byte 23 | NoisePresharedKey [NoisePresharedKeySize]byte 24 | NoiseNonce uint64 // padded to 12-bytes 25 | ) 26 | 27 | func loadExactHex(dst []byte, src string) error { 28 | slice, err := hex.DecodeString(src) 29 | if err != nil { 30 | return err 31 | } 32 | if len(slice) != len(dst) { 33 | return errors.New("hex string does not fit the slice") 34 | } 35 | copy(dst, slice) 36 | return nil 37 | } 38 | 39 | func (key NoisePrivateKey) IsZero() bool { 40 | var zero NoisePrivateKey 41 | return key.Equals(zero) 42 | } 43 | 44 | func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { 45 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 46 | } 47 | 48 | func (key *NoisePrivateKey) FromHex(src string) (err error) { 49 | err = loadExactHex(key[:], src) 50 | key.clamp() 51 | return 52 | } 53 | 54 | func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { 55 | err = loadExactHex(key[:], src) 56 | if key.IsZero() { 57 | return 58 | } 59 | key.clamp() 60 | return 61 | } 62 | 63 | func (key *NoisePublicKey) FromHex(src string) error { 64 | return loadExactHex(key[:], src) 65 | } 66 | 67 | func (key NoisePublicKey) IsZero() bool { 68 | var zero NoisePublicKey 69 | return key.Equals(zero) 70 | } 71 | 72 | func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { 73 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 74 | } 75 | 76 | func (key *NoisePresharedKey) FromHex(src string) error { 77 | return loadExactHex(key[:], src) 78 | } 79 | -------------------------------------------------------------------------------- /device/noise_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "bytes" 10 | "encoding/binary" 11 | "testing" 12 | 13 | "github.com/amnezia-vpn/amneziawg-go/conn" 14 | "github.com/amnezia-vpn/amneziawg-go/tun/tuntest" 15 | ) 16 | 17 | func TestCurveWrappers(t *testing.T) { 18 | sk1, err := newPrivateKey() 19 | assertNil(t, err) 20 | 21 | sk2, err := newPrivateKey() 22 | assertNil(t, err) 23 | 24 | pk1 := sk1.publicKey() 25 | pk2 := sk2.publicKey() 26 | 27 | ss1, err1 := sk1.sharedSecret(pk2) 28 | ss2, err2 := sk2.sharedSecret(pk1) 29 | 30 | if ss1 != ss2 || err1 != nil || err2 != nil { 31 | t.Fatal("Failed to compute shared secet") 32 | } 33 | } 34 | 35 | func randDevice(t *testing.T) *Device { 36 | sk, err := newPrivateKey() 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | tun := tuntest.NewChannelTUN() 41 | logger := NewLogger(LogLevelError, "") 42 | device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) 43 | device.SetPrivateKey(sk) 44 | return device 45 | } 46 | 47 | func assertNil(t *testing.T, err error) { 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | } 52 | 53 | func assertEqual(t *testing.T, a, b []byte) { 54 | if !bytes.Equal(a, b) { 55 | t.Fatal(a, "!=", b) 56 | } 57 | } 58 | 59 | func TestNoiseHandshake(t *testing.T) { 60 | dev1 := randDevice(t) 61 | dev2 := randDevice(t) 62 | 63 | defer dev1.Close() 64 | defer dev2.Close() 65 | 66 | peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) 67 | if err != nil { 68 | t.Fatal(err) 69 | } 70 | peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | peer1.Start() 75 | peer2.Start() 76 | 77 | assertEqual( 78 | t, 79 | peer1.handshake.precomputedStaticStatic[:], 80 | peer2.handshake.precomputedStaticStatic[:], 81 | ) 82 | 83 | /* simulate handshake */ 84 | 85 | // initiation message 86 | 87 | t.Log("exchange initiation message") 88 | 89 | msg1, err := dev1.CreateMessageInitiation(peer2) 90 | assertNil(t, err) 91 | 92 | packet := make([]byte, 0, 256) 93 | writer := bytes.NewBuffer(packet) 94 | err = binary.Write(writer, binary.LittleEndian, msg1) 95 | assertNil(t, err) 96 | peer := dev2.ConsumeMessageInitiation(msg1) 97 | if peer == nil { 98 | t.Fatal("handshake failed at initiation message") 99 | } 100 | 101 | assertEqual( 102 | t, 103 | peer1.handshake.chainKey[:], 104 | peer2.handshake.chainKey[:], 105 | ) 106 | 107 | assertEqual( 108 | t, 109 | peer1.handshake.hash[:], 110 | peer2.handshake.hash[:], 111 | ) 112 | 113 | // response message 114 | 115 | t.Log("exchange response message") 116 | 117 | msg2, err := dev2.CreateMessageResponse(peer1) 118 | assertNil(t, err) 119 | 120 | peer = dev1.ConsumeMessageResponse(msg2) 121 | if peer == nil { 122 | t.Fatal("handshake failed at response message") 123 | } 124 | 125 | assertEqual( 126 | t, 127 | peer1.handshake.chainKey[:], 128 | peer2.handshake.chainKey[:], 129 | ) 130 | 131 | assertEqual( 132 | t, 133 | peer1.handshake.hash[:], 134 | peer2.handshake.hash[:], 135 | ) 136 | 137 | // key pairs 138 | 139 | t.Log("deriving keys") 140 | 141 | err = peer1.BeginSymmetricSession() 142 | if err != nil { 143 | t.Fatal("failed to derive keypair for peer 1", err) 144 | } 145 | 146 | err = peer2.BeginSymmetricSession() 147 | if err != nil { 148 | t.Fatal("failed to derive keypair for peer 2", err) 149 | } 150 | 151 | key1 := peer1.keypairs.next.Load() 152 | key2 := peer2.keypairs.current 153 | 154 | // encrypting / decryption test 155 | 156 | t.Log("test key pairs") 157 | 158 | func() { 159 | testMsg := []byte("wireguard test message 1") 160 | var err error 161 | var out []byte 162 | var nonce [12]byte 163 | out = key1.send.Seal(out, nonce[:], testMsg, nil) 164 | out, err = key2.receive.Open(out[:0], nonce[:], out, nil) 165 | assertNil(t, err) 166 | assertEqual(t, out, testMsg) 167 | }() 168 | 169 | func() { 170 | testMsg := []byte("wireguard test message 2") 171 | var err error 172 | var out []byte 173 | var nonce [12]byte 174 | out = key2.send.Seal(out, nonce[:], testMsg, nil) 175 | out, err = key1.receive.Open(out[:0], nonce[:], out, nil) 176 | assertNil(t, err) 177 | assertEqual(t, out, testMsg) 178 | }() 179 | } 180 | -------------------------------------------------------------------------------- /device/pools.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "sync" 10 | "sync/atomic" 11 | ) 12 | 13 | type WaitPool struct { 14 | pool sync.Pool 15 | cond sync.Cond 16 | lock sync.Mutex 17 | count atomic.Uint32 18 | max uint32 19 | } 20 | 21 | func NewWaitPool(max uint32, new func() any) *WaitPool { 22 | p := &WaitPool{pool: sync.Pool{New: new}, max: max} 23 | p.cond = sync.Cond{L: &p.lock} 24 | return p 25 | } 26 | 27 | func (p *WaitPool) Get() any { 28 | if p.max != 0 { 29 | p.lock.Lock() 30 | for p.count.Load() >= p.max { 31 | p.cond.Wait() 32 | } 33 | p.count.Add(1) 34 | p.lock.Unlock() 35 | } 36 | return p.pool.Get() 37 | } 38 | 39 | func (p *WaitPool) Put(x any) { 40 | p.pool.Put(x) 41 | if p.max == 0 { 42 | return 43 | } 44 | p.count.Add(^uint32(0)) 45 | p.cond.Signal() 46 | } 47 | 48 | func (device *Device) PopulatePools() { 49 | device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { 50 | s := make([]*QueueInboundElement, 0, device.BatchSize()) 51 | return &QueueInboundElementsContainer{elems: s} 52 | }) 53 | device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { 54 | s := make([]*QueueOutboundElement, 0, device.BatchSize()) 55 | return &QueueOutboundElementsContainer{elems: s} 56 | }) 57 | device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { 58 | return new([MaxMessageSize]byte) 59 | }) 60 | device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 61 | return new(QueueInboundElement) 62 | }) 63 | device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 64 | return new(QueueOutboundElement) 65 | }) 66 | } 67 | 68 | func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { 69 | c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) 70 | c.Mutex = sync.Mutex{} 71 | return c 72 | } 73 | 74 | func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { 75 | for i := range c.elems { 76 | c.elems[i] = nil 77 | } 78 | c.elems = c.elems[:0] 79 | device.pool.inboundElementsContainer.Put(c) 80 | } 81 | 82 | func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { 83 | c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) 84 | c.Mutex = sync.Mutex{} 85 | return c 86 | } 87 | 88 | func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { 89 | for i := range c.elems { 90 | c.elems[i] = nil 91 | } 92 | c.elems = c.elems[:0] 93 | device.pool.outboundElementsContainer.Put(c) 94 | } 95 | 96 | func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { 97 | return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) 98 | } 99 | 100 | func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { 101 | device.pool.messageBuffers.Put(msg) 102 | } 103 | 104 | func (device *Device) GetInboundElement() *QueueInboundElement { 105 | return device.pool.inboundElements.Get().(*QueueInboundElement) 106 | } 107 | 108 | func (device *Device) PutInboundElement(elem *QueueInboundElement) { 109 | elem.clearPointers() 110 | device.pool.inboundElements.Put(elem) 111 | } 112 | 113 | func (device *Device) GetOutboundElement() *QueueOutboundElement { 114 | return device.pool.outboundElements.Get().(*QueueOutboundElement) 115 | } 116 | 117 | func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { 118 | elem.clearPointers() 119 | device.pool.outboundElements.Put(elem) 120 | } 121 | -------------------------------------------------------------------------------- /device/pools_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "runtime" 11 | "sync" 12 | "sync/atomic" 13 | "testing" 14 | "time" 15 | ) 16 | 17 | func TestWaitPool(t *testing.T) { 18 | t.Skip("Currently disabled") 19 | var wg sync.WaitGroup 20 | var trials atomic.Int32 21 | startTrials := int32(100000) 22 | if raceEnabled { 23 | // This test can be very slow with -race. 24 | startTrials /= 10 25 | } 26 | trials.Store(startTrials) 27 | workers := runtime.NumCPU() + 2 28 | if workers-4 <= 0 { 29 | t.Skip("Not enough cores") 30 | } 31 | p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) 32 | wg.Add(workers) 33 | var max atomic.Uint32 34 | updateMax := func() { 35 | count := p.count.Load() 36 | if count > p.max { 37 | t.Errorf("count (%d) > max (%d)", count, p.max) 38 | } 39 | for { 40 | old := max.Load() 41 | if count <= old { 42 | break 43 | } 44 | if max.CompareAndSwap(old, count) { 45 | break 46 | } 47 | } 48 | } 49 | for i := 0; i < workers; i++ { 50 | go func() { 51 | defer wg.Done() 52 | for trials.Add(-1) > 0 { 53 | updateMax() 54 | x := p.Get() 55 | updateMax() 56 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 57 | updateMax() 58 | p.Put(x) 59 | updateMax() 60 | } 61 | }() 62 | } 63 | wg.Wait() 64 | if max.Load() != p.max { 65 | t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) 66 | } 67 | } 68 | 69 | func BenchmarkWaitPool(b *testing.B) { 70 | var wg sync.WaitGroup 71 | var trials atomic.Int32 72 | trials.Store(int32(b.N)) 73 | workers := runtime.NumCPU() + 2 74 | if workers-4 <= 0 { 75 | b.Skip("Not enough cores") 76 | } 77 | p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) 78 | wg.Add(workers) 79 | b.ResetTimer() 80 | for i := 0; i < workers; i++ { 81 | go func() { 82 | defer wg.Done() 83 | for trials.Add(-1) > 0 { 84 | x := p.Get() 85 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 86 | p.Put(x) 87 | } 88 | }() 89 | } 90 | wg.Wait() 91 | } 92 | 93 | func BenchmarkWaitPoolEmpty(b *testing.B) { 94 | var wg sync.WaitGroup 95 | var trials atomic.Int32 96 | trials.Store(int32(b.N)) 97 | workers := runtime.NumCPU() + 2 98 | if workers-4 <= 0 { 99 | b.Skip("Not enough cores") 100 | } 101 | p := NewWaitPool(0, func() any { return make([]byte, 16) }) 102 | wg.Add(workers) 103 | b.ResetTimer() 104 | for i := 0; i < workers; i++ { 105 | go func() { 106 | defer wg.Done() 107 | for trials.Add(-1) > 0 { 108 | x := p.Get() 109 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 110 | p.Put(x) 111 | } 112 | }() 113 | } 114 | wg.Wait() 115 | } 116 | 117 | func BenchmarkSyncPool(b *testing.B) { 118 | var wg sync.WaitGroup 119 | var trials atomic.Int32 120 | trials.Store(int32(b.N)) 121 | workers := runtime.NumCPU() + 2 122 | if workers-4 <= 0 { 123 | b.Skip("Not enough cores") 124 | } 125 | p := sync.Pool{New: func() any { return make([]byte, 16) }} 126 | wg.Add(workers) 127 | b.ResetTimer() 128 | for i := 0; i < workers; i++ { 129 | go func() { 130 | defer wg.Done() 131 | for trials.Add(-1) > 0 { 132 | x := p.Get() 133 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 134 | p.Put(x) 135 | } 136 | }() 137 | } 138 | wg.Wait() 139 | } 140 | -------------------------------------------------------------------------------- /device/queueconstants_android.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import "github.com/amnezia-vpn/amneziawg-go/conn" 9 | 10 | /* Reduce memory consumption for Android */ 11 | 12 | const ( 13 | QueueStagedSize = conn.IdealBatchSize 14 | QueueOutboundSize = 1024 15 | QueueInboundSize = 1024 16 | QueueHandshakeSize = 1024 17 | MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram 18 | PreallocatedBuffersPerPool = 4096 19 | ) 20 | -------------------------------------------------------------------------------- /device/queueconstants_default.go: -------------------------------------------------------------------------------- 1 | //go:build !android && !ios && !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | import "github.com/amnezia-vpn/amneziawg-go/conn" 11 | 12 | const ( 13 | QueueStagedSize = conn.IdealBatchSize 14 | QueueOutboundSize = 1024 15 | QueueInboundSize = 1024 16 | QueueHandshakeSize = 1024 17 | MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram 18 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth 19 | ) 20 | -------------------------------------------------------------------------------- /device/queueconstants_ios.go: -------------------------------------------------------------------------------- 1 | //go:build ios 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | // Fit within memory limits for iOS's Network Extension API, which has stricter requirements. 11 | // These are vars instead of consts, because heavier network extensions might want to reduce 12 | // them further. 13 | var ( 14 | QueueStagedSize = 128 15 | QueueOutboundSize = 1024 16 | QueueInboundSize = 1024 17 | QueueHandshakeSize = 1024 18 | PreallocatedBuffersPerPool uint32 = 1024 19 | ) 20 | 21 | const MaxSegmentSize = 1700 22 | -------------------------------------------------------------------------------- /device/queueconstants_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | const ( 9 | QueueStagedSize = 128 10 | QueueOutboundSize = 1024 11 | QueueInboundSize = 1024 12 | QueueHandshakeSize = 1024 13 | MaxSegmentSize = 2048 - 32 // largest possible UDP datagram 14 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth 15 | ) 16 | -------------------------------------------------------------------------------- /device/race_disabled_test.go: -------------------------------------------------------------------------------- 1 | //go:build !race 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const raceEnabled = false 11 | -------------------------------------------------------------------------------- /device/race_enabled_test.go: -------------------------------------------------------------------------------- 1 | //go:build race 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const raceEnabled = true 11 | -------------------------------------------------------------------------------- /device/sticky_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | package device 4 | 5 | import ( 6 | "github.com/amnezia-vpn/amneziawg-go/conn" 7 | "github.com/amnezia-vpn/amneziawg-go/rwcancel" 8 | ) 9 | 10 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 11 | return nil, nil 12 | } 13 | -------------------------------------------------------------------------------- /device/sticky_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | * 5 | * This implements userspace semantics of "sticky sockets", modeled after 6 | * WireGuard's kernelspace implementation. This is more or less a straight port 7 | * of the sticky-sockets.c example code: 8 | * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c 9 | * 10 | * Currently there is no way to achieve this within the net package: 11 | * See e.g. https://github.com/golang/go/issues/17930 12 | * So this code is remains platform dependent. 13 | */ 14 | 15 | package device 16 | 17 | import ( 18 | "sync" 19 | "unsafe" 20 | 21 | "golang.org/x/sys/unix" 22 | 23 | "github.com/amnezia-vpn/amneziawg-go/conn" 24 | "github.com/amnezia-vpn/amneziawg-go/rwcancel" 25 | ) 26 | 27 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 28 | if !conn.StdNetSupportsStickySockets { 29 | return nil, nil 30 | } 31 | if _, ok := bind.(*conn.StdNetBind); !ok { 32 | return nil, nil 33 | } 34 | 35 | netlinkSock, err := createNetlinkRouteSocket() 36 | if err != nil { 37 | return nil, err 38 | } 39 | netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) 40 | if err != nil { 41 | unix.Close(netlinkSock) 42 | return nil, err 43 | } 44 | 45 | go device.routineRouteListener(bind, netlinkSock, netlinkCancel) 46 | 47 | return netlinkCancel, nil 48 | } 49 | 50 | func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { 51 | type peerEndpointPtr struct { 52 | peer *Peer 53 | endpoint *conn.Endpoint 54 | } 55 | var reqPeer map[uint32]peerEndpointPtr 56 | var reqPeerLock sync.Mutex 57 | 58 | defer netlinkCancel.Close() 59 | defer unix.Close(netlinkSock) 60 | 61 | for msg := make([]byte, 1<<16); ; { 62 | var err error 63 | var msgn int 64 | for { 65 | msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) 66 | if err == nil || !rwcancel.RetryAfterError(err) { 67 | break 68 | } 69 | if !netlinkCancel.ReadyRead() { 70 | return 71 | } 72 | } 73 | if err != nil { 74 | return 75 | } 76 | 77 | for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { 78 | 79 | hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) 80 | 81 | if uint(hdr.Len) > uint(len(remain)) { 82 | break 83 | } 84 | 85 | switch hdr.Type { 86 | case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: 87 | if hdr.Seq <= MaxPeers && hdr.Seq > 0 { 88 | if uint(len(remain)) < uint(hdr.Len) { 89 | break 90 | } 91 | if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { 92 | attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] 93 | for { 94 | if uint(len(attr)) < uint(unix.SizeofRtAttr) { 95 | break 96 | } 97 | attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) 98 | if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { 99 | break 100 | } 101 | if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { 102 | ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) 103 | reqPeerLock.Lock() 104 | if reqPeer == nil { 105 | reqPeerLock.Unlock() 106 | break 107 | } 108 | pePtr, ok := reqPeer[hdr.Seq] 109 | reqPeerLock.Unlock() 110 | if !ok { 111 | break 112 | } 113 | pePtr.peer.endpoint.Lock() 114 | if &pePtr.peer.endpoint.val != pePtr.endpoint { 115 | pePtr.peer.endpoint.Unlock() 116 | break 117 | } 118 | if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { 119 | pePtr.peer.endpoint.Unlock() 120 | break 121 | } 122 | pePtr.peer.endpoint.clearSrcOnTx = true 123 | pePtr.peer.endpoint.Unlock() 124 | } 125 | attr = attr[attrhdr.Len:] 126 | } 127 | } 128 | break 129 | } 130 | reqPeerLock.Lock() 131 | reqPeer = make(map[uint32]peerEndpointPtr) 132 | reqPeerLock.Unlock() 133 | go func() { 134 | device.peers.RLock() 135 | i := uint32(1) 136 | for _, peer := range device.peers.keyMap { 137 | peer.endpoint.Lock() 138 | if peer.endpoint.val == nil { 139 | peer.endpoint.Unlock() 140 | continue 141 | } 142 | nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) 143 | if nativeEP == nil { 144 | peer.endpoint.Unlock() 145 | continue 146 | } 147 | if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { 148 | peer.endpoint.Unlock() 149 | break 150 | } 151 | nlmsg := struct { 152 | hdr unix.NlMsghdr 153 | msg unix.RtMsg 154 | dsthdr unix.RtAttr 155 | dst [4]byte 156 | srchdr unix.RtAttr 157 | src [4]byte 158 | markhdr unix.RtAttr 159 | mark uint32 160 | }{ 161 | unix.NlMsghdr{ 162 | Type: uint16(unix.RTM_GETROUTE), 163 | Flags: unix.NLM_F_REQUEST, 164 | Seq: i, 165 | }, 166 | unix.RtMsg{ 167 | Family: unix.AF_INET, 168 | Dst_len: 32, 169 | Src_len: 32, 170 | }, 171 | unix.RtAttr{ 172 | Len: 8, 173 | Type: unix.RTA_DST, 174 | }, 175 | nativeEP.DstIP().As4(), 176 | unix.RtAttr{ 177 | Len: 8, 178 | Type: unix.RTA_SRC, 179 | }, 180 | nativeEP.SrcIP().As4(), 181 | unix.RtAttr{ 182 | Len: 8, 183 | Type: unix.RTA_MARK, 184 | }, 185 | device.net.fwmark, 186 | } 187 | nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) 188 | reqPeerLock.Lock() 189 | reqPeer[i] = peerEndpointPtr{ 190 | peer: peer, 191 | endpoint: &peer.endpoint.val, 192 | } 193 | reqPeerLock.Unlock() 194 | peer.endpoint.Unlock() 195 | i++ 196 | _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) 197 | if err != nil { 198 | break 199 | } 200 | } 201 | device.peers.RUnlock() 202 | }() 203 | } 204 | remain = remain[hdr.Len:] 205 | } 206 | } 207 | } 208 | 209 | func createNetlinkRouteSocket() (int, error) { 210 | sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) 211 | if err != nil { 212 | return -1, err 213 | } 214 | saddr := &unix.SockaddrNetlink{ 215 | Family: unix.AF_NETLINK, 216 | Groups: unix.RTMGRP_IPV4_ROUTE, 217 | } 218 | err = unix.Bind(sock, saddr) 219 | if err != nil { 220 | unix.Close(sock) 221 | return -1, err 222 | } 223 | return sock, nil 224 | } 225 | -------------------------------------------------------------------------------- /device/timers.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | * 5 | * This is based heavily on timers.c from the kernel implementation. 6 | */ 7 | 8 | package device 9 | 10 | import ( 11 | "sync" 12 | "time" 13 | _ "unsafe" 14 | ) 15 | 16 | //go:linkname fastrandn runtime.fastrandn 17 | func fastrandn(n uint32) uint32 18 | 19 | // A Timer manages time-based aspects of the WireGuard protocol. 20 | // Timer roughly copies the interface of the Linux kernel's struct timer_list. 21 | type Timer struct { 22 | *time.Timer 23 | modifyingLock sync.RWMutex 24 | runningLock sync.Mutex 25 | isPending bool 26 | } 27 | 28 | func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { 29 | timer := &Timer{} 30 | timer.Timer = time.AfterFunc(time.Hour, func() { 31 | timer.runningLock.Lock() 32 | defer timer.runningLock.Unlock() 33 | 34 | timer.modifyingLock.Lock() 35 | if !timer.isPending { 36 | timer.modifyingLock.Unlock() 37 | return 38 | } 39 | timer.isPending = false 40 | timer.modifyingLock.Unlock() 41 | 42 | expirationFunction(peer) 43 | }) 44 | timer.Stop() 45 | return timer 46 | } 47 | 48 | func (timer *Timer) Mod(d time.Duration) { 49 | timer.modifyingLock.Lock() 50 | timer.isPending = true 51 | timer.Reset(d) 52 | timer.modifyingLock.Unlock() 53 | } 54 | 55 | func (timer *Timer) Del() { 56 | timer.modifyingLock.Lock() 57 | timer.isPending = false 58 | timer.Stop() 59 | timer.modifyingLock.Unlock() 60 | } 61 | 62 | func (timer *Timer) DelSync() { 63 | timer.Del() 64 | timer.runningLock.Lock() 65 | timer.Del() 66 | timer.runningLock.Unlock() 67 | } 68 | 69 | func (timer *Timer) IsPending() bool { 70 | timer.modifyingLock.RLock() 71 | defer timer.modifyingLock.RUnlock() 72 | return timer.isPending 73 | } 74 | 75 | func (peer *Peer) timersActive() bool { 76 | return peer.isRunning.Load() && peer.device != nil && peer.device.isUp() 77 | } 78 | 79 | func expiredRetransmitHandshake(peer *Peer) { 80 | if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes { 81 | peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) 82 | 83 | if peer.timersActive() { 84 | peer.timers.sendKeepalive.Del() 85 | } 86 | 87 | /* We drop all packets without a keypair and don't try again, 88 | * if we try unsuccessfully for too long to make a handshake. 89 | */ 90 | peer.FlushStagedPackets() 91 | 92 | /* We set a timer for destroying any residue that might be left 93 | * of a partial exchange. 94 | */ 95 | if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() { 96 | peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) 97 | } 98 | } else { 99 | peer.timers.handshakeAttempts.Add(1) 100 | peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) 101 | 102 | /* We clear the endpoint address src address, in case this is the cause of trouble. */ 103 | peer.markEndpointSrcForClearing() 104 | 105 | peer.SendHandshakeInitiation(true) 106 | } 107 | } 108 | 109 | func expiredSendKeepalive(peer *Peer) { 110 | peer.SendKeepalive() 111 | if peer.timers.needAnotherKeepalive.Load() { 112 | peer.timers.needAnotherKeepalive.Store(false) 113 | if peer.timersActive() { 114 | peer.timers.sendKeepalive.Mod(KeepaliveTimeout) 115 | } 116 | } 117 | } 118 | 119 | func expiredNewHandshake(peer *Peer) { 120 | peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) 121 | /* We clear the endpoint address src address, in case this is the cause of trouble. */ 122 | peer.markEndpointSrcForClearing() 123 | peer.SendHandshakeInitiation(false) 124 | } 125 | 126 | func expiredZeroKeyMaterial(peer *Peer) { 127 | peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds())) 128 | peer.ZeroAndFlushAll() 129 | } 130 | 131 | func expiredPersistentKeepalive(peer *Peer) { 132 | if peer.persistentKeepaliveInterval.Load() > 0 { 133 | peer.SendKeepalive() 134 | } 135 | } 136 | 137 | /* Should be called after an authenticated data packet is sent. */ 138 | func (peer *Peer) timersDataSent() { 139 | if peer.timersActive() && !peer.timers.newHandshake.IsPending() { 140 | peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) 141 | } 142 | } 143 | 144 | /* Should be called after an authenticated data packet is received. */ 145 | func (peer *Peer) timersDataReceived() { 146 | if peer.timersActive() { 147 | if !peer.timers.sendKeepalive.IsPending() { 148 | peer.timers.sendKeepalive.Mod(KeepaliveTimeout) 149 | } else { 150 | peer.timers.needAnotherKeepalive.Store(true) 151 | } 152 | } 153 | } 154 | 155 | /* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ 156 | func (peer *Peer) timersAnyAuthenticatedPacketSent() { 157 | if peer.timersActive() { 158 | peer.timers.sendKeepalive.Del() 159 | } 160 | } 161 | 162 | /* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ 163 | func (peer *Peer) timersAnyAuthenticatedPacketReceived() { 164 | if peer.timersActive() { 165 | peer.timers.newHandshake.Del() 166 | } 167 | } 168 | 169 | /* Should be called after a handshake initiation message is sent. */ 170 | func (peer *Peer) timersHandshakeInitiated() { 171 | if peer.timersActive() { 172 | peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) 173 | } 174 | } 175 | 176 | /* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ 177 | func (peer *Peer) timersHandshakeComplete() { 178 | if peer.timersActive() { 179 | peer.timers.retransmitHandshake.Del() 180 | } 181 | peer.timers.handshakeAttempts.Store(0) 182 | peer.timers.sentLastMinuteHandshake.Store(false) 183 | peer.lastHandshakeNano.Store(time.Now().UnixNano()) 184 | } 185 | 186 | /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ 187 | func (peer *Peer) timersSessionDerived() { 188 | if peer.timersActive() { 189 | peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) 190 | } 191 | } 192 | 193 | /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ 194 | func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { 195 | keepalive := peer.persistentKeepaliveInterval.Load() 196 | if keepalive > 0 && peer.timersActive() { 197 | peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) 198 | } 199 | } 200 | 201 | func (peer *Peer) timersInit() { 202 | peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake) 203 | peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive) 204 | peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) 205 | peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) 206 | peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) 207 | } 208 | 209 | func (peer *Peer) timersStart() { 210 | peer.timers.handshakeAttempts.Store(0) 211 | peer.timers.sentLastMinuteHandshake.Store(false) 212 | peer.timers.needAnotherKeepalive.Store(false) 213 | } 214 | 215 | func (peer *Peer) timersStop() { 216 | peer.timers.retransmitHandshake.DelSync() 217 | peer.timers.sendKeepalive.DelSync() 218 | peer.timers.newHandshake.DelSync() 219 | peer.timers.zeroKeyMaterial.DelSync() 220 | peer.timers.persistentKeepalive.DelSync() 221 | } 222 | -------------------------------------------------------------------------------- /device/tun.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "fmt" 10 | 11 | "github.com/amnezia-vpn/amneziawg-go/tun" 12 | ) 13 | 14 | const DefaultMTU = 1420 15 | 16 | func (device *Device) RoutineTUNEventReader() { 17 | device.log.Verbosef("Routine: event worker - started") 18 | 19 | for event := range device.tun.device.Events() { 20 | if event&tun.EventMTUUpdate != 0 { 21 | mtu, err := device.tun.device.MTU() 22 | if err != nil { 23 | device.log.Errorf("Failed to load updated MTU of device: %v", err) 24 | continue 25 | } 26 | if mtu < 0 { 27 | device.log.Errorf("MTU not updated to negative value: %v", mtu) 28 | continue 29 | } 30 | var tooLarge string 31 | if mtu > MaxContentSize { 32 | tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) 33 | mtu = MaxContentSize 34 | } 35 | old := device.tun.mtu.Swap(int32(mtu)) 36 | if int(old) != mtu { 37 | device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) 38 | } 39 | } 40 | 41 | if event&tun.EventUp != 0 { 42 | device.log.Verbosef("Interface up requested") 43 | device.Up() 44 | } 45 | 46 | if event&tun.EventDown != 0 { 47 | device.log.Verbosef("Interface down requested") 48 | device.Down() 49 | } 50 | } 51 | 52 | device.log.Verbosef("Routine: event worker - stopped") 53 | } 54 | -------------------------------------------------------------------------------- /format_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | package main 6 | 7 | import ( 8 | "bytes" 9 | "go/format" 10 | "io/fs" 11 | "os" 12 | "path/filepath" 13 | "runtime" 14 | "sync" 15 | "testing" 16 | ) 17 | 18 | func TestFormatting(t *testing.T) { 19 | var wg sync.WaitGroup 20 | filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { 21 | if err != nil { 22 | t.Errorf("unable to walk %s: %v", path, err) 23 | return nil 24 | } 25 | if d.IsDir() || filepath.Ext(path) != ".go" { 26 | return nil 27 | } 28 | wg.Add(1) 29 | go func(path string) { 30 | defer wg.Done() 31 | src, err := os.ReadFile(path) 32 | if err != nil { 33 | t.Errorf("unable to read %s: %v", path, err) 34 | return 35 | } 36 | if runtime.GOOS == "windows" { 37 | src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'}) 38 | } 39 | formatted, err := format.Source(src) 40 | if err != nil { 41 | t.Errorf("unable to format %s: %v", path, err) 42 | return 43 | } 44 | if !bytes.Equal(src, formatted) { 45 | t.Errorf("unformatted code: %s", path) 46 | } 47 | }(path) 48 | return nil 49 | }) 50 | wg.Wait() 51 | } 52 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/amnezia-vpn/amneziawg-go 2 | 3 | go 1.24 4 | 5 | require ( 6 | github.com/tevino/abool/v2 v2.1.0 7 | golang.org/x/crypto v0.36.0 8 | golang.org/x/net v0.37.0 9 | golang.org/x/sys v0.31.0 10 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 11 | gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 12 | ) 13 | 14 | require ( 15 | github.com/google/btree v1.1.3 // indirect 16 | golang.org/x/time v0.9.0 // indirect 17 | ) 18 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= 2 | github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= 3 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 4 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 5 | github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= 6 | github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= 7 | golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= 8 | golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= 9 | golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= 10 | golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= 11 | golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= 12 | golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= 13 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= 14 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 15 | golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= 16 | golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 17 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= 18 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= 19 | gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ= 20 | gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= 21 | -------------------------------------------------------------------------------- /ipc/namedpipe/file.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 The Go Authors. All rights reserved. 2 | // Copyright 2015 Microsoft 3 | // Use of this source code is governed by a BSD-style 4 | // license that can be found in the LICENSE file. 5 | 6 | //go:build windows 7 | 8 | package namedpipe 9 | 10 | import ( 11 | "io" 12 | "os" 13 | "runtime" 14 | "sync" 15 | "sync/atomic" 16 | "time" 17 | "unsafe" 18 | 19 | "golang.org/x/sys/windows" 20 | ) 21 | 22 | type timeoutChan chan struct{} 23 | 24 | var ( 25 | ioInitOnce sync.Once 26 | ioCompletionPort windows.Handle 27 | ) 28 | 29 | // ioResult contains the result of an asynchronous IO operation 30 | type ioResult struct { 31 | bytes uint32 32 | err error 33 | } 34 | 35 | // ioOperation represents an outstanding asynchronous Win32 IO 36 | type ioOperation struct { 37 | o windows.Overlapped 38 | ch chan ioResult 39 | } 40 | 41 | func initIo() { 42 | h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 43 | if err != nil { 44 | panic(err) 45 | } 46 | ioCompletionPort = h 47 | go ioCompletionProcessor(h) 48 | } 49 | 50 | // file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. 51 | // It takes ownership of this handle and will close it if it is garbage collected. 52 | type file struct { 53 | handle windows.Handle 54 | wg sync.WaitGroup 55 | wgLock sync.RWMutex 56 | closing atomic.Bool 57 | socket bool 58 | readDeadline deadlineHandler 59 | writeDeadline deadlineHandler 60 | } 61 | 62 | type deadlineHandler struct { 63 | setLock sync.Mutex 64 | channel timeoutChan 65 | channelLock sync.RWMutex 66 | timer *time.Timer 67 | timedout atomic.Bool 68 | } 69 | 70 | // makeFile makes a new file from an existing file handle 71 | func makeFile(h windows.Handle) (*file, error) { 72 | f := &file{handle: h} 73 | ioInitOnce.Do(initIo) 74 | _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) 75 | if err != nil { 76 | return nil, err 77 | } 78 | err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) 79 | if err != nil { 80 | return nil, err 81 | } 82 | f.readDeadline.channel = make(timeoutChan) 83 | f.writeDeadline.channel = make(timeoutChan) 84 | return f, nil 85 | } 86 | 87 | // closeHandle closes the resources associated with a Win32 handle 88 | func (f *file) closeHandle() { 89 | f.wgLock.Lock() 90 | // Atomically set that we are closing, releasing the resources only once. 91 | if f.closing.Swap(true) == false { 92 | f.wgLock.Unlock() 93 | // cancel all IO and wait for it to complete 94 | windows.CancelIoEx(f.handle, nil) 95 | f.wg.Wait() 96 | // at this point, no new IO can start 97 | windows.Close(f.handle) 98 | f.handle = 0 99 | } else { 100 | f.wgLock.Unlock() 101 | } 102 | } 103 | 104 | // Close closes a file. 105 | func (f *file) Close() error { 106 | f.closeHandle() 107 | return nil 108 | } 109 | 110 | // prepareIo prepares for a new IO operation. 111 | // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. 112 | func (f *file) prepareIo() (*ioOperation, error) { 113 | f.wgLock.RLock() 114 | if f.closing.Load() { 115 | f.wgLock.RUnlock() 116 | return nil, os.ErrClosed 117 | } 118 | f.wg.Add(1) 119 | f.wgLock.RUnlock() 120 | c := &ioOperation{} 121 | c.ch = make(chan ioResult) 122 | return c, nil 123 | } 124 | 125 | // ioCompletionProcessor processes completed async IOs forever 126 | func ioCompletionProcessor(h windows.Handle) { 127 | for { 128 | var bytes uint32 129 | var key uintptr 130 | var op *ioOperation 131 | err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) 132 | if op == nil { 133 | panic(err) 134 | } 135 | op.ch <- ioResult{bytes, err} 136 | } 137 | } 138 | 139 | // asyncIo processes the return value from ReadFile or WriteFile, blocking until 140 | // the operation has actually completed. 141 | func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { 142 | if err != windows.ERROR_IO_PENDING { 143 | return int(bytes), err 144 | } 145 | 146 | if f.closing.Load() { 147 | windows.CancelIoEx(f.handle, &c.o) 148 | } 149 | 150 | var timeout timeoutChan 151 | if d != nil { 152 | d.channelLock.Lock() 153 | timeout = d.channel 154 | d.channelLock.Unlock() 155 | } 156 | 157 | var r ioResult 158 | select { 159 | case r = <-c.ch: 160 | err = r.err 161 | if err == windows.ERROR_OPERATION_ABORTED { 162 | if f.closing.Load() { 163 | err = os.ErrClosed 164 | } 165 | } else if err != nil && f.socket { 166 | // err is from Win32. Query the overlapped structure to get the winsock error. 167 | var bytes, flags uint32 168 | err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) 169 | } 170 | case <-timeout: 171 | windows.CancelIoEx(f.handle, &c.o) 172 | r = <-c.ch 173 | err = r.err 174 | if err == windows.ERROR_OPERATION_ABORTED { 175 | err = os.ErrDeadlineExceeded 176 | } 177 | } 178 | 179 | // runtime.KeepAlive is needed, as c is passed via native 180 | // code to ioCompletionProcessor, c must remain alive 181 | // until the channel read is complete. 182 | runtime.KeepAlive(c) 183 | return int(r.bytes), err 184 | } 185 | 186 | // Read reads from a file handle. 187 | func (f *file) Read(b []byte) (int, error) { 188 | c, err := f.prepareIo() 189 | if err != nil { 190 | return 0, err 191 | } 192 | defer f.wg.Done() 193 | 194 | if f.readDeadline.timedout.Load() { 195 | return 0, os.ErrDeadlineExceeded 196 | } 197 | 198 | var bytes uint32 199 | err = windows.ReadFile(f.handle, b, &bytes, &c.o) 200 | n, err := f.asyncIo(c, &f.readDeadline, bytes, err) 201 | runtime.KeepAlive(b) 202 | 203 | // Handle EOF conditions. 204 | if err == nil && n == 0 && len(b) != 0 { 205 | return 0, io.EOF 206 | } else if err == windows.ERROR_BROKEN_PIPE { 207 | return 0, io.EOF 208 | } else { 209 | return n, err 210 | } 211 | } 212 | 213 | // Write writes to a file handle. 214 | func (f *file) Write(b []byte) (int, error) { 215 | c, err := f.prepareIo() 216 | if err != nil { 217 | return 0, err 218 | } 219 | defer f.wg.Done() 220 | 221 | if f.writeDeadline.timedout.Load() { 222 | return 0, os.ErrDeadlineExceeded 223 | } 224 | 225 | var bytes uint32 226 | err = windows.WriteFile(f.handle, b, &bytes, &c.o) 227 | n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) 228 | runtime.KeepAlive(b) 229 | return n, err 230 | } 231 | 232 | func (f *file) SetReadDeadline(deadline time.Time) error { 233 | return f.readDeadline.set(deadline) 234 | } 235 | 236 | func (f *file) SetWriteDeadline(deadline time.Time) error { 237 | return f.writeDeadline.set(deadline) 238 | } 239 | 240 | func (f *file) Flush() error { 241 | return windows.FlushFileBuffers(f.handle) 242 | } 243 | 244 | func (f *file) Fd() uintptr { 245 | return uintptr(f.handle) 246 | } 247 | 248 | func (d *deadlineHandler) set(deadline time.Time) error { 249 | d.setLock.Lock() 250 | defer d.setLock.Unlock() 251 | 252 | if d.timer != nil { 253 | if !d.timer.Stop() { 254 | <-d.channel 255 | } 256 | d.timer = nil 257 | } 258 | d.timedout.Store(false) 259 | 260 | select { 261 | case <-d.channel: 262 | d.channelLock.Lock() 263 | d.channel = make(chan struct{}) 264 | d.channelLock.Unlock() 265 | default: 266 | } 267 | 268 | if deadline.IsZero() { 269 | return nil 270 | } 271 | 272 | timeoutIO := func() { 273 | d.timedout.Store(true) 274 | close(d.channel) 275 | } 276 | 277 | now := time.Now() 278 | duration := deadline.Sub(now) 279 | if deadline.After(now) { 280 | // Deadline is in the future, set a timer to wait 281 | d.timer = time.AfterFunc(duration, timeoutIO) 282 | } else { 283 | // Deadline is in the past. Cancel all pending IO now. 284 | timeoutIO() 285 | } 286 | return nil 287 | } 288 | -------------------------------------------------------------------------------- /ipc/uapi_bsd.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || freebsd || openbsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package ipc 9 | 10 | import ( 11 | "errors" 12 | "net" 13 | "os" 14 | "unsafe" 15 | 16 | "golang.org/x/sys/unix" 17 | ) 18 | 19 | type UAPIListener struct { 20 | listener net.Listener // unix socket listener 21 | connNew chan net.Conn 22 | connErr chan error 23 | kqueueFd int 24 | keventFd int 25 | } 26 | 27 | func (l *UAPIListener) Accept() (net.Conn, error) { 28 | for { 29 | select { 30 | case conn := <-l.connNew: 31 | return conn, nil 32 | 33 | case err := <-l.connErr: 34 | return nil, err 35 | } 36 | } 37 | } 38 | 39 | func (l *UAPIListener) Close() error { 40 | err1 := unix.Close(l.kqueueFd) 41 | err2 := unix.Close(l.keventFd) 42 | err3 := l.listener.Close() 43 | if err1 != nil { 44 | return err1 45 | } 46 | if err2 != nil { 47 | return err2 48 | } 49 | return err3 50 | } 51 | 52 | func (l *UAPIListener) Addr() net.Addr { 53 | return l.listener.Addr() 54 | } 55 | 56 | func UAPIListen(name string, file *os.File) (net.Listener, error) { 57 | // wrap file in listener 58 | 59 | listener, err := net.FileListener(file) 60 | if err != nil { 61 | return nil, err 62 | } 63 | 64 | uapi := &UAPIListener{ 65 | listener: listener, 66 | connNew: make(chan net.Conn, 1), 67 | connErr: make(chan error, 1), 68 | } 69 | 70 | if unixListener, ok := listener.(*net.UnixListener); ok { 71 | unixListener.SetUnlinkOnClose(true) 72 | } 73 | 74 | socketPath := sockPath(name) 75 | 76 | // watch for deletion of socket 77 | 78 | uapi.kqueueFd, err = unix.Kqueue() 79 | if err != nil { 80 | return nil, err 81 | } 82 | uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0) 83 | if err != nil { 84 | unix.Close(uapi.kqueueFd) 85 | return nil, err 86 | } 87 | 88 | go func(l *UAPIListener) { 89 | event := unix.Kevent_t{ 90 | Filter: unix.EVFILT_VNODE, 91 | Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT, 92 | Fflags: unix.NOTE_WRITE, 93 | } 94 | // Allow this assignment to work with both the 32-bit and 64-bit version 95 | // of the above struct. If you know another way, please submit a patch. 96 | *(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd) 97 | events := make([]unix.Kevent_t, 1) 98 | n := 1 99 | var kerr error 100 | for { 101 | // start with lstat to avoid race condition 102 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) { 103 | l.connErr <- err 104 | return 105 | } 106 | if (kerr != nil || n != 1) && kerr != unix.EINTR { 107 | if kerr != nil { 108 | l.connErr <- kerr 109 | } else { 110 | l.connErr <- errors.New("kqueue returned empty") 111 | } 112 | return 113 | } 114 | n, kerr = unix.Kevent(uapi.kqueueFd, []unix.Kevent_t{event}, events, nil) 115 | } 116 | }(uapi) 117 | 118 | // watch for new connections 119 | 120 | go func(l *UAPIListener) { 121 | for { 122 | conn, err := l.listener.Accept() 123 | if err != nil { 124 | l.connErr <- err 125 | break 126 | } 127 | l.connNew <- conn 128 | } 129 | }(uapi) 130 | 131 | return uapi, nil 132 | } 133 | -------------------------------------------------------------------------------- /ipc/uapi_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | import ( 9 | "net" 10 | "os" 11 | 12 | "github.com/amnezia-vpn/amneziawg-go/rwcancel" 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | type UAPIListener struct { 17 | listener net.Listener // unix socket listener 18 | connNew chan net.Conn 19 | connErr chan error 20 | inotifyFd int 21 | inotifyRWCancel *rwcancel.RWCancel 22 | } 23 | 24 | func (l *UAPIListener) Accept() (net.Conn, error) { 25 | for { 26 | select { 27 | case conn := <-l.connNew: 28 | return conn, nil 29 | 30 | case err := <-l.connErr: 31 | return nil, err 32 | } 33 | } 34 | } 35 | 36 | func (l *UAPIListener) Close() error { 37 | err1 := unix.Close(l.inotifyFd) 38 | err2 := l.inotifyRWCancel.Cancel() 39 | err3 := l.listener.Close() 40 | if err1 != nil { 41 | return err1 42 | } 43 | if err2 != nil { 44 | return err2 45 | } 46 | return err3 47 | } 48 | 49 | func (l *UAPIListener) Addr() net.Addr { 50 | return l.listener.Addr() 51 | } 52 | 53 | func UAPIListen(name string, file *os.File) (net.Listener, error) { 54 | // wrap file in listener 55 | 56 | listener, err := net.FileListener(file) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | if unixListener, ok := listener.(*net.UnixListener); ok { 62 | unixListener.SetUnlinkOnClose(true) 63 | } 64 | 65 | uapi := &UAPIListener{ 66 | listener: listener, 67 | connNew: make(chan net.Conn, 1), 68 | connErr: make(chan error, 1), 69 | } 70 | 71 | // watch for deletion of socket 72 | 73 | socketPath := sockPath(name) 74 | 75 | uapi.inotifyFd, err = unix.InotifyInit() 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | _, err = unix.InotifyAddWatch( 81 | uapi.inotifyFd, 82 | socketPath, 83 | unix.IN_ATTRIB| 84 | unix.IN_DELETE| 85 | unix.IN_DELETE_SELF, 86 | ) 87 | 88 | if err != nil { 89 | return nil, err 90 | } 91 | 92 | uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd) 93 | if err != nil { 94 | unix.Close(uapi.inotifyFd) 95 | return nil, err 96 | } 97 | 98 | go func(l *UAPIListener) { 99 | var buf [0]byte 100 | for { 101 | defer uapi.inotifyRWCancel.Close() 102 | // start with lstat to avoid race condition 103 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) { 104 | l.connErr <- err 105 | return 106 | } 107 | _, err := uapi.inotifyRWCancel.Read(buf[:]) 108 | if err != nil { 109 | l.connErr <- err 110 | return 111 | } 112 | } 113 | }(uapi) 114 | 115 | // watch for new connections 116 | 117 | go func(l *UAPIListener) { 118 | for { 119 | conn, err := l.listener.Accept() 120 | if err != nil { 121 | l.connErr <- err 122 | break 123 | } 124 | l.connNew <- conn 125 | } 126 | }(uapi) 127 | 128 | return uapi, nil 129 | } 130 | -------------------------------------------------------------------------------- /ipc/uapi_unix.go: -------------------------------------------------------------------------------- 1 | //go:build linux || darwin || freebsd || openbsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package ipc 9 | 10 | import ( 11 | "errors" 12 | "fmt" 13 | "net" 14 | "os" 15 | 16 | "golang.org/x/sys/unix" 17 | ) 18 | 19 | const ( 20 | IpcErrorIO = -int64(unix.EIO) 21 | IpcErrorProtocol = -int64(unix.EPROTO) 22 | IpcErrorInvalid = -int64(unix.EINVAL) 23 | IpcErrorPortInUse = -int64(unix.EADDRINUSE) 24 | IpcErrorUnknown = -55 // ENOANO 25 | ) 26 | 27 | // socketDirectory is variable because it is modified by a linker 28 | // flag in wireguard-android. 29 | var socketDirectory = "/var/run/amneziawg" 30 | 31 | func sockPath(iface string) string { 32 | return fmt.Sprintf("%s/%s.sock", socketDirectory, iface) 33 | } 34 | 35 | func UAPIOpen(name string) (*os.File, error) { 36 | if err := os.MkdirAll(socketDirectory, 0o755); err != nil { 37 | return nil, err 38 | } 39 | 40 | socketPath := sockPath(name) 41 | addr, err := net.ResolveUnixAddr("unix", socketPath) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | oldUmask := unix.Umask(0o077) 47 | defer unix.Umask(oldUmask) 48 | 49 | listener, err := net.ListenUnix("unix", addr) 50 | if err == nil { 51 | return listener.File() 52 | } 53 | 54 | // Test socket, if not in use cleanup and try again. 55 | if _, err := net.Dial("unix", socketPath); err == nil { 56 | return nil, errors.New("unix socket in use") 57 | } 58 | if err := os.Remove(socketPath); err != nil { 59 | return nil, err 60 | } 61 | listener, err = net.ListenUnix("unix", addr) 62 | if err != nil { 63 | return nil, err 64 | } 65 | return listener.File() 66 | } 67 | -------------------------------------------------------------------------------- /ipc/uapi_wasm.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | // Made up sentinel error codes for {js,wasip1}/wasm. 9 | const ( 10 | IpcErrorIO = 1 11 | IpcErrorInvalid = 2 12 | IpcErrorPortInUse = 3 13 | IpcErrorUnknown = 4 14 | IpcErrorProtocol = 5 15 | ) 16 | -------------------------------------------------------------------------------- /ipc/uapi_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | import ( 9 | "net" 10 | 11 | "github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe" 12 | "golang.org/x/sys/windows" 13 | ) 14 | 15 | // TODO: replace these with actual standard windows error numbers from the win package 16 | const ( 17 | IpcErrorIO = -int64(5) 18 | IpcErrorProtocol = -int64(71) 19 | IpcErrorInvalid = -int64(22) 20 | IpcErrorPortInUse = -int64(98) 21 | IpcErrorUnknown = -int64(55) 22 | ) 23 | 24 | type UAPIListener struct { 25 | listener net.Listener // unix socket listener 26 | connNew chan net.Conn 27 | connErr chan error 28 | kqueueFd int 29 | keventFd int 30 | } 31 | 32 | func (l *UAPIListener) Accept() (net.Conn, error) { 33 | for { 34 | select { 35 | case conn := <-l.connNew: 36 | return conn, nil 37 | 38 | case err := <-l.connErr: 39 | return nil, err 40 | } 41 | } 42 | } 43 | 44 | func (l *UAPIListener) Close() error { 45 | return l.listener.Close() 46 | } 47 | 48 | func (l *UAPIListener) Addr() net.Addr { 49 | return l.listener.Addr() 50 | } 51 | 52 | var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR 53 | 54 | func init() { 55 | var err error 56 | UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)") 57 | if err != nil { 58 | panic(err) 59 | } 60 | } 61 | 62 | func UAPIListen(name string) (net.Listener, error) { 63 | listener, err := (&namedpipe.ListenConfig{ 64 | SecurityDescriptor: UAPISecurityDescriptor, 65 | }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | uapi := &UAPIListener{ 71 | listener: listener, 72 | connNew: make(chan net.Conn, 1), 73 | connErr: make(chan error, 1), 74 | } 75 | 76 | go func(l *UAPIListener) { 77 | for { 78 | conn, err := l.listener.Accept() 79 | if err != nil { 80 | l.connErr <- err 81 | break 82 | } 83 | l.connNew <- conn 84 | } 85 | }(uapi) 86 | 87 | return uapi, nil 88 | } 89 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "fmt" 12 | "os" 13 | "os/signal" 14 | "runtime" 15 | "strconv" 16 | 17 | "github.com/amnezia-vpn/amneziawg-go/conn" 18 | "github.com/amnezia-vpn/amneziawg-go/device" 19 | "github.com/amnezia-vpn/amneziawg-go/ipc" 20 | "github.com/amnezia-vpn/amneziawg-go/tun" 21 | "golang.org/x/sys/unix" 22 | ) 23 | 24 | const ( 25 | ExitSetupSuccess = 0 26 | ExitSetupFailed = 1 27 | ) 28 | 29 | const ( 30 | ENV_WG_TUN_FD = "WG_TUN_FD" 31 | ENV_WG_UAPI_FD = "WG_UAPI_FD" 32 | ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" 33 | ) 34 | 35 | func printUsage() { 36 | fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) 37 | } 38 | 39 | func warning() { 40 | switch runtime.GOOS { 41 | case "linux", "freebsd", "openbsd": 42 | if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { 43 | return 44 | } 45 | default: 46 | return 47 | } 48 | 49 | fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐") 50 | fmt.Fprintln(os.Stderr, "│ │") 51 | fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │") 52 | fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │") 53 | fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │") 54 | fmt.Fprintln(os.Stderr, "│ please visit: │") 55 | fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │") 56 | fmt.Fprintln(os.Stderr, "│ │") 57 | fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘") 58 | } 59 | 60 | func main() { 61 | if len(os.Args) == 2 && os.Args[1] == "--version" { 62 | fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH) 63 | return 64 | } 65 | 66 | warning() 67 | 68 | var foreground bool 69 | var interfaceName string 70 | if len(os.Args) < 2 || len(os.Args) > 3 { 71 | printUsage() 72 | return 73 | } 74 | 75 | switch os.Args[1] { 76 | 77 | case "-f", "--foreground": 78 | foreground = true 79 | if len(os.Args) != 3 { 80 | printUsage() 81 | return 82 | } 83 | interfaceName = os.Args[2] 84 | 85 | default: 86 | foreground = false 87 | if len(os.Args) != 2 { 88 | printUsage() 89 | return 90 | } 91 | interfaceName = os.Args[1] 92 | } 93 | 94 | if !foreground { 95 | foreground = os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" 96 | } 97 | 98 | // get log level (default: info) 99 | 100 | logLevel := func() int { 101 | switch os.Getenv("LOG_LEVEL") { 102 | case "verbose", "debug": 103 | return device.LogLevelVerbose 104 | case "error": 105 | return device.LogLevelError 106 | case "silent": 107 | return device.LogLevelSilent 108 | } 109 | return device.LogLevelError 110 | }() 111 | 112 | // open TUN device (or use supplied fd) 113 | 114 | tdev, err := func() (tun.Device, error) { 115 | tunFdStr := os.Getenv(ENV_WG_TUN_FD) 116 | if tunFdStr == "" { 117 | return tun.CreateTUN(interfaceName, device.DefaultMTU) 118 | } 119 | 120 | // construct tun device from supplied fd 121 | 122 | fd, err := strconv.ParseUint(tunFdStr, 10, 32) 123 | if err != nil { 124 | return nil, err 125 | } 126 | 127 | err = unix.SetNonblock(int(fd), true) 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | file := os.NewFile(uintptr(fd), "") 133 | return tun.CreateTUNFromFile(file, device.DefaultMTU) 134 | }() 135 | 136 | if err == nil { 137 | realInterfaceName, err2 := tdev.Name() 138 | if err2 == nil { 139 | interfaceName = realInterfaceName 140 | } 141 | } 142 | 143 | logger := device.NewLogger( 144 | logLevel, 145 | fmt.Sprintf("(%s) ", interfaceName), 146 | ) 147 | 148 | logger.Verbosef("Starting amneziawg-go version %s", Version) 149 | 150 | if err != nil { 151 | logger.Errorf("Failed to create TUN device: %v", err) 152 | os.Exit(ExitSetupFailed) 153 | } 154 | 155 | // open UAPI file (or use supplied fd) 156 | 157 | fileUAPI, err := func() (*os.File, error) { 158 | uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) 159 | if uapiFdStr == "" { 160 | return ipc.UAPIOpen(interfaceName) 161 | } 162 | 163 | // use supplied fd 164 | 165 | fd, err := strconv.ParseUint(uapiFdStr, 10, 32) 166 | if err != nil { 167 | return nil, err 168 | } 169 | 170 | return os.NewFile(uintptr(fd), ""), nil 171 | }() 172 | if err != nil { 173 | logger.Errorf("UAPI listen error: %v", err) 174 | os.Exit(ExitSetupFailed) 175 | return 176 | } 177 | // daemonize the process 178 | 179 | if !foreground { 180 | env := os.Environ() 181 | env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) 182 | env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) 183 | env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND)) 184 | files := [3]*os.File{} 185 | if os.Getenv("LOG_LEVEL") != "" && logLevel != device.LogLevelSilent { 186 | files[0], _ = os.Open(os.DevNull) 187 | files[1] = os.Stdout 188 | files[2] = os.Stderr 189 | } else { 190 | files[0], _ = os.Open(os.DevNull) 191 | files[1], _ = os.Open(os.DevNull) 192 | files[2], _ = os.Open(os.DevNull) 193 | } 194 | attr := &os.ProcAttr{ 195 | Files: []*os.File{ 196 | files[0], // stdin 197 | files[1], // stdout 198 | files[2], // stderr 199 | tdev.File(), 200 | fileUAPI, 201 | }, 202 | Dir: ".", 203 | Env: env, 204 | } 205 | 206 | path, err := os.Executable() 207 | if err != nil { 208 | logger.Errorf("Failed to determine executable: %v", err) 209 | os.Exit(ExitSetupFailed) 210 | } 211 | 212 | process, err := os.StartProcess( 213 | path, 214 | os.Args, 215 | attr, 216 | ) 217 | if err != nil { 218 | logger.Errorf("Failed to daemonize: %v", err) 219 | os.Exit(ExitSetupFailed) 220 | } 221 | process.Release() 222 | return 223 | } 224 | 225 | device := device.NewDevice(tdev, conn.NewDefaultBind(), logger) 226 | 227 | logger.Verbosef("Device started") 228 | 229 | errs := make(chan error) 230 | term := make(chan os.Signal, 1) 231 | 232 | uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) 233 | if err != nil { 234 | logger.Errorf("Failed to listen on uapi socket: %v", err) 235 | os.Exit(ExitSetupFailed) 236 | } 237 | 238 | go func() { 239 | for { 240 | conn, err := uapi.Accept() 241 | if err != nil { 242 | errs <- err 243 | return 244 | } 245 | go device.IpcHandle(conn) 246 | } 247 | }() 248 | 249 | logger.Verbosef("UAPI listener started") 250 | 251 | // wait for program to terminate 252 | 253 | signal.Notify(term, unix.SIGTERM) 254 | signal.Notify(term, os.Interrupt) 255 | 256 | select { 257 | case <-term: 258 | case <-errs: 259 | case <-device.Wait(): 260 | } 261 | 262 | // clean up 263 | 264 | uapi.Close() 265 | device.Close() 266 | 267 | logger.Verbosef("Shutting down") 268 | } 269 | -------------------------------------------------------------------------------- /main_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package main 7 | 8 | import ( 9 | "fmt" 10 | "os" 11 | "os/signal" 12 | 13 | "golang.org/x/sys/windows" 14 | 15 | "github.com/amnezia-vpn/amneziawg-go/conn" 16 | "github.com/amnezia-vpn/amneziawg-go/device" 17 | "github.com/amnezia-vpn/amneziawg-go/ipc" 18 | 19 | "github.com/amnezia-vpn/amneziawg-go/tun" 20 | ) 21 | 22 | const ( 23 | ExitSetupSuccess = 0 24 | ExitSetupFailed = 1 25 | ) 26 | 27 | func main() { 28 | if len(os.Args) != 2 { 29 | os.Exit(ExitSetupFailed) 30 | } 31 | interfaceName := os.Args[1] 32 | 33 | fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org") 34 | 35 | logger := device.NewLogger( 36 | device.LogLevelVerbose, 37 | fmt.Sprintf("(%s) ", interfaceName), 38 | ) 39 | logger.Verbosef("Starting amneziawg-go version %s", Version) 40 | 41 | tun, err := tun.CreateTUN(interfaceName, 0) 42 | if err == nil { 43 | realInterfaceName, err2 := tun.Name() 44 | if err2 == nil { 45 | interfaceName = realInterfaceName 46 | } 47 | } else { 48 | logger.Errorf("Failed to create TUN device: %v", err) 49 | os.Exit(ExitSetupFailed) 50 | } 51 | 52 | device := device.NewDevice(tun, conn.NewDefaultBind(), logger) 53 | err = device.Up() 54 | if err != nil { 55 | logger.Errorf("Failed to bring up device: %v", err) 56 | os.Exit(ExitSetupFailed) 57 | } 58 | logger.Verbosef("Device started") 59 | 60 | uapi, err := ipc.UAPIListen(interfaceName) 61 | if err != nil { 62 | logger.Errorf("Failed to listen on uapi socket: %v", err) 63 | os.Exit(ExitSetupFailed) 64 | } 65 | 66 | errs := make(chan error) 67 | term := make(chan os.Signal, 1) 68 | 69 | go func() { 70 | for { 71 | conn, err := uapi.Accept() 72 | if err != nil { 73 | errs <- err 74 | return 75 | } 76 | go device.IpcHandle(conn) 77 | } 78 | }() 79 | logger.Verbosef("UAPI listener started") 80 | 81 | // wait for program to terminate 82 | 83 | signal.Notify(term, os.Interrupt) 84 | signal.Notify(term, os.Kill) 85 | signal.Notify(term, windows.SIGTERM) 86 | 87 | select { 88 | case <-term: 89 | case <-errs: 90 | case <-device.Wait(): 91 | } 92 | 93 | // clean up 94 | 95 | uapi.Close() 96 | device.Close() 97 | 98 | logger.Verbosef("Shutting down") 99 | } 100 | -------------------------------------------------------------------------------- /ratelimiter/ratelimiter.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ratelimiter 7 | 8 | import ( 9 | "net/netip" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | const ( 15 | packetsPerSecond = 20 16 | packetsBurstable = 5 17 | garbageCollectTime = time.Second 18 | packetCost = 1000000000 / packetsPerSecond 19 | maxTokens = packetCost * packetsBurstable 20 | ) 21 | 22 | type RatelimiterEntry struct { 23 | mu sync.Mutex 24 | lastTime time.Time 25 | tokens int64 26 | } 27 | 28 | type Ratelimiter struct { 29 | mu sync.RWMutex 30 | timeNow func() time.Time 31 | 32 | stopReset chan struct{} // send to reset, close to stop 33 | table map[netip.Addr]*RatelimiterEntry 34 | } 35 | 36 | func (rate *Ratelimiter) Close() { 37 | rate.mu.Lock() 38 | defer rate.mu.Unlock() 39 | 40 | if rate.stopReset != nil { 41 | close(rate.stopReset) 42 | } 43 | } 44 | 45 | func (rate *Ratelimiter) Init() { 46 | rate.mu.Lock() 47 | defer rate.mu.Unlock() 48 | 49 | if rate.timeNow == nil { 50 | rate.timeNow = time.Now 51 | } 52 | 53 | // stop any ongoing garbage collection routine 54 | if rate.stopReset != nil { 55 | close(rate.stopReset) 56 | } 57 | 58 | rate.stopReset = make(chan struct{}) 59 | rate.table = make(map[netip.Addr]*RatelimiterEntry) 60 | 61 | stopReset := rate.stopReset // store in case Init is called again. 62 | 63 | // Start garbage collection routine. 64 | go func() { 65 | ticker := time.NewTicker(time.Second) 66 | ticker.Stop() 67 | for { 68 | select { 69 | case _, ok := <-stopReset: 70 | ticker.Stop() 71 | if !ok { 72 | return 73 | } 74 | ticker = time.NewTicker(time.Second) 75 | case <-ticker.C: 76 | if rate.cleanup() { 77 | ticker.Stop() 78 | } 79 | } 80 | } 81 | }() 82 | } 83 | 84 | func (rate *Ratelimiter) cleanup() (empty bool) { 85 | rate.mu.Lock() 86 | defer rate.mu.Unlock() 87 | 88 | for key, entry := range rate.table { 89 | entry.mu.Lock() 90 | if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 91 | delete(rate.table, key) 92 | } 93 | entry.mu.Unlock() 94 | } 95 | 96 | return len(rate.table) == 0 97 | } 98 | 99 | func (rate *Ratelimiter) Allow(ip netip.Addr) bool { 100 | var entry *RatelimiterEntry 101 | // lookup entry 102 | rate.mu.RLock() 103 | entry = rate.table[ip] 104 | rate.mu.RUnlock() 105 | 106 | // make new entry if not found 107 | if entry == nil { 108 | entry = new(RatelimiterEntry) 109 | entry.tokens = maxTokens - packetCost 110 | entry.lastTime = rate.timeNow() 111 | rate.mu.Lock() 112 | rate.table[ip] = entry 113 | if len(rate.table) == 1 { 114 | rate.stopReset <- struct{}{} 115 | } 116 | rate.mu.Unlock() 117 | return true 118 | } 119 | 120 | // add tokens to entry 121 | entry.mu.Lock() 122 | now := rate.timeNow() 123 | entry.tokens += now.Sub(entry.lastTime).Nanoseconds() 124 | entry.lastTime = now 125 | if entry.tokens > maxTokens { 126 | entry.tokens = maxTokens 127 | } 128 | 129 | // subtract cost of packet 130 | if entry.tokens > packetCost { 131 | entry.tokens -= packetCost 132 | entry.mu.Unlock() 133 | return true 134 | } 135 | entry.mu.Unlock() 136 | return false 137 | } 138 | -------------------------------------------------------------------------------- /ratelimiter/ratelimiter_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ratelimiter 7 | 8 | import ( 9 | "net/netip" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | type result struct { 15 | allowed bool 16 | text string 17 | wait time.Duration 18 | } 19 | 20 | func TestRatelimiter(t *testing.T) { 21 | var rate Ratelimiter 22 | var expectedResults []result 23 | 24 | nano := func(nano int64) time.Duration { 25 | return time.Nanosecond * time.Duration(nano) 26 | } 27 | 28 | add := func(res result) { 29 | expectedResults = append( 30 | expectedResults, 31 | res, 32 | ) 33 | } 34 | 35 | for i := 0; i < packetsBurstable; i++ { 36 | add(result{ 37 | allowed: true, 38 | text: "initial burst", 39 | }) 40 | } 41 | 42 | add(result{ 43 | allowed: false, 44 | text: "after burst", 45 | }) 46 | 47 | add(result{ 48 | allowed: true, 49 | wait: nano(time.Second.Nanoseconds() / packetsPerSecond), 50 | text: "filling tokens for single packet", 51 | }) 52 | 53 | add(result{ 54 | allowed: false, 55 | text: "not having refilled enough", 56 | }) 57 | 58 | add(result{ 59 | allowed: true, 60 | wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)), 61 | text: "filling tokens for two packet burst", 62 | }) 63 | 64 | add(result{ 65 | allowed: true, 66 | text: "second packet in 2 packet burst", 67 | }) 68 | 69 | add(result{ 70 | allowed: false, 71 | text: "packet following 2 packet burst", 72 | }) 73 | 74 | ips := []netip.Addr{ 75 | netip.MustParseAddr("127.0.0.1"), 76 | netip.MustParseAddr("192.168.1.1"), 77 | netip.MustParseAddr("172.167.2.3"), 78 | netip.MustParseAddr("97.231.252.215"), 79 | netip.MustParseAddr("248.97.91.167"), 80 | netip.MustParseAddr("188.208.233.47"), 81 | netip.MustParseAddr("104.2.183.179"), 82 | netip.MustParseAddr("72.129.46.120"), 83 | netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), 84 | netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), 85 | netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), 86 | netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), 87 | netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), 88 | netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), 89 | } 90 | 91 | now := time.Now() 92 | rate.timeNow = func() time.Time { 93 | return now 94 | } 95 | defer func() { 96 | // Lock to avoid data race with cleanup goroutine from Init. 97 | rate.mu.Lock() 98 | defer rate.mu.Unlock() 99 | 100 | rate.timeNow = time.Now 101 | }() 102 | timeSleep := func(d time.Duration) { 103 | now = now.Add(d + 1) 104 | rate.cleanup() 105 | } 106 | 107 | rate.Init() 108 | defer rate.Close() 109 | 110 | for i, res := range expectedResults { 111 | timeSleep(res.wait) 112 | for _, ip := range ips { 113 | allowed := rate.Allow(ip) 114 | if allowed != res.allowed { 115 | t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed) 116 | } 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /replay/replay.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. 7 | package replay 8 | 9 | type block uint64 10 | 11 | const ( 12 | blockBitLog = 6 // 1<<6 == 64 bits 13 | blockBits = 1 << blockBitLog // must be power of 2 14 | ringBlocks = 1 << 7 // must be power of 2 15 | windowSize = (ringBlocks - 1) * blockBits 16 | blockMask = ringBlocks - 1 17 | bitMask = blockBits - 1 18 | ) 19 | 20 | // A Filter rejects replayed messages by checking if message counter value is 21 | // within a sliding window of previously received messages. 22 | // The zero value for Filter is an empty filter ready to use. 23 | // Filters are unsafe for concurrent use. 24 | type Filter struct { 25 | last uint64 26 | ring [ringBlocks]block 27 | } 28 | 29 | // Reset resets the filter to empty state. 30 | func (f *Filter) Reset() { 31 | f.last = 0 32 | f.ring[0] = 0 33 | } 34 | 35 | // ValidateCounter checks if the counter should be accepted. 36 | // Overlimit counters (>= limit) are always rejected. 37 | func (f *Filter) ValidateCounter(counter, limit uint64) bool { 38 | if counter >= limit { 39 | return false 40 | } 41 | indexBlock := counter >> blockBitLog 42 | if counter > f.last { // move window forward 43 | current := f.last >> blockBitLog 44 | diff := indexBlock - current 45 | if diff > ringBlocks { 46 | diff = ringBlocks // cap diff to clear the whole ring 47 | } 48 | for i := current + 1; i <= current+diff; i++ { 49 | f.ring[i&blockMask] = 0 50 | } 51 | f.last = counter 52 | } else if f.last-counter > windowSize { // behind current window 53 | return false 54 | } 55 | // check and set bit 56 | indexBlock &= blockMask 57 | indexBit := counter & bitMask 58 | old := f.ring[indexBlock] 59 | new := old | 1< 0; i-- { 91 | T(i, true) 92 | } 93 | 94 | t.Log("Bulk test 4") 95 | filter.Reset() 96 | testNumber = 0 97 | for i := uint64(windowSize + 2); i > 1; i-- { 98 | T(i, true) 99 | } 100 | T(0, false) 101 | 102 | t.Log("Bulk test 5") 103 | filter.Reset() 104 | testNumber = 0 105 | for i := uint64(windowSize); i > 0; i-- { 106 | T(i, true) 107 | } 108 | T(windowSize+1, true) 109 | T(0, false) 110 | 111 | t.Log("Bulk test 6") 112 | filter.Reset() 113 | testNumber = 0 114 | for i := uint64(windowSize); i > 0; i-- { 115 | T(i, true) 116 | } 117 | T(0, true) 118 | T(windowSize+1, true) 119 | } 120 | -------------------------------------------------------------------------------- /rwcancel/rwcancel.go: -------------------------------------------------------------------------------- 1 | //go:build !windows && !wasm 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | // Package rwcancel implements cancelable read/write operations on 9 | // a file descriptor. 10 | package rwcancel 11 | 12 | import ( 13 | "errors" 14 | "os" 15 | "syscall" 16 | 17 | "golang.org/x/sys/unix" 18 | ) 19 | 20 | type RWCancel struct { 21 | fd int 22 | closingReader *os.File 23 | closingWriter *os.File 24 | } 25 | 26 | func NewRWCancel(fd int) (*RWCancel, error) { 27 | err := unix.SetNonblock(fd, true) 28 | if err != nil { 29 | return nil, err 30 | } 31 | rwcancel := RWCancel{fd: fd} 32 | 33 | rwcancel.closingReader, rwcancel.closingWriter, err = os.Pipe() 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | return &rwcancel, nil 39 | } 40 | 41 | func RetryAfterError(err error) bool { 42 | return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) 43 | } 44 | 45 | func (rw *RWCancel) ReadyRead() bool { 46 | closeFd := int32(rw.closingReader.Fd()) 47 | 48 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}} 49 | var err error 50 | for { 51 | _, err = unix.Poll(pollFds, -1) 52 | if err == nil || !RetryAfterError(err) { 53 | break 54 | } 55 | } 56 | if err != nil { 57 | return false 58 | } 59 | if pollFds[1].Revents != 0 { 60 | return false 61 | } 62 | return pollFds[0].Revents != 0 63 | } 64 | 65 | func (rw *RWCancel) ReadyWrite() bool { 66 | closeFd := int32(rw.closingReader.Fd()) 67 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}} 68 | var err error 69 | for { 70 | _, err = unix.Poll(pollFds, -1) 71 | if err == nil || !RetryAfterError(err) { 72 | break 73 | } 74 | } 75 | if err != nil { 76 | return false 77 | } 78 | 79 | if pollFds[1].Revents != 0 { 80 | return false 81 | } 82 | return pollFds[0].Revents != 0 83 | } 84 | 85 | func (rw *RWCancel) Read(p []byte) (n int, err error) { 86 | for { 87 | n, err := unix.Read(rw.fd, p) 88 | if err == nil || !RetryAfterError(err) { 89 | return n, err 90 | } 91 | if !rw.ReadyRead() { 92 | return 0, os.ErrClosed 93 | } 94 | } 95 | } 96 | 97 | func (rw *RWCancel) Write(p []byte) (n int, err error) { 98 | for { 99 | n, err := unix.Write(rw.fd, p) 100 | if err == nil || !RetryAfterError(err) { 101 | return n, err 102 | } 103 | if !rw.ReadyWrite() { 104 | return 0, os.ErrClosed 105 | } 106 | } 107 | } 108 | 109 | func (rw *RWCancel) Cancel() (err error) { 110 | _, err = rw.closingWriter.Write([]byte{0}) 111 | return 112 | } 113 | 114 | func (rw *RWCancel) Close() { 115 | rw.closingReader.Close() 116 | rw.closingWriter.Close() 117 | } 118 | -------------------------------------------------------------------------------- /rwcancel/rwcancel_stub.go: -------------------------------------------------------------------------------- 1 | //go:build windows || wasm 2 | 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rwcancel 6 | 7 | type RWCancel struct{} 8 | 9 | func (*RWCancel) Cancel() {} 10 | -------------------------------------------------------------------------------- /tai64n/tai64n.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tai64n 7 | 8 | import ( 9 | "bytes" 10 | "encoding/binary" 11 | "time" 12 | ) 13 | 14 | const ( 15 | TimestampSize = 12 16 | base = uint64(0x400000000000000a) 17 | whitenerMask = uint32(0x1000000 - 1) 18 | ) 19 | 20 | type Timestamp [TimestampSize]byte 21 | 22 | func stamp(t time.Time) Timestamp { 23 | var tai64n Timestamp 24 | secs := base + uint64(t.Unix()) 25 | nano := uint32(t.Nanosecond()) &^ whitenerMask 26 | binary.BigEndian.PutUint64(tai64n[:], secs) 27 | binary.BigEndian.PutUint32(tai64n[8:], nano) 28 | return tai64n 29 | } 30 | 31 | func Now() Timestamp { 32 | return stamp(time.Now()) 33 | } 34 | 35 | func (t1 Timestamp) After(t2 Timestamp) bool { 36 | return bytes.Compare(t1[:], t2[:]) > 0 37 | } 38 | 39 | func (t Timestamp) String() string { 40 | return time.Unix(int64(binary.BigEndian.Uint64(t[:8])-base), int64(binary.BigEndian.Uint32(t[8:12]))).String() 41 | } 42 | -------------------------------------------------------------------------------- /tai64n/tai64n_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tai64n 7 | 8 | import ( 9 | "testing" 10 | "time" 11 | ) 12 | 13 | // Test that timestamps are monotonic as required by Wireguard and that 14 | // nanosecond-level information is whitened to prevent side channel attacks. 15 | func TestMonotonic(t *testing.T) { 16 | startTime := time.Unix(0, 123456789) // a nontrivial bit pattern 17 | // Whitening should reduce timestamp granularity 18 | // to more than 10 but fewer than 20 milliseconds. 19 | tests := []struct { 20 | name string 21 | t1, t2 time.Time 22 | wantAfter bool 23 | }{ 24 | {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false}, 25 | {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false}, 26 | {"after_1_ms", startTime, startTime.Add(time.Millisecond), false}, 27 | {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false}, 28 | {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true}, 29 | } 30 | 31 | for _, tt := range tests { 32 | t.Run(tt.name, func(t *testing.T) { 33 | ts1, ts2 := stamp(tt.t1), stamp(tt.t2) 34 | got := ts2.After(ts1) 35 | if got != tt.wantAfter { 36 | t.Errorf("after = %v; want %v", got, tt.wantAfter) 37 | } 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /tun/alignment_windows_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "reflect" 10 | "testing" 11 | "unsafe" 12 | ) 13 | 14 | func checkAlignment(t *testing.T, name string, offset uintptr) { 15 | t.Helper() 16 | if offset%8 != 0 { 17 | t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) 18 | } 19 | } 20 | 21 | // TestRateJugglerAlignment checks that atomically-accessed fields are 22 | // aligned to 64-bit boundaries, as required by the atomic package. 23 | // 24 | // Unfortunately, violating this rule on 32-bit platforms results in a 25 | // hard segfault at runtime. 26 | func TestRateJugglerAlignment(t *testing.T) { 27 | var r rateJuggler 28 | 29 | typ := reflect.TypeOf(&r).Elem() 30 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 31 | for i := 0; i < typ.NumField(); i++ { 32 | field := typ.Field(i) 33 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 34 | field.Name, 35 | field.Offset, 36 | field.Type.Size(), 37 | field.Type.Align(), 38 | ) 39 | } 40 | 41 | checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current)) 42 | checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount)) 43 | checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime)) 44 | } 45 | 46 | // TestNativeTunAlignment checks that atomically-accessed fields are 47 | // aligned to 64-bit boundaries, as required by the atomic package. 48 | // 49 | // Unfortunately, violating this rule on 32-bit platforms results in a 50 | // hard segfault at runtime. 51 | func TestNativeTunAlignment(t *testing.T) { 52 | var tun NativeTun 53 | 54 | typ := reflect.TypeOf(&tun).Elem() 55 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 56 | for i := 0; i < typ.NumField(); i++ { 57 | field := typ.Field(i) 58 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 59 | field.Name, 60 | field.Offset, 61 | field.Type.Size(), 62 | field.Type.Align(), 63 | ) 64 | } 65 | 66 | checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate)) 67 | } 68 | -------------------------------------------------------------------------------- /tun/checksum.go: -------------------------------------------------------------------------------- 1 | package tun 2 | 3 | import "encoding/binary" 4 | 5 | // TODO: Explore SIMD and/or other assembly optimizations. 6 | // TODO: Test native endian loads. See RFC 1071 section 2 part B. 7 | func checksumNoFold(b []byte, initial uint64) uint64 { 8 | ac := initial 9 | 10 | for len(b) >= 128 { 11 | ac += uint64(binary.BigEndian.Uint32(b[:4])) 12 | ac += uint64(binary.BigEndian.Uint32(b[4:8])) 13 | ac += uint64(binary.BigEndian.Uint32(b[8:12])) 14 | ac += uint64(binary.BigEndian.Uint32(b[12:16])) 15 | ac += uint64(binary.BigEndian.Uint32(b[16:20])) 16 | ac += uint64(binary.BigEndian.Uint32(b[20:24])) 17 | ac += uint64(binary.BigEndian.Uint32(b[24:28])) 18 | ac += uint64(binary.BigEndian.Uint32(b[28:32])) 19 | ac += uint64(binary.BigEndian.Uint32(b[32:36])) 20 | ac += uint64(binary.BigEndian.Uint32(b[36:40])) 21 | ac += uint64(binary.BigEndian.Uint32(b[40:44])) 22 | ac += uint64(binary.BigEndian.Uint32(b[44:48])) 23 | ac += uint64(binary.BigEndian.Uint32(b[48:52])) 24 | ac += uint64(binary.BigEndian.Uint32(b[52:56])) 25 | ac += uint64(binary.BigEndian.Uint32(b[56:60])) 26 | ac += uint64(binary.BigEndian.Uint32(b[60:64])) 27 | ac += uint64(binary.BigEndian.Uint32(b[64:68])) 28 | ac += uint64(binary.BigEndian.Uint32(b[68:72])) 29 | ac += uint64(binary.BigEndian.Uint32(b[72:76])) 30 | ac += uint64(binary.BigEndian.Uint32(b[76:80])) 31 | ac += uint64(binary.BigEndian.Uint32(b[80:84])) 32 | ac += uint64(binary.BigEndian.Uint32(b[84:88])) 33 | ac += uint64(binary.BigEndian.Uint32(b[88:92])) 34 | ac += uint64(binary.BigEndian.Uint32(b[92:96])) 35 | ac += uint64(binary.BigEndian.Uint32(b[96:100])) 36 | ac += uint64(binary.BigEndian.Uint32(b[100:104])) 37 | ac += uint64(binary.BigEndian.Uint32(b[104:108])) 38 | ac += uint64(binary.BigEndian.Uint32(b[108:112])) 39 | ac += uint64(binary.BigEndian.Uint32(b[112:116])) 40 | ac += uint64(binary.BigEndian.Uint32(b[116:120])) 41 | ac += uint64(binary.BigEndian.Uint32(b[120:124])) 42 | ac += uint64(binary.BigEndian.Uint32(b[124:128])) 43 | b = b[128:] 44 | } 45 | if len(b) >= 64 { 46 | ac += uint64(binary.BigEndian.Uint32(b[:4])) 47 | ac += uint64(binary.BigEndian.Uint32(b[4:8])) 48 | ac += uint64(binary.BigEndian.Uint32(b[8:12])) 49 | ac += uint64(binary.BigEndian.Uint32(b[12:16])) 50 | ac += uint64(binary.BigEndian.Uint32(b[16:20])) 51 | ac += uint64(binary.BigEndian.Uint32(b[20:24])) 52 | ac += uint64(binary.BigEndian.Uint32(b[24:28])) 53 | ac += uint64(binary.BigEndian.Uint32(b[28:32])) 54 | ac += uint64(binary.BigEndian.Uint32(b[32:36])) 55 | ac += uint64(binary.BigEndian.Uint32(b[36:40])) 56 | ac += uint64(binary.BigEndian.Uint32(b[40:44])) 57 | ac += uint64(binary.BigEndian.Uint32(b[44:48])) 58 | ac += uint64(binary.BigEndian.Uint32(b[48:52])) 59 | ac += uint64(binary.BigEndian.Uint32(b[52:56])) 60 | ac += uint64(binary.BigEndian.Uint32(b[56:60])) 61 | ac += uint64(binary.BigEndian.Uint32(b[60:64])) 62 | b = b[64:] 63 | } 64 | if len(b) >= 32 { 65 | ac += uint64(binary.BigEndian.Uint32(b[:4])) 66 | ac += uint64(binary.BigEndian.Uint32(b[4:8])) 67 | ac += uint64(binary.BigEndian.Uint32(b[8:12])) 68 | ac += uint64(binary.BigEndian.Uint32(b[12:16])) 69 | ac += uint64(binary.BigEndian.Uint32(b[16:20])) 70 | ac += uint64(binary.BigEndian.Uint32(b[20:24])) 71 | ac += uint64(binary.BigEndian.Uint32(b[24:28])) 72 | ac += uint64(binary.BigEndian.Uint32(b[28:32])) 73 | b = b[32:] 74 | } 75 | if len(b) >= 16 { 76 | ac += uint64(binary.BigEndian.Uint32(b[:4])) 77 | ac += uint64(binary.BigEndian.Uint32(b[4:8])) 78 | ac += uint64(binary.BigEndian.Uint32(b[8:12])) 79 | ac += uint64(binary.BigEndian.Uint32(b[12:16])) 80 | b = b[16:] 81 | } 82 | if len(b) >= 8 { 83 | ac += uint64(binary.BigEndian.Uint32(b[:4])) 84 | ac += uint64(binary.BigEndian.Uint32(b[4:8])) 85 | b = b[8:] 86 | } 87 | if len(b) >= 4 { 88 | ac += uint64(binary.BigEndian.Uint32(b)) 89 | b = b[4:] 90 | } 91 | if len(b) >= 2 { 92 | ac += uint64(binary.BigEndian.Uint16(b)) 93 | b = b[2:] 94 | } 95 | if len(b) == 1 { 96 | ac += uint64(b[0]) << 8 97 | } 98 | 99 | return ac 100 | } 101 | 102 | func checksum(b []byte, initial uint64) uint16 { 103 | ac := checksumNoFold(b, initial) 104 | ac = (ac >> 16) + (ac & 0xffff) 105 | ac = (ac >> 16) + (ac & 0xffff) 106 | ac = (ac >> 16) + (ac & 0xffff) 107 | ac = (ac >> 16) + (ac & 0xffff) 108 | return uint16(ac) 109 | } 110 | 111 | func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { 112 | sum := checksumNoFold(srcAddr, 0) 113 | sum = checksumNoFold(dstAddr, sum) 114 | sum = checksumNoFold([]byte{0, protocol}, sum) 115 | tmp := make([]byte, 2) 116 | binary.BigEndian.PutUint16(tmp, totalLen) 117 | return checksumNoFold(tmp, sum) 118 | } 119 | -------------------------------------------------------------------------------- /tun/checksum_test.go: -------------------------------------------------------------------------------- 1 | package tun 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | ) 8 | 9 | func BenchmarkChecksum(b *testing.B) { 10 | lengths := []int{ 11 | 64, 12 | 128, 13 | 256, 14 | 512, 15 | 1024, 16 | 1500, 17 | 2048, 18 | 4096, 19 | 8192, 20 | 9000, 21 | 9001, 22 | } 23 | 24 | for _, length := range lengths { 25 | b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { 26 | buf := make([]byte, length) 27 | rng := rand.New(rand.NewSource(1)) 28 | rng.Read(buf) 29 | b.ResetTimer() 30 | for i := 0; i < b.N; i++ { 31 | checksum(buf, 0) 32 | } 33 | }) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /tun/errors.go: -------------------------------------------------------------------------------- 1 | package tun 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | var ( 8 | // ErrTooManySegments is returned by Device.Read() when segmentation 9 | // overflows the length of supplied buffers. This error should not cause 10 | // reads to cease. 11 | ErrTooManySegments = errors.New("too many segments") 12 | ) 13 | -------------------------------------------------------------------------------- /tun/netstack/examples/http_client.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "io" 12 | "log" 13 | "net/http" 14 | "net/netip" 15 | 16 | "github.com/amnezia-vpn/amneziawg-go/conn" 17 | "github.com/amnezia-vpn/amneziawg-go/device" 18 | "github.com/amnezia-vpn/amneziawg-go/tun/netstack" 19 | ) 20 | 21 | func main() { 22 | tun, tnet, err := netstack.CreateNetTUN( 23 | []netip.Addr{netip.MustParseAddr("192.168.4.28")}, 24 | []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 25 | 1420) 26 | if err != nil { 27 | log.Panic(err) 28 | } 29 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 30 | err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379 31 | public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28 32 | allowed_ip=0.0.0.0/0 33 | endpoint=127.0.0.1:58120 34 | `) 35 | err = dev.Up() 36 | if err != nil { 37 | log.Panic(err) 38 | } 39 | 40 | client := http.Client{ 41 | Transport: &http.Transport{ 42 | DialContext: tnet.DialContext, 43 | }, 44 | } 45 | resp, err := client.Get("http://192.168.4.29/") 46 | if err != nil { 47 | log.Panic(err) 48 | } 49 | body, err := io.ReadAll(resp.Body) 50 | if err != nil { 51 | log.Panic(err) 52 | } 53 | log.Println(string(body)) 54 | } 55 | -------------------------------------------------------------------------------- /tun/netstack/examples/http_server.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "io" 12 | "log" 13 | "net" 14 | "net/http" 15 | "net/netip" 16 | 17 | "github.com/amnezia-vpn/amneziawg-go/conn" 18 | "github.com/amnezia-vpn/amneziawg-go/device" 19 | "github.com/amnezia-vpn/amneziawg-go/tun/netstack" 20 | ) 21 | 22 | func main() { 23 | tun, tnet, err := netstack.CreateNetTUN( 24 | []netip.Addr{netip.MustParseAddr("192.168.4.29")}, 25 | []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, 26 | 1420, 27 | ) 28 | if err != nil { 29 | log.Panic(err) 30 | } 31 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 32 | dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641 33 | listen_port=58120 34 | public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c 35 | allowed_ip=192.168.4.28/32 36 | persistent_keepalive_interval=25 37 | `) 38 | dev.Up() 39 | listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80}) 40 | if err != nil { 41 | log.Panicln(err) 42 | } 43 | http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { 44 | log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent()) 45 | io.WriteString(writer, "Hello from userspace TCP!") 46 | }) 47 | err = http.Serve(listener, nil) 48 | if err != nil { 49 | log.Panicln(err) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /tun/netstack/examples/ping_client.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "bytes" 12 | "log" 13 | "math/rand" 14 | "net/netip" 15 | "time" 16 | 17 | "golang.org/x/net/icmp" 18 | "golang.org/x/net/ipv4" 19 | 20 | "github.com/amnezia-vpn/amneziawg-go/conn" 21 | "github.com/amnezia-vpn/amneziawg-go/device" 22 | "github.com/amnezia-vpn/amneziawg-go/tun/netstack" 23 | ) 24 | 25 | func main() { 26 | tun, tnet, err := netstack.CreateNetTUN( 27 | []netip.Addr{netip.MustParseAddr("192.168.4.29")}, 28 | []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 29 | 1420) 30 | if err != nil { 31 | log.Panic(err) 32 | } 33 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 34 | dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f 35 | public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b 36 | endpoint=163.172.161.0:12912 37 | allowed_ip=0.0.0.0/0 38 | `) 39 | err = dev.Up() 40 | if err != nil { 41 | log.Panic(err) 42 | } 43 | 44 | socket, err := tnet.Dial("ping4", "zx2c4.com") 45 | if err != nil { 46 | log.Panic(err) 47 | } 48 | requestPing := icmp.Echo{ 49 | Seq: rand.Intn(1 << 16), 50 | Data: []byte("gopher burrow"), 51 | } 52 | icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) 53 | socket.SetReadDeadline(time.Now().Add(time.Second * 10)) 54 | start := time.Now() 55 | _, err = socket.Write(icmpBytes) 56 | if err != nil { 57 | log.Panic(err) 58 | } 59 | n, err := socket.Read(icmpBytes[:]) 60 | if err != nil { 61 | log.Panic(err) 62 | } 63 | replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) 64 | if err != nil { 65 | log.Panic(err) 66 | } 67 | replyPing, ok := replyPacket.Body.(*icmp.Echo) 68 | if !ok { 69 | log.Panicf("invalid reply type: %v", replyPacket) 70 | } 71 | if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { 72 | log.Panicf("invalid ping reply: %v", replyPing) 73 | } 74 | log.Printf("Ping latency: %v", time.Since(start)) 75 | } 76 | -------------------------------------------------------------------------------- /tun/operateonfd.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package tun 9 | 10 | import ( 11 | "fmt" 12 | ) 13 | 14 | func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) { 15 | sysconn, err := tun.tunFile.SyscallConn() 16 | if err != nil { 17 | tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error()) 18 | return 19 | } 20 | err = sysconn.Control(fn) 21 | if err != nil { 22 | tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error()) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /tun/tun.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "os" 10 | ) 11 | 12 | type Event int 13 | 14 | const ( 15 | EventUp = 1 << iota 16 | EventDown 17 | EventMTUUpdate 18 | ) 19 | 20 | type Device interface { 21 | // File returns the file descriptor of the device. 22 | File() *os.File 23 | 24 | // Read one or more packets from the Device (without any additional headers). 25 | // On a successful read it returns the number of packets read, and sets 26 | // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). 27 | // A nonzero offset can be used to instruct the Device on where to begin 28 | // reading into each element of the bufs slice. 29 | Read(bufs [][]byte, sizes []int, offset int) (n int, err error) 30 | 31 | // Write one or more packets to the device (without any additional headers). 32 | // On a successful write it returns the number of packets written. A nonzero 33 | // offset can be used to instruct the Device on where to begin writing from 34 | // each packet contained within the bufs slice. 35 | Write(bufs [][]byte, offset int) (int, error) 36 | 37 | // MTU returns the MTU of the Device. 38 | MTU() (int, error) 39 | 40 | // Name returns the current name of the Device. 41 | Name() (string, error) 42 | 43 | // Events returns a channel of type Event, which is fed Device events. 44 | Events() <-chan Event 45 | 46 | // Close stops the Device and closes the Event channel. 47 | Close() error 48 | 49 | // BatchSize returns the preferred/max number of packets that can be read or 50 | // written in a single read/write call. BatchSize must not change over the 51 | // lifetime of a Device. 52 | BatchSize() int 53 | } 54 | -------------------------------------------------------------------------------- /tun/tun_darwin.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net" 13 | "os" 14 | "sync" 15 | "syscall" 16 | "time" 17 | "unsafe" 18 | 19 | "golang.org/x/sys/unix" 20 | ) 21 | 22 | const utunControlName = "com.apple.net.utun_control" 23 | 24 | type NativeTun struct { 25 | name string 26 | tunFile *os.File 27 | events chan Event 28 | errors chan error 29 | routeSocket int 30 | closeOnce sync.Once 31 | } 32 | 33 | func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { 34 | for i := 0; i < 20; i++ { 35 | iface, err = net.InterfaceByIndex(index) 36 | if err != nil && errors.Is(err, unix.ENOMEM) { 37 | time.Sleep(time.Duration(i) * time.Second / 3) 38 | continue 39 | } 40 | return iface, err 41 | } 42 | return nil, err 43 | } 44 | 45 | func (tun *NativeTun) routineRouteListener(tunIfindex int) { 46 | var ( 47 | statusUp bool 48 | statusMTU int 49 | ) 50 | 51 | defer close(tun.events) 52 | 53 | data := make([]byte, os.Getpagesize()) 54 | for { 55 | retry: 56 | n, err := unix.Read(tun.routeSocket, data) 57 | if err != nil { 58 | if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR { 59 | goto retry 60 | } 61 | tun.errors <- err 62 | return 63 | } 64 | 65 | if n < 14 { 66 | continue 67 | } 68 | 69 | if data[3 /* type */] != unix.RTM_IFINFO { 70 | continue 71 | } 72 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */]))) 73 | if ifindex != tunIfindex { 74 | continue 75 | } 76 | 77 | iface, err := retryInterfaceByIndex(ifindex) 78 | if err != nil { 79 | tun.errors <- err 80 | return 81 | } 82 | 83 | // Up / Down event 84 | up := (iface.Flags & net.FlagUp) != 0 85 | if up != statusUp && up { 86 | tun.events <- EventUp 87 | } 88 | if up != statusUp && !up { 89 | tun.events <- EventDown 90 | } 91 | statusUp = up 92 | 93 | // MTU changes 94 | if iface.MTU != statusMTU { 95 | tun.events <- EventMTUUpdate 96 | } 97 | statusMTU = iface.MTU 98 | } 99 | } 100 | 101 | func CreateTUN(name string, mtu int) (Device, error) { 102 | ifIndex := -1 103 | if name != "utun" { 104 | _, err := fmt.Sscanf(name, "utun%d", &ifIndex) 105 | if err != nil || ifIndex < 0 { 106 | return nil, fmt.Errorf("Interface name must be utun[0-9]*") 107 | } 108 | } 109 | 110 | fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) 111 | if err != nil { 112 | return nil, err 113 | } 114 | 115 | ctlInfo := &unix.CtlInfo{} 116 | copy(ctlInfo.Name[:], []byte(utunControlName)) 117 | err = unix.IoctlCtlInfo(fd, ctlInfo) 118 | if err != nil { 119 | unix.Close(fd) 120 | return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err) 121 | } 122 | 123 | sc := &unix.SockaddrCtl{ 124 | ID: ctlInfo.Id, 125 | Unit: uint32(ifIndex) + 1, 126 | } 127 | 128 | err = unix.Connect(fd, sc) 129 | if err != nil { 130 | unix.Close(fd) 131 | return nil, err 132 | } 133 | 134 | err = unix.SetNonblock(fd, true) 135 | if err != nil { 136 | unix.Close(fd) 137 | return nil, err 138 | } 139 | tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu) 140 | 141 | if err == nil && name == "utun" { 142 | fname := os.Getenv("WG_TUN_NAME_FILE") 143 | if fname != "" { 144 | os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) 145 | } 146 | } 147 | 148 | return tun, err 149 | } 150 | 151 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 152 | tun := &NativeTun{ 153 | tunFile: file, 154 | events: make(chan Event, 10), 155 | errors: make(chan error, 5), 156 | } 157 | 158 | name, err := tun.Name() 159 | if err != nil { 160 | tun.tunFile.Close() 161 | return nil, err 162 | } 163 | 164 | tunIfindex, err := func() (int, error) { 165 | iface, err := net.InterfaceByName(name) 166 | if err != nil { 167 | return -1, err 168 | } 169 | return iface.Index, nil 170 | }() 171 | if err != nil { 172 | tun.tunFile.Close() 173 | return nil, err 174 | } 175 | 176 | tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) 177 | if err != nil { 178 | tun.tunFile.Close() 179 | return nil, err 180 | } 181 | 182 | go tun.routineRouteListener(tunIfindex) 183 | 184 | if mtu > 0 { 185 | err = tun.setMTU(mtu) 186 | if err != nil { 187 | tun.Close() 188 | return nil, err 189 | } 190 | } 191 | 192 | return tun, nil 193 | } 194 | 195 | func (tun *NativeTun) Name() (string, error) { 196 | var err error 197 | tun.operateOnFd(func(fd uintptr) { 198 | tun.name, err = unix.GetsockoptString( 199 | int(fd), 200 | 2, /* #define SYSPROTO_CONTROL 2 */ 201 | 2, /* #define UTUN_OPT_IFNAME 2 */ 202 | ) 203 | }) 204 | 205 | if err != nil { 206 | return "", fmt.Errorf("GetSockoptString: %w", err) 207 | } 208 | 209 | return tun.name, nil 210 | } 211 | 212 | func (tun *NativeTun) File() *os.File { 213 | return tun.tunFile 214 | } 215 | 216 | func (tun *NativeTun) Events() <-chan Event { 217 | return tun.events 218 | } 219 | 220 | func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { 221 | // TODO: the BSDs look very similar in Read() and Write(). They should be 222 | // collapsed, with platform-specific files containing the varying parts of 223 | // their implementations. 224 | select { 225 | case err := <-tun.errors: 226 | return 0, err 227 | default: 228 | buf := bufs[0][offset-4:] 229 | n, err := tun.tunFile.Read(buf[:]) 230 | if n < 4 { 231 | return 0, err 232 | } 233 | sizes[0] = n - 4 234 | return 1, err 235 | } 236 | } 237 | 238 | func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { 239 | if offset < 4 { 240 | return 0, io.ErrShortBuffer 241 | } 242 | for i, buf := range bufs { 243 | buf = buf[offset-4:] 244 | buf[0] = 0x00 245 | buf[1] = 0x00 246 | buf[2] = 0x00 247 | switch buf[4] >> 4 { 248 | case 4: 249 | buf[3] = unix.AF_INET 250 | case 6: 251 | buf[3] = unix.AF_INET6 252 | default: 253 | return i, unix.EAFNOSUPPORT 254 | } 255 | if _, err := tun.tunFile.Write(buf); err != nil { 256 | return i, err 257 | } 258 | } 259 | return len(bufs), nil 260 | } 261 | 262 | func (tun *NativeTun) Close() error { 263 | var err1, err2 error 264 | tun.closeOnce.Do(func() { 265 | err1 = tun.tunFile.Close() 266 | if tun.routeSocket != -1 { 267 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) 268 | err2 = unix.Close(tun.routeSocket) 269 | } else if tun.events != nil { 270 | close(tun.events) 271 | } 272 | }) 273 | if err1 != nil { 274 | return err1 275 | } 276 | return err2 277 | } 278 | 279 | func (tun *NativeTun) setMTU(n int) error { 280 | fd, err := socketCloexec( 281 | unix.AF_INET, 282 | unix.SOCK_DGRAM, 283 | 0, 284 | ) 285 | if err != nil { 286 | return err 287 | } 288 | 289 | defer unix.Close(fd) 290 | 291 | var ifr unix.IfreqMTU 292 | copy(ifr.Name[:], tun.name) 293 | ifr.MTU = int32(n) 294 | err = unix.IoctlSetIfreqMTU(fd, &ifr) 295 | if err != nil { 296 | return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err) 297 | } 298 | 299 | return nil 300 | } 301 | 302 | func (tun *NativeTun) MTU() (int, error) { 303 | fd, err := socketCloexec( 304 | unix.AF_INET, 305 | unix.SOCK_DGRAM, 306 | 0, 307 | ) 308 | if err != nil { 309 | return 0, err 310 | } 311 | 312 | defer unix.Close(fd) 313 | 314 | ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name) 315 | if err != nil { 316 | return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err) 317 | } 318 | 319 | return int(ifr.MTU), nil 320 | } 321 | 322 | func (tun *NativeTun) BatchSize() int { 323 | return 1 324 | } 325 | 326 | func socketCloexec(family, sotype, proto int) (fd int, err error) { 327 | // See go/src/net/sys_cloexec.go for background. 328 | syscall.ForkLock.RLock() 329 | defer syscall.ForkLock.RUnlock() 330 | 331 | fd, err = unix.Socket(family, sotype, proto) 332 | if err == nil { 333 | unix.CloseOnExec(fd) 334 | } 335 | return 336 | } 337 | -------------------------------------------------------------------------------- /tun/tun_openbsd.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net" 13 | "os" 14 | "sync" 15 | "syscall" 16 | "unsafe" 17 | 18 | "golang.org/x/sys/unix" 19 | ) 20 | 21 | // Structure for iface mtu get/set ioctls 22 | type ifreq_mtu struct { 23 | Name [unix.IFNAMSIZ]byte 24 | MTU uint32 25 | Pad0 [12]byte 26 | } 27 | 28 | const _TUNSIFMODE = 0x8004745d 29 | 30 | type NativeTun struct { 31 | name string 32 | tunFile *os.File 33 | events chan Event 34 | errors chan error 35 | routeSocket int 36 | closeOnce sync.Once 37 | } 38 | 39 | func (tun *NativeTun) routineRouteListener(tunIfindex int) { 40 | var ( 41 | statusUp bool 42 | statusMTU int 43 | ) 44 | 45 | defer close(tun.events) 46 | 47 | check := func() bool { 48 | iface, err := net.InterfaceByIndex(tunIfindex) 49 | if err != nil { 50 | tun.errors <- err 51 | return true 52 | } 53 | 54 | // Up / Down event 55 | up := (iface.Flags & net.FlagUp) != 0 56 | if up != statusUp && up { 57 | tun.events <- EventUp 58 | } 59 | if up != statusUp && !up { 60 | tun.events <- EventDown 61 | } 62 | statusUp = up 63 | 64 | // MTU changes 65 | if iface.MTU != statusMTU { 66 | tun.events <- EventMTUUpdate 67 | } 68 | statusMTU = iface.MTU 69 | return false 70 | } 71 | 72 | if check() { 73 | return 74 | } 75 | 76 | data := make([]byte, os.Getpagesize()) 77 | for { 78 | n, err := unix.Read(tun.routeSocket, data) 79 | if err != nil { 80 | if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { 81 | continue 82 | } 83 | tun.errors <- err 84 | return 85 | } 86 | 87 | if n < 8 { 88 | continue 89 | } 90 | 91 | if data[3 /* type */] != unix.RTM_IFINFO { 92 | continue 93 | } 94 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */]))) 95 | if ifindex != tunIfindex { 96 | continue 97 | } 98 | if check() { 99 | return 100 | } 101 | } 102 | } 103 | 104 | func CreateTUN(name string, mtu int) (Device, error) { 105 | ifIndex := -1 106 | if name != "tun" { 107 | _, err := fmt.Sscanf(name, "tun%d", &ifIndex) 108 | if err != nil || ifIndex < 0 { 109 | return nil, fmt.Errorf("Interface name must be tun[0-9]*") 110 | } 111 | } 112 | 113 | var tunfile *os.File 114 | var err error 115 | 116 | if ifIndex != -1 { 117 | tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) 118 | } else { 119 | for ifIndex = 0; ifIndex < 256; ifIndex++ { 120 | tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) 121 | if err == nil || !errors.Is(err, syscall.EBUSY) { 122 | break 123 | } 124 | } 125 | } 126 | 127 | if err != nil { 128 | return nil, err 129 | } 130 | 131 | tun, err := CreateTUNFromFile(tunfile, mtu) 132 | 133 | if err == nil && name == "tun" { 134 | fname := os.Getenv("WG_TUN_NAME_FILE") 135 | if fname != "" { 136 | os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) 137 | } 138 | } 139 | 140 | return tun, err 141 | } 142 | 143 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 144 | tun := &NativeTun{ 145 | tunFile: file, 146 | events: make(chan Event, 10), 147 | errors: make(chan error, 1), 148 | } 149 | 150 | name, err := tun.Name() 151 | if err != nil { 152 | tun.tunFile.Close() 153 | return nil, err 154 | } 155 | 156 | tunIfindex, err := func() (int, error) { 157 | iface, err := net.InterfaceByName(name) 158 | if err != nil { 159 | return -1, err 160 | } 161 | return iface.Index, nil 162 | }() 163 | if err != nil { 164 | tun.tunFile.Close() 165 | return nil, err 166 | } 167 | 168 | tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC) 169 | if err != nil { 170 | tun.tunFile.Close() 171 | return nil, err 172 | } 173 | 174 | go tun.routineRouteListener(tunIfindex) 175 | 176 | currentMTU, err := tun.MTU() 177 | if err != nil || currentMTU != mtu { 178 | err = tun.setMTU(mtu) 179 | if err != nil { 180 | tun.Close() 181 | return nil, err 182 | } 183 | } 184 | 185 | return tun, nil 186 | } 187 | 188 | func (tun *NativeTun) Name() (string, error) { 189 | gostat, err := tun.tunFile.Stat() 190 | if err != nil { 191 | tun.name = "" 192 | return "", err 193 | } 194 | stat := gostat.Sys().(*syscall.Stat_t) 195 | tun.name = fmt.Sprintf("tun%d", stat.Rdev%256) 196 | return tun.name, nil 197 | } 198 | 199 | func (tun *NativeTun) File() *os.File { 200 | return tun.tunFile 201 | } 202 | 203 | func (tun *NativeTun) Events() <-chan Event { 204 | return tun.events 205 | } 206 | 207 | func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { 208 | select { 209 | case err := <-tun.errors: 210 | return 0, err 211 | default: 212 | buf := bufs[0][offset-4:] 213 | n, err := tun.tunFile.Read(buf[:]) 214 | if n < 4 { 215 | return 0, err 216 | } 217 | sizes[0] = n - 4 218 | return 1, err 219 | } 220 | } 221 | 222 | func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { 223 | if offset < 4 { 224 | return 0, io.ErrShortBuffer 225 | } 226 | for i, buf := range bufs { 227 | buf = buf[offset-4:] 228 | buf[0] = 0x00 229 | buf[1] = 0x00 230 | buf[2] = 0x00 231 | switch buf[4] >> 4 { 232 | case 4: 233 | buf[3] = unix.AF_INET 234 | case 6: 235 | buf[3] = unix.AF_INET6 236 | default: 237 | return i, unix.EAFNOSUPPORT 238 | } 239 | if _, err := tun.tunFile.Write(buf); err != nil { 240 | return i, err 241 | } 242 | } 243 | return len(bufs), nil 244 | } 245 | 246 | func (tun *NativeTun) Close() error { 247 | var err1, err2 error 248 | tun.closeOnce.Do(func() { 249 | err1 = tun.tunFile.Close() 250 | if tun.routeSocket != -1 { 251 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) 252 | err2 = unix.Close(tun.routeSocket) 253 | tun.routeSocket = -1 254 | } else if tun.events != nil { 255 | close(tun.events) 256 | } 257 | }) 258 | if err1 != nil { 259 | return err1 260 | } 261 | return err2 262 | } 263 | 264 | func (tun *NativeTun) setMTU(n int) error { 265 | // open datagram socket 266 | 267 | var fd int 268 | 269 | fd, err := unix.Socket( 270 | unix.AF_INET, 271 | unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 272 | 0, 273 | ) 274 | if err != nil { 275 | return err 276 | } 277 | 278 | defer unix.Close(fd) 279 | 280 | // do ioctl call 281 | 282 | var ifr ifreq_mtu 283 | copy(ifr.Name[:], tun.name) 284 | ifr.MTU = uint32(n) 285 | 286 | _, _, errno := unix.Syscall( 287 | unix.SYS_IOCTL, 288 | uintptr(fd), 289 | uintptr(unix.SIOCSIFMTU), 290 | uintptr(unsafe.Pointer(&ifr)), 291 | ) 292 | 293 | if errno != 0 { 294 | return fmt.Errorf("failed to set MTU on %s", tun.name) 295 | } 296 | 297 | return nil 298 | } 299 | 300 | func (tun *NativeTun) MTU() (int, error) { 301 | // open datagram socket 302 | 303 | fd, err := unix.Socket( 304 | unix.AF_INET, 305 | unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 306 | 0, 307 | ) 308 | if err != nil { 309 | return 0, err 310 | } 311 | 312 | defer unix.Close(fd) 313 | 314 | // do ioctl call 315 | var ifr ifreq_mtu 316 | copy(ifr.Name[:], tun.name) 317 | 318 | _, _, errno := unix.Syscall( 319 | unix.SYS_IOCTL, 320 | uintptr(fd), 321 | uintptr(unix.SIOCGIFMTU), 322 | uintptr(unsafe.Pointer(&ifr)), 323 | ) 324 | if errno != 0 { 325 | return 0, fmt.Errorf("failed to get MTU on %s", tun.name) 326 | } 327 | 328 | return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil 329 | } 330 | 331 | func (tun *NativeTun) BatchSize() int { 332 | return 1 333 | } 334 | -------------------------------------------------------------------------------- /tun/tun_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "os" 12 | "sync" 13 | "sync/atomic" 14 | "time" 15 | _ "unsafe" 16 | 17 | "golang.org/x/sys/windows" 18 | "golang.zx2c4.com/wintun" 19 | ) 20 | 21 | const ( 22 | rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 23 | spinloopRateThreshold = 800000000 / 8 // 800mbps 24 | spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 25 | ) 26 | 27 | type rateJuggler struct { 28 | current atomic.Uint64 29 | nextByteCount atomic.Uint64 30 | nextStartTime atomic.Int64 31 | changing atomic.Bool 32 | } 33 | 34 | type NativeTun struct { 35 | wt *wintun.Adapter 36 | name string 37 | handle windows.Handle 38 | rate rateJuggler 39 | session wintun.Session 40 | readWait windows.Handle 41 | events chan Event 42 | running sync.WaitGroup 43 | closeOnce sync.Once 44 | close atomic.Bool 45 | forcedMTU int 46 | outSizes []int 47 | } 48 | 49 | var ( 50 | WintunTunnelType = "WireGuard" 51 | WintunStaticRequestedGUID *windows.GUID 52 | ) 53 | 54 | //go:linkname procyield runtime.procyield 55 | func procyield(cycles uint32) 56 | 57 | //go:linkname nanotime runtime.nanotime 58 | func nanotime() int64 59 | 60 | // CreateTUN creates a Wintun interface with the given name. Should a Wintun 61 | // interface with the same name exist, it is reused. 62 | func CreateTUN(ifname string, mtu int) (Device, error) { 63 | return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) 64 | } 65 | 66 | // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and 67 | // a requested GUID. Should a Wintun interface with the same name exist, it is reused. 68 | func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { 69 | wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) 70 | if err != nil { 71 | return nil, fmt.Errorf("Error creating interface: %w", err) 72 | } 73 | 74 | forcedMTU := 1420 75 | if mtu > 0 { 76 | forcedMTU = mtu 77 | } 78 | 79 | tun := &NativeTun{ 80 | wt: wt, 81 | name: ifname, 82 | handle: windows.InvalidHandle, 83 | events: make(chan Event, 10), 84 | forcedMTU: forcedMTU, 85 | } 86 | 87 | tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB 88 | if err != nil { 89 | tun.wt.Close() 90 | close(tun.events) 91 | return nil, fmt.Errorf("Error starting session: %w", err) 92 | } 93 | tun.readWait = tun.session.ReadWaitEvent() 94 | return tun, nil 95 | } 96 | 97 | func (tun *NativeTun) Name() (string, error) { 98 | return tun.name, nil 99 | } 100 | 101 | func (tun *NativeTun) File() *os.File { 102 | return nil 103 | } 104 | 105 | func (tun *NativeTun) Events() <-chan Event { 106 | return tun.events 107 | } 108 | 109 | func (tun *NativeTun) Close() error { 110 | var err error 111 | tun.closeOnce.Do(func() { 112 | tun.close.Store(true) 113 | windows.SetEvent(tun.readWait) 114 | tun.running.Wait() 115 | tun.session.End() 116 | if tun.wt != nil { 117 | tun.wt.Close() 118 | } 119 | close(tun.events) 120 | }) 121 | return err 122 | } 123 | 124 | func (tun *NativeTun) MTU() (int, error) { 125 | return tun.forcedMTU, nil 126 | } 127 | 128 | // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. 129 | func (tun *NativeTun) ForceMTU(mtu int) { 130 | if tun.close.Load() { 131 | return 132 | } 133 | update := tun.forcedMTU != mtu 134 | tun.forcedMTU = mtu 135 | if update { 136 | tun.events <- EventMTUUpdate 137 | } 138 | } 139 | 140 | func (tun *NativeTun) BatchSize() int { 141 | // TODO: implement batching with wintun 142 | return 1 143 | } 144 | 145 | // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. 146 | 147 | func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { 148 | tun.running.Add(1) 149 | defer tun.running.Done() 150 | retry: 151 | if tun.close.Load() { 152 | return 0, os.ErrClosed 153 | } 154 | start := nanotime() 155 | shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 156 | for { 157 | if tun.close.Load() { 158 | return 0, os.ErrClosed 159 | } 160 | packet, err := tun.session.ReceivePacket() 161 | switch err { 162 | case nil: 163 | n := copy(bufs[0][offset:], packet) 164 | sizes[0] = n 165 | tun.session.ReleaseReceivePacket(packet) 166 | tun.rate.update(uint64(n)) 167 | return 1, nil 168 | case windows.ERROR_NO_MORE_ITEMS: 169 | if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 170 | windows.WaitForSingleObject(tun.readWait, windows.INFINITE) 171 | goto retry 172 | } 173 | procyield(1) 174 | continue 175 | case windows.ERROR_HANDLE_EOF: 176 | return 0, os.ErrClosed 177 | case windows.ERROR_INVALID_DATA: 178 | return 0, errors.New("Send ring corrupt") 179 | } 180 | return 0, fmt.Errorf("Read failed: %w", err) 181 | } 182 | } 183 | 184 | func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { 185 | tun.running.Add(1) 186 | defer tun.running.Done() 187 | if tun.close.Load() { 188 | return 0, os.ErrClosed 189 | } 190 | 191 | for i, buf := range bufs { 192 | packetSize := len(buf) - offset 193 | tun.rate.update(uint64(packetSize)) 194 | 195 | packet, err := tun.session.AllocateSendPacket(packetSize) 196 | switch err { 197 | case nil: 198 | // TODO: Explore options to eliminate this copy. 199 | copy(packet, buf[offset:]) 200 | tun.session.SendPacket(packet) 201 | continue 202 | case windows.ERROR_HANDLE_EOF: 203 | return i, os.ErrClosed 204 | case windows.ERROR_BUFFER_OVERFLOW: 205 | continue // Dropping when ring is full. 206 | default: 207 | return i, fmt.Errorf("Write failed: %w", err) 208 | } 209 | } 210 | return len(bufs), nil 211 | } 212 | 213 | // LUID returns Windows interface instance ID. 214 | func (tun *NativeTun) LUID() uint64 { 215 | tun.running.Add(1) 216 | defer tun.running.Done() 217 | if tun.close.Load() { 218 | return 0 219 | } 220 | return tun.wt.LUID() 221 | } 222 | 223 | // RunningVersion returns the running version of the Wintun driver. 224 | func (tun *NativeTun) RunningVersion() (version uint32, err error) { 225 | return wintun.RunningVersion() 226 | } 227 | 228 | func (rate *rateJuggler) update(packetLen uint64) { 229 | now := nanotime() 230 | total := rate.nextByteCount.Add(packetLen) 231 | period := uint64(now - rate.nextStartTime.Load()) 232 | if period >= rateMeasurementGranularity { 233 | if !rate.changing.CompareAndSwap(false, true) { 234 | return 235 | } 236 | rate.nextStartTime.Store(now) 237 | rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) 238 | rate.nextByteCount.Store(0) 239 | rate.changing.Store(false) 240 | } 241 | } 242 | -------------------------------------------------------------------------------- /tun/tuntest/tuntest.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tuntest 7 | 8 | import ( 9 | "encoding/binary" 10 | "io" 11 | "net/netip" 12 | "os" 13 | 14 | "github.com/amnezia-vpn/amneziawg-go/tun" 15 | ) 16 | 17 | func Ping(dst, src netip.Addr) []byte { 18 | localPort := uint16(1337) 19 | seq := uint16(0) 20 | 21 | payload := make([]byte, 4) 22 | binary.BigEndian.PutUint16(payload[0:], localPort) 23 | binary.BigEndian.PutUint16(payload[2:], seq) 24 | 25 | return genICMPv4(payload, dst, src) 26 | } 27 | 28 | // Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. 29 | func checksum(buf []byte, initial uint16) uint16 { 30 | v := uint32(initial) 31 | for i := 0; i < len(buf)-1; i += 2 { 32 | v += uint32(binary.BigEndian.Uint16(buf[i:])) 33 | } 34 | if len(buf)%2 == 1 { 35 | v += uint32(buf[len(buf)-1]) << 8 36 | } 37 | for v > 0xffff { 38 | v = (v >> 16) + (v & 0xffff) 39 | } 40 | return ^uint16(v) 41 | } 42 | 43 | func genICMPv4(payload []byte, dst, src netip.Addr) []byte { 44 | const ( 45 | icmpv4ProtocolNumber = 1 46 | icmpv4Echo = 8 47 | icmpv4ChecksumOffset = 2 48 | icmpv4Size = 8 49 | ipv4Size = 20 50 | ipv4TotalLenOffset = 2 51 | ipv4ChecksumOffset = 10 52 | ttl = 65 53 | headerSize = ipv4Size + icmpv4Size 54 | ) 55 | 56 | pkt := make([]byte, headerSize+len(payload)) 57 | 58 | ip := pkt[0:ipv4Size] 59 | icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size] 60 | 61 | // https://tools.ietf.org/html/rfc792 62 | icmpv4[0] = icmpv4Echo // type 63 | icmpv4[1] = 0 // code 64 | chksum := ^checksum(icmpv4, checksum(payload, 0)) 65 | binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) 66 | 67 | // https://tools.ietf.org/html/rfc760 section 3.1 68 | length := uint16(len(pkt)) 69 | ip[0] = (4 << 4) | (ipv4Size / 4) 70 | binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) 71 | ip[8] = ttl 72 | ip[9] = icmpv4ProtocolNumber 73 | copy(ip[12:], src.AsSlice()) 74 | copy(ip[16:], dst.AsSlice()) 75 | chksum = ^checksum(ip[:], 0) 76 | binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) 77 | 78 | copy(pkt[headerSize:], payload) 79 | return pkt 80 | } 81 | 82 | type ChannelTUN struct { 83 | Inbound chan []byte // incoming packets, closed on TUN close 84 | Outbound chan []byte // outbound packets, blocks forever on TUN close 85 | 86 | closed chan struct{} 87 | events chan tun.Event 88 | tun chTun 89 | } 90 | 91 | func NewChannelTUN() *ChannelTUN { 92 | c := &ChannelTUN{ 93 | Inbound: make(chan []byte), 94 | Outbound: make(chan []byte), 95 | closed: make(chan struct{}), 96 | events: make(chan tun.Event, 1), 97 | } 98 | c.tun.c = c 99 | c.events <- tun.EventUp 100 | return c 101 | } 102 | 103 | func (c *ChannelTUN) TUN() tun.Device { 104 | return &c.tun 105 | } 106 | 107 | type chTun struct { 108 | c *ChannelTUN 109 | } 110 | 111 | func (t *chTun) File() *os.File { return nil } 112 | 113 | func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { 114 | select { 115 | case <-t.c.closed: 116 | return 0, os.ErrClosed 117 | case msg := <-t.c.Outbound: 118 | n := copy(packets[0][offset:], msg) 119 | sizes[0] = n 120 | return 1, nil 121 | } 122 | } 123 | 124 | // Write is called by the wireguard device to deliver a packet for routing. 125 | func (t *chTun) Write(packets [][]byte, offset int) (int, error) { 126 | if offset == -1 { 127 | close(t.c.closed) 128 | close(t.c.events) 129 | return 0, io.EOF 130 | } 131 | for i, data := range packets { 132 | msg := make([]byte, len(data)-offset) 133 | copy(msg, data[offset:]) 134 | select { 135 | case <-t.c.closed: 136 | return i, os.ErrClosed 137 | case t.c.Inbound <- msg: 138 | } 139 | } 140 | return len(packets), nil 141 | } 142 | 143 | func (t *chTun) BatchSize() int { 144 | return 1 145 | } 146 | 147 | const DefaultMTU = 1420 148 | 149 | func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } 150 | func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } 151 | func (t *chTun) Events() <-chan tun.Event { return t.c.events } 152 | func (t *chTun) Close() error { 153 | t.Write(nil, -1) 154 | return nil 155 | } 156 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | const Version = "0.0.20230223" 4 | --------------------------------------------------------------------------------