├── .gitignore ├── version.go ├── CODEOWNERS ├── rwcancel ├── rwcancel_stub.go └── rwcancel.go ├── device ├── race_disabled_test.go ├── race_enabled_test.go ├── sticky_default.go ├── ip.go ├── queueconstants_android.go ├── queueconstants_windows.go ├── devicestate_string.go ├── queueconstants_default.go ├── mobilequirks.go ├── queueconstants_ios.go ├── endpoint_test.go ├── keypair.go ├── bind_test.go ├── constants.go ├── tun.go ├── logger.go ├── noise-types.go ├── pools_test.go ├── indextable.go ├── pools.go ├── kdf_test.go ├── noise-helpers.go ├── allowedips_rand_test.go ├── channels.go ├── noise_test.go ├── statemanager_test.go ├── cookie.go ├── sticky_linux.go ├── cookie_test.go ├── allowedips_test.go ├── peer.go ├── timers.go └── allowedips.go ├── conn ├── default.go ├── mark_default.go ├── boundif_android.go ├── mark_unix.go ├── server_name_utils │ ├── consistent_hash_test.go │ ├── consistent_hash.go │ └── server_name_utils.go ├── bindtest │ └── bindtest.go ├── conn.go ├── tcp_tls_utils.go └── bind_std.go ├── ipc ├── uapi_js.go ├── uapi_unix.go ├── uapi_windows.go ├── uapi_linux.go ├── uapi_bsd.go └── namedpipe │ └── file.go ├── tun ├── operateonfd.go ├── tun.go ├── netstack │ └── examples │ │ ├── http_client.go │ │ ├── http_server.go │ │ └── ping_client.go ├── alignment_windows_test.go ├── tuntest │ └── tuntest.go ├── tun_windows.go ├── tun_openbsd.go └── tun_darwin.go ├── .gitlab-ci.yml ├── Makefile ├── go.mod ├── tai64n ├── tai64n.go └── tai64n_test.go ├── LICENSE ├── format_test.go ├── replay ├── replay.go └── replay_test.go ├── main_windows.go ├── ratelimiter ├── ratelimiter_test.go └── ratelimiter.go ├── README.md ├── go.sum └── main.go /.gitignore: -------------------------------------------------------------------------------- 1 | wireguard-go 2 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | const Version = "0.0.20230223" 4 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # ownership: loose 2 | * @ProtonVPN/groups/android-developers 3 | -------------------------------------------------------------------------------- /rwcancel/rwcancel_stub.go: -------------------------------------------------------------------------------- 1 | //go:build windows || js 2 | 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rwcancel 6 | 7 | type RWCancel struct{} 8 | 9 | func (*RWCancel) Cancel() {} 10 | -------------------------------------------------------------------------------- /device/race_disabled_test.go: -------------------------------------------------------------------------------- 1 | //go:build !race 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const raceEnabled = false 11 | -------------------------------------------------------------------------------- /device/race_enabled_test.go: -------------------------------------------------------------------------------- 1 | //go:build race 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const raceEnabled = true 11 | -------------------------------------------------------------------------------- /conn/default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux && !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func NewDefaultBind() Bind { return NewStdNetBind() } 11 | -------------------------------------------------------------------------------- /conn/mark_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux && !openbsd && !freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func (bind *StdNetBind) SetMark(mark uint32) error { 11 | return nil 12 | } 13 | -------------------------------------------------------------------------------- /device/sticky_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | package device 4 | 5 | import ( 6 | "golang.zx2c4.com/wireguard/conn" 7 | "golang.zx2c4.com/wireguard/rwcancel" 8 | ) 9 | 10 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 11 | return nil, nil 12 | } 13 | -------------------------------------------------------------------------------- /ipc/uapi_js.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | // Made up sentinel error codes for the js/wasm platform. 9 | const ( 10 | IpcErrorIO = 1 11 | IpcErrorInvalid = 2 12 | IpcErrorPortInUse = 3 13 | IpcErrorUnknown = 4 14 | IpcErrorProtocol = 5 15 | ) 16 | -------------------------------------------------------------------------------- /device/ip.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "net" 10 | ) 11 | 12 | const ( 13 | IPv4offsetTotalLength = 2 14 | IPv4offsetSrc = 12 15 | IPv4offsetDst = IPv4offsetSrc + net.IPv4len 16 | ) 17 | 18 | const ( 19 | IPv6offsetPayloadLength = 4 20 | IPv6offsetSrc = 8 21 | IPv6offsetDst = IPv6offsetSrc + net.IPv6len 22 | ) 23 | -------------------------------------------------------------------------------- /device/queueconstants_android.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | /* Reduce memory consumption for Android */ 9 | 10 | const ( 11 | QueueStagedSize = 128 12 | QueueOutboundSize = 1024 13 | QueueInboundSize = 1024 14 | QueueHandshakeSize = 1024 15 | MaxSegmentSize = 2200 16 | PreallocatedBuffersPerPool = 4096 17 | ) 18 | -------------------------------------------------------------------------------- /device/queueconstants_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | const ( 9 | QueueStagedSize = 128 10 | QueueOutboundSize = 1024 11 | QueueInboundSize = 1024 12 | QueueHandshakeSize = 1024 13 | MaxSegmentSize = 2048 - 32 // largest possible UDP datagram 14 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth 15 | ) 16 | -------------------------------------------------------------------------------- /device/devicestate_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT. 2 | 3 | package device 4 | 5 | import "strconv" 6 | 7 | const _deviceState_name = "DownUpClosed" 8 | 9 | var _deviceState_index = [...]uint8{0, 4, 6, 12} 10 | 11 | func (i deviceState) String() string { 12 | if i >= deviceState(len(_deviceState_index)-1) { 13 | return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")" 14 | } 15 | return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]] 16 | } 17 | -------------------------------------------------------------------------------- /device/queueconstants_default.go: -------------------------------------------------------------------------------- 1 | //go:build !android && !ios && !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const ( 11 | QueueStagedSize = 128 12 | QueueOutboundSize = 1024 13 | QueueInboundSize = 1024 14 | QueueHandshakeSize = 1024 15 | MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram 16 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth 17 | ) 18 | -------------------------------------------------------------------------------- /tun/operateonfd.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package tun 9 | 10 | import ( 11 | "fmt" 12 | ) 13 | 14 | func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) { 15 | sysconn, err := tun.tunFile.SyscallConn() 16 | if err != nil { 17 | tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error()) 18 | return 19 | } 20 | err = sysconn.Control(fn) 21 | if err != nil { 22 | tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error()) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /device/mobilequirks.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | // DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created, 9 | // though it will try to deal with it, and race maybe, if called after. 10 | func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { 11 | device.net.brokenRoaming = true 12 | device.peers.RLock() 13 | for _, peer := range device.peers.keyMap { 14 | peer.Lock() 15 | peer.disableRoaming = peer.endpoint != nil 16 | peer.Unlock() 17 | } 18 | device.peers.RUnlock() 19 | } 20 | -------------------------------------------------------------------------------- /device/queueconstants_ios.go: -------------------------------------------------------------------------------- 1 | //go:build ios 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | // Fit within memory limits for iOS's Network Extension API, which has stricter requirements. 11 | // These are vars instead of consts, because heavier network extensions might want to reduce 12 | // them further. 13 | var ( 14 | QueueStagedSize = 128 15 | QueueOutboundSize = 1024 16 | QueueInboundSize = 1024 17 | QueueHandshakeSize = 1024 18 | PreallocatedBuffersPerPool uint32 = 1024 19 | ) 20 | 21 | const MaxSegmentSize = 1700 22 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | default: 2 | image: ${PROTON_CI_REGISTRY}/android-shared/docker-android/oci-ndk:v2.1.7 3 | before_script: 4 | - if [[ -f /load-env.sh ]]; then source /load-env.sh; fi 5 | 6 | sync-wireguard-go: 7 | tags: 8 | - shared-medium 9 | rules: 10 | - if: $CI_COMMIT_BRANCH == "release/android" 11 | when: manual 12 | - if: '$OPENSOURCE_GO' 13 | when: always 14 | - when: never 15 | allow_failure: true 16 | before_script: 17 | - !reference [ default, before_script ] 18 | - apt update && apt-get install -y connect-proxy 19 | script: 20 | - git clone "$CI_REPOSITORY_URL" --branch "$CI_COMMIT_BRANCH" _APP_CLONE; 21 | - cd _APP_CLONE 22 | - git remote add public git@github.com:ProtonVPN/wireguard-go.git 23 | - git push public "$CI_COMMIT_BRANCH" -f 24 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PREFIX ?= /usr 2 | DESTDIR ?= 3 | BINDIR ?= $(PREFIX)/bin 4 | export GO111MODULE := on 5 | 6 | all: generate-version-and-build 7 | 8 | MAKEFLAGS += --no-print-directory 9 | 10 | generate-version-and-build: 11 | @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ 12 | tag="$$(git describe --dirty 2>/dev/null)" && \ 13 | ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \ 14 | [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ 15 | echo "$$ver" > version.go && \ 16 | git update-index --assume-unchanged version.go || true 17 | @$(MAKE) wireguard-go 18 | 19 | wireguard-go: $(wildcard *.go) $(wildcard */*.go) 20 | go build -v -o "$@" 21 | 22 | install: wireguard-go 23 | @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go" 24 | 25 | test: 26 | go test ./... 27 | 28 | clean: 29 | rm -f wireguard-go 30 | 31 | .PHONY: all clean test install generate-version-and-build 32 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module golang.zx2c4.com/wireguard 2 | 3 | go 1.23.4 4 | 5 | require ( 6 | github.com/refraction-networking/utls v1.6.7 7 | github.com/stretchr/testify v1.8.0 8 | golang.org/x/crypto v0.21.0 9 | golang.org/x/net v0.23.0 10 | golang.org/x/sys v0.18.0 11 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 12 | gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 13 | ) 14 | 15 | require ( 16 | github.com/google/btree v1.0.1 // indirect 17 | golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect 18 | ) 19 | 20 | require ( 21 | github.com/andybalholm/brotli v1.0.6 // indirect 22 | github.com/cloudflare/circl v1.3.7 // indirect 23 | github.com/davecgh/go-spew v1.1.1 // indirect 24 | github.com/klauspost/compress v1.17.4 // indirect 25 | github.com/kr/pretty v0.3.1 // indirect 26 | github.com/pmezard/go-difflib v1.0.0 // indirect 27 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect 28 | gopkg.in/yaml.v3 v3.0.1 // indirect 29 | ) 30 | -------------------------------------------------------------------------------- /conn/boundif_android.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { 9 | sysconn, err := bind.ipv4.SyscallConn() 10 | if err != nil { 11 | return -1, err 12 | } 13 | err = sysconn.Control(func(f uintptr) { 14 | fd = int(f) 15 | }) 16 | if err != nil { 17 | return -1, err 18 | } 19 | return 20 | } 21 | 22 | func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { 23 | sysconn, err := bind.ipv6.SyscallConn() 24 | if err != nil { 25 | return -1, err 26 | } 27 | err = sysconn.Control(func(f uintptr) { 28 | fd = int(f) 29 | }) 30 | if err != nil { 31 | return -1, err 32 | } 33 | return 34 | } 35 | 36 | func (bind *StdNetBindTcp) PeekLookAtSocketFd4() (fd int, err error) { 37 | return -1, err 38 | } 39 | 40 | func (bind *StdNetBindTcp) PeekLookAtSocketFd6() (fd int, err error) { 41 | return -1, err 42 | } 43 | -------------------------------------------------------------------------------- /tun/tun.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "os" 10 | ) 11 | 12 | type Event int 13 | 14 | const ( 15 | EventUp = 1 << iota 16 | EventDown 17 | EventMTUUpdate 18 | ) 19 | 20 | type Device interface { 21 | File() *os.File // returns the file descriptor of the device 22 | Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) 23 | Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) 24 | Flush() error // flush all previous writes to the device 25 | MTU() (int, error) // returns the MTU of the device 26 | Name() (string, error) // fetches and returns the current name 27 | Events() <-chan Event // returns a constant channel of events related to the device 28 | Close() error // stops the device and closes the event channel 29 | } 30 | -------------------------------------------------------------------------------- /tai64n/tai64n.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tai64n 7 | 8 | import ( 9 | "bytes" 10 | "encoding/binary" 11 | "time" 12 | ) 13 | 14 | const ( 15 | TimestampSize = 12 16 | base = uint64(0x400000000000000a) 17 | whitenerMask = uint32(0x1000000 - 1) 18 | ) 19 | 20 | type Timestamp [TimestampSize]byte 21 | 22 | func stamp(t time.Time) Timestamp { 23 | var tai64n Timestamp 24 | secs := base + uint64(t.Unix()) 25 | nano := uint32(t.Nanosecond()) &^ whitenerMask 26 | binary.BigEndian.PutUint64(tai64n[:], secs) 27 | binary.BigEndian.PutUint32(tai64n[8:], nano) 28 | return tai64n 29 | } 30 | 31 | func Now() Timestamp { 32 | return stamp(time.Now()) 33 | } 34 | 35 | func (t1 Timestamp) After(t2 Timestamp) bool { 36 | return bytes.Compare(t1[:], t2[:]) > 0 37 | } 38 | 39 | func (t Timestamp) String() string { 40 | return time.Unix(int64(binary.BigEndian.Uint64(t[:8])-base), int64(binary.BigEndian.Uint32(t[8:12]))).String() 41 | } 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy of 2 | this software and associated documentation files (the "Software"), to deal in 3 | the Software without restriction, including without limitation the rights to 4 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 5 | of the Software, and to permit persons to whom the Software is furnished to do 6 | so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all 9 | copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. 18 | -------------------------------------------------------------------------------- /device/endpoint_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net/netip" 11 | ) 12 | 13 | type DummyEndpoint struct { 14 | src, dst netip.Addr 15 | } 16 | 17 | func CreateDummyEndpoint() (*DummyEndpoint, error) { 18 | var src, dst [16]byte 19 | if _, err := rand.Read(src[:]); err != nil { 20 | return nil, err 21 | } 22 | _, err := rand.Read(dst[:]) 23 | return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err 24 | } 25 | 26 | func (e *DummyEndpoint) ClearSrc() {} 27 | 28 | func (e *DummyEndpoint) SrcToString() string { 29 | return netip.AddrPortFrom(e.SrcIP(), 1000).String() 30 | } 31 | 32 | func (e *DummyEndpoint) DstToString() string { 33 | return netip.AddrPortFrom(e.DstIP(), 1000).String() 34 | } 35 | 36 | func (e *DummyEndpoint) DstToBytes() []byte { 37 | out := e.DstIP().AsSlice() 38 | out = append(out, byte(1000&0xff)) 39 | out = append(out, byte((1000>>8)&0xff)) 40 | return out 41 | } 42 | 43 | func (e *DummyEndpoint) DstIP() netip.Addr { 44 | return e.dst 45 | } 46 | 47 | func (e *DummyEndpoint) SrcIP() netip.Addr { 48 | return e.src 49 | } 50 | -------------------------------------------------------------------------------- /device/keypair.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/cipher" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | 14 | "golang.zx2c4.com/wireguard/replay" 15 | ) 16 | 17 | /* Due to limitations in Go and /x/crypto there is currently 18 | * no way to ensure that key material is securely ereased in memory. 19 | * 20 | * Since this may harm the forward secrecy property, 21 | * we plan to resolve this issue; whenever Go allows us to do so. 22 | */ 23 | 24 | type Keypair struct { 25 | sendNonce atomic.Uint64 26 | send cipher.AEAD 27 | receive cipher.AEAD 28 | replayFilter replay.Filter 29 | isInitiator bool 30 | created time.Time 31 | localIndex uint32 32 | remoteIndex uint32 33 | } 34 | 35 | type Keypairs struct { 36 | sync.RWMutex 37 | current *Keypair 38 | previous *Keypair 39 | next atomic.Pointer[Keypair] 40 | } 41 | 42 | func (kp *Keypairs) Current() *Keypair { 43 | kp.RLock() 44 | defer kp.RUnlock() 45 | return kp.current 46 | } 47 | 48 | func (device *Device) DeleteKeypair(key *Keypair) { 49 | if key != nil { 50 | device.indexTable.Delete(key.localIndex) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /format_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | package main 6 | 7 | import ( 8 | "bytes" 9 | "go/format" 10 | "io/fs" 11 | "os" 12 | "path/filepath" 13 | "runtime" 14 | "sync" 15 | "testing" 16 | ) 17 | 18 | func TestFormatting(t *testing.T) { 19 | var wg sync.WaitGroup 20 | filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { 21 | if err != nil { 22 | t.Errorf("unable to walk %s: %v", path, err) 23 | return nil 24 | } 25 | if d.IsDir() || filepath.Ext(path) != ".go" { 26 | return nil 27 | } 28 | wg.Add(1) 29 | go func(path string) { 30 | defer wg.Done() 31 | src, err := os.ReadFile(path) 32 | if err != nil { 33 | t.Errorf("unable to read %s: %v", path, err) 34 | return 35 | } 36 | if runtime.GOOS == "windows" { 37 | src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'}) 38 | } 39 | formatted, err := format.Source(src) 40 | if err != nil { 41 | t.Errorf("unable to format %s: %v", path, err) 42 | return 43 | } 44 | if !bytes.Equal(src, formatted) { 45 | t.Errorf("unformatted code: %s", path) 46 | } 47 | }(path) 48 | return nil 49 | }) 50 | wg.Wait() 51 | } 52 | -------------------------------------------------------------------------------- /device/bind_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "errors" 10 | 11 | "golang.zx2c4.com/wireguard/conn" 12 | ) 13 | 14 | type DummyDatagram struct { 15 | msg []byte 16 | endpoint conn.Endpoint 17 | } 18 | 19 | type DummyBind struct { 20 | in6 chan DummyDatagram 21 | in4 chan DummyDatagram 22 | closed bool 23 | } 24 | 25 | func (b *DummyBind) SetMark(v uint32) error { 26 | return nil 27 | } 28 | 29 | func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) { 30 | datagram, ok := <-b.in6 31 | if !ok { 32 | return 0, nil, errors.New("closed") 33 | } 34 | copy(buff, datagram.msg) 35 | return len(datagram.msg), datagram.endpoint, nil 36 | } 37 | 38 | func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) { 39 | datagram, ok := <-b.in4 40 | if !ok { 41 | return 0, nil, errors.New("closed") 42 | } 43 | copy(buff, datagram.msg) 44 | return len(datagram.msg), datagram.endpoint, nil 45 | } 46 | 47 | func (b *DummyBind) Close() error { 48 | close(b.in6) 49 | close(b.in4) 50 | b.closed = true 51 | return nil 52 | } 53 | 54 | func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error { 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /tai64n/tai64n_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tai64n 7 | 8 | import ( 9 | "testing" 10 | "time" 11 | ) 12 | 13 | // Test that timestamps are monotonic as required by Wireguard and that 14 | // nanosecond-level information is whitened to prevent side channel attacks. 15 | func TestMonotonic(t *testing.T) { 16 | startTime := time.Unix(0, 123456789) // a nontrivial bit pattern 17 | // Whitening should reduce timestamp granularity 18 | // to more than 10 but fewer than 20 milliseconds. 19 | tests := []struct { 20 | name string 21 | t1, t2 time.Time 22 | wantAfter bool 23 | }{ 24 | {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false}, 25 | {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false}, 26 | {"after_1_ms", startTime, startTime.Add(time.Millisecond), false}, 27 | {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false}, 28 | {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true}, 29 | } 30 | 31 | for _, tt := range tests { 32 | t.Run(tt.name, func(t *testing.T) { 33 | ts1, ts2 := stamp(tt.t1), stamp(tt.t2) 34 | got := ts2.After(ts1) 35 | if got != tt.wantAfter { 36 | t.Errorf("after = %v; want %v", got, tt.wantAfter) 37 | } 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /device/constants.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "time" 10 | ) 11 | 12 | /* Specification constants */ 13 | 14 | const ( 15 | RekeyAfterMessages = (1 << 60) 16 | RejectAfterMessages = (1 << 64) - (1 << 13) - 1 17 | RekeyAfterTime = time.Second * 120 18 | RekeyAttemptTime = time.Second * 90 19 | RekeyTimeout = time.Second * 5 20 | MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */ 21 | RekeyTimeoutJitterMaxMs = 334 22 | RejectAfterTime = time.Second * 180 23 | KeepaliveTimeout = time.Second * 10 24 | CookieRefreshTime = time.Second * 120 25 | HandshakeInitationRate = time.Second / 50 26 | PaddingMultiple = 16 27 | ) 28 | 29 | const ( 30 | MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) 31 | MaxMessageSize = MaxSegmentSize // maximum size of transport message 32 | MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content 33 | ) 34 | 35 | /* Implementation constants */ 36 | 37 | const ( 38 | UnderLoadAfterTime = time.Second // how long does the device remain under load after detected 39 | MaxPeers = 1 << 16 // maximum number of configured peers 40 | ) 41 | -------------------------------------------------------------------------------- /device/tun.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "fmt" 10 | 11 | "golang.zx2c4.com/wireguard/tun" 12 | ) 13 | 14 | const DefaultMTU = 1420 15 | 16 | func (device *Device) RoutineTUNEventReader() { 17 | device.log.Verbosef("Routine: event worker - started") 18 | 19 | for event := range device.tun.device.Events() { 20 | if event&tun.EventMTUUpdate != 0 { 21 | mtu, err := device.tun.device.MTU() 22 | if err != nil { 23 | device.log.Errorf("Failed to load updated MTU of device: %v", err) 24 | continue 25 | } 26 | if mtu < 0 { 27 | device.log.Errorf("MTU not updated to negative value: %v", mtu) 28 | continue 29 | } 30 | var tooLarge string 31 | if mtu > MaxContentSize { 32 | tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) 33 | mtu = MaxContentSize 34 | } 35 | old := device.tun.mtu.Swap(int32(mtu)) 36 | if int(old) != mtu { 37 | device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) 38 | } 39 | } 40 | 41 | if event&tun.EventUp != 0 { 42 | device.log.Verbosef("Interface up requested") 43 | device.Up() 44 | } 45 | 46 | if event&tun.EventDown != 0 { 47 | device.log.Verbosef("Interface down requested") 48 | device.Down() 49 | } 50 | } 51 | 52 | device.log.Verbosef("Routine: event worker - stopped") 53 | } 54 | -------------------------------------------------------------------------------- /tun/netstack/examples/http_client.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | /* SPDX-License-Identifier: MIT 5 | * 6 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 7 | */ 8 | 9 | package main 10 | 11 | import ( 12 | "io" 13 | "log" 14 | "net/http" 15 | "net/netip" 16 | 17 | "golang.zx2c4.com/wireguard/conn" 18 | "golang.zx2c4.com/wireguard/device" 19 | "golang.zx2c4.com/wireguard/tun/netstack" 20 | ) 21 | 22 | func main() { 23 | tun, tnet, err := netstack.CreateNetTUN( 24 | []netip.Addr{netip.MustParseAddr("192.168.4.28")}, 25 | []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 26 | 1420) 27 | if err != nil { 28 | log.Panic(err) 29 | } 30 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 31 | err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379 32 | public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28 33 | allowed_ip=0.0.0.0/0 34 | endpoint=127.0.0.1:58120 35 | `) 36 | err = dev.Up() 37 | if err != nil { 38 | log.Panic(err) 39 | } 40 | 41 | client := http.Client{ 42 | Transport: &http.Transport{ 43 | DialContext: tnet.DialContext, 44 | }, 45 | } 46 | resp, err := client.Get("http://192.168.4.29/") 47 | if err != nil { 48 | log.Panic(err) 49 | } 50 | body, err := io.ReadAll(resp.Body) 51 | if err != nil { 52 | log.Panic(err) 53 | } 54 | log.Println(string(body)) 55 | } 56 | -------------------------------------------------------------------------------- /device/logger.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "golang.zx2c4.com/wireguard/conn" 10 | "log" 11 | "os" 12 | ) 13 | 14 | // A Logger provides logging for a Device. 15 | // The functions are Printf-style functions. 16 | // They must be safe for concurrent use. 17 | // They do not require a trailing newline in the format. 18 | // If nil, that level of logging will be silent. 19 | type Logger struct { 20 | conn.Logger 21 | } 22 | 23 | // Log levels for use with NewLogger. 24 | const ( 25 | LogLevelSilent = iota 26 | LogLevelError 27 | LogLevelVerbose 28 | ) 29 | 30 | // Function for use in Logger for discarding logged lines. 31 | func DiscardLogf(format string, args ...any) {} 32 | 33 | // NewLogger constructs a Logger that writes to stdout. 34 | // It logs at the specified log level and above. 35 | // It decorates log lines with the log level, date, time, and prepend. 36 | func NewLogger(level int, prepend string) *Logger { 37 | logger := &Logger{conn.Logger{Verbosef: DiscardLogf, Errorf: DiscardLogf}} 38 | logf := func(prefix string) func(string, ...any) { 39 | return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf 40 | } 41 | if level >= LogLevelVerbose { 42 | logger.Verbosef = logf("DEBUG") 43 | } 44 | if level >= LogLevelError { 45 | logger.Errorf = logf("ERROR") 46 | } 47 | return logger 48 | } 49 | -------------------------------------------------------------------------------- /conn/mark_unix.go: -------------------------------------------------------------------------------- 1 | //go:build linux || openbsd || freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "runtime" 12 | 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | var fwmarkIoctl int 17 | 18 | func init() { 19 | switch runtime.GOOS { 20 | case "linux", "android": 21 | fwmarkIoctl = 36 /* unix.SO_MARK */ 22 | case "freebsd": 23 | fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ 24 | case "openbsd": 25 | fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ 26 | } 27 | } 28 | 29 | func (bind *StdNetBind) SetMark(mark uint32) error { 30 | var operr error 31 | if fwmarkIoctl == 0 { 32 | return nil 33 | } 34 | if bind.ipv4 != nil { 35 | fd, err := bind.ipv4.SyscallConn() 36 | if err != nil { 37 | return err 38 | } 39 | err = fd.Control(func(fd uintptr) { 40 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) 41 | }) 42 | if err == nil { 43 | err = operr 44 | } 45 | if err != nil { 46 | return err 47 | } 48 | } 49 | if bind.ipv6 != nil { 50 | fd, err := bind.ipv6.SyscallConn() 51 | if err != nil { 52 | return err 53 | } 54 | err = fd.Control(func(fd uintptr) { 55 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) 56 | }) 57 | if err == nil { 58 | err = operr 59 | } 60 | if err != nil { 61 | return err 62 | } 63 | } 64 | return nil 65 | } 66 | -------------------------------------------------------------------------------- /tun/netstack/examples/http_server.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | /* SPDX-License-Identifier: MIT 5 | * 6 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 7 | */ 8 | 9 | package main 10 | 11 | import ( 12 | "io" 13 | "log" 14 | "net" 15 | "net/http" 16 | "net/netip" 17 | 18 | "golang.zx2c4.com/wireguard/conn" 19 | "golang.zx2c4.com/wireguard/device" 20 | "golang.zx2c4.com/wireguard/tun/netstack" 21 | ) 22 | 23 | func main() { 24 | tun, tnet, err := netstack.CreateNetTUN( 25 | []netip.Addr{netip.MustParseAddr("192.168.4.29")}, 26 | []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, 27 | 1420, 28 | ) 29 | if err != nil { 30 | log.Panic(err) 31 | } 32 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 33 | dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641 34 | listen_port=58120 35 | public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c 36 | allowed_ip=192.168.4.28/32 37 | persistent_keepalive_interval=25 38 | `) 39 | dev.Up() 40 | listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80}) 41 | if err != nil { 42 | log.Panicln(err) 43 | } 44 | http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { 45 | log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent()) 46 | io.WriteString(writer, "Hello from userspace TCP!") 47 | }) 48 | err = http.Serve(listener, nil) 49 | if err != nil { 50 | log.Panicln(err) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /ipc/uapi_unix.go: -------------------------------------------------------------------------------- 1 | //go:build linux || darwin || freebsd || openbsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package ipc 9 | 10 | import ( 11 | "errors" 12 | "fmt" 13 | "net" 14 | "os" 15 | 16 | "golang.org/x/sys/unix" 17 | ) 18 | 19 | const ( 20 | IpcErrorIO = -int64(unix.EIO) 21 | IpcErrorProtocol = -int64(unix.EPROTO) 22 | IpcErrorInvalid = -int64(unix.EINVAL) 23 | IpcErrorPortInUse = -int64(unix.EADDRINUSE) 24 | IpcErrorUnknown = -55 // ENOANO 25 | ) 26 | 27 | // socketDirectory is variable because it is modified by a linker 28 | // flag in wireguard-android. 29 | var socketDirectory = "/var/run/wireguard" 30 | 31 | func sockPath(iface string) string { 32 | return fmt.Sprintf("%s/%s.sock", socketDirectory, iface) 33 | } 34 | 35 | func UAPIOpen(name string) (*os.File, error) { 36 | if err := os.MkdirAll(socketDirectory, 0o755); err != nil { 37 | return nil, err 38 | } 39 | 40 | socketPath := sockPath(name) 41 | addr, err := net.ResolveUnixAddr("unix", socketPath) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | oldUmask := unix.Umask(0o077) 47 | defer unix.Umask(oldUmask) 48 | 49 | listener, err := net.ListenUnix("unix", addr) 50 | if err == nil { 51 | return listener.File() 52 | } 53 | 54 | // Test socket, if not in use cleanup and try again. 55 | if _, err := net.Dial("unix", socketPath); err == nil { 56 | return nil, errors.New("unix socket in use") 57 | } 58 | if err := os.Remove(socketPath); err != nil { 59 | return nil, err 60 | } 61 | listener, err = net.ListenUnix("unix", addr) 62 | if err != nil { 63 | return nil, err 64 | } 65 | return listener.File() 66 | } 67 | -------------------------------------------------------------------------------- /conn/server_name_utils/consistent_hash_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025. Proton AG 3 | * 4 | * This file is part of ProtonVPN. 5 | * 6 | * ProtonVPN is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * ProtonVPN is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with ProtonVPN. If not, see . 18 | */ 19 | 20 | package server_name_utils 21 | 22 | import ( 23 | "math" 24 | "strconv" 25 | "testing" 26 | 27 | "github.com/stretchr/testify/assert" 28 | ) 29 | 30 | var maxMinus5 = uint32(math.MaxUint32 - 5) 31 | 32 | func testHash(value string) uint32 { 33 | switch value { 34 | case "max-5": 35 | return maxMinus5 36 | default: 37 | hash, _ := strconv.ParseUint(value, 10, 32) 38 | return uint32(hash) 39 | } 40 | } 41 | 42 | func TestConsistentHash(t *testing.T) { 43 | assert := assert.New(t) 44 | 45 | values := []string{"70", "max-5", "10"} 46 | hashedValues := sortValuesByHash(values, testHash) 47 | assert.Equal([]HashedValue{{"10", 10}, {"70", 70}, {"max-5", maxMinus5}}, hashedValues) 48 | assert.Equal("70", findClosestValue("68", hashedValues, testHash)) 49 | assert.Equal("70", findClosestValue("72", hashedValues, testHash)) 50 | assert.Equal("max-5", findClosestValue("2", hashedValues, testHash)) 51 | assert.Equal("10", findClosestValue("6", hashedValues, testHash)) 52 | } -------------------------------------------------------------------------------- /device/noise-types.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/subtle" 10 | "encoding/hex" 11 | "errors" 12 | ) 13 | 14 | const ( 15 | NoisePublicKeySize = 32 16 | NoisePrivateKeySize = 32 17 | NoisePresharedKeySize = 32 18 | ) 19 | 20 | type ( 21 | NoisePublicKey [NoisePublicKeySize]byte 22 | NoisePrivateKey [NoisePrivateKeySize]byte 23 | NoisePresharedKey [NoisePresharedKeySize]byte 24 | NoiseNonce uint64 // padded to 12-bytes 25 | ) 26 | 27 | func loadExactHex(dst []byte, src string) error { 28 | slice, err := hex.DecodeString(src) 29 | if err != nil { 30 | return err 31 | } 32 | if len(slice) != len(dst) { 33 | return errors.New("hex string does not fit the slice") 34 | } 35 | copy(dst, slice) 36 | return nil 37 | } 38 | 39 | func (key NoisePrivateKey) IsZero() bool { 40 | var zero NoisePrivateKey 41 | return key.Equals(zero) 42 | } 43 | 44 | func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { 45 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 46 | } 47 | 48 | func (key *NoisePrivateKey) FromHex(src string) (err error) { 49 | err = loadExactHex(key[:], src) 50 | key.clamp() 51 | return 52 | } 53 | 54 | func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { 55 | err = loadExactHex(key[:], src) 56 | if key.IsZero() { 57 | return 58 | } 59 | key.clamp() 60 | return 61 | } 62 | 63 | func (key *NoisePublicKey) FromHex(src string) error { 64 | return loadExactHex(key[:], src) 65 | } 66 | 67 | func (key NoisePublicKey) IsZero() bool { 68 | var zero NoisePublicKey 69 | return key.Equals(zero) 70 | } 71 | 72 | func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { 73 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 74 | } 75 | 76 | func (key *NoisePresharedKey) FromHex(src string) error { 77 | return loadExactHex(key[:], src) 78 | } 79 | -------------------------------------------------------------------------------- /replay/replay.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. 7 | package replay 8 | 9 | type block uint64 10 | 11 | const ( 12 | blockBitLog = 6 // 1<<6 == 64 bits 13 | blockBits = 1 << blockBitLog // must be power of 2 14 | ringBlocks = 1 << 7 // must be power of 2 15 | windowSize = (ringBlocks - 1) * blockBits 16 | blockMask = ringBlocks - 1 17 | bitMask = blockBits - 1 18 | ) 19 | 20 | // A Filter rejects replayed messages by checking if message counter value is 21 | // within a sliding window of previously received messages. 22 | // The zero value for Filter is an empty filter ready to use. 23 | // Filters are unsafe for concurrent use. 24 | type Filter struct { 25 | last uint64 26 | ring [ringBlocks]block 27 | } 28 | 29 | // Reset resets the filter to empty state. 30 | func (f *Filter) Reset() { 31 | f.last = 0 32 | f.ring[0] = 0 33 | } 34 | 35 | // ValidateCounter checks if the counter should be accepted. 36 | // Overlimit counters (>= limit) are always rejected. 37 | func (f *Filter) ValidateCounter(counter, limit uint64) bool { 38 | if counter >= limit { 39 | return false 40 | } 41 | indexBlock := counter >> blockBitLog 42 | if counter > f.last { // move window forward 43 | current := f.last >> blockBitLog 44 | diff := indexBlock - current 45 | if diff > ringBlocks { 46 | diff = ringBlocks // cap diff to clear the whole ring 47 | } 48 | for i := current + 1; i <= current+diff; i++ { 49 | f.ring[i&blockMask] = 0 50 | } 51 | f.last = counter 52 | } else if f.last-counter > windowSize { // behind current window 53 | return false 54 | } 55 | // check and set bit 56 | indexBlock &= blockMask 57 | indexBit := counter & bitMask 58 | old := f.ring[indexBlock] 59 | new := old | 1< p.max { 38 | t.Errorf("count (%d) > max (%d)", count, p.max) 39 | } 40 | for { 41 | old := max.Load() 42 | if count <= old { 43 | break 44 | } 45 | if max.CompareAndSwap(old, count) { 46 | break 47 | } 48 | } 49 | } 50 | for i := 0; i < workers; i++ { 51 | go func() { 52 | defer wg.Done() 53 | for trials.Add(-1) > 0 { 54 | updateMax() 55 | x := p.Get() 56 | updateMax() 57 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 58 | updateMax() 59 | p.Put(x) 60 | updateMax() 61 | } 62 | }() 63 | } 64 | wg.Wait() 65 | if max.Load() != p.max { 66 | t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) 67 | } 68 | } 69 | 70 | func BenchmarkWaitPool(b *testing.B) { 71 | var wg sync.WaitGroup 72 | var trials atomic.Int32 73 | trials.Store(int32(b.N)) 74 | workers := runtime.NumCPU() + 2 75 | if workers-4 <= 0 { 76 | b.Skip("Not enough cores") 77 | } 78 | p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) 79 | wg.Add(workers) 80 | b.ResetTimer() 81 | for i := 0; i < workers; i++ { 82 | go func() { 83 | defer wg.Done() 84 | for trials.Add(-1) > 0 { 85 | x := p.Get() 86 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 87 | p.Put(x) 88 | } 89 | }() 90 | } 91 | wg.Wait() 92 | } 93 | -------------------------------------------------------------------------------- /device/indextable.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/rand" 10 | "encoding/binary" 11 | "sync" 12 | ) 13 | 14 | type IndexTableEntry struct { 15 | peer *Peer 16 | handshake *Handshake 17 | keypair *Keypair 18 | } 19 | 20 | type IndexTable struct { 21 | sync.RWMutex 22 | table map[uint32]IndexTableEntry 23 | } 24 | 25 | func randUint32() (uint32, error) { 26 | var integer [4]byte 27 | _, err := rand.Read(integer[:]) 28 | // Arbitrary endianness; both are intrinsified by the Go compiler. 29 | return binary.LittleEndian.Uint32(integer[:]), err 30 | } 31 | 32 | func (table *IndexTable) Init() { 33 | table.Lock() 34 | defer table.Unlock() 35 | table.table = make(map[uint32]IndexTableEntry) 36 | } 37 | 38 | func (table *IndexTable) Delete(index uint32) { 39 | table.Lock() 40 | defer table.Unlock() 41 | delete(table.table, index) 42 | } 43 | 44 | func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { 45 | table.Lock() 46 | defer table.Unlock() 47 | entry, ok := table.table[index] 48 | if !ok { 49 | return 50 | } 51 | table.table[index] = IndexTableEntry{ 52 | peer: entry.peer, 53 | keypair: keypair, 54 | handshake: nil, 55 | } 56 | } 57 | 58 | func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) { 59 | for { 60 | // generate random index 61 | 62 | index, err := randUint32() 63 | if err != nil { 64 | return index, err 65 | } 66 | 67 | // check if index used 68 | 69 | table.RLock() 70 | _, ok := table.table[index] 71 | table.RUnlock() 72 | if ok { 73 | continue 74 | } 75 | 76 | // check again while locked 77 | 78 | table.Lock() 79 | _, found := table.table[index] 80 | if found { 81 | table.Unlock() 82 | continue 83 | } 84 | table.table[index] = IndexTableEntry{ 85 | peer: peer, 86 | handshake: handshake, 87 | keypair: nil, 88 | } 89 | table.Unlock() 90 | return index, nil 91 | } 92 | } 93 | 94 | func (table *IndexTable) Lookup(id uint32) IndexTableEntry { 95 | table.RLock() 96 | defer table.RUnlock() 97 | return table.table[id] 98 | } 99 | -------------------------------------------------------------------------------- /tun/alignment_windows_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "reflect" 10 | "testing" 11 | "unsafe" 12 | ) 13 | 14 | func checkAlignment(t *testing.T, name string, offset uintptr) { 15 | t.Helper() 16 | if offset%8 != 0 { 17 | t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) 18 | } 19 | } 20 | 21 | // TestRateJugglerAlignment checks that atomically-accessed fields are 22 | // aligned to 64-bit boundaries, as required by the atomic package. 23 | // 24 | // Unfortunately, violating this rule on 32-bit platforms results in a 25 | // hard segfault at runtime. 26 | func TestRateJugglerAlignment(t *testing.T) { 27 | var r rateJuggler 28 | 29 | typ := reflect.TypeOf(&r).Elem() 30 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 31 | for i := 0; i < typ.NumField(); i++ { 32 | field := typ.Field(i) 33 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 34 | field.Name, 35 | field.Offset, 36 | field.Type.Size(), 37 | field.Type.Align(), 38 | ) 39 | } 40 | 41 | checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current)) 42 | checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount)) 43 | checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime)) 44 | } 45 | 46 | // TestNativeTunAlignment checks that atomically-accessed fields are 47 | // aligned to 64-bit boundaries, as required by the atomic package. 48 | // 49 | // Unfortunately, violating this rule on 32-bit platforms results in a 50 | // hard segfault at runtime. 51 | func TestNativeTunAlignment(t *testing.T) { 52 | var tun NativeTun 53 | 54 | typ := reflect.TypeOf(&tun).Elem() 55 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 56 | for i := 0; i < typ.NumField(); i++ { 57 | field := typ.Field(i) 58 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 59 | field.Name, 60 | field.Offset, 61 | field.Type.Size(), 62 | field.Type.Align(), 63 | ) 64 | } 65 | 66 | checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate)) 67 | } 68 | -------------------------------------------------------------------------------- /device/pools.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "sync" 10 | ) 11 | 12 | type WaitPool struct { 13 | pool sync.Pool 14 | cond sync.Cond 15 | lock sync.Mutex 16 | count uint32 // Get calls not yet Put back 17 | max uint32 18 | } 19 | 20 | func NewWaitPool(max uint32, new func() any) *WaitPool { 21 | p := &WaitPool{pool: sync.Pool{New: new}, max: max} 22 | p.cond = sync.Cond{L: &p.lock} 23 | return p 24 | } 25 | 26 | func (p *WaitPool) Get() any { 27 | if p.max != 0 { 28 | p.lock.Lock() 29 | for p.count >= p.max { 30 | p.cond.Wait() 31 | } 32 | p.count++ 33 | p.lock.Unlock() 34 | } 35 | return p.pool.Get() 36 | } 37 | 38 | func (p *WaitPool) Put(x any) { 39 | p.pool.Put(x) 40 | if p.max == 0 { 41 | return 42 | } 43 | p.lock.Lock() 44 | defer p.lock.Unlock() 45 | p.count-- 46 | p.cond.Signal() 47 | } 48 | 49 | func (device *Device) PopulatePools() { 50 | device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { 51 | return new([MaxMessageSize]byte) 52 | }) 53 | device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 54 | return new(QueueInboundElement) 55 | }) 56 | device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 57 | return new(QueueOutboundElement) 58 | }) 59 | } 60 | 61 | func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { 62 | return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) 63 | } 64 | 65 | func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { 66 | device.pool.messageBuffers.Put(msg) 67 | } 68 | 69 | func (device *Device) GetInboundElement() *QueueInboundElement { 70 | return device.pool.inboundElements.Get().(*QueueInboundElement) 71 | } 72 | 73 | func (device *Device) PutInboundElement(elem *QueueInboundElement) { 74 | elem.clearPointers() 75 | device.pool.inboundElements.Put(elem) 76 | } 77 | 78 | func (device *Device) GetOutboundElement() *QueueOutboundElement { 79 | return device.pool.outboundElements.Get().(*QueueOutboundElement) 80 | } 81 | 82 | func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { 83 | elem.clearPointers() 84 | device.pool.outboundElements.Put(elem) 85 | } 86 | -------------------------------------------------------------------------------- /main_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package main 7 | 8 | import ( 9 | "fmt" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | 14 | "golang.zx2c4.com/wireguard/conn" 15 | "golang.zx2c4.com/wireguard/device" 16 | "golang.zx2c4.com/wireguard/ipc" 17 | 18 | "golang.zx2c4.com/wireguard/tun" 19 | ) 20 | 21 | const ( 22 | ExitSetupSuccess = 0 23 | ExitSetupFailed = 1 24 | ) 25 | 26 | func main() { 27 | if len(os.Args) != 2 { 28 | os.Exit(ExitSetupFailed) 29 | } 30 | interfaceName := os.Args[1] 31 | 32 | fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is , which includes this code as a module.") 33 | 34 | logger := device.NewLogger( 35 | device.LogLevelVerbose, 36 | fmt.Sprintf("(%s) ", interfaceName), 37 | ) 38 | logger.Verbosef("Starting wireguard-go version %s", Version) 39 | 40 | tun, err := tun.CreateTUN(interfaceName, 0) 41 | if err == nil { 42 | realInterfaceName, err2 := tun.Name() 43 | if err2 == nil { 44 | interfaceName = realInterfaceName 45 | } 46 | } else { 47 | logger.Errorf("Failed to create TUN device: %v", err) 48 | os.Exit(ExitSetupFailed) 49 | } 50 | 51 | device := device.NewDevice(tun, conn.NewDefaultBind(), logger) 52 | err = device.Up() 53 | if err != nil { 54 | logger.Errorf("Failed to bring up device: %v", err) 55 | os.Exit(ExitSetupFailed) 56 | } 57 | logger.Verbosef("Device started") 58 | 59 | uapi, err := ipc.UAPIListen(interfaceName) 60 | if err != nil { 61 | logger.Errorf("Failed to listen on uapi socket: %v", err) 62 | os.Exit(ExitSetupFailed) 63 | } 64 | 65 | errs := make(chan error) 66 | term := make(chan os.Signal, 1) 67 | 68 | go func() { 69 | for { 70 | conn, err := uapi.Accept() 71 | if err != nil { 72 | errs <- err 73 | return 74 | } 75 | go device.IpcHandle(conn) 76 | } 77 | }() 78 | logger.Verbosef("UAPI listener started") 79 | 80 | // wait for program to terminate 81 | 82 | signal.Notify(term, os.Interrupt) 83 | signal.Notify(term, os.Kill) 84 | signal.Notify(term, syscall.SIGTERM) 85 | 86 | select { 87 | case <-term: 88 | case <-errs: 89 | case <-device.Wait(): 90 | } 91 | 92 | // clean up 93 | 94 | uapi.Close() 95 | device.Close() 96 | 97 | logger.Verbosef("Shutting down") 98 | } 99 | -------------------------------------------------------------------------------- /device/kdf_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "encoding/hex" 10 | "testing" 11 | 12 | "golang.org/x/crypto/blake2s" 13 | ) 14 | 15 | type KDFTest struct { 16 | key string 17 | input string 18 | t0 string 19 | t1 string 20 | t2 string 21 | } 22 | 23 | func assertEquals(t *testing.T, a, b string) { 24 | if a != b { 25 | t.Fatal("expected", a, "=", b) 26 | } 27 | } 28 | 29 | func TestKDF(t *testing.T) { 30 | tests := []KDFTest{ 31 | { 32 | key: "746573742d6b6579", 33 | input: "746573742d696e707574", 34 | t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", 35 | t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", 36 | t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", 37 | }, 38 | { 39 | key: "776972656775617264", 40 | input: "776972656775617264", 41 | t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", 42 | t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", 43 | t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", 44 | }, 45 | { 46 | key: "", 47 | input: "", 48 | t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", 49 | t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", 50 | t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", 51 | }, 52 | } 53 | 54 | var t0, t1, t2 [blake2s.Size]byte 55 | 56 | for _, test := range tests { 57 | key, _ := hex.DecodeString(test.key) 58 | input, _ := hex.DecodeString(test.input) 59 | KDF3(&t0, &t1, &t2, key, input) 60 | t0s := hex.EncodeToString(t0[:]) 61 | t1s := hex.EncodeToString(t1[:]) 62 | t2s := hex.EncodeToString(t2[:]) 63 | assertEquals(t, t0s, test.t0) 64 | assertEquals(t, t1s, test.t1) 65 | assertEquals(t, t2s, test.t2) 66 | } 67 | 68 | for _, test := range tests { 69 | key, _ := hex.DecodeString(test.key) 70 | input, _ := hex.DecodeString(test.input) 71 | KDF2(&t0, &t1, key, input) 72 | t0s := hex.EncodeToString(t0[:]) 73 | t1s := hex.EncodeToString(t1[:]) 74 | assertEquals(t, t0s, test.t0) 75 | assertEquals(t, t1s, test.t1) 76 | } 77 | 78 | for _, test := range tests { 79 | key, _ := hex.DecodeString(test.key) 80 | input, _ := hex.DecodeString(test.input) 81 | KDF1(&t0, key, input) 82 | t0s := hex.EncodeToString(t0[:]) 83 | assertEquals(t, t0s, test.t0) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /rwcancel/rwcancel.go: -------------------------------------------------------------------------------- 1 | //go:build !windows && !js 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | // Package rwcancel implements cancelable read/write operations on 9 | // a file descriptor. 10 | package rwcancel 11 | 12 | import ( 13 | "errors" 14 | "os" 15 | "syscall" 16 | 17 | "golang.org/x/sys/unix" 18 | ) 19 | 20 | type RWCancel struct { 21 | fd int 22 | closingReader *os.File 23 | closingWriter *os.File 24 | } 25 | 26 | func NewRWCancel(fd int) (*RWCancel, error) { 27 | err := unix.SetNonblock(fd, true) 28 | if err != nil { 29 | return nil, err 30 | } 31 | rwcancel := RWCancel{fd: fd} 32 | 33 | rwcancel.closingReader, rwcancel.closingWriter, err = os.Pipe() 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | return &rwcancel, nil 39 | } 40 | 41 | func RetryAfterError(err error) bool { 42 | return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) 43 | } 44 | 45 | func (rw *RWCancel) ReadyRead() bool { 46 | closeFd := int32(rw.closingReader.Fd()) 47 | 48 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}} 49 | var err error 50 | for { 51 | _, err = unix.Poll(pollFds, -1) 52 | if err == nil || !RetryAfterError(err) { 53 | break 54 | } 55 | } 56 | if err != nil { 57 | return false 58 | } 59 | if pollFds[1].Revents != 0 { 60 | return false 61 | } 62 | return pollFds[0].Revents != 0 63 | } 64 | 65 | func (rw *RWCancel) ReadyWrite() bool { 66 | closeFd := int32(rw.closingReader.Fd()) 67 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}} 68 | var err error 69 | for { 70 | _, err = unix.Poll(pollFds, -1) 71 | if err == nil || !RetryAfterError(err) { 72 | break 73 | } 74 | } 75 | if err != nil { 76 | return false 77 | } 78 | 79 | if pollFds[1].Revents != 0 { 80 | return false 81 | } 82 | return pollFds[0].Revents != 0 83 | } 84 | 85 | func (rw *RWCancel) Read(p []byte) (n int, err error) { 86 | for { 87 | n, err := unix.Read(rw.fd, p) 88 | if err == nil || !RetryAfterError(err) { 89 | return n, err 90 | } 91 | if !rw.ReadyRead() { 92 | return 0, os.ErrClosed 93 | } 94 | } 95 | } 96 | 97 | func (rw *RWCancel) Write(p []byte) (n int, err error) { 98 | for { 99 | n, err := unix.Write(rw.fd, p) 100 | if err == nil || !RetryAfterError(err) { 101 | return n, err 102 | } 103 | if !rw.ReadyWrite() { 104 | return 0, os.ErrClosed 105 | } 106 | } 107 | } 108 | 109 | func (rw *RWCancel) Cancel() (err error) { 110 | _, err = rw.closingWriter.Write([]byte{0}) 111 | return 112 | } 113 | 114 | func (rw *RWCancel) Close() { 115 | rw.closingReader.Close() 116 | rw.closingWriter.Close() 117 | } 118 | -------------------------------------------------------------------------------- /device/noise-helpers.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "crypto/subtle" 12 | "errors" 13 | "hash" 14 | 15 | "golang.org/x/crypto/blake2s" 16 | "golang.org/x/crypto/curve25519" 17 | ) 18 | 19 | /* KDF related functions. 20 | * HMAC-based Key Derivation Function (HKDF) 21 | * https://tools.ietf.org/html/rfc5869 22 | */ 23 | 24 | func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { 25 | mac := hmac.New(func() hash.Hash { 26 | h, _ := blake2s.New256(nil) 27 | return h 28 | }, key) 29 | mac.Write(in0) 30 | mac.Sum(sum[:0]) 31 | } 32 | 33 | func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { 34 | mac := hmac.New(func() hash.Hash { 35 | h, _ := blake2s.New256(nil) 36 | return h 37 | }, key) 38 | mac.Write(in0) 39 | mac.Write(in1) 40 | mac.Sum(sum[:0]) 41 | } 42 | 43 | func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { 44 | HMAC1(t0, key, input) 45 | HMAC1(t0, t0[:], []byte{0x1}) 46 | } 47 | 48 | func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { 49 | var prk [blake2s.Size]byte 50 | HMAC1(&prk, key, input) 51 | HMAC1(t0, prk[:], []byte{0x1}) 52 | HMAC2(t1, prk[:], t0[:], []byte{0x2}) 53 | setZero(prk[:]) 54 | } 55 | 56 | func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { 57 | var prk [blake2s.Size]byte 58 | HMAC1(&prk, key, input) 59 | HMAC1(t0, prk[:], []byte{0x1}) 60 | HMAC2(t1, prk[:], t0[:], []byte{0x2}) 61 | HMAC2(t2, prk[:], t1[:], []byte{0x3}) 62 | setZero(prk[:]) 63 | } 64 | 65 | func isZero(val []byte) bool { 66 | acc := 1 67 | for _, b := range val { 68 | acc &= subtle.ConstantTimeByteEq(b, 0) 69 | } 70 | return acc == 1 71 | } 72 | 73 | /* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ 74 | func setZero(arr []byte) { 75 | for i := range arr { 76 | arr[i] = 0 77 | } 78 | } 79 | 80 | func (sk *NoisePrivateKey) clamp() { 81 | sk[0] &= 248 82 | sk[31] = (sk[31] & 127) | 64 83 | } 84 | 85 | func newPrivateKey() (sk NoisePrivateKey, err error) { 86 | _, err = rand.Read(sk[:]) 87 | sk.clamp() 88 | return 89 | } 90 | 91 | func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { 92 | apk := (*[NoisePublicKeySize]byte)(&pk) 93 | ask := (*[NoisePrivateKeySize]byte)(sk) 94 | curve25519.ScalarBaseMult(apk, ask) 95 | return 96 | } 97 | 98 | var errInvalidPublicKey = errors.New("invalid public key") 99 | 100 | func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) { 101 | apk := (*[NoisePublicKeySize]byte)(&pk) 102 | ask := (*[NoisePrivateKeySize]byte)(sk) 103 | curve25519.ScalarMult(&ss, ask, apk) 104 | if isZero(ss[:]) { 105 | return ss, errInvalidPublicKey 106 | } 107 | return ss, nil 108 | } 109 | -------------------------------------------------------------------------------- /ipc/uapi_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | import ( 9 | "net" 10 | "os" 11 | 12 | "golang.org/x/sys/unix" 13 | "golang.zx2c4.com/wireguard/rwcancel" 14 | ) 15 | 16 | type UAPIListener struct { 17 | listener net.Listener // unix socket listener 18 | connNew chan net.Conn 19 | connErr chan error 20 | inotifyFd int 21 | inotifyRWCancel *rwcancel.RWCancel 22 | } 23 | 24 | func (l *UAPIListener) Accept() (net.Conn, error) { 25 | for { 26 | select { 27 | case conn := <-l.connNew: 28 | return conn, nil 29 | 30 | case err := <-l.connErr: 31 | return nil, err 32 | } 33 | } 34 | } 35 | 36 | func (l *UAPIListener) Close() error { 37 | err1 := unix.Close(l.inotifyFd) 38 | err2 := l.inotifyRWCancel.Cancel() 39 | err3 := l.listener.Close() 40 | if err1 != nil { 41 | return err1 42 | } 43 | if err2 != nil { 44 | return err2 45 | } 46 | return err3 47 | } 48 | 49 | func (l *UAPIListener) Addr() net.Addr { 50 | return l.listener.Addr() 51 | } 52 | 53 | func UAPIListen(name string, file *os.File) (net.Listener, error) { 54 | // wrap file in listener 55 | 56 | listener, err := net.FileListener(file) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | if unixListener, ok := listener.(*net.UnixListener); ok { 62 | unixListener.SetUnlinkOnClose(true) 63 | } 64 | 65 | uapi := &UAPIListener{ 66 | listener: listener, 67 | connNew: make(chan net.Conn, 1), 68 | connErr: make(chan error, 1), 69 | } 70 | 71 | // watch for deletion of socket 72 | 73 | socketPath := sockPath(name) 74 | 75 | uapi.inotifyFd, err = unix.InotifyInit() 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | _, err = unix.InotifyAddWatch( 81 | uapi.inotifyFd, 82 | socketPath, 83 | unix.IN_ATTRIB| 84 | unix.IN_DELETE| 85 | unix.IN_DELETE_SELF, 86 | ) 87 | 88 | if err != nil { 89 | return nil, err 90 | } 91 | 92 | uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd) 93 | if err != nil { 94 | unix.Close(uapi.inotifyFd) 95 | return nil, err 96 | } 97 | 98 | go func(l *UAPIListener) { 99 | var buff [0]byte 100 | for { 101 | defer uapi.inotifyRWCancel.Close() 102 | // start with lstat to avoid race condition 103 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) { 104 | l.connErr <- err 105 | return 106 | } 107 | _, err := uapi.inotifyRWCancel.Read(buff[:]) 108 | if err != nil { 109 | l.connErr <- err 110 | return 111 | } 112 | } 113 | }(uapi) 114 | 115 | // watch for new connections 116 | 117 | go func(l *UAPIListener) { 118 | for { 119 | conn, err := l.listener.Accept() 120 | if err != nil { 121 | l.connErr <- err 122 | break 123 | } 124 | l.connNew <- conn 125 | } 126 | }(uapi) 127 | 128 | return uapi, nil 129 | } 130 | -------------------------------------------------------------------------------- /ratelimiter/ratelimiter_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ratelimiter 7 | 8 | import ( 9 | "net/netip" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | type result struct { 15 | allowed bool 16 | text string 17 | wait time.Duration 18 | } 19 | 20 | func TestRatelimiter(t *testing.T) { 21 | var rate Ratelimiter 22 | var expectedResults []result 23 | 24 | nano := func(nano int64) time.Duration { 25 | return time.Nanosecond * time.Duration(nano) 26 | } 27 | 28 | add := func(res result) { 29 | expectedResults = append( 30 | expectedResults, 31 | res, 32 | ) 33 | } 34 | 35 | for i := 0; i < packetsBurstable; i++ { 36 | add(result{ 37 | allowed: true, 38 | text: "initial burst", 39 | }) 40 | } 41 | 42 | add(result{ 43 | allowed: false, 44 | text: "after burst", 45 | }) 46 | 47 | add(result{ 48 | allowed: true, 49 | wait: nano(time.Second.Nanoseconds() / packetsPerSecond), 50 | text: "filling tokens for single packet", 51 | }) 52 | 53 | add(result{ 54 | allowed: false, 55 | text: "not having refilled enough", 56 | }) 57 | 58 | add(result{ 59 | allowed: true, 60 | wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)), 61 | text: "filling tokens for two packet burst", 62 | }) 63 | 64 | add(result{ 65 | allowed: true, 66 | text: "second packet in 2 packet burst", 67 | }) 68 | 69 | add(result{ 70 | allowed: false, 71 | text: "packet following 2 packet burst", 72 | }) 73 | 74 | ips := []netip.Addr{ 75 | netip.MustParseAddr("127.0.0.1"), 76 | netip.MustParseAddr("192.168.1.1"), 77 | netip.MustParseAddr("172.167.2.3"), 78 | netip.MustParseAddr("97.231.252.215"), 79 | netip.MustParseAddr("248.97.91.167"), 80 | netip.MustParseAddr("188.208.233.47"), 81 | netip.MustParseAddr("104.2.183.179"), 82 | netip.MustParseAddr("72.129.46.120"), 83 | netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), 84 | netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), 85 | netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), 86 | netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), 87 | netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), 88 | netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), 89 | } 90 | 91 | now := time.Now() 92 | rate.timeNow = func() time.Time { 93 | return now 94 | } 95 | defer func() { 96 | // Lock to avoid data race with cleanup goroutine from Init. 97 | rate.mu.Lock() 98 | defer rate.mu.Unlock() 99 | 100 | rate.timeNow = time.Now 101 | }() 102 | timeSleep := func(d time.Duration) { 103 | now = now.Add(d + 1) 104 | rate.cleanup() 105 | } 106 | 107 | rate.Init() 108 | defer rate.Close() 109 | 110 | for i, res := range expectedResults { 111 | timeSleep(res.wait) 112 | for _, ip := range ips { 113 | allowed := rate.Allow(ip) 114 | if allowed != res.allowed { 115 | t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed) 116 | } 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /ipc/uapi_bsd.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || freebsd || openbsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package ipc 9 | 10 | import ( 11 | "errors" 12 | "net" 13 | "os" 14 | "unsafe" 15 | 16 | "golang.org/x/sys/unix" 17 | ) 18 | 19 | type UAPIListener struct { 20 | listener net.Listener // unix socket listener 21 | connNew chan net.Conn 22 | connErr chan error 23 | kqueueFd int 24 | keventFd int 25 | } 26 | 27 | func (l *UAPIListener) Accept() (net.Conn, error) { 28 | for { 29 | select { 30 | case conn := <-l.connNew: 31 | return conn, nil 32 | 33 | case err := <-l.connErr: 34 | return nil, err 35 | } 36 | } 37 | } 38 | 39 | func (l *UAPIListener) Close() error { 40 | err1 := unix.Close(l.kqueueFd) 41 | err2 := unix.Close(l.keventFd) 42 | err3 := l.listener.Close() 43 | if err1 != nil { 44 | return err1 45 | } 46 | if err2 != nil { 47 | return err2 48 | } 49 | return err3 50 | } 51 | 52 | func (l *UAPIListener) Addr() net.Addr { 53 | return l.listener.Addr() 54 | } 55 | 56 | func UAPIListen(name string, file *os.File) (net.Listener, error) { 57 | // wrap file in listener 58 | 59 | listener, err := net.FileListener(file) 60 | if err != nil { 61 | return nil, err 62 | } 63 | 64 | uapi := &UAPIListener{ 65 | listener: listener, 66 | connNew: make(chan net.Conn, 1), 67 | connErr: make(chan error, 1), 68 | } 69 | 70 | if unixListener, ok := listener.(*net.UnixListener); ok { 71 | unixListener.SetUnlinkOnClose(true) 72 | } 73 | 74 | socketPath := sockPath(name) 75 | 76 | // watch for deletion of socket 77 | 78 | uapi.kqueueFd, err = unix.Kqueue() 79 | if err != nil { 80 | return nil, err 81 | } 82 | uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0) 83 | if err != nil { 84 | unix.Close(uapi.kqueueFd) 85 | return nil, err 86 | } 87 | 88 | go func(l *UAPIListener) { 89 | event := unix.Kevent_t{ 90 | Filter: unix.EVFILT_VNODE, 91 | Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT, 92 | Fflags: unix.NOTE_WRITE, 93 | } 94 | // Allow this assignment to work with both the 32-bit and 64-bit version 95 | // of the above struct. If you know another way, please submit a patch. 96 | *(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd) 97 | events := make([]unix.Kevent_t, 1) 98 | n := 1 99 | var kerr error 100 | for { 101 | // start with lstat to avoid race condition 102 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) { 103 | l.connErr <- err 104 | return 105 | } 106 | if (kerr != nil || n != 1) && kerr != unix.EINTR { 107 | if kerr != nil { 108 | l.connErr <- kerr 109 | } else { 110 | l.connErr <- errors.New("kqueue returned empty") 111 | } 112 | return 113 | } 114 | n, kerr = unix.Kevent(uapi.kqueueFd, []unix.Kevent_t{event}, events, nil) 115 | } 116 | }(uapi) 117 | 118 | // watch for new connections 119 | 120 | go func(l *UAPIListener) { 121 | for { 122 | conn, err := l.listener.Accept() 123 | if err != nil { 124 | l.connErr <- err 125 | break 126 | } 127 | l.connNew <- conn 128 | } 129 | }(uapi) 130 | 131 | return uapi, nil 132 | } 133 | -------------------------------------------------------------------------------- /ratelimiter/ratelimiter.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ratelimiter 7 | 8 | import ( 9 | "net/netip" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | const ( 15 | packetsPerSecond = 20 16 | packetsBurstable = 5 17 | garbageCollectTime = time.Second 18 | packetCost = 1000000000 / packetsPerSecond 19 | maxTokens = packetCost * packetsBurstable 20 | ) 21 | 22 | type RatelimiterEntry struct { 23 | mu sync.Mutex 24 | lastTime time.Time 25 | tokens int64 26 | } 27 | 28 | type Ratelimiter struct { 29 | mu sync.RWMutex 30 | timeNow func() time.Time 31 | 32 | stopReset chan struct{} // send to reset, close to stop 33 | table map[netip.Addr]*RatelimiterEntry 34 | } 35 | 36 | func (rate *Ratelimiter) Close() { 37 | rate.mu.Lock() 38 | defer rate.mu.Unlock() 39 | 40 | if rate.stopReset != nil { 41 | close(rate.stopReset) 42 | } 43 | } 44 | 45 | func (rate *Ratelimiter) Init() { 46 | rate.mu.Lock() 47 | defer rate.mu.Unlock() 48 | 49 | if rate.timeNow == nil { 50 | rate.timeNow = time.Now 51 | } 52 | 53 | // stop any ongoing garbage collection routine 54 | if rate.stopReset != nil { 55 | close(rate.stopReset) 56 | } 57 | 58 | rate.stopReset = make(chan struct{}) 59 | rate.table = make(map[netip.Addr]*RatelimiterEntry) 60 | 61 | stopReset := rate.stopReset // store in case Init is called again. 62 | 63 | // Start garbage collection routine. 64 | go func() { 65 | ticker := time.NewTicker(time.Second) 66 | ticker.Stop() 67 | for { 68 | select { 69 | case _, ok := <-stopReset: 70 | ticker.Stop() 71 | if !ok { 72 | return 73 | } 74 | ticker = time.NewTicker(time.Second) 75 | case <-ticker.C: 76 | if rate.cleanup() { 77 | ticker.Stop() 78 | } 79 | } 80 | } 81 | }() 82 | } 83 | 84 | func (rate *Ratelimiter) cleanup() (empty bool) { 85 | rate.mu.Lock() 86 | defer rate.mu.Unlock() 87 | 88 | for key, entry := range rate.table { 89 | entry.mu.Lock() 90 | if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 91 | delete(rate.table, key) 92 | } 93 | entry.mu.Unlock() 94 | } 95 | 96 | return len(rate.table) == 0 97 | } 98 | 99 | func (rate *Ratelimiter) Allow(ip netip.Addr) bool { 100 | var entry *RatelimiterEntry 101 | // lookup entry 102 | rate.mu.RLock() 103 | entry = rate.table[ip] 104 | rate.mu.RUnlock() 105 | 106 | // make new entry if not found 107 | if entry == nil { 108 | entry = new(RatelimiterEntry) 109 | entry.tokens = maxTokens - packetCost 110 | entry.lastTime = rate.timeNow() 111 | rate.mu.Lock() 112 | rate.table[ip] = entry 113 | if len(rate.table) == 1 { 114 | rate.stopReset <- struct{}{} 115 | } 116 | rate.mu.Unlock() 117 | return true 118 | } 119 | 120 | // add tokens to entry 121 | entry.mu.Lock() 122 | now := rate.timeNow() 123 | entry.tokens += now.Sub(entry.lastTime).Nanoseconds() 124 | entry.lastTime = now 125 | if entry.tokens > maxTokens { 126 | entry.tokens = maxTokens 127 | } 128 | 129 | // subtract cost of packet 130 | if entry.tokens > packetCost { 131 | entry.tokens -= packetCost 132 | entry.mu.Unlock() 133 | return true 134 | } 135 | entry.mu.Unlock() 136 | return false 137 | } 138 | -------------------------------------------------------------------------------- /replay/replay_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package replay 7 | 8 | import ( 9 | "testing" 10 | ) 11 | 12 | /* Ported from the linux kernel implementation 13 | * 14 | * 15 | */ 16 | 17 | const RejectAfterMessages = 1<<64 - 1<<13 - 1 18 | 19 | func TestReplay(t *testing.T) { 20 | var filter Filter 21 | 22 | const T_LIM = windowSize + 1 23 | 24 | testNumber := 0 25 | T := func(n uint64, expected bool) { 26 | testNumber++ 27 | if filter.ValidateCounter(n, RejectAfterMessages) != expected { 28 | t.Fatal("Test", testNumber, "failed", n, expected) 29 | } 30 | } 31 | 32 | filter.Reset() 33 | 34 | T(0, true) /* 1 */ 35 | T(1, true) /* 2 */ 36 | T(1, false) /* 3 */ 37 | T(9, true) /* 4 */ 38 | T(8, true) /* 5 */ 39 | T(7, true) /* 6 */ 40 | T(7, false) /* 7 */ 41 | T(T_LIM, true) /* 8 */ 42 | T(T_LIM-1, true) /* 9 */ 43 | T(T_LIM-1, false) /* 10 */ 44 | T(T_LIM-2, true) /* 11 */ 45 | T(2, true) /* 12 */ 46 | T(2, false) /* 13 */ 47 | T(T_LIM+16, true) /* 14 */ 48 | T(3, false) /* 15 */ 49 | T(T_LIM+16, false) /* 16 */ 50 | T(T_LIM*4, true) /* 17 */ 51 | T(T_LIM*4-(T_LIM-1), true) /* 18 */ 52 | T(10, false) /* 19 */ 53 | T(T_LIM*4-T_LIM, false) /* 20 */ 54 | T(T_LIM*4-(T_LIM+1), false) /* 21 */ 55 | T(T_LIM*4-(T_LIM-2), true) /* 22 */ 56 | T(T_LIM*4+1-T_LIM, false) /* 23 */ 57 | T(0, false) /* 24 */ 58 | T(RejectAfterMessages, false) /* 25 */ 59 | T(RejectAfterMessages-1, true) /* 26 */ 60 | T(RejectAfterMessages, false) /* 27 */ 61 | T(RejectAfterMessages-1, false) /* 28 */ 62 | T(RejectAfterMessages-2, true) /* 29 */ 63 | T(RejectAfterMessages+1, false) /* 30 */ 64 | T(RejectAfterMessages+2, false) /* 31 */ 65 | T(RejectAfterMessages-2, false) /* 32 */ 66 | T(RejectAfterMessages-3, true) /* 33 */ 67 | T(0, false) /* 34 */ 68 | 69 | t.Log("Bulk test 1") 70 | filter.Reset() 71 | testNumber = 0 72 | for i := uint64(1); i <= windowSize; i++ { 73 | T(i, true) 74 | } 75 | T(0, true) 76 | T(0, false) 77 | 78 | t.Log("Bulk test 2") 79 | filter.Reset() 80 | testNumber = 0 81 | for i := uint64(2); i <= windowSize+1; i++ { 82 | T(i, true) 83 | } 84 | T(1, true) 85 | T(0, false) 86 | 87 | t.Log("Bulk test 3") 88 | filter.Reset() 89 | testNumber = 0 90 | for i := uint64(windowSize + 1); i > 0; i-- { 91 | T(i, true) 92 | } 93 | 94 | t.Log("Bulk test 4") 95 | filter.Reset() 96 | testNumber = 0 97 | for i := uint64(windowSize + 2); i > 1; i-- { 98 | T(i, true) 99 | } 100 | T(0, false) 101 | 102 | t.Log("Bulk test 5") 103 | filter.Reset() 104 | testNumber = 0 105 | for i := uint64(windowSize); i > 0; i-- { 106 | T(i, true) 107 | } 108 | T(windowSize+1, true) 109 | T(0, false) 110 | 111 | t.Log("Bulk test 6") 112 | filter.Reset() 113 | testNumber = 0 114 | for i := uint64(windowSize); i > 0; i-- { 115 | T(i, true) 116 | } 117 | T(0, true) 118 | T(windowSize+1, true) 119 | } 120 | -------------------------------------------------------------------------------- /conn/server_name_utils/consistent_hash.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025. Proton AG 3 | * 4 | * This file is part of ProtonVPN. 5 | * 6 | * ProtonVPN is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * ProtonVPN is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with ProtonVPN. If not, see . 18 | */ 19 | 20 | package server_name_utils 21 | 22 | import ( 23 | "hash/crc32" 24 | "math" 25 | "sort" 26 | ) 27 | 28 | // Set of utilities to implement consistent hashing where map one set of strings (keys) to another (values). 29 | // Set of values will change over time but mapping will remain largely stable. 30 | // Usage: 31 | // hashedValues := sortValuesByHash(values, crc32Hash) 32 | // value := findClosestValue(key, hashedValues, crc32Hash) 33 | // when set of values changes (e.g. new one is added), only mappings that will change are for keys for which 34 | // new value is the closest one (its uint32 hash is closest to key's hash). 35 | 36 | type HashedValue struct { 37 | value string 38 | hash uint32 39 | } 40 | 41 | // Picks value out of sortedValuesWithHashes which hash is closest to hash of value. For distance 42 | // calculation, uint32 is forming a ring where 0 is next to math.MaxUint32. Closer of clockwise and 43 | // counter-clockwise distance is picked. 44 | func findClosestValue(key string, sortedValuesWithHashes []HashedValue, hashFun func(string) uint32) string { 45 | n := len(sortedValuesWithHashes) 46 | if n == 0 { 47 | return "" 48 | } else if n == 1 { 49 | return sortedValuesWithHashes[0].value 50 | } 51 | 52 | keyHash := hashFun(key) 53 | i := sort.Search(n, func(i int) bool { 54 | return sortedValuesWithHashes[i].hash >= keyHash 55 | }) 56 | 57 | if i <= 0 || i >= n { 58 | // If it's smaller than first or larger than last, return closest 59 | // between first and last 60 | return closerValue(keyHash, sortedValuesWithHashes[0], sortedValuesWithHashes[n-1]) 61 | } else { 62 | return closerValue(keyHash, sortedValuesWithHashes[i-1], sortedValuesWithHashes[i]) 63 | } 64 | } 65 | 66 | func sortValuesByHash(values []string, hashFun func(string) uint32) []HashedValue { 67 | hashedValues := make([]HashedValue, len(values)) 68 | for i, domain := range values { 69 | hashedValues[i] = HashedValue{domain, hashFun(domain)} 70 | } 71 | sort.Slice(hashedValues, func(i, j int) bool { 72 | return hashedValues[i].hash < hashedValues[j].hash 73 | }) 74 | return hashedValues 75 | } 76 | 77 | func crc32Hash(s string) uint32 { 78 | return crc32.ChecksumIEEE([]byte(s)) 79 | } 80 | 81 | func ringDistance(a, b uint32) int64 { 82 | var fa = int64(a) 83 | var fb = int64(b) 84 | var large = max(fa, fb) 85 | var small = min(fa, fb) 86 | // Take smaller of clockwise and counter-clockwise distance 87 | return min(large - small, small - large + math.MaxUint32) 88 | } 89 | 90 | func closerValue(hash uint32, a HashedValue, b HashedValue) string { 91 | var da = ringDistance(hash, a.hash) 92 | var db = ringDistance(hash, b.hash) 93 | if da < db { 94 | return a.value 95 | } else { 96 | return b.value 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /device/allowedips_rand_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net" 11 | "net/netip" 12 | "sort" 13 | "testing" 14 | ) 15 | 16 | const ( 17 | NumberOfPeers = 100 18 | NumberOfPeerRemovals = 4 19 | NumberOfAddresses = 250 20 | NumberOfTests = 10000 21 | ) 22 | 23 | type SlowNode struct { 24 | peer *Peer 25 | cidr uint8 26 | bits []byte 27 | } 28 | 29 | type SlowRouter []*SlowNode 30 | 31 | func (r SlowRouter) Len() int { 32 | return len(r) 33 | } 34 | 35 | func (r SlowRouter) Less(i, j int) bool { 36 | return r[i].cidr > r[j].cidr 37 | } 38 | 39 | func (r SlowRouter) Swap(i, j int) { 40 | r[i], r[j] = r[j], r[i] 41 | } 42 | 43 | func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { 44 | for _, t := range r { 45 | if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { 46 | t.peer = peer 47 | t.bits = addr 48 | return r 49 | } 50 | } 51 | r = append(r, &SlowNode{ 52 | cidr: cidr, 53 | bits: addr, 54 | peer: peer, 55 | }) 56 | sort.Sort(r) 57 | return r 58 | } 59 | 60 | func (r SlowRouter) Lookup(addr []byte) *Peer { 61 | for _, t := range r { 62 | common := commonBits(t.bits, addr) 63 | if common >= t.cidr { 64 | return t.peer 65 | } 66 | } 67 | return nil 68 | } 69 | 70 | func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { 71 | n := 0 72 | for _, x := range r { 73 | if x.peer != peer { 74 | r[n] = x 75 | n++ 76 | } 77 | } 78 | return r[:n] 79 | } 80 | 81 | func TestTrieRandom(t *testing.T) { 82 | var slow4, slow6 SlowRouter 83 | var peers []*Peer 84 | var allowedIPs AllowedIPs 85 | 86 | rand.Seed(1) 87 | 88 | for n := 0; n < NumberOfPeers; n++ { 89 | peers = append(peers, &Peer{}) 90 | } 91 | 92 | for n := 0; n < NumberOfAddresses; n++ { 93 | var addr4 [4]byte 94 | rand.Read(addr4[:]) 95 | cidr := uint8(rand.Intn(32) + 1) 96 | index := rand.Intn(NumberOfPeers) 97 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) 98 | slow4 = slow4.Insert(addr4[:], cidr, peers[index]) 99 | 100 | var addr6 [16]byte 101 | rand.Read(addr6[:]) 102 | cidr = uint8(rand.Intn(128) + 1) 103 | index = rand.Intn(NumberOfPeers) 104 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) 105 | slow6 = slow6.Insert(addr6[:], cidr, peers[index]) 106 | } 107 | 108 | var p int 109 | for p = 0; ; p++ { 110 | for n := 0; n < NumberOfTests; n++ { 111 | var addr4 [4]byte 112 | rand.Read(addr4[:]) 113 | peer1 := slow4.Lookup(addr4[:]) 114 | peer2 := allowedIPs.Lookup(addr4[:]) 115 | if peer1 != peer2 { 116 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) 117 | } 118 | 119 | var addr6 [16]byte 120 | rand.Read(addr6[:]) 121 | peer1 = slow6.Lookup(addr6[:]) 122 | peer2 = allowedIPs.Lookup(addr6[:]) 123 | if peer1 != peer2 { 124 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) 125 | } 126 | } 127 | if p >= len(peers) || p >= NumberOfPeerRemovals { 128 | break 129 | } 130 | allowedIPs.RemoveByPeer(peers[p]) 131 | slow4 = slow4.RemoveByPeer(peers[p]) 132 | slow6 = slow6.RemoveByPeer(peers[p]) 133 | } 134 | for ; p < len(peers); p++ { 135 | allowedIPs.RemoveByPeer(peers[p]) 136 | } 137 | 138 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 139 | t.Error("Failed to remove all nodes from trie by peer") 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /conn/bindtest/bindtest.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package bindtest 7 | 8 | import ( 9 | "fmt" 10 | "math/rand" 11 | "net" 12 | "net/netip" 13 | "os" 14 | 15 | "golang.zx2c4.com/wireguard/conn" 16 | ) 17 | 18 | type ChannelBind struct { 19 | rx4, tx4 *chan []byte 20 | rx6, tx6 *chan []byte 21 | closeSignal chan bool 22 | source4, source6 ChannelEndpoint 23 | target4, target6 ChannelEndpoint 24 | } 25 | 26 | type ChannelEndpoint uint16 27 | 28 | var ( 29 | _ conn.Bind = (*ChannelBind)(nil) 30 | _ conn.Endpoint = (*ChannelEndpoint)(nil) 31 | ) 32 | 33 | func NewChannelBinds() [2]conn.Bind { 34 | arx4 := make(chan []byte, 8192) 35 | brx4 := make(chan []byte, 8192) 36 | arx6 := make(chan []byte, 8192) 37 | brx6 := make(chan []byte, 8192) 38 | var binds [2]ChannelBind 39 | binds[0].rx4 = &arx4 40 | binds[0].tx4 = &brx4 41 | binds[1].rx4 = &brx4 42 | binds[1].tx4 = &arx4 43 | binds[0].rx6 = &arx6 44 | binds[0].tx6 = &brx6 45 | binds[1].rx6 = &brx6 46 | binds[1].tx6 = &arx6 47 | binds[0].target4 = ChannelEndpoint(1) 48 | binds[1].target4 = ChannelEndpoint(2) 49 | binds[0].target6 = ChannelEndpoint(3) 50 | binds[1].target6 = ChannelEndpoint(4) 51 | binds[0].source4 = binds[1].target4 52 | binds[0].source6 = binds[1].target6 53 | binds[1].source4 = binds[0].target4 54 | binds[1].source6 = binds[0].target6 55 | return [2]conn.Bind{&binds[0], &binds[1]} 56 | } 57 | 58 | func (c ChannelEndpoint) ClearSrc() {} 59 | 60 | func (c ChannelEndpoint) SrcToString() string { return "" } 61 | 62 | func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } 63 | 64 | func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } 65 | 66 | func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } 67 | 68 | func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } 69 | 70 | func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 71 | c.closeSignal = make(chan bool) 72 | fns = append(fns, c.makeReceiveFunc(*c.rx4)) 73 | fns = append(fns, c.makeReceiveFunc(*c.rx6)) 74 | if rand.Uint32()&1 == 0 { 75 | return fns, uint16(c.source4), nil 76 | } else { 77 | return fns, uint16(c.source6), nil 78 | } 79 | } 80 | 81 | func (c *ChannelBind) Close() error { 82 | if c.closeSignal != nil { 83 | select { 84 | case <-c.closeSignal: 85 | default: 86 | close(c.closeSignal) 87 | } 88 | } 89 | return nil 90 | } 91 | 92 | func (c *ChannelBind) SetMark(mark uint32) error { return nil } 93 | 94 | func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { 95 | return func(b []byte) (n int, ep conn.Endpoint, err error) { 96 | select { 97 | case <-c.closeSignal: 98 | return 0, nil, net.ErrClosed 99 | case rx := <-ch: 100 | return copy(b, rx), c.target6, nil 101 | } 102 | } 103 | } 104 | 105 | func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { 106 | select { 107 | case <-c.closeSignal: 108 | return net.ErrClosed 109 | default: 110 | bc := make([]byte, len(b)) 111 | copy(bc, b) 112 | if ep.(ChannelEndpoint) == c.target4 { 113 | *c.tx4 <- bc 114 | } else if ep.(ChannelEndpoint) == c.target6 { 115 | *c.tx6 <- bc 116 | } else { 117 | return os.ErrInvalid 118 | } 119 | } 120 | return nil 121 | } 122 | 123 | func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { 124 | addr, err := netip.ParseAddrPort(s) 125 | if err != nil { 126 | return nil, err 127 | } 128 | return ChannelEndpoint(addr.Port()), nil 129 | } 130 | -------------------------------------------------------------------------------- /device/channels.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "runtime" 10 | "sync" 11 | ) 12 | 13 | // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. 14 | // An outboundQueue is ref-counted using its wg field. 15 | // An outboundQueue created with newOutboundQueue has one reference. 16 | // Every additional writer must call wg.Add(1). 17 | // Every completed writer must call wg.Done(). 18 | // When no further writers will be added, 19 | // call wg.Done to remove the initial reference. 20 | // When the refcount hits 0, the queue's channel is closed. 21 | type outboundQueue struct { 22 | c chan *QueueOutboundElement 23 | wg sync.WaitGroup 24 | } 25 | 26 | func newOutboundQueue() *outboundQueue { 27 | q := &outboundQueue{ 28 | c: make(chan *QueueOutboundElement, QueueOutboundSize), 29 | } 30 | q.wg.Add(1) 31 | go func() { 32 | q.wg.Wait() 33 | close(q.c) 34 | }() 35 | return q 36 | } 37 | 38 | // A inboundQueue is similar to an outboundQueue; see those docs. 39 | type inboundQueue struct { 40 | c chan *QueueInboundElement 41 | wg sync.WaitGroup 42 | } 43 | 44 | func newInboundQueue() *inboundQueue { 45 | q := &inboundQueue{ 46 | c: make(chan *QueueInboundElement, QueueInboundSize), 47 | } 48 | q.wg.Add(1) 49 | go func() { 50 | q.wg.Wait() 51 | close(q.c) 52 | }() 53 | return q 54 | } 55 | 56 | // A handshakeQueue is similar to an outboundQueue; see those docs. 57 | type handshakeQueue struct { 58 | c chan QueueHandshakeElement 59 | wg sync.WaitGroup 60 | } 61 | 62 | func newHandshakeQueue() *handshakeQueue { 63 | q := &handshakeQueue{ 64 | c: make(chan QueueHandshakeElement, QueueHandshakeSize), 65 | } 66 | q.wg.Add(1) 67 | go func() { 68 | q.wg.Wait() 69 | close(q.c) 70 | }() 71 | return q 72 | } 73 | 74 | type autodrainingInboundQueue struct { 75 | c chan *QueueInboundElement 76 | } 77 | 78 | // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. 79 | // It is useful in cases in which is it hard to manage the lifetime of the channel. 80 | // The returned channel must not be closed. Senders should signal shutdown using 81 | // some other means, such as sending a sentinel nil values. 82 | func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { 83 | q := &autodrainingInboundQueue{ 84 | c: make(chan *QueueInboundElement, QueueInboundSize), 85 | } 86 | runtime.SetFinalizer(q, device.flushInboundQueue) 87 | return q 88 | } 89 | 90 | func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { 91 | for { 92 | select { 93 | case elem := <-q.c: 94 | elem.Lock() 95 | device.PutMessageBuffer(elem.buffer) 96 | device.PutInboundElement(elem) 97 | default: 98 | return 99 | } 100 | } 101 | } 102 | 103 | type autodrainingOutboundQueue struct { 104 | c chan *QueueOutboundElement 105 | } 106 | 107 | // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. 108 | // It is useful in cases in which is it hard to manage the lifetime of the channel. 109 | // The returned channel must not be closed. Senders should signal shutdown using 110 | // some other means, such as sending a sentinel nil values. 111 | // All sends to the channel must be best-effort, because there may be no receivers. 112 | func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { 113 | q := &autodrainingOutboundQueue{ 114 | c: make(chan *QueueOutboundElement, QueueOutboundSize), 115 | } 116 | runtime.SetFinalizer(q, device.flushOutboundQueue) 117 | return q 118 | } 119 | 120 | func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { 121 | for { 122 | select { 123 | case elem := <-q.c: 124 | elem.Lock() 125 | device.PutMessageBuffer(elem.buffer) 126 | device.PutOutboundElement(elem) 127 | default: 128 | return 129 | } 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Go Implementation of [WireGuard](https://www.wireguard.com/) 2 | 3 | This is an implementation of WireGuard in Go. 4 | 5 | ## Usage 6 | 7 | Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run: 8 | 9 | ``` 10 | $ wireguard-go wg0 11 | ``` 12 | 13 | This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down. 14 | 15 | To run wireguard-go without forking to the background, pass `-f` or `--foreground`: 16 | 17 | ``` 18 | $ wireguard-go -f wg0 19 | ``` 20 | 21 | When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands. 22 | 23 | To run with more logging you may set the environment variable `LOG_LEVEL=debug`. 24 | 25 | ## Platforms 26 | 27 | ### Linux 28 | 29 | This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions. 30 | 31 | ### macOS 32 | 33 | This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. 34 | 35 | ### Windows 36 | 37 | This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module. 38 | 39 | ### FreeBSD 40 | 41 | This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`. 42 | 43 | ### OpenBSD 44 | 45 | This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. 46 | 47 | ## Building 48 | 49 | This requires an installation of [go](https://golang.org) ≥ 1.18. 50 | 51 | ``` 52 | $ git clone https://git.zx2c4.com/wireguard-go 53 | $ cd wireguard-go 54 | $ make 55 | ``` 56 | 57 | ## License 58 | 59 | Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 60 | 61 | Permission is hereby granted, free of charge, to any person obtaining a copy of 62 | this software and associated documentation files (the "Software"), to deal in 63 | the Software without restriction, including without limitation the rights to 64 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 65 | of the Software, and to permit persons to whom the Software is furnished to do 66 | so, subject to the following conditions: 67 | 68 | The above copyright notice and this permission notice shall be included in all 69 | copies or substantial portions of the Software. 70 | 71 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 72 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 73 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 74 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 75 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 76 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 77 | SOFTWARE. 78 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= 2 | github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= 3 | github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= 4 | github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= 5 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 6 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 8 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= 10 | github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= 11 | github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= 12 | github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= 13 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 14 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 15 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 16 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 17 | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= 18 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 19 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 20 | github.com/refraction-networking/utls v1.6.7 h1:zVJ7sP1dJx/WtVuITug3qYUq034cDq9B2MR1K67ULZM= 21 | github.com/refraction-networking/utls v1.6.7/go.mod h1:BC3O4vQzye5hqpmDTWUqi4P5DDhzJfkV1tdqtawQIH0= 22 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= 23 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 24 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 25 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 26 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 27 | github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= 28 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 29 | golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= 30 | golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= 31 | golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= 32 | golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= 33 | golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= 34 | golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 35 | golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= 36 | golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= 37 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= 38 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= 39 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 40 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 41 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 42 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 43 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 44 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 45 | gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY= 46 | gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA= 47 | -------------------------------------------------------------------------------- /tun/tuntest/tuntest.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tuntest 7 | 8 | import ( 9 | "encoding/binary" 10 | "io" 11 | "net/netip" 12 | "os" 13 | 14 | "golang.zx2c4.com/wireguard/tun" 15 | ) 16 | 17 | func Ping(dst, src netip.Addr) []byte { 18 | localPort := uint16(1337) 19 | seq := uint16(0) 20 | 21 | payload := make([]byte, 4) 22 | binary.BigEndian.PutUint16(payload[0:], localPort) 23 | binary.BigEndian.PutUint16(payload[2:], seq) 24 | 25 | return genICMPv4(payload, dst, src) 26 | } 27 | 28 | // Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. 29 | func checksum(buf []byte, initial uint16) uint16 { 30 | v := uint32(initial) 31 | for i := 0; i < len(buf)-1; i += 2 { 32 | v += uint32(binary.BigEndian.Uint16(buf[i:])) 33 | } 34 | if len(buf)%2 == 1 { 35 | v += uint32(buf[len(buf)-1]) << 8 36 | } 37 | for v > 0xffff { 38 | v = (v >> 16) + (v & 0xffff) 39 | } 40 | return ^uint16(v) 41 | } 42 | 43 | func genICMPv4(payload []byte, dst, src netip.Addr) []byte { 44 | const ( 45 | icmpv4ProtocolNumber = 1 46 | icmpv4Echo = 8 47 | icmpv4ChecksumOffset = 2 48 | icmpv4Size = 8 49 | ipv4Size = 20 50 | ipv4TotalLenOffset = 2 51 | ipv4ChecksumOffset = 10 52 | ttl = 65 53 | headerSize = ipv4Size + icmpv4Size 54 | ) 55 | 56 | pkt := make([]byte, headerSize+len(payload)) 57 | 58 | ip := pkt[0:ipv4Size] 59 | icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size] 60 | 61 | // https://tools.ietf.org/html/rfc792 62 | icmpv4[0] = icmpv4Echo // type 63 | icmpv4[1] = 0 // code 64 | chksum := ^checksum(icmpv4, checksum(payload, 0)) 65 | binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) 66 | 67 | // https://tools.ietf.org/html/rfc760 section 3.1 68 | length := uint16(len(pkt)) 69 | ip[0] = (4 << 4) | (ipv4Size / 4) 70 | binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) 71 | ip[8] = ttl 72 | ip[9] = icmpv4ProtocolNumber 73 | copy(ip[12:], src.AsSlice()) 74 | copy(ip[16:], dst.AsSlice()) 75 | chksum = ^checksum(ip[:], 0) 76 | binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) 77 | 78 | copy(pkt[headerSize:], payload) 79 | return pkt 80 | } 81 | 82 | type ChannelTUN struct { 83 | Inbound chan []byte // incoming packets, closed on TUN close 84 | Outbound chan []byte // outbound packets, blocks forever on TUN close 85 | 86 | closed chan struct{} 87 | events chan tun.Event 88 | tun chTun 89 | } 90 | 91 | func NewChannelTUN() *ChannelTUN { 92 | c := &ChannelTUN{ 93 | Inbound: make(chan []byte), 94 | Outbound: make(chan []byte), 95 | closed: make(chan struct{}), 96 | events: make(chan tun.Event, 1), 97 | } 98 | c.tun.c = c 99 | c.events <- tun.EventUp 100 | return c 101 | } 102 | 103 | func (c *ChannelTUN) TUN() tun.Device { 104 | return &c.tun 105 | } 106 | 107 | type chTun struct { 108 | c *ChannelTUN 109 | } 110 | 111 | func (t *chTun) File() *os.File { return nil } 112 | 113 | func (t *chTun) Read(data []byte, offset int) (int, error) { 114 | select { 115 | case <-t.c.closed: 116 | return 0, os.ErrClosed 117 | case msg := <-t.c.Outbound: 118 | return copy(data[offset:], msg), nil 119 | } 120 | } 121 | 122 | // Write is called by the wireguard device to deliver a packet for routing. 123 | func (t *chTun) Write(data []byte, offset int) (int, error) { 124 | if offset == -1 { 125 | close(t.c.closed) 126 | close(t.c.events) 127 | return 0, io.EOF 128 | } 129 | msg := make([]byte, len(data)-offset) 130 | copy(msg, data[offset:]) 131 | select { 132 | case <-t.c.closed: 133 | return 0, os.ErrClosed 134 | case t.c.Inbound <- msg: 135 | return len(data) - offset, nil 136 | } 137 | } 138 | 139 | const DefaultMTU = 1420 140 | 141 | func (t *chTun) Flush() error { return nil } 142 | func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } 143 | func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } 144 | func (t *chTun) Events() <-chan tun.Event { return t.c.events } 145 | func (t *chTun) Close() error { 146 | t.Write(nil, -1) 147 | return nil 148 | } 149 | -------------------------------------------------------------------------------- /conn/conn.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | // Package conn implements WireGuard's network connections. 7 | package conn 8 | 9 | import ( 10 | "errors" 11 | "fmt" 12 | "net/netip" 13 | "reflect" 14 | "runtime" 15 | "strings" 16 | ) 17 | 18 | // A ReceiveFunc receives a single inbound packet from the network. 19 | // It writes the data into b. n is the length of the packet. 20 | // ep is the remote endpoint. 21 | type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error) 22 | 23 | // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. 24 | // 25 | // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, 26 | // depending on the platform-specific implementation. 27 | type Bind interface { 28 | // Open puts the Bind into a listening state on a given port and reports the actual 29 | // port that it bound to. Passing zero results in a random selection. 30 | // fns is the set of functions that will be called to receive packets. 31 | Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) 32 | 33 | // Close closes the Bind listener. 34 | // All fns returned by Open must return net.ErrClosed after a call to Close. 35 | Close() error 36 | 37 | // SetMark sets the mark for each packet sent through this Bind. 38 | // This mark is passed to the kernel as the socket option SO_MARK. 39 | SetMark(mark uint32) error 40 | 41 | // Send writes a packet b to address ep. 42 | Send(b []byte, ep Endpoint) error 43 | 44 | // ParseEndpoint creates a new endpoint from a string. 45 | ParseEndpoint(s string) (Endpoint, error) 46 | } 47 | 48 | type Logger struct { 49 | Verbosef func(format string, args ...any) 50 | Errorf func(format string, args ...any) 51 | } 52 | 53 | // BindSocketToInterface is implemented by Bind objects that support being 54 | // tied to a single network interface. Used by wireguard-windows. 55 | type BindSocketToInterface interface { 56 | BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error 57 | BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error 58 | } 59 | 60 | // PeekLookAtSocketFd is implemented by Bind objects that support having their 61 | // file descriptor peeked at. Used by wireguard-android. 62 | type PeekLookAtSocketFd interface { 63 | PeekLookAtSocketFd4() (fd int, err error) 64 | PeekLookAtSocketFd6() (fd int, err error) 65 | } 66 | 67 | // An Endpoint maintains the source/destination caching for a peer. 68 | // 69 | // dst: the remote address of a peer ("endpoint" in uapi terminology) 70 | // src: the local address from which datagrams originate going to the peer 71 | type Endpoint interface { 72 | ClearSrc() // clears the source address 73 | SrcToString() string // returns the local source address (ip:port) 74 | DstToString() string // returns the destination address (ip:port) 75 | DstToBytes() []byte // used for mac2 cookie calculations 76 | DstIP() netip.Addr 77 | SrcIP() netip.Addr 78 | } 79 | 80 | var ( 81 | ErrBindAlreadyOpen = errors.New("bind is already open") 82 | ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") 83 | ) 84 | 85 | func (fn ReceiveFunc) PrettyName() string { 86 | name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() 87 | // 0. cheese/taco.beansIPv6.func12.func21218-fm 88 | name = strings.TrimSuffix(name, "-fm") 89 | // 1. cheese/taco.beansIPv6.func12.func21218 90 | if idx := strings.LastIndexByte(name, '/'); idx != -1 { 91 | name = name[idx+1:] 92 | // 2. taco.beansIPv6.func12.func21218 93 | } 94 | for { 95 | var idx int 96 | for idx = len(name) - 1; idx >= 0; idx-- { 97 | if name[idx] < '0' || name[idx] > '9' { 98 | break 99 | } 100 | } 101 | if idx == len(name)-1 { 102 | break 103 | } 104 | const dotFunc = ".func" 105 | if !strings.HasSuffix(name[:idx+1], dotFunc) { 106 | break 107 | } 108 | name = name[:idx+1-len(dotFunc)] 109 | // 3. taco.beansIPv6.func12 110 | // 4. taco.beansIPv6 111 | } 112 | if idx := strings.LastIndexByte(name, '.'); idx != -1 { 113 | name = name[idx+1:] 114 | // 5. beansIPv6 115 | } 116 | if name == "" { 117 | return fmt.Sprintf("%p", fn) 118 | } 119 | if strings.HasSuffix(name, "IPv4") { 120 | return "v4" 121 | } 122 | if strings.HasSuffix(name, "IPv6") { 123 | return "v6" 124 | } 125 | return name 126 | } 127 | -------------------------------------------------------------------------------- /device/noise_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "bytes" 10 | "encoding/binary" 11 | "testing" 12 | 13 | "golang.zx2c4.com/wireguard/conn" 14 | "golang.zx2c4.com/wireguard/tun/tuntest" 15 | ) 16 | 17 | func TestCurveWrappers(t *testing.T) { 18 | sk1, err := newPrivateKey() 19 | assertNil(t, err) 20 | 21 | sk2, err := newPrivateKey() 22 | assertNil(t, err) 23 | 24 | pk1 := sk1.publicKey() 25 | pk2 := sk2.publicKey() 26 | 27 | ss1, err1 := sk1.sharedSecret(pk2) 28 | ss2, err2 := sk2.sharedSecret(pk1) 29 | 30 | if ss1 != ss2 || err1 != nil || err2 != nil { 31 | t.Fatal("Failed to compute shared secet") 32 | } 33 | } 34 | 35 | func randDevice(t *testing.T) *Device { 36 | sk, err := newPrivateKey() 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | tun := tuntest.NewChannelTUN() 41 | logger := NewLogger(LogLevelError, "") 42 | device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger, make(chan HandshakeState)) 43 | device.SetPrivateKey(sk) 44 | return device 45 | } 46 | 47 | func assertNil(t *testing.T, err error) { 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | } 52 | 53 | func assertEqual(t *testing.T, a, b []byte) { 54 | if !bytes.Equal(a, b) { 55 | t.Fatal(a, "!=", b) 56 | } 57 | } 58 | 59 | func TestNoiseHandshake(t *testing.T) { 60 | dev1 := randDevice(t) 61 | dev2 := randDevice(t) 62 | 63 | defer dev1.Close() 64 | defer dev2.Close() 65 | 66 | peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) 67 | if err != nil { 68 | t.Fatal(err) 69 | } 70 | peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | peer1.Start() 75 | peer2.Start() 76 | 77 | assertEqual( 78 | t, 79 | peer1.handshake.precomputedStaticStatic[:], 80 | peer2.handshake.precomputedStaticStatic[:], 81 | ) 82 | 83 | /* simulate handshake */ 84 | 85 | // initiation message 86 | 87 | t.Log("exchange initiation message") 88 | 89 | msg1, err := dev1.CreateMessageInitiation(peer2) 90 | assertNil(t, err) 91 | 92 | packet := make([]byte, 0, 256) 93 | writer := bytes.NewBuffer(packet) 94 | err = binary.Write(writer, binary.LittleEndian, msg1) 95 | assertNil(t, err) 96 | peer := dev2.ConsumeMessageInitiation(msg1) 97 | if peer == nil { 98 | t.Fatal("handshake failed at initiation message") 99 | } 100 | 101 | assertEqual( 102 | t, 103 | peer1.handshake.chainKey[:], 104 | peer2.handshake.chainKey[:], 105 | ) 106 | 107 | assertEqual( 108 | t, 109 | peer1.handshake.hash[:], 110 | peer2.handshake.hash[:], 111 | ) 112 | 113 | // response message 114 | 115 | t.Log("exchange response message") 116 | 117 | msg2, err := dev2.CreateMessageResponse(peer1) 118 | assertNil(t, err) 119 | 120 | peer = dev1.ConsumeMessageResponse(msg2) 121 | if peer == nil { 122 | t.Fatal("handshake failed at response message") 123 | } 124 | 125 | assertEqual( 126 | t, 127 | peer1.handshake.chainKey[:], 128 | peer2.handshake.chainKey[:], 129 | ) 130 | 131 | assertEqual( 132 | t, 133 | peer1.handshake.hash[:], 134 | peer2.handshake.hash[:], 135 | ) 136 | 137 | // key pairs 138 | 139 | t.Log("deriving keys") 140 | 141 | err = peer1.BeginSymmetricSession() 142 | if err != nil { 143 | t.Fatal("failed to derive keypair for peer 1", err) 144 | } 145 | 146 | err = peer2.BeginSymmetricSession() 147 | if err != nil { 148 | t.Fatal("failed to derive keypair for peer 2", err) 149 | } 150 | 151 | key1 := peer1.keypairs.next.Load() 152 | key2 := peer2.keypairs.current 153 | 154 | // encrypting / decryption test 155 | 156 | t.Log("test key pairs") 157 | 158 | func() { 159 | testMsg := []byte("wireguard test message 1") 160 | var err error 161 | var out []byte 162 | var nonce [12]byte 163 | out = key1.send.Seal(out, nonce[:], testMsg, nil) 164 | out, err = key2.receive.Open(out[:0], nonce[:], out, nil) 165 | assertNil(t, err) 166 | assertEqual(t, out, testMsg) 167 | }() 168 | 169 | func() { 170 | testMsg := []byte("wireguard test message 2") 171 | var err error 172 | var out []byte 173 | var nonce [12]byte 174 | out = key2.send.Seal(out, nonce[:], testMsg, nil) 175 | out, err = key1.receive.Open(out[:0], nonce[:], out, nil) 176 | assertNil(t, err) 177 | assertEqual(t, out, testMsg) 178 | }() 179 | } 180 | -------------------------------------------------------------------------------- /conn/tcp_tls_utils.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022. Proton AG 3 | * 4 | * This file is part of ProtonVPN. 5 | * 6 | * ProtonVPN is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * ProtonVPN is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with ProtonVPN. If not, see . 18 | */ 19 | 20 | package conn 21 | 22 | import ( 23 | "bytes" 24 | "encoding/binary" 25 | "errors" 26 | ) 27 | 28 | var wgDataPrefix = []byte{4, 0, 0, 0} 29 | var wgDataHeaderSize = 16 30 | var wgDataPrefixSize = 8 // Wireguard data header without counter 31 | 32 | var tunSafeHeaderSize = 2 33 | var tunSafeNormalType = uint8(0b00) 34 | var tunSafeDataType = uint8(0b10) 35 | 36 | type TunSafeData struct { 37 | wgSendPrefix []byte 38 | wgSendCount uint64 39 | wgRecvPrefix []byte 40 | wgRecvCount uint64 41 | } 42 | 43 | 44 | func NewTunSafeData() *TunSafeData { 45 | return &TunSafeData{ 46 | wgRecvPrefix: make([]byte, 8), 47 | wgSendPrefix: make([]byte, 8), 48 | } 49 | } 50 | 51 | // Returns (type, size) 52 | func parseTunSafeHeader(header []byte) (byte, int) { 53 | tunSafeType := header[0] >> 6 54 | size := (int(header[0])&0b00111111)<<8 | int(header[1]) 55 | return tunSafeType, size 56 | } 57 | 58 | func (tunSafe *TunSafeData) clear() { 59 | tunSafe.wgSendCount = 0 60 | tunSafe.wgRecvCount = 0 61 | } 62 | 63 | func (tunSafe *TunSafeData) writeWgHeader(wgPacket []byte) { 64 | buffer := new(bytes.Buffer) 65 | buffer.Grow(len(tunSafe.wgRecvPrefix) + binary.Size(tunSafe.wgRecvCount)) 66 | buffer.Write(tunSafe.wgRecvPrefix) 67 | _ = binary.Write(buffer, binary.LittleEndian, tunSafe.wgRecvCount) 68 | copy(wgPacket, buffer.Bytes()) 69 | } 70 | 71 | func (tunSafe *TunSafeData) prepareWgPacket(tunSafeType byte, payloadSize int) ([]byte, int, error) { 72 | var wgPacket []byte 73 | offset := 0 74 | switch tunSafeType { 75 | case tunSafeNormalType: 76 | wgPacket = make([]byte, payloadSize) 77 | case tunSafeDataType: 78 | offset = wgDataHeaderSize 79 | wgPacket = make([]byte, payloadSize+offset) 80 | tunSafe.writeWgHeader(wgPacket) 81 | default: 82 | return nil, 0, errors.New("StdNetBindTcp: unknown TunSafe type") 83 | } 84 | return wgPacket, offset, nil 85 | } 86 | 87 | func (tunSafe *TunSafeData) onRecvPacket(tunSafeType byte, wgPacket []byte) { 88 | if tunSafeType == tunSafeNormalType { 89 | isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix) 90 | if isWgDataPacket { 91 | copy(tunSafe.wgRecvPrefix, wgPacket[:wgDataPrefixSize]) 92 | countBuffer := bytes.NewBuffer(wgPacket[wgDataPrefixSize:wgDataHeaderSize]) 93 | _ = binary.Read(countBuffer, binary.LittleEndian, &tunSafe.wgRecvCount) 94 | } 95 | } 96 | tunSafe.wgRecvCount++ 97 | } 98 | 99 | func (tunSafe *TunSafeData) wgToTunSafe(wgPacket []byte) []byte { 100 | wgLen := len(wgPacket) 101 | if wgLen < wgDataHeaderSize { 102 | return wgToTunSafeNormal(wgPacket) 103 | } 104 | wgPrefix := wgPacket[:wgDataPrefixSize] 105 | var wgCount uint64 106 | _ = binary.Read(bytes.NewReader(wgPacket[wgDataPrefixSize:wgDataHeaderSize]), binary.LittleEndian, &wgCount) 107 | prefixMatch := bytes.Equal(wgPrefix, tunSafe.wgSendPrefix) 108 | if prefixMatch && wgCount == tunSafe.wgSendCount+1 { 109 | tunSafe.wgSendCount += 1 110 | return wgToTunSafeData(wgPacket) 111 | } else { 112 | isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix) 113 | if isWgDataPacket { 114 | tunSafe.wgSendPrefix = wgPrefix 115 | tunSafe.wgSendCount = wgCount 116 | } 117 | return wgToTunSafeNormal(wgPacket) 118 | } 119 | } 120 | 121 | func wgToTunSafeNormal(wgPacket []byte) []byte { 122 | payloadSize := len(wgPacket) 123 | result := make([]byte, payloadSize+tunSafeHeaderSize) 124 | 125 | // Tunsafe normal header 126 | result[0] = uint8(payloadSize >> 8) 127 | result[1] = uint8(payloadSize & 0xff) 128 | 129 | // Full packet 130 | copy(result[tunSafeHeaderSize:], wgPacket) 131 | 132 | return result 133 | } 134 | 135 | func wgToTunSafeData(wgPacket []byte) []byte { 136 | payloadSize := len(wgPacket) - wgDataHeaderSize 137 | result := make([]byte, payloadSize+tunSafeHeaderSize) 138 | 139 | // TunSafe data header 140 | result[0] = uint8(0b10<<6 | payloadSize>>8) 141 | result[1] = uint8(payloadSize & 0xff) 142 | 143 | // Packet without header 144 | copy(result[tunSafeHeaderSize:], wgPacket[wgDataHeaderSize:]) 145 | 146 | return result 147 | } 148 | -------------------------------------------------------------------------------- /device/statemanager_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022. Proton AG 3 | * 4 | * This file is part of ProtonVPN. 5 | * 6 | * ProtonVPN is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * ProtonVPN is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with ProtonVPN. If not, see . 18 | */ 19 | 20 | package device 21 | 22 | import ( 23 | "errors" 24 | "github.com/stretchr/testify/assert" 25 | "testing" 26 | "time" 27 | ) 28 | 29 | var timeMs int64 = 0 30 | var mockDevice MockDevice 31 | var manager *WireGuardStateManager 32 | var lastState WireGuardState 33 | 34 | type MockDevice struct { 35 | isUp bool 36 | upCount int 37 | } 38 | 39 | func (dev *MockDevice) Up() error { 40 | dev.isUp = true 41 | dev.upCount++ 42 | return nil 43 | } 44 | 45 | func (dev *MockDevice) Down() error { 46 | dev.isUp = false 47 | return nil 48 | } 49 | 50 | func setup() { 51 | timeMs = 0 52 | timeNow = func() time.Time { return time.UnixMilli(timeMs) } 53 | mockDevice.isUp = false 54 | 55 | manager = NewWireGuardStateManager(NewLogger(LogLevelVerbose, ""), false) 56 | manager.Start(&mockDevice) 57 | lastState = WireGuardDisabled 58 | go func() { 59 | for lastState != -1 { 60 | lastState = manager.GetState() 61 | } 62 | }() 63 | } 64 | 65 | func setdown() { 66 | manager.Close() 67 | } 68 | 69 | func TestWireGuardStateManager_shouldRestart(t *testing.T) { 70 | assert := assert.New(t) 71 | setup() 72 | defer setdown() 73 | 74 | assert.Equal(initialRestartDelay, manager.nextRestartDelay) 75 | 76 | assert.Equal(false, manager.shouldRestart()) 77 | timeMs += initialRestartDelay.Milliseconds() 78 | assert.Equal(false, manager.shouldRestart()) 79 | timeMs += 1 80 | assert.Equal(true, manager.shouldRestart()) 81 | 82 | assert.Equal(2*initialRestartDelay, manager.nextRestartDelay) 83 | assert.Equal(false, manager.shouldRestart()) 84 | timeMs += 2 * initialRestartDelay.Milliseconds() 85 | assert.Equal(false, manager.shouldRestart()) 86 | timeMs += 1 87 | assert.Equal(true, manager.shouldRestart()) 88 | 89 | timeMs += resetRestartDelay.Milliseconds() + 1 90 | assert.Equal(true, manager.shouldRestart()) 91 | assert.Equal(initialRestartDelay, manager.nextRestartDelay) 92 | } 93 | 94 | func TestWireGuardStateManager_networkStartsAndStopsDevice(t *testing.T) { 95 | assert := assert.New(t) 96 | setup() 97 | defer setdown() 98 | 99 | assert.Equal(false, mockDevice.isUp) 100 | manager.SetNetworkAvailable(true) 101 | time.Sleep(time.Millisecond) // Poor substitute for advanceUntilIdle, make sure goroutines finish before checking 102 | assert.Equal(true, mockDevice.isUp) 103 | assert.Equal(WireGuardConnecting, lastState) 104 | manager.SetNetworkAvailable(false) 105 | time.Sleep(time.Millisecond) 106 | assert.Equal(WireGuardWaitingForNetwork, lastState) 107 | assert.Equal(false, mockDevice.isUp) 108 | } 109 | 110 | func TestWireGuardStateManager_happyConnectionPath(t *testing.T) { 111 | assert := assert.New(t) 112 | setup() 113 | defer setdown() 114 | 115 | manager.SetNetworkAvailable(true) 116 | time.Sleep(time.Millisecond) 117 | manager.HandshakeStateChan <- HandshakeSuccess 118 | time.Sleep(time.Millisecond) 119 | assert.Equal(WireGuardConnected, lastState) 120 | assert.Equal(true, mockDevice.isUp) 121 | } 122 | 123 | func TestWireGuardStateManager_handshakeFailCausesRestart(t *testing.T) { 124 | assert := assert.New(t) 125 | setup() 126 | defer setdown() 127 | 128 | manager.SetNetworkAvailable(true) 129 | time.Sleep(time.Millisecond) 130 | manager.HandshakeStateChan <- HandshakeFail 131 | time.Sleep(time.Millisecond) 132 | assert.Equal(WireGuardError, lastState) 133 | timeMs += initialRestartDelay.Milliseconds() + 1 134 | manager.HandshakeStateChan <- HandshakeFail 135 | time.Sleep(time.Millisecond) 136 | assert.Equal(WireGuardConnecting, lastState) 137 | assert.Equal(2, mockDevice.upCount) 138 | } 139 | 140 | func TestWireGuardStateManager_brokenPipeCausesRestart(t *testing.T) { 141 | assert := assert.New(t) 142 | setup() 143 | defer setdown() 144 | 145 | manager.SetNetworkAvailable(true) 146 | timeMs += initialRestartDelay.Milliseconds() + 1 147 | time.Sleep(time.Millisecond) 148 | manager.SocketErrChan <- errors.New("broken pipe") 149 | time.Sleep(time.Millisecond) 150 | assert.Equal(WireGuardConnecting, lastState) 151 | assert.Equal(2, mockDevice.upCount) 152 | } 153 | -------------------------------------------------------------------------------- /conn/server_name_utils/server_name_utils.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025. Proton AG 3 | * 4 | * This file is part of ProtonVPN. 5 | * 6 | * ProtonVPN is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * ProtonVPN is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU General Public License 17 | * along with ProtonVPN. If not, see . 18 | */ 19 | 20 | package server_name_utils 21 | 22 | import ( 23 | cryptoRand "crypto/rand" 24 | "math/big" 25 | "math/rand" 26 | "time" 27 | ) 28 | 29 | type ServerNameStrategy int 30 | 31 | const ( 32 | ServerNameRandom ServerNameStrategy = iota 33 | ServerNameTop 34 | ) 35 | 36 | var topLevelDomains = []string{"com", "net", "org", "it", "fr", "me", "ru", "cn", "es", "tr", "top", "xyz", "info"} 37 | 38 | var domains = []string{ 39 | "accounts.google.com", 40 | "activity.windows.com", 41 | "analytics.apis.mcafee.com", 42 | "android.apis.google.com", 43 | "android.googleapis.com", 44 | "api.account.samsung.com", 45 | "api.accounts.firefox.com", 46 | "api.accuweather.com", 47 | "api.amazon.com", 48 | "api.browser.yandex.net", 49 | "api.ipify.org", 50 | "api.onedrive.com", 51 | "api.reasonsecurity.com", 52 | "api.samsungcloud.com", 53 | "api.sec.intl.miui.com", 54 | "api.vk.com", 55 | "api.weather.com", 56 | "app-site-association.cdn-apple.com", 57 | "apps.mzstatic.com", 58 | "assets.msn.com", 59 | "backup.googleapis.com", 60 | "brave-core-ext.s3.brave.com", 61 | "caldav.calendar.yahoo.com", 62 | "cc-api-data.adobe.io", 63 | "cdn.ampproject.org", 64 | "cdn.cookielaw.org", 65 | "client.wns.windows.com", 66 | "cloudflare.com", 67 | "cloudflare-dns.com", 68 | "cloudflare-ech.com", 69 | "config.extension.grammarly.com", 70 | "connectivitycheck.android.com", 71 | "connectivitycheck.gstatic.com", 72 | "courier.push.apple.com", 73 | "crl.globalsign.com", 74 | "dc1-file.ksn.kaspersky-labs.com", 75 | "dl.google.com", 76 | "dns.google", 77 | "dns.quad9.net", 78 | "doh.cleanbrowsing.org", 79 | "doh.dns.apple.com", 80 | "doh.opendns.com", 81 | "doh.pub", 82 | "ds.kaspersky.com", 83 | "ecs.office.com", 84 | "edge.microsoft.com", 85 | "events.gfe.nvidia.com", 86 | "excess.duolingo.com", 87 | "firefox.settings.services.mozilla.com", 88 | "fonts.googleapis.com", 89 | "fonts.gstatic.com", 90 | "gateway-asset.icloud-content.com", 91 | "gateway.icloud.com", 92 | "gdmf.apple.com", 93 | "github.com", 94 | "go.microsoft.com", 95 | "go-updater.brave.com", 96 | "graph.microsoft.com", 97 | "gs-loc.apple.com", 98 | "gtglobal.intl.miui.com", 99 | "hcaptcha.com", 100 | "imap.gmail.com", 101 | "imap-mail.outlook.com", 102 | "imap.mail.yahoo.com", 103 | "in.appcenter.ms", 104 | "ipmcdn.avast.com", 105 | "itunes.apple.com", 106 | "loc.map.baidu.com", 107 | "login.live.com", 108 | "login.microsoftonline.com", 109 | "m.media-amazon.com", 110 | "mobile.events.data.microsoft.com", 111 | "mozilla.cloudflare-dns.com", 112 | "mtalk.google.com", 113 | "nimbus.bitdefender.net", 114 | "ocsp2.apple.com", 115 | "outlook.office365.com", 116 | "play-fe.googleapis.com", 117 | "play.googleapis.com", 118 | "play.samsungcloud.com", 119 | "raw.githubusercontent.com", 120 | "s3.amazonaws.com", 121 | "safebrowsing.googleapis.com", 122 | "s.alicdn.com", 123 | "self.events.data.microsoft.com", 124 | "settings-win.data.microsoft.com", 125 | "setup.icloud.com", 126 | "sirius.mwbsys.com", 127 | "spoc.norton.com", 128 | "ssl.gstatic.com", 129 | "translate.goo", 130 | "unpkg.com", 131 | "update.googleapis.com", 132 | "weatherapi.intl.xiaomi.com", 133 | "weatherkit.apple.com", 134 | "westus-0.in.applicationinsights.azure.com", 135 | "www.googleapis.com", 136 | "www.gstatic.com", 137 | "www.msftconnecttest.com", 138 | "www.msftncsi.com", 139 | "www.ntppool.org", 140 | "www.pool.ntp.org", 141 | "www.recaptcha.net", 142 | } 143 | 144 | var domainsSortedByHashes []HashedValue = sortValuesByHash(domains, crc32Hash) 145 | 146 | func ServerNameFor(strategy ServerNameStrategy, addr string) string { 147 | switch strategy { 148 | case ServerNameTop: 149 | return serverNameForAddr(addr) 150 | case ServerNameRandom: 151 | return randomServerName() 152 | default: 153 | return randomServerName() 154 | } 155 | } 156 | 157 | func serverNameForAddr(addr string) string { 158 | return findClosestValue(addr, domainsSortedByHashes, crc32Hash) 159 | } 160 | 161 | func randomServerName() string { 162 | charNum := int('z') - int('a') + 1 163 | size := 3 + randInt(10) 164 | name := make([]byte, size) 165 | for i := range name { 166 | name[i] = byte(int('a') + randInt(charNum)) 167 | } 168 | return string(name) + "." + randItem(topLevelDomains) 169 | } 170 | 171 | func randItem(list []string) string { 172 | return list[randInt(len(list))] 173 | } 174 | 175 | func randInt(n int) int { 176 | size, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(n))) 177 | if err == nil { 178 | return int(size.Int64()) 179 | } 180 | rand.Seed(time.Now().UnixNano()) 181 | return rand.Intn(n) 182 | } 183 | -------------------------------------------------------------------------------- /device/cookie.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "sync" 12 | "time" 13 | 14 | "golang.org/x/crypto/blake2s" 15 | "golang.org/x/crypto/chacha20poly1305" 16 | ) 17 | 18 | type CookieChecker struct { 19 | sync.RWMutex 20 | mac1 struct { 21 | key [blake2s.Size]byte 22 | } 23 | mac2 struct { 24 | secret [blake2s.Size]byte 25 | secretSet time.Time 26 | encryptionKey [chacha20poly1305.KeySize]byte 27 | } 28 | } 29 | 30 | type CookieGenerator struct { 31 | sync.RWMutex 32 | mac1 struct { 33 | key [blake2s.Size]byte 34 | } 35 | mac2 struct { 36 | cookie [blake2s.Size128]byte 37 | cookieSet time.Time 38 | hasLastMAC1 bool 39 | lastMAC1 [blake2s.Size128]byte 40 | encryptionKey [chacha20poly1305.KeySize]byte 41 | } 42 | } 43 | 44 | func (st *CookieChecker) Init(pk NoisePublicKey) { 45 | st.Lock() 46 | defer st.Unlock() 47 | 48 | // mac1 state 49 | 50 | func() { 51 | hash, _ := blake2s.New256(nil) 52 | hash.Write([]byte(WGLabelMAC1)) 53 | hash.Write(pk[:]) 54 | hash.Sum(st.mac1.key[:0]) 55 | }() 56 | 57 | // mac2 state 58 | 59 | func() { 60 | hash, _ := blake2s.New256(nil) 61 | hash.Write([]byte(WGLabelCookie)) 62 | hash.Write(pk[:]) 63 | hash.Sum(st.mac2.encryptionKey[:0]) 64 | }() 65 | 66 | st.mac2.secretSet = time.Time{} 67 | } 68 | 69 | func (st *CookieChecker) CheckMAC1(msg []byte) bool { 70 | st.RLock() 71 | defer st.RUnlock() 72 | 73 | size := len(msg) 74 | smac2 := size - blake2s.Size128 75 | smac1 := smac2 - blake2s.Size128 76 | 77 | var mac1 [blake2s.Size128]byte 78 | 79 | mac, _ := blake2s.New128(st.mac1.key[:]) 80 | mac.Write(msg[:smac1]) 81 | mac.Sum(mac1[:0]) 82 | 83 | return hmac.Equal(mac1[:], msg[smac1:smac2]) 84 | } 85 | 86 | func (st *CookieChecker) CheckMAC2(msg, src []byte) bool { 87 | st.RLock() 88 | defer st.RUnlock() 89 | 90 | if time.Since(st.mac2.secretSet) > CookieRefreshTime { 91 | return false 92 | } 93 | 94 | // derive cookie key 95 | 96 | var cookie [blake2s.Size128]byte 97 | func() { 98 | mac, _ := blake2s.New128(st.mac2.secret[:]) 99 | mac.Write(src) 100 | mac.Sum(cookie[:0]) 101 | }() 102 | 103 | // calculate mac of packet (including mac1) 104 | 105 | smac2 := len(msg) - blake2s.Size128 106 | 107 | var mac2 [blake2s.Size128]byte 108 | func() { 109 | mac, _ := blake2s.New128(cookie[:]) 110 | mac.Write(msg[:smac2]) 111 | mac.Sum(mac2[:0]) 112 | }() 113 | 114 | return hmac.Equal(mac2[:], msg[smac2:]) 115 | } 116 | 117 | func (st *CookieChecker) CreateReply( 118 | msg []byte, 119 | recv uint32, 120 | src []byte, 121 | ) (*MessageCookieReply, error) { 122 | st.RLock() 123 | 124 | // refresh cookie secret 125 | 126 | if time.Since(st.mac2.secretSet) > CookieRefreshTime { 127 | st.RUnlock() 128 | st.Lock() 129 | _, err := rand.Read(st.mac2.secret[:]) 130 | if err != nil { 131 | st.Unlock() 132 | return nil, err 133 | } 134 | st.mac2.secretSet = time.Now() 135 | st.Unlock() 136 | st.RLock() 137 | } 138 | 139 | // derive cookie 140 | 141 | var cookie [blake2s.Size128]byte 142 | func() { 143 | mac, _ := blake2s.New128(st.mac2.secret[:]) 144 | mac.Write(src) 145 | mac.Sum(cookie[:0]) 146 | }() 147 | 148 | // encrypt cookie 149 | 150 | size := len(msg) 151 | 152 | smac2 := size - blake2s.Size128 153 | smac1 := smac2 - blake2s.Size128 154 | 155 | reply := new(MessageCookieReply) 156 | reply.Type = MessageCookieReplyType 157 | reply.Receiver = recv 158 | 159 | _, err := rand.Read(reply.Nonce[:]) 160 | if err != nil { 161 | st.RUnlock() 162 | return nil, err 163 | } 164 | 165 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) 166 | xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2]) 167 | 168 | st.RUnlock() 169 | 170 | return reply, nil 171 | } 172 | 173 | func (st *CookieGenerator) Init(pk NoisePublicKey) { 174 | st.Lock() 175 | defer st.Unlock() 176 | 177 | func() { 178 | hash, _ := blake2s.New256(nil) 179 | hash.Write([]byte(WGLabelMAC1)) 180 | hash.Write(pk[:]) 181 | hash.Sum(st.mac1.key[:0]) 182 | }() 183 | 184 | func() { 185 | hash, _ := blake2s.New256(nil) 186 | hash.Write([]byte(WGLabelCookie)) 187 | hash.Write(pk[:]) 188 | hash.Sum(st.mac2.encryptionKey[:0]) 189 | }() 190 | 191 | st.mac2.cookieSet = time.Time{} 192 | } 193 | 194 | func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { 195 | st.Lock() 196 | defer st.Unlock() 197 | 198 | if !st.mac2.hasLastMAC1 { 199 | return false 200 | } 201 | 202 | var cookie [blake2s.Size128]byte 203 | 204 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) 205 | _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) 206 | if err != nil { 207 | return false 208 | } 209 | 210 | st.mac2.cookieSet = time.Now() 211 | st.mac2.cookie = cookie 212 | return true 213 | } 214 | 215 | func (st *CookieGenerator) AddMacs(msg []byte) { 216 | size := len(msg) 217 | 218 | smac2 := size - blake2s.Size128 219 | smac1 := smac2 - blake2s.Size128 220 | 221 | mac1 := msg[smac1:smac2] 222 | mac2 := msg[smac2:] 223 | 224 | st.Lock() 225 | defer st.Unlock() 226 | 227 | // set mac1 228 | 229 | func() { 230 | mac, _ := blake2s.New128(st.mac1.key[:]) 231 | mac.Write(msg[:smac1]) 232 | mac.Sum(mac1[:0]) 233 | }() 234 | copy(st.mac2.lastMAC1[:], mac1) 235 | st.mac2.hasLastMAC1 = true 236 | 237 | // set mac2 238 | 239 | if time.Since(st.mac2.cookieSet) > CookieRefreshTime { 240 | return 241 | } 242 | 243 | func() { 244 | mac, _ := blake2s.New128(st.mac2.cookie[:]) 245 | mac.Write(msg[:smac2]) 246 | mac.Sum(mac2[:0]) 247 | }() 248 | } 249 | -------------------------------------------------------------------------------- /device/sticky_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | * 5 | * This implements userspace semantics of "sticky sockets", modeled after 6 | * WireGuard's kernelspace implementation. This is more or less a straight port 7 | * of the sticky-sockets.c example code: 8 | * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c 9 | * 10 | * Currently there is no way to achieve this within the net package: 11 | * See e.g. https://github.com/golang/go/issues/17930 12 | * So this code is remains platform dependent. 13 | */ 14 | 15 | package device 16 | 17 | import ( 18 | "sync" 19 | "unsafe" 20 | 21 | "golang.org/x/sys/unix" 22 | 23 | "golang.zx2c4.com/wireguard/conn" 24 | "golang.zx2c4.com/wireguard/rwcancel" 25 | ) 26 | 27 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 28 | if _, ok := bind.(*conn.LinuxSocketBind); !ok { 29 | return nil, nil 30 | } 31 | 32 | netlinkSock, err := createNetlinkRouteSocket() 33 | if err != nil { 34 | return nil, err 35 | } 36 | netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) 37 | if err != nil { 38 | unix.Close(netlinkSock) 39 | return nil, err 40 | } 41 | 42 | go device.routineRouteListener(bind, netlinkSock, netlinkCancel) 43 | 44 | return netlinkCancel, nil 45 | } 46 | 47 | func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { 48 | type peerEndpointPtr struct { 49 | peer *Peer 50 | endpoint *conn.Endpoint 51 | } 52 | var reqPeer map[uint32]peerEndpointPtr 53 | var reqPeerLock sync.Mutex 54 | 55 | defer netlinkCancel.Close() 56 | defer unix.Close(netlinkSock) 57 | 58 | for msg := make([]byte, 1<<16); ; { 59 | var err error 60 | var msgn int 61 | for { 62 | msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) 63 | if err == nil || !rwcancel.RetryAfterError(err) { 64 | break 65 | } 66 | if !netlinkCancel.ReadyRead() { 67 | return 68 | } 69 | } 70 | if err != nil { 71 | return 72 | } 73 | 74 | for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { 75 | 76 | hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) 77 | 78 | if uint(hdr.Len) > uint(len(remain)) { 79 | break 80 | } 81 | 82 | switch hdr.Type { 83 | case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: 84 | if hdr.Seq <= MaxPeers && hdr.Seq > 0 { 85 | if uint(len(remain)) < uint(hdr.Len) { 86 | break 87 | } 88 | if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { 89 | attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] 90 | for { 91 | if uint(len(attr)) < uint(unix.SizeofRtAttr) { 92 | break 93 | } 94 | attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) 95 | if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { 96 | break 97 | } 98 | if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { 99 | ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) 100 | reqPeerLock.Lock() 101 | if reqPeer == nil { 102 | reqPeerLock.Unlock() 103 | break 104 | } 105 | pePtr, ok := reqPeer[hdr.Seq] 106 | reqPeerLock.Unlock() 107 | if !ok { 108 | break 109 | } 110 | pePtr.peer.Lock() 111 | if &pePtr.peer.endpoint != pePtr.endpoint { 112 | pePtr.peer.Unlock() 113 | break 114 | } 115 | if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx { 116 | pePtr.peer.Unlock() 117 | break 118 | } 119 | pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc() 120 | pePtr.peer.Unlock() 121 | } 122 | attr = attr[attrhdr.Len:] 123 | } 124 | } 125 | break 126 | } 127 | reqPeerLock.Lock() 128 | reqPeer = make(map[uint32]peerEndpointPtr) 129 | reqPeerLock.Unlock() 130 | go func() { 131 | device.peers.RLock() 132 | i := uint32(1) 133 | for _, peer := range device.peers.keyMap { 134 | peer.RLock() 135 | if peer.endpoint == nil { 136 | peer.RUnlock() 137 | continue 138 | } 139 | nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint) 140 | if nativeEP == nil { 141 | peer.RUnlock() 142 | continue 143 | } 144 | if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { 145 | peer.RUnlock() 146 | break 147 | } 148 | nlmsg := struct { 149 | hdr unix.NlMsghdr 150 | msg unix.RtMsg 151 | dsthdr unix.RtAttr 152 | dst [4]byte 153 | srchdr unix.RtAttr 154 | src [4]byte 155 | markhdr unix.RtAttr 156 | mark uint32 157 | }{ 158 | unix.NlMsghdr{ 159 | Type: uint16(unix.RTM_GETROUTE), 160 | Flags: unix.NLM_F_REQUEST, 161 | Seq: i, 162 | }, 163 | unix.RtMsg{ 164 | Family: unix.AF_INET, 165 | Dst_len: 32, 166 | Src_len: 32, 167 | }, 168 | unix.RtAttr{ 169 | Len: 8, 170 | Type: unix.RTA_DST, 171 | }, 172 | nativeEP.Dst4().Addr, 173 | unix.RtAttr{ 174 | Len: 8, 175 | Type: unix.RTA_SRC, 176 | }, 177 | nativeEP.Src4().Src, 178 | unix.RtAttr{ 179 | Len: 8, 180 | Type: unix.RTA_MARK, 181 | }, 182 | device.net.fwmark, 183 | } 184 | nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) 185 | reqPeerLock.Lock() 186 | reqPeer[i] = peerEndpointPtr{ 187 | peer: peer, 188 | endpoint: &peer.endpoint, 189 | } 190 | reqPeerLock.Unlock() 191 | peer.RUnlock() 192 | i++ 193 | _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) 194 | if err != nil { 195 | break 196 | } 197 | } 198 | device.peers.RUnlock() 199 | }() 200 | } 201 | remain = remain[hdr.Len:] 202 | } 203 | } 204 | } 205 | 206 | func createNetlinkRouteSocket() (int, error) { 207 | sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) 208 | if err != nil { 209 | return -1, err 210 | } 211 | saddr := &unix.SockaddrNetlink{ 212 | Family: unix.AF_NETLINK, 213 | Groups: unix.RTMGRP_IPV4_ROUTE, 214 | } 215 | err = unix.Bind(sock, saddr) 216 | if err != nil { 217 | unix.Close(sock) 218 | return -1, err 219 | } 220 | return sock, nil 221 | } 222 | -------------------------------------------------------------------------------- /conn/bind_std.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "net" 12 | "net/netip" 13 | "sync" 14 | "syscall" 15 | ) 16 | 17 | // StdNetBind is meant to be a temporary solution on platforms for which 18 | // the sticky socket / source caching behavior has not yet been implemented. 19 | // It uses the Go's net package to implement networking. 20 | // See LinuxSocketBind for a proper implementation on the Linux platform. 21 | type StdNetBind struct { 22 | mu sync.Mutex // protects following fields 23 | ipv4 *net.UDPConn 24 | ipv6 *net.UDPConn 25 | blackhole4 bool 26 | blackhole6 bool 27 | 28 | protectSocket func(fd int) int 29 | } 30 | 31 | func NewStdNetBind(protectSocket func(fd int) int) Bind { 32 | return &StdNetBind{protectSocket: protectSocket} 33 | } 34 | 35 | type StdNetEndpoint netip.AddrPort 36 | 37 | var ( 38 | _ Bind = (*StdNetBind)(nil) 39 | _ Endpoint = StdNetEndpoint{} 40 | ) 41 | 42 | func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 43 | e, err := netip.ParseAddrPort(s) 44 | return asEndpoint(e), err 45 | } 46 | 47 | func (StdNetEndpoint) ClearSrc() {} 48 | 49 | func (e StdNetEndpoint) DstIP() netip.Addr { 50 | return (netip.AddrPort)(e).Addr() 51 | } 52 | 53 | func (e StdNetEndpoint) SrcIP() netip.Addr { 54 | return netip.Addr{} // not supported 55 | } 56 | 57 | func (e StdNetEndpoint) DstToBytes() []byte { 58 | b, _ := (netip.AddrPort)(e).MarshalBinary() 59 | return b 60 | } 61 | 62 | func (e StdNetEndpoint) DstToString() string { 63 | return (netip.AddrPort)(e).String() 64 | } 65 | 66 | func (e StdNetEndpoint) SrcToString() string { 67 | return "" 68 | } 69 | 70 | func listenNet(network string, port int) (*net.UDPConn, int, error) { 71 | conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) 72 | if err != nil { 73 | return nil, 0, err 74 | } 75 | 76 | // Retrieve port. 77 | laddr := conn.LocalAddr() 78 | uaddr, err := net.ResolveUDPAddr( 79 | laddr.Network(), 80 | laddr.String(), 81 | ) 82 | if err != nil { 83 | return nil, 0, err 84 | } 85 | return conn, uaddr.Port, nil 86 | } 87 | 88 | func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 89 | bind.mu.Lock() 90 | defer bind.mu.Unlock() 91 | 92 | var err error 93 | var tries int 94 | 95 | if bind.ipv4 != nil || bind.ipv6 != nil { 96 | return nil, 0, ErrBindAlreadyOpen 97 | } 98 | 99 | // Attempt to open ipv4 and ipv6 listeners on the same port. 100 | // If uport is 0, we can retry on failure. 101 | again: 102 | port := int(uport) 103 | var ipv4, ipv6 *net.UDPConn 104 | 105 | ipv4, port, err = listenNet("udp4", port) 106 | if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 107 | return nil, 0, err 108 | } 109 | 110 | // Listen on the same port as we're using for ipv4. 111 | ipv6, port, err = listenNet("udp6", port) 112 | if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 113 | ipv4.Close() 114 | tries++ 115 | goto again 116 | } 117 | if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 118 | ipv4.Close() 119 | return nil, 0, err 120 | } 121 | var fns []ReceiveFunc 122 | if ipv4 != nil { 123 | fns = append(fns, bind.makeReceiveIPv4(ipv4)) 124 | bind.ipv4 = ipv4 125 | } 126 | if ipv6 != nil { 127 | fns = append(fns, bind.makeReceiveIPv6(ipv6)) 128 | bind.ipv6 = ipv6 129 | } 130 | if len(fns) == 0 { 131 | return nil, 0, syscall.EAFNOSUPPORT 132 | } 133 | 134 | err = bind.protectSockets() 135 | if err != nil { 136 | return nil, 0, err 137 | } 138 | 139 | return fns, uint16(port), nil 140 | } 141 | 142 | func (bind *StdNetBind) protectSockets() error { 143 | if bind.ipv4 != nil { 144 | fd, err := bind.PeekLookAtSocketFd4() 145 | if err != nil { 146 | return err 147 | } 148 | status := bind.protectSocket(fd) 149 | if status < 0 { 150 | return fmt.Errorf("Failed to protect socket: status=%d", status) 151 | } 152 | } 153 | if bind.ipv6 != nil { 154 | fd, err := bind.PeekLookAtSocketFd6() 155 | if err != nil { 156 | return err 157 | } 158 | status := bind.protectSocket(fd) 159 | if status < 0 { 160 | return fmt.Errorf("Failed to protect socket: status=%d", status) 161 | } 162 | } 163 | return nil 164 | } 165 | 166 | func (bind *StdNetBind) Close() error { 167 | bind.mu.Lock() 168 | defer bind.mu.Unlock() 169 | 170 | var err1, err2 error 171 | if bind.ipv4 != nil { 172 | err1 = bind.ipv4.Close() 173 | bind.ipv4 = nil 174 | } 175 | if bind.ipv6 != nil { 176 | err2 = bind.ipv6.Close() 177 | bind.ipv6 = nil 178 | } 179 | bind.blackhole4 = false 180 | bind.blackhole6 = false 181 | if err1 != nil { 182 | return err1 183 | } 184 | return err2 185 | } 186 | 187 | func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { 188 | return func(buff []byte) (int, Endpoint, error) { 189 | n, endpoint, err := conn.ReadFromUDPAddrPort(buff) 190 | return n, asEndpoint(endpoint), err 191 | } 192 | } 193 | 194 | func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { 195 | return func(buff []byte) (int, Endpoint, error) { 196 | n, endpoint, err := conn.ReadFromUDPAddrPort(buff) 197 | return n, asEndpoint(endpoint), err 198 | } 199 | } 200 | 201 | func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { 202 | var err error 203 | nend, ok := endpoint.(StdNetEndpoint) 204 | if !ok { 205 | return ErrWrongEndpointType 206 | } 207 | addrPort := netip.AddrPort(nend) 208 | 209 | bind.mu.Lock() 210 | blackhole := bind.blackhole4 211 | conn := bind.ipv4 212 | if addrPort.Addr().Is6() { 213 | blackhole = bind.blackhole6 214 | conn = bind.ipv6 215 | } 216 | bind.mu.Unlock() 217 | 218 | if blackhole { 219 | return nil 220 | } 221 | if conn == nil { 222 | return syscall.EAFNOSUPPORT 223 | } 224 | _, err = conn.WriteToUDPAddrPort(buff, addrPort) 225 | return err 226 | } 227 | 228 | // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. 229 | // This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates, 230 | // but Endpoints are immutable, so we can re-use them. 231 | var endpointPool = sync.Pool{ 232 | New: func() any { 233 | return make(map[netip.AddrPort]Endpoint) 234 | }, 235 | } 236 | 237 | // asEndpoint returns an Endpoint containing ap. 238 | func asEndpoint(ap netip.AddrPort) Endpoint { 239 | m := endpointPool.Get().(map[netip.AddrPort]Endpoint) 240 | defer endpointPool.Put(m) 241 | e, ok := m[ap] 242 | if !ok { 243 | e = Endpoint(StdNetEndpoint(ap)) 244 | m[ap] = e 245 | } 246 | return e 247 | } 248 | -------------------------------------------------------------------------------- /tun/tun_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "os" 12 | "sync" 13 | "sync/atomic" 14 | "time" 15 | _ "unsafe" 16 | 17 | "golang.org/x/sys/windows" 18 | 19 | "golang.zx2c4.com/wintun" 20 | ) 21 | 22 | const ( 23 | rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 24 | spinloopRateThreshold = 800000000 / 8 // 800mbps 25 | spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 26 | ) 27 | 28 | type rateJuggler struct { 29 | current atomic.Uint64 30 | nextByteCount atomic.Uint64 31 | nextStartTime atomic.Int64 32 | changing atomic.Bool 33 | } 34 | 35 | type NativeTun struct { 36 | wt *wintun.Adapter 37 | name string 38 | handle windows.Handle 39 | rate rateJuggler 40 | session wintun.Session 41 | readWait windows.Handle 42 | events chan Event 43 | running sync.WaitGroup 44 | closeOnce sync.Once 45 | close atomic.Bool 46 | forcedMTU int 47 | } 48 | 49 | var ( 50 | WintunTunnelType = "WireGuard" 51 | WintunStaticRequestedGUID *windows.GUID 52 | ) 53 | 54 | //go:linkname procyield runtime.procyield 55 | func procyield(cycles uint32) 56 | 57 | //go:linkname nanotime runtime.nanotime 58 | func nanotime() int64 59 | 60 | // CreateTUN creates a Wintun interface with the given name. Should a Wintun 61 | // interface with the same name exist, it is reused. 62 | func CreateTUN(ifname string, mtu int) (Device, error) { 63 | return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) 64 | } 65 | 66 | // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and 67 | // a requested GUID. Should a Wintun interface with the same name exist, it is reused. 68 | func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { 69 | wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) 70 | if err != nil { 71 | return nil, fmt.Errorf("Error creating interface: %w", err) 72 | } 73 | 74 | forcedMTU := 1420 75 | if mtu > 0 { 76 | forcedMTU = mtu 77 | } 78 | 79 | tun := &NativeTun{ 80 | wt: wt, 81 | name: ifname, 82 | handle: windows.InvalidHandle, 83 | events: make(chan Event, 10), 84 | forcedMTU: forcedMTU, 85 | } 86 | 87 | tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB 88 | if err != nil { 89 | tun.wt.Close() 90 | close(tun.events) 91 | return nil, fmt.Errorf("Error starting session: %w", err) 92 | } 93 | tun.readWait = tun.session.ReadWaitEvent() 94 | return tun, nil 95 | } 96 | 97 | func (tun *NativeTun) Name() (string, error) { 98 | return tun.name, nil 99 | } 100 | 101 | func (tun *NativeTun) File() *os.File { 102 | return nil 103 | } 104 | 105 | func (tun *NativeTun) Events() <-chan Event { 106 | return tun.events 107 | } 108 | 109 | func (tun *NativeTun) Close() error { 110 | var err error 111 | tun.closeOnce.Do(func() { 112 | tun.close.Store(true) 113 | windows.SetEvent(tun.readWait) 114 | tun.running.Wait() 115 | tun.session.End() 116 | if tun.wt != nil { 117 | tun.wt.Close() 118 | } 119 | close(tun.events) 120 | }) 121 | return err 122 | } 123 | 124 | func (tun *NativeTun) MTU() (int, error) { 125 | return tun.forcedMTU, nil 126 | } 127 | 128 | // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. 129 | func (tun *NativeTun) ForceMTU(mtu int) { 130 | update := tun.forcedMTU != mtu 131 | tun.forcedMTU = mtu 132 | if update { 133 | tun.events <- EventMTUUpdate 134 | } 135 | } 136 | 137 | // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. 138 | 139 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 140 | tun.running.Add(1) 141 | defer tun.running.Done() 142 | retry: 143 | if tun.close.Load() { 144 | return 0, os.ErrClosed 145 | } 146 | start := nanotime() 147 | shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 148 | for { 149 | if tun.close.Load() { 150 | return 0, os.ErrClosed 151 | } 152 | packet, err := tun.session.ReceivePacket() 153 | switch err { 154 | case nil: 155 | packetSize := len(packet) 156 | copy(buff[offset:], packet) 157 | tun.session.ReleaseReceivePacket(packet) 158 | tun.rate.update(uint64(packetSize)) 159 | return packetSize, nil 160 | case windows.ERROR_NO_MORE_ITEMS: 161 | if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 162 | windows.WaitForSingleObject(tun.readWait, windows.INFINITE) 163 | goto retry 164 | } 165 | procyield(1) 166 | continue 167 | case windows.ERROR_HANDLE_EOF: 168 | return 0, os.ErrClosed 169 | case windows.ERROR_INVALID_DATA: 170 | return 0, errors.New("Send ring corrupt") 171 | } 172 | return 0, fmt.Errorf("Read failed: %w", err) 173 | } 174 | } 175 | 176 | func (tun *NativeTun) Flush() error { 177 | return nil 178 | } 179 | 180 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 181 | tun.running.Add(1) 182 | defer tun.running.Done() 183 | if tun.close.Load() { 184 | return 0, os.ErrClosed 185 | } 186 | 187 | packetSize := len(buff) - offset 188 | tun.rate.update(uint64(packetSize)) 189 | 190 | packet, err := tun.session.AllocateSendPacket(packetSize) 191 | if err == nil { 192 | copy(packet, buff[offset:]) 193 | tun.session.SendPacket(packet) 194 | return packetSize, nil 195 | } 196 | switch err { 197 | case windows.ERROR_HANDLE_EOF: 198 | return 0, os.ErrClosed 199 | case windows.ERROR_BUFFER_OVERFLOW: 200 | return 0, nil // Dropping when ring is full. 201 | } 202 | return 0, fmt.Errorf("Write failed: %w", err) 203 | } 204 | 205 | // LUID returns Windows interface instance ID. 206 | func (tun *NativeTun) LUID() uint64 { 207 | tun.running.Add(1) 208 | defer tun.running.Done() 209 | if tun.close.Load() { 210 | return 0 211 | } 212 | return tun.wt.LUID() 213 | } 214 | 215 | // RunningVersion returns the running version of the Wintun driver. 216 | func (tun *NativeTun) RunningVersion() (version uint32, err error) { 217 | return wintun.RunningVersion() 218 | } 219 | 220 | func (rate *rateJuggler) update(packetLen uint64) { 221 | now := nanotime() 222 | total := rate.nextByteCount.Add(packetLen) 223 | period := uint64(now - rate.nextStartTime.Load()) 224 | if period >= rateMeasurementGranularity { 225 | if !rate.changing.CompareAndSwap(false, true) { 226 | return 227 | } 228 | rate.nextStartTime.Store(now) 229 | rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) 230 | rate.nextByteCount.Store(0) 231 | rate.changing.Store(false) 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /device/cookie_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "testing" 10 | ) 11 | 12 | func TestCookieMAC1(t *testing.T) { 13 | // setup generator / checker 14 | 15 | var ( 16 | generator CookieGenerator 17 | checker CookieChecker 18 | ) 19 | 20 | sk, err := newPrivateKey() 21 | if err != nil { 22 | t.Fatal(err) 23 | } 24 | pk := sk.publicKey() 25 | 26 | generator.Init(pk) 27 | checker.Init(pk) 28 | 29 | // check mac1 30 | 31 | src := []byte{192, 168, 13, 37, 10, 10, 10} 32 | 33 | checkMAC1 := func(msg []byte) { 34 | generator.AddMacs(msg) 35 | if !checker.CheckMAC1(msg) { 36 | t.Fatal("MAC1 generation/verification failed") 37 | } 38 | if checker.CheckMAC2(msg, src) { 39 | t.Fatal("MAC2 generation/verification failed") 40 | } 41 | } 42 | 43 | checkMAC1([]byte{ 44 | 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, 45 | 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, 46 | 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, 47 | 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, 48 | 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, 49 | 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, 50 | 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, 51 | 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, 52 | 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, 53 | 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, 54 | 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, 55 | 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, 56 | 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, 57 | 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, 58 | 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, 59 | 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, 60 | }) 61 | 62 | checkMAC1([]byte{ 63 | 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, 64 | 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, 65 | 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, 66 | 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, 67 | 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, 68 | 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, 69 | 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, 70 | 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, 71 | }) 72 | 73 | checkMAC1([]byte{ 74 | 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, 75 | 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, 76 | 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, 77 | 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, 78 | }) 79 | 80 | // exchange cookie reply 81 | 82 | func() { 83 | msg := []byte{ 84 | 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, 85 | 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, 86 | 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, 87 | 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, 88 | 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, 89 | 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, 90 | 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, 91 | 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, 92 | 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, 93 | 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, 94 | 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, 95 | 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, 96 | 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, 97 | 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, 98 | 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, 99 | 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, 100 | } 101 | generator.AddMacs(msg) 102 | reply, err := checker.CreateReply(msg, 1377, src) 103 | if err != nil { 104 | t.Fatal("Failed to create cookie reply:", err) 105 | } 106 | if !generator.ConsumeReply(reply) { 107 | t.Fatal("Failed to consume cookie reply") 108 | } 109 | }() 110 | 111 | // check mac2 112 | 113 | checkMAC2 := func(msg []byte) { 114 | generator.AddMacs(msg) 115 | 116 | if !checker.CheckMAC1(msg) { 117 | t.Fatal("MAC1 generation/verification failed") 118 | } 119 | if !checker.CheckMAC2(msg, src) { 120 | t.Fatal("MAC2 generation/verification failed") 121 | } 122 | 123 | msg[5] ^= 0x20 124 | 125 | if checker.CheckMAC1(msg) { 126 | t.Fatal("MAC1 generation/verification failed") 127 | } 128 | if checker.CheckMAC2(msg, src) { 129 | t.Fatal("MAC2 generation/verification failed") 130 | } 131 | 132 | msg[5] ^= 0x20 133 | 134 | srcBad1 := []byte{192, 168, 13, 37, 40, 1} 135 | if checker.CheckMAC2(msg, srcBad1) { 136 | t.Fatal("MAC2 generation/verification failed") 137 | } 138 | 139 | srcBad2 := []byte{192, 168, 13, 38, 40, 1} 140 | if checker.CheckMAC2(msg, srcBad2) { 141 | t.Fatal("MAC2 generation/verification failed") 142 | } 143 | } 144 | 145 | checkMAC2([]byte{ 146 | 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, 147 | 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, 148 | 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, 149 | 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, 150 | 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, 151 | 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, 152 | 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, 153 | 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, 154 | }) 155 | 156 | checkMAC2([]byte{ 157 | 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, 158 | 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, 159 | 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, 160 | 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, 161 | 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, 162 | 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, 163 | 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, 164 | 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, 165 | 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, 166 | 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, 167 | 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, 168 | 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, 169 | 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, 170 | 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, 171 | 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, 172 | 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, 173 | 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, 174 | 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, 175 | 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, 176 | 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, 177 | 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, 178 | 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, 179 | 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, 180 | 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, 181 | 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, 182 | 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, 183 | 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, 184 | 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, 185 | 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, 186 | 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, 187 | 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, 188 | 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, 189 | }) 190 | } 191 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "fmt" 12 | "os" 13 | "os/signal" 14 | "runtime" 15 | "strconv" 16 | "syscall" 17 | 18 | "golang.zx2c4.com/wireguard/conn" 19 | "golang.zx2c4.com/wireguard/device" 20 | "golang.zx2c4.com/wireguard/ipc" 21 | "golang.zx2c4.com/wireguard/tun" 22 | ) 23 | 24 | const ( 25 | ExitSetupSuccess = 0 26 | ExitSetupFailed = 1 27 | ) 28 | 29 | const ( 30 | ENV_WG_TUN_FD = "WG_TUN_FD" 31 | ENV_WG_UAPI_FD = "WG_UAPI_FD" 32 | ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" 33 | ) 34 | 35 | func printUsage() { 36 | fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) 37 | } 38 | 39 | func warning() { 40 | switch runtime.GOOS { 41 | case "linux", "freebsd", "openbsd": 42 | if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { 43 | return 44 | } 45 | default: 46 | return 47 | } 48 | 49 | fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐") 50 | fmt.Fprintln(os.Stderr, "│ │") 51 | fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │") 52 | fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │") 53 | fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │") 54 | fmt.Fprintln(os.Stderr, "│ please visit: │") 55 | fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │") 56 | fmt.Fprintln(os.Stderr, "│ │") 57 | fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘") 58 | } 59 | 60 | func main() { 61 | if len(os.Args) == 2 && os.Args[1] == "--version" { 62 | fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld .\n", Version, runtime.GOOS, runtime.GOARCH) 63 | return 64 | } 65 | 66 | warning() 67 | 68 | var foreground bool 69 | var interfaceName string 70 | if len(os.Args) < 2 || len(os.Args) > 3 { 71 | printUsage() 72 | return 73 | } 74 | 75 | switch os.Args[1] { 76 | 77 | case "-f", "--foreground": 78 | foreground = true 79 | if len(os.Args) != 3 { 80 | printUsage() 81 | return 82 | } 83 | interfaceName = os.Args[2] 84 | 85 | default: 86 | foreground = false 87 | if len(os.Args) != 2 { 88 | printUsage() 89 | return 90 | } 91 | interfaceName = os.Args[1] 92 | } 93 | 94 | if !foreground { 95 | foreground = os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" 96 | } 97 | 98 | // get log level (default: info) 99 | 100 | logLevel := func() int { 101 | switch os.Getenv("LOG_LEVEL") { 102 | case "verbose", "debug": 103 | return device.LogLevelVerbose 104 | case "error": 105 | return device.LogLevelError 106 | case "silent": 107 | return device.LogLevelSilent 108 | } 109 | return device.LogLevelError 110 | }() 111 | 112 | // open TUN device (or use supplied fd) 113 | 114 | tun, err := func() (tun.Device, error) { 115 | tunFdStr := os.Getenv(ENV_WG_TUN_FD) 116 | if tunFdStr == "" { 117 | return tun.CreateTUN(interfaceName, device.DefaultMTU) 118 | } 119 | 120 | // construct tun device from supplied fd 121 | 122 | fd, err := strconv.ParseUint(tunFdStr, 10, 32) 123 | if err != nil { 124 | return nil, err 125 | } 126 | 127 | err = syscall.SetNonblock(int(fd), true) 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | file := os.NewFile(uintptr(fd), "") 133 | return tun.CreateTUNFromFile(file, device.DefaultMTU) 134 | }() 135 | 136 | if err == nil { 137 | realInterfaceName, err2 := tun.Name() 138 | if err2 == nil { 139 | interfaceName = realInterfaceName 140 | } 141 | } 142 | 143 | logger := device.NewLogger( 144 | logLevel, 145 | fmt.Sprintf("(%s) ", interfaceName), 146 | ) 147 | 148 | logger.Verbosef("Starting wireguard-go version %s", Version) 149 | 150 | if err != nil { 151 | logger.Errorf("Failed to create TUN device: %v", err) 152 | os.Exit(ExitSetupFailed) 153 | } 154 | 155 | // open UAPI file (or use supplied fd) 156 | 157 | fileUAPI, err := func() (*os.File, error) { 158 | uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) 159 | if uapiFdStr == "" { 160 | return ipc.UAPIOpen(interfaceName) 161 | } 162 | 163 | // use supplied fd 164 | 165 | fd, err := strconv.ParseUint(uapiFdStr, 10, 32) 166 | if err != nil { 167 | return nil, err 168 | } 169 | 170 | return os.NewFile(uintptr(fd), ""), nil 171 | }() 172 | if err != nil { 173 | logger.Errorf("UAPI listen error: %v", err) 174 | os.Exit(ExitSetupFailed) 175 | return 176 | } 177 | // daemonize the process 178 | 179 | if !foreground { 180 | env := os.Environ() 181 | env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) 182 | env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) 183 | env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND)) 184 | files := [3]*os.File{} 185 | if os.Getenv("LOG_LEVEL") != "" && logLevel != device.LogLevelSilent { 186 | files[0], _ = os.Open(os.DevNull) 187 | files[1] = os.Stdout 188 | files[2] = os.Stderr 189 | } else { 190 | files[0], _ = os.Open(os.DevNull) 191 | files[1], _ = os.Open(os.DevNull) 192 | files[2], _ = os.Open(os.DevNull) 193 | } 194 | attr := &os.ProcAttr{ 195 | Files: []*os.File{ 196 | files[0], // stdin 197 | files[1], // stdout 198 | files[2], // stderr 199 | tun.File(), 200 | fileUAPI, 201 | }, 202 | Dir: ".", 203 | Env: env, 204 | } 205 | 206 | path, err := os.Executable() 207 | if err != nil { 208 | logger.Errorf("Failed to determine executable: %v", err) 209 | os.Exit(ExitSetupFailed) 210 | } 211 | 212 | process, err := os.StartProcess( 213 | path, 214 | os.Args, 215 | attr, 216 | ) 217 | if err != nil { 218 | logger.Errorf("Failed to daemonize: %v", err) 219 | os.Exit(ExitSetupFailed) 220 | } 221 | process.Release() 222 | return 223 | } 224 | 225 | device := device.NewDevice(tun, conn.NewDefaultBind(), logger) 226 | 227 | logger.Verbosef("Device started") 228 | 229 | errs := make(chan error) 230 | term := make(chan os.Signal, 1) 231 | 232 | uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) 233 | if err != nil { 234 | logger.Errorf("Failed to listen on uapi socket: %v", err) 235 | os.Exit(ExitSetupFailed) 236 | } 237 | 238 | go func() { 239 | for { 240 | conn, err := uapi.Accept() 241 | if err != nil { 242 | errs <- err 243 | return 244 | } 245 | go device.IpcHandle(conn) 246 | } 247 | }() 248 | 249 | logger.Verbosef("UAPI listener started") 250 | 251 | // wait for program to terminate 252 | 253 | signal.Notify(term, syscall.SIGTERM) 254 | signal.Notify(term, os.Interrupt) 255 | 256 | select { 257 | case <-term: 258 | case <-errs: 259 | case <-device.Wait(): 260 | } 261 | 262 | // clean up 263 | 264 | uapi.Close() 265 | device.Close() 266 | 267 | logger.Verbosef("Shutting down") 268 | } 269 | -------------------------------------------------------------------------------- /device/allowedips_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net" 11 | "net/netip" 12 | "testing" 13 | ) 14 | 15 | type testPairCommonBits struct { 16 | s1 []byte 17 | s2 []byte 18 | match uint8 19 | } 20 | 21 | func TestCommonBits(t *testing.T) { 22 | tests := []testPairCommonBits{ 23 | {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, 24 | {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, 25 | {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, 26 | {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, 27 | {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, 28 | } 29 | 30 | for _, p := range tests { 31 | v := commonBits(p.s1, p.s2) 32 | if v != p.match { 33 | t.Error( 34 | "For slice", p.s1, p.s2, 35 | "expected match", p.match, 36 | ",but got", v, 37 | ) 38 | } 39 | } 40 | } 41 | 42 | func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { 43 | var trie *trieEntry 44 | var peers []*Peer 45 | root := parentIndirection{&trie, 2} 46 | 47 | rand.Seed(1) 48 | 49 | const AddressLength = 4 50 | 51 | for n := 0; n < peerNumber; n++ { 52 | peers = append(peers, &Peer{}) 53 | } 54 | 55 | for n := 0; n < addressNumber; n++ { 56 | var addr [AddressLength]byte 57 | rand.Read(addr[:]) 58 | cidr := uint8(rand.Uint32() % (AddressLength * 8)) 59 | index := rand.Int() % peerNumber 60 | root.insert(addr[:], cidr, peers[index]) 61 | } 62 | 63 | for n := 0; n < b.N; n++ { 64 | var addr [AddressLength]byte 65 | rand.Read(addr[:]) 66 | trie.lookup(addr[:]) 67 | } 68 | } 69 | 70 | func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { 71 | benchmarkTrie(100, 1000, net.IPv4len, b) 72 | } 73 | 74 | func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { 75 | benchmarkTrie(10, 10, net.IPv4len, b) 76 | } 77 | 78 | func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { 79 | benchmarkTrie(100, 1000, net.IPv6len, b) 80 | } 81 | 82 | func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { 83 | benchmarkTrie(10, 10, net.IPv6len, b) 84 | } 85 | 86 | /* Test ported from kernel implementation: 87 | * selftest/allowedips.h 88 | */ 89 | func TestTrieIPv4(t *testing.T) { 90 | a := &Peer{} 91 | b := &Peer{} 92 | c := &Peer{} 93 | d := &Peer{} 94 | e := &Peer{} 95 | g := &Peer{} 96 | h := &Peer{} 97 | 98 | var allowedIPs AllowedIPs 99 | 100 | insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { 101 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) 102 | } 103 | 104 | assertEQ := func(peer *Peer, a, b, c, d byte) { 105 | p := allowedIPs.Lookup([]byte{a, b, c, d}) 106 | if p != peer { 107 | t.Error("Assert EQ failed") 108 | } 109 | } 110 | 111 | assertNEQ := func(peer *Peer, a, b, c, d byte) { 112 | p := allowedIPs.Lookup([]byte{a, b, c, d}) 113 | if p == peer { 114 | t.Error("Assert NEQ failed") 115 | } 116 | } 117 | 118 | insert(a, 192, 168, 4, 0, 24) 119 | insert(b, 192, 168, 4, 4, 32) 120 | insert(c, 192, 168, 0, 0, 16) 121 | insert(d, 192, 95, 5, 64, 27) 122 | insert(c, 192, 95, 5, 65, 27) 123 | insert(e, 0, 0, 0, 0, 0) 124 | insert(g, 64, 15, 112, 0, 20) 125 | insert(h, 64, 15, 123, 211, 25) 126 | insert(a, 10, 0, 0, 0, 25) 127 | insert(b, 10, 0, 0, 128, 25) 128 | insert(a, 10, 1, 0, 0, 30) 129 | insert(b, 10, 1, 0, 4, 30) 130 | insert(c, 10, 1, 0, 8, 29) 131 | insert(d, 10, 1, 0, 16, 29) 132 | 133 | assertEQ(a, 192, 168, 4, 20) 134 | assertEQ(a, 192, 168, 4, 0) 135 | assertEQ(b, 192, 168, 4, 4) 136 | assertEQ(c, 192, 168, 200, 182) 137 | assertEQ(c, 192, 95, 5, 68) 138 | assertEQ(e, 192, 95, 5, 96) 139 | assertEQ(g, 64, 15, 116, 26) 140 | assertEQ(g, 64, 15, 127, 3) 141 | 142 | insert(a, 1, 0, 0, 0, 32) 143 | insert(a, 64, 0, 0, 0, 32) 144 | insert(a, 128, 0, 0, 0, 32) 145 | insert(a, 192, 0, 0, 0, 32) 146 | insert(a, 255, 0, 0, 0, 32) 147 | 148 | assertEQ(a, 1, 0, 0, 0) 149 | assertEQ(a, 64, 0, 0, 0) 150 | assertEQ(a, 128, 0, 0, 0) 151 | assertEQ(a, 192, 0, 0, 0) 152 | assertEQ(a, 255, 0, 0, 0) 153 | 154 | allowedIPs.RemoveByPeer(a) 155 | 156 | assertNEQ(a, 1, 0, 0, 0) 157 | assertNEQ(a, 64, 0, 0, 0) 158 | assertNEQ(a, 128, 0, 0, 0) 159 | assertNEQ(a, 192, 0, 0, 0) 160 | assertNEQ(a, 255, 0, 0, 0) 161 | 162 | allowedIPs.RemoveByPeer(a) 163 | allowedIPs.RemoveByPeer(b) 164 | allowedIPs.RemoveByPeer(c) 165 | allowedIPs.RemoveByPeer(d) 166 | allowedIPs.RemoveByPeer(e) 167 | allowedIPs.RemoveByPeer(g) 168 | allowedIPs.RemoveByPeer(h) 169 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 170 | t.Error("Expected removing all the peers to empty trie, but it did not") 171 | } 172 | 173 | insert(a, 192, 168, 0, 0, 16) 174 | insert(a, 192, 168, 0, 0, 24) 175 | 176 | allowedIPs.RemoveByPeer(a) 177 | 178 | assertNEQ(a, 192, 168, 0, 1) 179 | } 180 | 181 | /* Test ported from kernel implementation: 182 | * selftest/allowedips.h 183 | */ 184 | func TestTrieIPv6(t *testing.T) { 185 | a := &Peer{} 186 | b := &Peer{} 187 | c := &Peer{} 188 | d := &Peer{} 189 | e := &Peer{} 190 | f := &Peer{} 191 | g := &Peer{} 192 | h := &Peer{} 193 | 194 | var allowedIPs AllowedIPs 195 | 196 | expand := func(a uint32) []byte { 197 | var out [4]byte 198 | out[0] = byte(a >> 24 & 0xff) 199 | out[1] = byte(a >> 16 & 0xff) 200 | out[2] = byte(a >> 8 & 0xff) 201 | out[3] = byte(a & 0xff) 202 | return out[:] 203 | } 204 | 205 | insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { 206 | var addr []byte 207 | addr = append(addr, expand(a)...) 208 | addr = append(addr, expand(b)...) 209 | addr = append(addr, expand(c)...) 210 | addr = append(addr, expand(d)...) 211 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) 212 | } 213 | 214 | assertEQ := func(peer *Peer, a, b, c, d uint32) { 215 | var addr []byte 216 | addr = append(addr, expand(a)...) 217 | addr = append(addr, expand(b)...) 218 | addr = append(addr, expand(c)...) 219 | addr = append(addr, expand(d)...) 220 | p := allowedIPs.Lookup(addr) 221 | if p != peer { 222 | t.Error("Assert EQ failed") 223 | } 224 | } 225 | 226 | insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) 227 | insert(c, 0x26075300, 0x60006b00, 0, 0, 64) 228 | insert(e, 0, 0, 0, 0, 0) 229 | insert(f, 0, 0, 0, 0, 0) 230 | insert(g, 0x24046800, 0, 0, 0, 32) 231 | insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) 232 | insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) 233 | insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) 234 | insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) 235 | 236 | assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) 237 | assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) 238 | assertEQ(f, 0x26075300, 0x60006b01, 0, 0) 239 | assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) 240 | assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) 241 | assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) 242 | assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) 243 | assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) 244 | assertEQ(h, 0x24046800, 0x40040800, 0, 0) 245 | assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) 246 | assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) 247 | } 248 | -------------------------------------------------------------------------------- /tun/tun_openbsd.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "net" 12 | "os" 13 | "sync" 14 | "syscall" 15 | "unsafe" 16 | 17 | "golang.org/x/net/ipv6" 18 | "golang.org/x/sys/unix" 19 | ) 20 | 21 | // Structure for iface mtu get/set ioctls 22 | type ifreq_mtu struct { 23 | Name [unix.IFNAMSIZ]byte 24 | MTU uint32 25 | Pad0 [12]byte 26 | } 27 | 28 | const _TUNSIFMODE = 0x8004745d 29 | 30 | type NativeTun struct { 31 | name string 32 | tunFile *os.File 33 | events chan Event 34 | errors chan error 35 | routeSocket int 36 | closeOnce sync.Once 37 | } 38 | 39 | func (tun *NativeTun) routineRouteListener(tunIfindex int) { 40 | var ( 41 | statusUp bool 42 | statusMTU int 43 | ) 44 | 45 | defer close(tun.events) 46 | 47 | check := func() bool { 48 | iface, err := net.InterfaceByIndex(tunIfindex) 49 | if err != nil { 50 | tun.errors <- err 51 | return true 52 | } 53 | 54 | // Up / Down event 55 | up := (iface.Flags & net.FlagUp) != 0 56 | if up != statusUp && up { 57 | tun.events <- EventUp 58 | } 59 | if up != statusUp && !up { 60 | tun.events <- EventDown 61 | } 62 | statusUp = up 63 | 64 | // MTU changes 65 | if iface.MTU != statusMTU { 66 | tun.events <- EventMTUUpdate 67 | } 68 | statusMTU = iface.MTU 69 | return false 70 | } 71 | 72 | if check() { 73 | return 74 | } 75 | 76 | data := make([]byte, os.Getpagesize()) 77 | for { 78 | n, err := unix.Read(tun.routeSocket, data) 79 | if err != nil { 80 | if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { 81 | continue 82 | } 83 | tun.errors <- err 84 | return 85 | } 86 | 87 | if n < 8 { 88 | continue 89 | } 90 | 91 | if data[3 /* type */] != unix.RTM_IFINFO { 92 | continue 93 | } 94 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */]))) 95 | if ifindex != tunIfindex { 96 | continue 97 | } 98 | if check() { 99 | return 100 | } 101 | } 102 | } 103 | 104 | func CreateTUN(name string, mtu int) (Device, error) { 105 | ifIndex := -1 106 | if name != "tun" { 107 | _, err := fmt.Sscanf(name, "tun%d", &ifIndex) 108 | if err != nil || ifIndex < 0 { 109 | return nil, fmt.Errorf("Interface name must be tun[0-9]*") 110 | } 111 | } 112 | 113 | var tunfile *os.File 114 | var err error 115 | 116 | if ifIndex != -1 { 117 | tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) 118 | } else { 119 | for ifIndex = 0; ifIndex < 256; ifIndex++ { 120 | tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) 121 | if err == nil || !errors.Is(err, syscall.EBUSY) { 122 | break 123 | } 124 | } 125 | } 126 | 127 | if err != nil { 128 | return nil, err 129 | } 130 | 131 | tun, err := CreateTUNFromFile(tunfile, mtu) 132 | 133 | if err == nil && name == "tun" { 134 | fname := os.Getenv("WG_TUN_NAME_FILE") 135 | if fname != "" { 136 | os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) 137 | } 138 | } 139 | 140 | return tun, err 141 | } 142 | 143 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 144 | tun := &NativeTun{ 145 | tunFile: file, 146 | events: make(chan Event, 10), 147 | errors: make(chan error, 1), 148 | } 149 | 150 | name, err := tun.Name() 151 | if err != nil { 152 | tun.tunFile.Close() 153 | return nil, err 154 | } 155 | 156 | tunIfindex, err := func() (int, error) { 157 | iface, err := net.InterfaceByName(name) 158 | if err != nil { 159 | return -1, err 160 | } 161 | return iface.Index, nil 162 | }() 163 | if err != nil { 164 | tun.tunFile.Close() 165 | return nil, err 166 | } 167 | 168 | tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC) 169 | if err != nil { 170 | tun.tunFile.Close() 171 | return nil, err 172 | } 173 | 174 | go tun.routineRouteListener(tunIfindex) 175 | 176 | currentMTU, err := tun.MTU() 177 | if err != nil || currentMTU != mtu { 178 | err = tun.setMTU(mtu) 179 | if err != nil { 180 | tun.Close() 181 | return nil, err 182 | } 183 | } 184 | 185 | return tun, nil 186 | } 187 | 188 | func (tun *NativeTun) Name() (string, error) { 189 | gostat, err := tun.tunFile.Stat() 190 | if err != nil { 191 | tun.name = "" 192 | return "", err 193 | } 194 | stat := gostat.Sys().(*syscall.Stat_t) 195 | tun.name = fmt.Sprintf("tun%d", stat.Rdev%256) 196 | return tun.name, nil 197 | } 198 | 199 | func (tun *NativeTun) File() *os.File { 200 | return tun.tunFile 201 | } 202 | 203 | func (tun *NativeTun) Events() <-chan Event { 204 | return tun.events 205 | } 206 | 207 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 208 | select { 209 | case err := <-tun.errors: 210 | return 0, err 211 | default: 212 | buff := buff[offset-4:] 213 | n, err := tun.tunFile.Read(buff[:]) 214 | if n < 4 { 215 | return 0, err 216 | } 217 | return n - 4, err 218 | } 219 | } 220 | 221 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 222 | // reserve space for header 223 | 224 | buff = buff[offset-4:] 225 | 226 | // add packet information header 227 | 228 | buff[0] = 0x00 229 | buff[1] = 0x00 230 | buff[2] = 0x00 231 | 232 | if buff[4]>>4 == ipv6.Version { 233 | buff[3] = unix.AF_INET6 234 | } else { 235 | buff[3] = unix.AF_INET 236 | } 237 | 238 | // write 239 | 240 | return tun.tunFile.Write(buff) 241 | } 242 | 243 | func (tun *NativeTun) Flush() error { 244 | // TODO: can flushing be implemented by buffering and using sendmmsg? 245 | return nil 246 | } 247 | 248 | func (tun *NativeTun) Close() error { 249 | var err1, err2 error 250 | tun.closeOnce.Do(func() { 251 | err1 = tun.tunFile.Close() 252 | if tun.routeSocket != -1 { 253 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) 254 | err2 = unix.Close(tun.routeSocket) 255 | tun.routeSocket = -1 256 | } else if tun.events != nil { 257 | close(tun.events) 258 | } 259 | }) 260 | if err1 != nil { 261 | return err1 262 | } 263 | return err2 264 | } 265 | 266 | func (tun *NativeTun) setMTU(n int) error { 267 | // open datagram socket 268 | 269 | var fd int 270 | 271 | fd, err := unix.Socket( 272 | unix.AF_INET, 273 | unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 274 | 0, 275 | ) 276 | if err != nil { 277 | return err 278 | } 279 | 280 | defer unix.Close(fd) 281 | 282 | // do ioctl call 283 | 284 | var ifr ifreq_mtu 285 | copy(ifr.Name[:], tun.name) 286 | ifr.MTU = uint32(n) 287 | 288 | _, _, errno := unix.Syscall( 289 | unix.SYS_IOCTL, 290 | uintptr(fd), 291 | uintptr(unix.SIOCSIFMTU), 292 | uintptr(unsafe.Pointer(&ifr)), 293 | ) 294 | 295 | if errno != 0 { 296 | return fmt.Errorf("failed to set MTU on %s", tun.name) 297 | } 298 | 299 | return nil 300 | } 301 | 302 | func (tun *NativeTun) MTU() (int, error) { 303 | // open datagram socket 304 | 305 | fd, err := unix.Socket( 306 | unix.AF_INET, 307 | unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 308 | 0, 309 | ) 310 | if err != nil { 311 | return 0, err 312 | } 313 | 314 | defer unix.Close(fd) 315 | 316 | // do ioctl call 317 | var ifr ifreq_mtu 318 | copy(ifr.Name[:], tun.name) 319 | 320 | _, _, errno := unix.Syscall( 321 | unix.SYS_IOCTL, 322 | uintptr(fd), 323 | uintptr(unix.SIOCGIFMTU), 324 | uintptr(unsafe.Pointer(&ifr)), 325 | ) 326 | if errno != 0 { 327 | return 0, fmt.Errorf("failed to get MTU on %s", tun.name) 328 | } 329 | 330 | return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil 331 | } 332 | -------------------------------------------------------------------------------- /tun/tun_darwin.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "net" 12 | "os" 13 | "sync" 14 | "syscall" 15 | "time" 16 | "unsafe" 17 | 18 | "golang.org/x/net/ipv6" 19 | "golang.org/x/sys/unix" 20 | ) 21 | 22 | const utunControlName = "com.apple.net.utun_control" 23 | 24 | type NativeTun struct { 25 | name string 26 | tunFile *os.File 27 | events chan Event 28 | errors chan error 29 | routeSocket int 30 | closeOnce sync.Once 31 | } 32 | 33 | func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { 34 | for i := 0; i < 20; i++ { 35 | iface, err = net.InterfaceByIndex(index) 36 | if err != nil && errors.Is(err, syscall.ENOMEM) { 37 | time.Sleep(time.Duration(i) * time.Second / 3) 38 | continue 39 | } 40 | return iface, err 41 | } 42 | return nil, err 43 | } 44 | 45 | func (tun *NativeTun) routineRouteListener(tunIfindex int) { 46 | var ( 47 | statusUp bool 48 | statusMTU int 49 | ) 50 | 51 | defer close(tun.events) 52 | 53 | data := make([]byte, os.Getpagesize()) 54 | for { 55 | retry: 56 | n, err := unix.Read(tun.routeSocket, data) 57 | if err != nil { 58 | if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { 59 | goto retry 60 | } 61 | tun.errors <- err 62 | return 63 | } 64 | 65 | if n < 14 { 66 | continue 67 | } 68 | 69 | if data[3 /* type */] != unix.RTM_IFINFO { 70 | continue 71 | } 72 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */]))) 73 | if ifindex != tunIfindex { 74 | continue 75 | } 76 | 77 | iface, err := retryInterfaceByIndex(ifindex) 78 | if err != nil { 79 | tun.errors <- err 80 | return 81 | } 82 | 83 | // Up / Down event 84 | up := (iface.Flags & net.FlagUp) != 0 85 | if up != statusUp && up { 86 | tun.events <- EventUp 87 | } 88 | if up != statusUp && !up { 89 | tun.events <- EventDown 90 | } 91 | statusUp = up 92 | 93 | // MTU changes 94 | if iface.MTU != statusMTU { 95 | tun.events <- EventMTUUpdate 96 | } 97 | statusMTU = iface.MTU 98 | } 99 | } 100 | 101 | func CreateTUN(name string, mtu int) (Device, error) { 102 | ifIndex := -1 103 | if name != "utun" { 104 | _, err := fmt.Sscanf(name, "utun%d", &ifIndex) 105 | if err != nil || ifIndex < 0 { 106 | return nil, fmt.Errorf("Interface name must be utun[0-9]*") 107 | } 108 | } 109 | 110 | fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) 111 | if err != nil { 112 | return nil, err 113 | } 114 | 115 | ctlInfo := &unix.CtlInfo{} 116 | copy(ctlInfo.Name[:], []byte(utunControlName)) 117 | err = unix.IoctlCtlInfo(fd, ctlInfo) 118 | if err != nil { 119 | unix.Close(fd) 120 | return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err) 121 | } 122 | 123 | sc := &unix.SockaddrCtl{ 124 | ID: ctlInfo.Id, 125 | Unit: uint32(ifIndex) + 1, 126 | } 127 | 128 | err = unix.Connect(fd, sc) 129 | if err != nil { 130 | unix.Close(fd) 131 | return nil, err 132 | } 133 | 134 | err = unix.SetNonblock(fd, true) 135 | if err != nil { 136 | unix.Close(fd) 137 | return nil, err 138 | } 139 | tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu) 140 | 141 | if err == nil && name == "utun" { 142 | fname := os.Getenv("WG_TUN_NAME_FILE") 143 | if fname != "" { 144 | os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) 145 | } 146 | } 147 | 148 | return tun, err 149 | } 150 | 151 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 152 | tun := &NativeTun{ 153 | tunFile: file, 154 | events: make(chan Event, 10), 155 | errors: make(chan error, 5), 156 | } 157 | 158 | name, err := tun.Name() 159 | if err != nil { 160 | tun.tunFile.Close() 161 | return nil, err 162 | } 163 | 164 | tunIfindex, err := func() (int, error) { 165 | iface, err := net.InterfaceByName(name) 166 | if err != nil { 167 | return -1, err 168 | } 169 | return iface.Index, nil 170 | }() 171 | if err != nil { 172 | tun.tunFile.Close() 173 | return nil, err 174 | } 175 | 176 | tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) 177 | if err != nil { 178 | tun.tunFile.Close() 179 | return nil, err 180 | } 181 | 182 | go tun.routineRouteListener(tunIfindex) 183 | 184 | if mtu > 0 { 185 | err = tun.setMTU(mtu) 186 | if err != nil { 187 | tun.Close() 188 | return nil, err 189 | } 190 | } 191 | 192 | return tun, nil 193 | } 194 | 195 | func (tun *NativeTun) Name() (string, error) { 196 | var err error 197 | tun.operateOnFd(func(fd uintptr) { 198 | tun.name, err = unix.GetsockoptString( 199 | int(fd), 200 | 2, /* #define SYSPROTO_CONTROL 2 */ 201 | 2, /* #define UTUN_OPT_IFNAME 2 */ 202 | ) 203 | }) 204 | 205 | if err != nil { 206 | return "", fmt.Errorf("GetSockoptString: %w", err) 207 | } 208 | 209 | return tun.name, nil 210 | } 211 | 212 | func (tun *NativeTun) File() *os.File { 213 | return tun.tunFile 214 | } 215 | 216 | func (tun *NativeTun) Events() <-chan Event { 217 | return tun.events 218 | } 219 | 220 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 221 | select { 222 | case err := <-tun.errors: 223 | return 0, err 224 | default: 225 | buff := buff[offset-4:] 226 | n, err := tun.tunFile.Read(buff[:]) 227 | if n < 4 { 228 | return 0, err 229 | } 230 | return n - 4, err 231 | } 232 | } 233 | 234 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 235 | // reserve space for header 236 | 237 | buff = buff[offset-4:] 238 | 239 | // add packet information header 240 | 241 | buff[0] = 0x00 242 | buff[1] = 0x00 243 | buff[2] = 0x00 244 | 245 | if buff[4]>>4 == ipv6.Version { 246 | buff[3] = unix.AF_INET6 247 | } else { 248 | buff[3] = unix.AF_INET 249 | } 250 | 251 | // write 252 | 253 | return tun.tunFile.Write(buff) 254 | } 255 | 256 | func (tun *NativeTun) Flush() error { 257 | // TODO: can flushing be implemented by buffering and using sendmmsg? 258 | return nil 259 | } 260 | 261 | func (tun *NativeTun) Close() error { 262 | var err1, err2 error 263 | tun.closeOnce.Do(func() { 264 | err1 = tun.tunFile.Close() 265 | if tun.routeSocket != -1 { 266 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) 267 | err2 = unix.Close(tun.routeSocket) 268 | } else if tun.events != nil { 269 | close(tun.events) 270 | } 271 | }) 272 | if err1 != nil { 273 | return err1 274 | } 275 | return err2 276 | } 277 | 278 | func (tun *NativeTun) setMTU(n int) error { 279 | fd, err := socketCloexec( 280 | unix.AF_INET, 281 | unix.SOCK_DGRAM, 282 | 0, 283 | ) 284 | if err != nil { 285 | return err 286 | } 287 | 288 | defer unix.Close(fd) 289 | 290 | var ifr unix.IfreqMTU 291 | copy(ifr.Name[:], tun.name) 292 | ifr.MTU = int32(n) 293 | err = unix.IoctlSetIfreqMTU(fd, &ifr) 294 | if err != nil { 295 | return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err) 296 | } 297 | 298 | return nil 299 | } 300 | 301 | func (tun *NativeTun) MTU() (int, error) { 302 | fd, err := socketCloexec( 303 | unix.AF_INET, 304 | unix.SOCK_DGRAM, 305 | 0, 306 | ) 307 | if err != nil { 308 | return 0, err 309 | } 310 | 311 | defer unix.Close(fd) 312 | 313 | ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name) 314 | if err != nil { 315 | return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err) 316 | } 317 | 318 | return int(ifr.MTU), nil 319 | } 320 | 321 | func socketCloexec(family, sotype, proto int) (fd int, err error) { 322 | // See go/src/net/sys_cloexec.go for background. 323 | syscall.ForkLock.RLock() 324 | defer syscall.ForkLock.RUnlock() 325 | 326 | fd, err = unix.Socket(family, sotype, proto) 327 | if err == nil { 328 | unix.CloseOnExec(fd) 329 | } 330 | return 331 | } 332 | -------------------------------------------------------------------------------- /ipc/namedpipe/file.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 The Go Authors. All rights reserved. 2 | // Copyright 2015 Microsoft 3 | // Use of this source code is governed by a BSD-style 4 | // license that can be found in the LICENSE file. 5 | 6 | //go:build windows 7 | // +build windows 8 | 9 | package namedpipe 10 | 11 | import ( 12 | "io" 13 | "os" 14 | "runtime" 15 | "sync" 16 | "sync/atomic" 17 | "time" 18 | "unsafe" 19 | 20 | "golang.org/x/sys/windows" 21 | ) 22 | 23 | type timeoutChan chan struct{} 24 | 25 | var ( 26 | ioInitOnce sync.Once 27 | ioCompletionPort windows.Handle 28 | ) 29 | 30 | // ioResult contains the result of an asynchronous IO operation 31 | type ioResult struct { 32 | bytes uint32 33 | err error 34 | } 35 | 36 | // ioOperation represents an outstanding asynchronous Win32 IO 37 | type ioOperation struct { 38 | o windows.Overlapped 39 | ch chan ioResult 40 | } 41 | 42 | func initIo() { 43 | h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 44 | if err != nil { 45 | panic(err) 46 | } 47 | ioCompletionPort = h 48 | go ioCompletionProcessor(h) 49 | } 50 | 51 | // file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. 52 | // It takes ownership of this handle and will close it if it is garbage collected. 53 | type file struct { 54 | handle windows.Handle 55 | wg sync.WaitGroup 56 | wgLock sync.RWMutex 57 | closing atomic.Bool 58 | socket bool 59 | readDeadline deadlineHandler 60 | writeDeadline deadlineHandler 61 | } 62 | 63 | type deadlineHandler struct { 64 | setLock sync.Mutex 65 | channel timeoutChan 66 | channelLock sync.RWMutex 67 | timer *time.Timer 68 | timedout atomic.Bool 69 | } 70 | 71 | // makeFile makes a new file from an existing file handle 72 | func makeFile(h windows.Handle) (*file, error) { 73 | f := &file{handle: h} 74 | ioInitOnce.Do(initIo) 75 | _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) 76 | if err != nil { 77 | return nil, err 78 | } 79 | err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) 80 | if err != nil { 81 | return nil, err 82 | } 83 | f.readDeadline.channel = make(timeoutChan) 84 | f.writeDeadline.channel = make(timeoutChan) 85 | return f, nil 86 | } 87 | 88 | // closeHandle closes the resources associated with a Win32 handle 89 | func (f *file) closeHandle() { 90 | f.wgLock.Lock() 91 | // Atomically set that we are closing, releasing the resources only once. 92 | if f.closing.Swap(true) == false { 93 | f.wgLock.Unlock() 94 | // cancel all IO and wait for it to complete 95 | windows.CancelIoEx(f.handle, nil) 96 | f.wg.Wait() 97 | // at this point, no new IO can start 98 | windows.Close(f.handle) 99 | f.handle = 0 100 | } else { 101 | f.wgLock.Unlock() 102 | } 103 | } 104 | 105 | // Close closes a file. 106 | func (f *file) Close() error { 107 | f.closeHandle() 108 | return nil 109 | } 110 | 111 | // prepareIo prepares for a new IO operation. 112 | // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. 113 | func (f *file) prepareIo() (*ioOperation, error) { 114 | f.wgLock.RLock() 115 | if f.closing.Load() { 116 | f.wgLock.RUnlock() 117 | return nil, os.ErrClosed 118 | } 119 | f.wg.Add(1) 120 | f.wgLock.RUnlock() 121 | c := &ioOperation{} 122 | c.ch = make(chan ioResult) 123 | return c, nil 124 | } 125 | 126 | // ioCompletionProcessor processes completed async IOs forever 127 | func ioCompletionProcessor(h windows.Handle) { 128 | for { 129 | var bytes uint32 130 | var key uintptr 131 | var op *ioOperation 132 | err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) 133 | if op == nil { 134 | panic(err) 135 | } 136 | op.ch <- ioResult{bytes, err} 137 | } 138 | } 139 | 140 | // asyncIo processes the return value from ReadFile or WriteFile, blocking until 141 | // the operation has actually completed. 142 | func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { 143 | if err != windows.ERROR_IO_PENDING { 144 | return int(bytes), err 145 | } 146 | 147 | if f.closing.Load() { 148 | windows.CancelIoEx(f.handle, &c.o) 149 | } 150 | 151 | var timeout timeoutChan 152 | if d != nil { 153 | d.channelLock.Lock() 154 | timeout = d.channel 155 | d.channelLock.Unlock() 156 | } 157 | 158 | var r ioResult 159 | select { 160 | case r = <-c.ch: 161 | err = r.err 162 | if err == windows.ERROR_OPERATION_ABORTED { 163 | if f.closing.Load() { 164 | err = os.ErrClosed 165 | } 166 | } else if err != nil && f.socket { 167 | // err is from Win32. Query the overlapped structure to get the winsock error. 168 | var bytes, flags uint32 169 | err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) 170 | } 171 | case <-timeout: 172 | windows.CancelIoEx(f.handle, &c.o) 173 | r = <-c.ch 174 | err = r.err 175 | if err == windows.ERROR_OPERATION_ABORTED { 176 | err = os.ErrDeadlineExceeded 177 | } 178 | } 179 | 180 | // runtime.KeepAlive is needed, as c is passed via native 181 | // code to ioCompletionProcessor, c must remain alive 182 | // until the channel read is complete. 183 | runtime.KeepAlive(c) 184 | return int(r.bytes), err 185 | } 186 | 187 | // Read reads from a file handle. 188 | func (f *file) Read(b []byte) (int, error) { 189 | c, err := f.prepareIo() 190 | if err != nil { 191 | return 0, err 192 | } 193 | defer f.wg.Done() 194 | 195 | if f.readDeadline.timedout.Load() { 196 | return 0, os.ErrDeadlineExceeded 197 | } 198 | 199 | var bytes uint32 200 | err = windows.ReadFile(f.handle, b, &bytes, &c.o) 201 | n, err := f.asyncIo(c, &f.readDeadline, bytes, err) 202 | runtime.KeepAlive(b) 203 | 204 | // Handle EOF conditions. 205 | if err == nil && n == 0 && len(b) != 0 { 206 | return 0, io.EOF 207 | } else if err == windows.ERROR_BROKEN_PIPE { 208 | return 0, io.EOF 209 | } else { 210 | return n, err 211 | } 212 | } 213 | 214 | // Write writes to a file handle. 215 | func (f *file) Write(b []byte) (int, error) { 216 | c, err := f.prepareIo() 217 | if err != nil { 218 | return 0, err 219 | } 220 | defer f.wg.Done() 221 | 222 | if f.writeDeadline.timedout.Load() { 223 | return 0, os.ErrDeadlineExceeded 224 | } 225 | 226 | var bytes uint32 227 | err = windows.WriteFile(f.handle, b, &bytes, &c.o) 228 | n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) 229 | runtime.KeepAlive(b) 230 | return n, err 231 | } 232 | 233 | func (f *file) SetReadDeadline(deadline time.Time) error { 234 | return f.readDeadline.set(deadline) 235 | } 236 | 237 | func (f *file) SetWriteDeadline(deadline time.Time) error { 238 | return f.writeDeadline.set(deadline) 239 | } 240 | 241 | func (f *file) Flush() error { 242 | return windows.FlushFileBuffers(f.handle) 243 | } 244 | 245 | func (f *file) Fd() uintptr { 246 | return uintptr(f.handle) 247 | } 248 | 249 | func (d *deadlineHandler) set(deadline time.Time) error { 250 | d.setLock.Lock() 251 | defer d.setLock.Unlock() 252 | 253 | if d.timer != nil { 254 | if !d.timer.Stop() { 255 | <-d.channel 256 | } 257 | d.timer = nil 258 | } 259 | d.timedout.Store(false) 260 | 261 | select { 262 | case <-d.channel: 263 | d.channelLock.Lock() 264 | d.channel = make(chan struct{}) 265 | d.channelLock.Unlock() 266 | default: 267 | } 268 | 269 | if deadline.IsZero() { 270 | return nil 271 | } 272 | 273 | timeoutIO := func() { 274 | d.timedout.Store(true) 275 | close(d.channel) 276 | } 277 | 278 | now := time.Now() 279 | duration := deadline.Sub(now) 280 | if deadline.After(now) { 281 | // Deadline is in the future, set a timer to wait 282 | d.timer = time.AfterFunc(duration, timeoutIO) 283 | } else { 284 | // Deadline is in the past. Cancel all pending IO now. 285 | timeoutIO() 286 | } 287 | return nil 288 | } 289 | -------------------------------------------------------------------------------- /device/peer.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "container/list" 10 | "errors" 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | 15 | "golang.zx2c4.com/wireguard/conn" 16 | ) 17 | 18 | type Peer struct { 19 | isRunning atomic.Bool 20 | sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer 21 | keypairs Keypairs 22 | handshake Handshake 23 | device *Device 24 | endpoint conn.Endpoint 25 | stopping sync.WaitGroup // routines pending stop 26 | txBytes atomic.Uint64 // bytes send to peer (endpoint) 27 | rxBytes atomic.Uint64 // bytes received from peer 28 | lastHandshakeNano atomic.Int64 // nano seconds since epoch 29 | 30 | disableRoaming bool 31 | 32 | timers struct { 33 | retransmitHandshake *Timer 34 | sendKeepalive *Timer 35 | newHandshake *Timer 36 | zeroKeyMaterial *Timer 37 | persistentKeepalive *Timer 38 | handshakeAttempts atomic.Uint32 39 | needAnotherKeepalive atomic.Bool 40 | sentLastMinuteHandshake atomic.Bool 41 | } 42 | 43 | state struct { 44 | sync.Mutex // protects against concurrent Start/Stop 45 | } 46 | 47 | queue struct { 48 | staged chan *QueueOutboundElement // staged packets before a handshake is available 49 | outbound *autodrainingOutboundQueue // sequential ordering of udp transmission 50 | inbound *autodrainingInboundQueue // sequential ordering of tun writing 51 | } 52 | 53 | cookieGenerator CookieGenerator 54 | trieEntries list.List 55 | persistentKeepaliveInterval atomic.Uint32 56 | } 57 | 58 | func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { 59 | if device.isClosed() { 60 | return nil, errors.New("device closed") 61 | } 62 | 63 | // lock resources 64 | device.staticIdentity.RLock() 65 | defer device.staticIdentity.RUnlock() 66 | 67 | device.peers.Lock() 68 | defer device.peers.Unlock() 69 | 70 | // check if over limit 71 | if len(device.peers.keyMap) >= MaxPeers { 72 | return nil, errors.New("too many peers") 73 | } 74 | 75 | // create peer 76 | peer := new(Peer) 77 | peer.Lock() 78 | defer peer.Unlock() 79 | 80 | peer.cookieGenerator.Init(pk) 81 | peer.device = device 82 | peer.queue.outbound = newAutodrainingOutboundQueue(device) 83 | peer.queue.inbound = newAutodrainingInboundQueue(device) 84 | peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) 85 | 86 | // map public key 87 | _, ok := device.peers.keyMap[pk] 88 | if ok { 89 | return nil, errors.New("adding existing peer") 90 | } 91 | 92 | // pre-compute DH 93 | handshake := &peer.handshake 94 | handshake.mutex.Lock() 95 | handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk) 96 | handshake.remoteStatic = pk 97 | handshake.mutex.Unlock() 98 | 99 | // reset endpoint 100 | peer.endpoint = nil 101 | 102 | // init timers 103 | peer.timersInit() 104 | 105 | // add 106 | device.peers.keyMap[pk] = peer 107 | 108 | return peer, nil 109 | } 110 | 111 | func (peer *Peer) SendBuffer(buffer []byte) error { 112 | peer.device.net.RLock() 113 | defer peer.device.net.RUnlock() 114 | 115 | if peer.device.isClosed() { 116 | return nil 117 | } 118 | 119 | peer.RLock() 120 | defer peer.RUnlock() 121 | 122 | if peer.endpoint == nil { 123 | return errors.New("no known endpoint for peer") 124 | } 125 | 126 | err := peer.device.net.bind.Send(buffer, peer.endpoint) 127 | if err == nil { 128 | peer.txBytes.Add(uint64(len(buffer))) 129 | } 130 | return err 131 | } 132 | 133 | func (peer *Peer) String() string { 134 | // The awful goo that follows is identical to: 135 | // 136 | // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) 137 | // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43] 138 | // return fmt.Sprintf("peer(%s)", abbreviatedKey) 139 | // 140 | // except that it is considerably more efficient. 141 | src := peer.handshake.remoteStatic 142 | b64 := func(input byte) byte { 143 | return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3) 144 | } 145 | b := []byte("peer(____…____)") 146 | const first = len("peer(") 147 | const second = len("peer(____…") 148 | b[first+0] = b64((src[0] >> 2) & 63) 149 | b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63) 150 | b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63) 151 | b[first+3] = b64(src[2] & 63) 152 | b[second+0] = b64(src[29] & 63) 153 | b[second+1] = b64((src[30] >> 2) & 63) 154 | b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63) 155 | b[second+3] = b64((src[31] << 2) & 63) 156 | return string(b) 157 | } 158 | 159 | func (peer *Peer) Start() { 160 | // should never start a peer on a closed device 161 | if peer.device.isClosed() { 162 | return 163 | } 164 | 165 | // prevent simultaneous start/stop operations 166 | peer.state.Lock() 167 | defer peer.state.Unlock() 168 | 169 | if peer.isRunning.Load() { 170 | return 171 | } 172 | 173 | device := peer.device 174 | device.log.Verbosef("%v - Starting", peer) 175 | 176 | // reset routine state 177 | peer.stopping.Wait() 178 | peer.stopping.Add(2) 179 | 180 | peer.handshake.mutex.Lock() 181 | peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) 182 | peer.handshake.mutex.Unlock() 183 | 184 | peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes 185 | 186 | peer.timersStart() 187 | 188 | device.flushInboundQueue(peer.queue.inbound) 189 | device.flushOutboundQueue(peer.queue.outbound) 190 | go peer.RoutineSequentialSender() 191 | go peer.RoutineSequentialReceiver() 192 | 193 | peer.isRunning.Store(true) 194 | } 195 | 196 | func (peer *Peer) ZeroAndFlushAll() { 197 | device := peer.device 198 | 199 | // clear key pairs 200 | 201 | keypairs := &peer.keypairs 202 | keypairs.Lock() 203 | device.DeleteKeypair(keypairs.previous) 204 | device.DeleteKeypair(keypairs.current) 205 | device.DeleteKeypair(keypairs.next.Load()) 206 | keypairs.previous = nil 207 | keypairs.current = nil 208 | keypairs.next.Store(nil) 209 | keypairs.Unlock() 210 | 211 | // clear handshake state 212 | 213 | handshake := &peer.handshake 214 | handshake.mutex.Lock() 215 | device.indexTable.Delete(handshake.localIndex) 216 | handshake.Clear() 217 | handshake.mutex.Unlock() 218 | 219 | peer.FlushStagedPackets() 220 | } 221 | 222 | func (peer *Peer) ExpireCurrentKeypairs() { 223 | handshake := &peer.handshake 224 | handshake.mutex.Lock() 225 | peer.device.indexTable.Delete(handshake.localIndex) 226 | handshake.Clear() 227 | peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) 228 | handshake.mutex.Unlock() 229 | 230 | keypairs := &peer.keypairs 231 | keypairs.Lock() 232 | if keypairs.current != nil { 233 | keypairs.current.sendNonce.Store(RejectAfterMessages) 234 | } 235 | if next := keypairs.next.Load(); next != nil { 236 | next.sendNonce.Store(RejectAfterMessages) 237 | } 238 | keypairs.Unlock() 239 | } 240 | 241 | func (peer *Peer) Stop() { 242 | peer.state.Lock() 243 | defer peer.state.Unlock() 244 | 245 | if !peer.isRunning.Swap(false) { 246 | return 247 | } 248 | 249 | peer.device.log.Verbosef("%v - Stopping", peer) 250 | 251 | peer.timersStop() 252 | // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. 253 | peer.queue.inbound.c <- nil 254 | peer.queue.outbound.c <- nil 255 | peer.stopping.Wait() 256 | peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us 257 | 258 | peer.ZeroAndFlushAll() 259 | } 260 | 261 | func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { 262 | if peer.disableRoaming { 263 | return 264 | } 265 | peer.Lock() 266 | peer.endpoint = endpoint 267 | peer.Unlock() 268 | } 269 | -------------------------------------------------------------------------------- /device/timers.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | * 5 | * This is based heavily on timers.c from the kernel implementation. 6 | */ 7 | 8 | package device 9 | 10 | import ( 11 | "sync" 12 | "time" 13 | _ "unsafe" 14 | ) 15 | 16 | //go:linkname fastrandn runtime.fastrandn 17 | func fastrandn(n uint32) uint32 18 | 19 | // A Timer manages time-based aspects of the WireGuard protocol. 20 | // Timer roughly copies the interface of the Linux kernel's struct timer_list. 21 | type Timer struct { 22 | *time.Timer 23 | modifyingLock sync.RWMutex 24 | runningLock sync.Mutex 25 | isPending bool 26 | } 27 | 28 | func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { 29 | timer := &Timer{} 30 | timer.Timer = time.AfterFunc(time.Hour, func() { 31 | timer.runningLock.Lock() 32 | defer timer.runningLock.Unlock() 33 | 34 | timer.modifyingLock.Lock() 35 | if !timer.isPending { 36 | timer.modifyingLock.Unlock() 37 | return 38 | } 39 | timer.isPending = false 40 | timer.modifyingLock.Unlock() 41 | 42 | expirationFunction(peer) 43 | }) 44 | timer.Stop() 45 | return timer 46 | } 47 | 48 | func (timer *Timer) Mod(d time.Duration) { 49 | timer.modifyingLock.Lock() 50 | timer.isPending = true 51 | timer.Reset(d) 52 | timer.modifyingLock.Unlock() 53 | } 54 | 55 | func (timer *Timer) Del() { 56 | timer.modifyingLock.Lock() 57 | timer.isPending = false 58 | timer.Stop() 59 | timer.modifyingLock.Unlock() 60 | } 61 | 62 | func (timer *Timer) DelSync() { 63 | timer.Del() 64 | timer.runningLock.Lock() 65 | timer.Del() 66 | timer.runningLock.Unlock() 67 | } 68 | 69 | func (timer *Timer) IsPending() bool { 70 | timer.modifyingLock.RLock() 71 | defer timer.modifyingLock.RUnlock() 72 | return timer.isPending 73 | } 74 | 75 | func (peer *Peer) timersActive() bool { 76 | return peer.isRunning.Load() && peer.device != nil && peer.device.isUp() 77 | } 78 | 79 | func expiredRetransmitHandshake(peer *Peer) { 80 | if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes { 81 | peer.device.UpdateHandshakeState(HandshakeFail) 82 | peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) 83 | 84 | if peer.timersActive() { 85 | peer.timers.sendKeepalive.Del() 86 | } 87 | 88 | /* We drop all packets without a keypair and don't try again, 89 | * if we try unsuccessfully for too long to make a handshake. 90 | */ 91 | peer.FlushStagedPackets() 92 | 93 | /* We set a timer for destroying any residue that might be left 94 | * of a partial exchange. 95 | */ 96 | if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() { 97 | peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) 98 | } 99 | } else { 100 | peer.timers.handshakeAttempts.Add(1) 101 | peer.device.UpdateHandshakeState(HandshakeFail) 102 | peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) 103 | 104 | /* We clear the endpoint address src address, in case this is the cause of trouble. */ 105 | peer.Lock() 106 | if peer.endpoint != nil { 107 | peer.endpoint.ClearSrc() 108 | } 109 | peer.Unlock() 110 | 111 | peer.SendHandshakeInitiation(true) 112 | } 113 | } 114 | 115 | func expiredSendKeepalive(peer *Peer) { 116 | peer.SendKeepalive() 117 | if peer.timers.needAnotherKeepalive.Load() { 118 | peer.timers.needAnotherKeepalive.Store(false) 119 | if peer.timersActive() { 120 | peer.timers.sendKeepalive.Mod(KeepaliveTimeout) 121 | } 122 | } 123 | } 124 | 125 | func expiredNewHandshake(peer *Peer) { 126 | peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) 127 | /* We clear the endpoint address src address, in case this is the cause of trouble. */ 128 | peer.Lock() 129 | if peer.endpoint != nil { 130 | peer.endpoint.ClearSrc() 131 | } 132 | peer.Unlock() 133 | peer.SendHandshakeInitiation(false) 134 | } 135 | 136 | func expiredZeroKeyMaterial(peer *Peer) { 137 | peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds())) 138 | peer.ZeroAndFlushAll() 139 | } 140 | 141 | func expiredPersistentKeepalive(peer *Peer) { 142 | if peer.persistentKeepaliveInterval.Load() > 0 { 143 | peer.SendKeepalive() 144 | } 145 | } 146 | 147 | /* Should be called after an authenticated data packet is sent. */ 148 | func (peer *Peer) timersDataSent() { 149 | if peer.timersActive() && !peer.timers.newHandshake.IsPending() { 150 | peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) 151 | } 152 | } 153 | 154 | /* Should be called after an authenticated data packet is received. */ 155 | func (peer *Peer) timersDataReceived() { 156 | if peer.timersActive() { 157 | if !peer.timers.sendKeepalive.IsPending() { 158 | peer.timers.sendKeepalive.Mod(KeepaliveTimeout) 159 | } else { 160 | peer.timers.needAnotherKeepalive.Store(true) 161 | } 162 | } 163 | } 164 | 165 | /* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ 166 | func (peer *Peer) timersAnyAuthenticatedPacketSent() { 167 | if peer.timersActive() { 168 | peer.timers.sendKeepalive.Del() 169 | } 170 | } 171 | 172 | /* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ 173 | func (peer *Peer) timersAnyAuthenticatedPacketReceived() { 174 | if peer.timersActive() { 175 | peer.timers.newHandshake.Del() 176 | } 177 | } 178 | 179 | /* Should be called after a handshake initiation message is sent. */ 180 | func (peer *Peer) timersHandshakeInitiated() { 181 | if peer.timersActive() { 182 | peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) 183 | } 184 | } 185 | 186 | /* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ 187 | func (peer *Peer) timersHandshakeComplete() { 188 | if peer.timersActive() { 189 | peer.timers.retransmitHandshake.Del() 190 | } 191 | peer.timers.handshakeAttempts.Store(0) 192 | peer.timers.sentLastMinuteHandshake.Store(false) 193 | peer.lastHandshakeNano.Store(time.Now().UnixNano()) 194 | } 195 | 196 | /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ 197 | func (peer *Peer) timersSessionDerived() { 198 | if peer.timersActive() { 199 | peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) 200 | } 201 | } 202 | 203 | /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ 204 | func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { 205 | keepalive := peer.persistentKeepaliveInterval.Load() 206 | if keepalive > 0 && peer.timersActive() { 207 | peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) 208 | } 209 | } 210 | 211 | func (peer *Peer) timersInit() { 212 | peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake) 213 | peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive) 214 | peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) 215 | peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) 216 | peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) 217 | } 218 | 219 | func (peer *Peer) timersStart() { 220 | peer.timers.handshakeAttempts.Store(0) 221 | peer.timers.sentLastMinuteHandshake.Store(false) 222 | peer.timers.needAnotherKeepalive.Store(false) 223 | } 224 | 225 | func (peer *Peer) timersStop() { 226 | peer.timers.retransmitHandshake.DelSync() 227 | peer.timers.sendKeepalive.DelSync() 228 | peer.timers.newHandshake.DelSync() 229 | peer.timers.zeroKeyMaterial.DelSync() 230 | peer.timers.persistentKeepalive.DelSync() 231 | } 232 | -------------------------------------------------------------------------------- /device/allowedips.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "container/list" 10 | "encoding/binary" 11 | "errors" 12 | "math/bits" 13 | "net" 14 | "net/netip" 15 | "sync" 16 | "unsafe" 17 | ) 18 | 19 | type parentIndirection struct { 20 | parentBit **trieEntry 21 | parentBitType uint8 22 | } 23 | 24 | type trieEntry struct { 25 | peer *Peer 26 | child [2]*trieEntry 27 | parent parentIndirection 28 | cidr uint8 29 | bitAtByte uint8 30 | bitAtShift uint8 31 | bits []byte 32 | perPeerElem *list.Element 33 | } 34 | 35 | func commonBits(ip1, ip2 []byte) uint8 { 36 | size := len(ip1) 37 | if size == net.IPv4len { 38 | a := binary.BigEndian.Uint32(ip1) 39 | b := binary.BigEndian.Uint32(ip2) 40 | x := a ^ b 41 | return uint8(bits.LeadingZeros32(x)) 42 | } else if size == net.IPv6len { 43 | a := binary.BigEndian.Uint64(ip1) 44 | b := binary.BigEndian.Uint64(ip2) 45 | x := a ^ b 46 | if x != 0 { 47 | return uint8(bits.LeadingZeros64(x)) 48 | } 49 | a = binary.BigEndian.Uint64(ip1[8:]) 50 | b = binary.BigEndian.Uint64(ip2[8:]) 51 | x = a ^ b 52 | return 64 + uint8(bits.LeadingZeros64(x)) 53 | } else { 54 | panic("Wrong size bit string") 55 | } 56 | } 57 | 58 | func (node *trieEntry) addToPeerEntries() { 59 | node.perPeerElem = node.peer.trieEntries.PushBack(node) 60 | } 61 | 62 | func (node *trieEntry) removeFromPeerEntries() { 63 | if node.perPeerElem != nil { 64 | node.peer.trieEntries.Remove(node.perPeerElem) 65 | node.perPeerElem = nil 66 | } 67 | } 68 | 69 | func (node *trieEntry) choose(ip []byte) byte { 70 | return (ip[node.bitAtByte] >> node.bitAtShift) & 1 71 | } 72 | 73 | func (node *trieEntry) maskSelf() { 74 | mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) 75 | for i := 0; i < len(mask); i++ { 76 | node.bits[i] &= mask[i] 77 | } 78 | } 79 | 80 | func (node *trieEntry) zeroizePointers() { 81 | // Make the garbage collector's life slightly easier 82 | node.peer = nil 83 | node.child[0] = nil 84 | node.child[1] = nil 85 | node.parent.parentBit = nil 86 | } 87 | 88 | func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { 89 | for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { 90 | parent = node 91 | if parent.cidr == cidr { 92 | exact = true 93 | return 94 | } 95 | bit := node.choose(ip) 96 | node = node.child[bit] 97 | } 98 | return 99 | } 100 | 101 | func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { 102 | if *trie.parentBit == nil { 103 | node := &trieEntry{ 104 | peer: peer, 105 | parent: trie, 106 | bits: ip, 107 | cidr: cidr, 108 | bitAtByte: cidr / 8, 109 | bitAtShift: 7 - (cidr % 8), 110 | } 111 | node.maskSelf() 112 | node.addToPeerEntries() 113 | *trie.parentBit = node 114 | return 115 | } 116 | node, exact := (*trie.parentBit).nodePlacement(ip, cidr) 117 | if exact { 118 | node.removeFromPeerEntries() 119 | node.peer = peer 120 | node.addToPeerEntries() 121 | return 122 | } 123 | 124 | newNode := &trieEntry{ 125 | peer: peer, 126 | bits: ip, 127 | cidr: cidr, 128 | bitAtByte: cidr / 8, 129 | bitAtShift: 7 - (cidr % 8), 130 | } 131 | newNode.maskSelf() 132 | newNode.addToPeerEntries() 133 | 134 | var down *trieEntry 135 | if node == nil { 136 | down = *trie.parentBit 137 | } else { 138 | bit := node.choose(ip) 139 | down = node.child[bit] 140 | if down == nil { 141 | newNode.parent = parentIndirection{&node.child[bit], bit} 142 | node.child[bit] = newNode 143 | return 144 | } 145 | } 146 | common := commonBits(down.bits, ip) 147 | if common < cidr { 148 | cidr = common 149 | } 150 | parent := node 151 | 152 | if newNode.cidr == cidr { 153 | bit := newNode.choose(down.bits) 154 | down.parent = parentIndirection{&newNode.child[bit], bit} 155 | newNode.child[bit] = down 156 | if parent == nil { 157 | newNode.parent = trie 158 | *trie.parentBit = newNode 159 | } else { 160 | bit := parent.choose(newNode.bits) 161 | newNode.parent = parentIndirection{&parent.child[bit], bit} 162 | parent.child[bit] = newNode 163 | } 164 | return 165 | } 166 | 167 | node = &trieEntry{ 168 | bits: append([]byte{}, newNode.bits...), 169 | cidr: cidr, 170 | bitAtByte: cidr / 8, 171 | bitAtShift: 7 - (cidr % 8), 172 | } 173 | node.maskSelf() 174 | 175 | bit := node.choose(down.bits) 176 | down.parent = parentIndirection{&node.child[bit], bit} 177 | node.child[bit] = down 178 | bit = node.choose(newNode.bits) 179 | newNode.parent = parentIndirection{&node.child[bit], bit} 180 | node.child[bit] = newNode 181 | if parent == nil { 182 | node.parent = trie 183 | *trie.parentBit = node 184 | } else { 185 | bit := parent.choose(node.bits) 186 | node.parent = parentIndirection{&parent.child[bit], bit} 187 | parent.child[bit] = node 188 | } 189 | } 190 | 191 | func (node *trieEntry) lookup(ip []byte) *Peer { 192 | var found *Peer 193 | size := uint8(len(ip)) 194 | for node != nil && commonBits(node.bits, ip) >= node.cidr { 195 | if node.peer != nil { 196 | found = node.peer 197 | } 198 | if node.bitAtByte == size { 199 | break 200 | } 201 | bit := node.choose(ip) 202 | node = node.child[bit] 203 | } 204 | return found 205 | } 206 | 207 | type AllowedIPs struct { 208 | IPv4 *trieEntry 209 | IPv6 *trieEntry 210 | mutex sync.RWMutex 211 | } 212 | 213 | func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { 214 | table.mutex.RLock() 215 | defer table.mutex.RUnlock() 216 | 217 | for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { 218 | node := elem.Value.(*trieEntry) 219 | a, _ := netip.AddrFromSlice(node.bits) 220 | if !cb(netip.PrefixFrom(a, int(node.cidr))) { 221 | return 222 | } 223 | } 224 | } 225 | 226 | func (table *AllowedIPs) RemoveByPeer(peer *Peer) { 227 | table.mutex.Lock() 228 | defer table.mutex.Unlock() 229 | 230 | var next *list.Element 231 | for elem := peer.trieEntries.Front(); elem != nil; elem = next { 232 | next = elem.Next() 233 | node := elem.Value.(*trieEntry) 234 | 235 | node.removeFromPeerEntries() 236 | node.peer = nil 237 | if node.child[0] != nil && node.child[1] != nil { 238 | continue 239 | } 240 | bit := 0 241 | if node.child[0] == nil { 242 | bit = 1 243 | } 244 | child := node.child[bit] 245 | if child != nil { 246 | child.parent = node.parent 247 | } 248 | *node.parent.parentBit = child 249 | if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { 250 | node.zeroizePointers() 251 | continue 252 | } 253 | parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) 254 | if parent.peer != nil { 255 | node.zeroizePointers() 256 | continue 257 | } 258 | child = parent.child[node.parent.parentBitType^1] 259 | if child != nil { 260 | child.parent = parent.parent 261 | } 262 | *parent.parent.parentBit = child 263 | node.zeroizePointers() 264 | parent.zeroizePointers() 265 | } 266 | } 267 | 268 | func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { 269 | table.mutex.Lock() 270 | defer table.mutex.Unlock() 271 | 272 | if prefix.Addr().Is6() { 273 | ip := prefix.Addr().As16() 274 | parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) 275 | } else if prefix.Addr().Is4() { 276 | ip := prefix.Addr().As4() 277 | parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) 278 | } else { 279 | panic(errors.New("inserting unknown address type")) 280 | } 281 | } 282 | 283 | func (table *AllowedIPs) Lookup(ip []byte) *Peer { 284 | table.mutex.RLock() 285 | defer table.mutex.RUnlock() 286 | switch len(ip) { 287 | case net.IPv6len: 288 | return table.IPv6.lookup(ip) 289 | case net.IPv4len: 290 | return table.IPv4.lookup(ip) 291 | default: 292 | panic(errors.New("looking up unknown address type")) 293 | } 294 | } 295 | --------------------------------------------------------------------------------