├── .gitignore
├── version.go
├── CODEOWNERS
├── rwcancel
├── rwcancel_stub.go
└── rwcancel.go
├── device
├── race_disabled_test.go
├── race_enabled_test.go
├── sticky_default.go
├── ip.go
├── queueconstants_android.go
├── queueconstants_windows.go
├── devicestate_string.go
├── queueconstants_default.go
├── mobilequirks.go
├── queueconstants_ios.go
├── endpoint_test.go
├── keypair.go
├── bind_test.go
├── constants.go
├── tun.go
├── logger.go
├── noise-types.go
├── pools_test.go
├── indextable.go
├── pools.go
├── kdf_test.go
├── noise-helpers.go
├── allowedips_rand_test.go
├── channels.go
├── noise_test.go
├── statemanager_test.go
├── cookie.go
├── sticky_linux.go
├── cookie_test.go
├── allowedips_test.go
├── peer.go
├── timers.go
└── allowedips.go
├── conn
├── default.go
├── mark_default.go
├── boundif_android.go
├── mark_unix.go
├── server_name_utils
│ ├── consistent_hash_test.go
│ ├── consistent_hash.go
│ └── server_name_utils.go
├── bindtest
│ └── bindtest.go
├── conn.go
├── tcp_tls_utils.go
└── bind_std.go
├── ipc
├── uapi_js.go
├── uapi_unix.go
├── uapi_windows.go
├── uapi_linux.go
├── uapi_bsd.go
└── namedpipe
│ └── file.go
├── tun
├── operateonfd.go
├── tun.go
├── netstack
│ └── examples
│ │ ├── http_client.go
│ │ ├── http_server.go
│ │ └── ping_client.go
├── alignment_windows_test.go
├── tuntest
│ └── tuntest.go
├── tun_windows.go
├── tun_openbsd.go
└── tun_darwin.go
├── .gitlab-ci.yml
├── Makefile
├── go.mod
├── tai64n
├── tai64n.go
└── tai64n_test.go
├── LICENSE
├── format_test.go
├── replay
├── replay.go
└── replay_test.go
├── main_windows.go
├── ratelimiter
├── ratelimiter_test.go
└── ratelimiter.go
├── README.md
├── go.sum
└── main.go
/.gitignore:
--------------------------------------------------------------------------------
1 | wireguard-go
2 |
--------------------------------------------------------------------------------
/version.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | const Version = "0.0.20230223"
4 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # ownership: loose
2 | * @ProtonVPN/groups/android-developers
3 |
--------------------------------------------------------------------------------
/rwcancel/rwcancel_stub.go:
--------------------------------------------------------------------------------
1 | //go:build windows || js
2 |
3 | // SPDX-License-Identifier: MIT
4 |
5 | package rwcancel
6 |
7 | type RWCancel struct{}
8 |
9 | func (*RWCancel) Cancel() {}
10 |
--------------------------------------------------------------------------------
/device/race_disabled_test.go:
--------------------------------------------------------------------------------
1 | //go:build !race
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package device
9 |
10 | const raceEnabled = false
11 |
--------------------------------------------------------------------------------
/device/race_enabled_test.go:
--------------------------------------------------------------------------------
1 | //go:build race
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package device
9 |
10 | const raceEnabled = true
11 |
--------------------------------------------------------------------------------
/conn/default.go:
--------------------------------------------------------------------------------
1 | //go:build !linux && !windows
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package conn
9 |
10 | func NewDefaultBind() Bind { return NewStdNetBind() }
11 |
--------------------------------------------------------------------------------
/conn/mark_default.go:
--------------------------------------------------------------------------------
1 | //go:build !linux && !openbsd && !freebsd
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package conn
9 |
10 | func (bind *StdNetBind) SetMark(mark uint32) error {
11 | return nil
12 | }
13 |
--------------------------------------------------------------------------------
/device/sticky_default.go:
--------------------------------------------------------------------------------
1 | //go:build !linux
2 |
3 | package device
4 |
5 | import (
6 | "golang.zx2c4.com/wireguard/conn"
7 | "golang.zx2c4.com/wireguard/rwcancel"
8 | )
9 |
10 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
11 | return nil, nil
12 | }
13 |
--------------------------------------------------------------------------------
/ipc/uapi_js.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package ipc
7 |
8 | // Made up sentinel error codes for the js/wasm platform.
9 | const (
10 | IpcErrorIO = 1
11 | IpcErrorInvalid = 2
12 | IpcErrorPortInUse = 3
13 | IpcErrorUnknown = 4
14 | IpcErrorProtocol = 5
15 | )
16 |
--------------------------------------------------------------------------------
/device/ip.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "net"
10 | )
11 |
12 | const (
13 | IPv4offsetTotalLength = 2
14 | IPv4offsetSrc = 12
15 | IPv4offsetDst = IPv4offsetSrc + net.IPv4len
16 | )
17 |
18 | const (
19 | IPv6offsetPayloadLength = 4
20 | IPv6offsetSrc = 8
21 | IPv6offsetDst = IPv6offsetSrc + net.IPv6len
22 | )
23 |
--------------------------------------------------------------------------------
/device/queueconstants_android.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | /* Reduce memory consumption for Android */
9 |
10 | const (
11 | QueueStagedSize = 128
12 | QueueOutboundSize = 1024
13 | QueueInboundSize = 1024
14 | QueueHandshakeSize = 1024
15 | MaxSegmentSize = 2200
16 | PreallocatedBuffersPerPool = 4096
17 | )
18 |
--------------------------------------------------------------------------------
/device/queueconstants_windows.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | const (
9 | QueueStagedSize = 128
10 | QueueOutboundSize = 1024
11 | QueueInboundSize = 1024
12 | QueueHandshakeSize = 1024
13 | MaxSegmentSize = 2048 - 32 // largest possible UDP datagram
14 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
15 | )
16 |
--------------------------------------------------------------------------------
/device/devicestate_string.go:
--------------------------------------------------------------------------------
1 | // Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT.
2 |
3 | package device
4 |
5 | import "strconv"
6 |
7 | const _deviceState_name = "DownUpClosed"
8 |
9 | var _deviceState_index = [...]uint8{0, 4, 6, 12}
10 |
11 | func (i deviceState) String() string {
12 | if i >= deviceState(len(_deviceState_index)-1) {
13 | return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")"
14 | }
15 | return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]]
16 | }
17 |
--------------------------------------------------------------------------------
/device/queueconstants_default.go:
--------------------------------------------------------------------------------
1 | //go:build !android && !ios && !windows
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package device
9 |
10 | const (
11 | QueueStagedSize = 128
12 | QueueOutboundSize = 1024
13 | QueueInboundSize = 1024
14 | QueueHandshakeSize = 1024
15 | MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
16 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
17 | )
18 |
--------------------------------------------------------------------------------
/tun/operateonfd.go:
--------------------------------------------------------------------------------
1 | //go:build darwin || freebsd
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package tun
9 |
10 | import (
11 | "fmt"
12 | )
13 |
14 | func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) {
15 | sysconn, err := tun.tunFile.SyscallConn()
16 | if err != nil {
17 | tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error())
18 | return
19 | }
20 | err = sysconn.Control(fn)
21 | if err != nil {
22 | tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error())
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/device/mobilequirks.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | // DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
9 | // though it will try to deal with it, and race maybe, if called after.
10 | func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
11 | device.net.brokenRoaming = true
12 | device.peers.RLock()
13 | for _, peer := range device.peers.keyMap {
14 | peer.Lock()
15 | peer.disableRoaming = peer.endpoint != nil
16 | peer.Unlock()
17 | }
18 | device.peers.RUnlock()
19 | }
20 |
--------------------------------------------------------------------------------
/device/queueconstants_ios.go:
--------------------------------------------------------------------------------
1 | //go:build ios
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package device
9 |
10 | // Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
11 | // These are vars instead of consts, because heavier network extensions might want to reduce
12 | // them further.
13 | var (
14 | QueueStagedSize = 128
15 | QueueOutboundSize = 1024
16 | QueueInboundSize = 1024
17 | QueueHandshakeSize = 1024
18 | PreallocatedBuffersPerPool uint32 = 1024
19 | )
20 |
21 | const MaxSegmentSize = 1700
22 |
--------------------------------------------------------------------------------
/.gitlab-ci.yml:
--------------------------------------------------------------------------------
1 | default:
2 | image: ${PROTON_CI_REGISTRY}/android-shared/docker-android/oci-ndk:v2.1.7
3 | before_script:
4 | - if [[ -f /load-env.sh ]]; then source /load-env.sh; fi
5 |
6 | sync-wireguard-go:
7 | tags:
8 | - shared-medium
9 | rules:
10 | - if: $CI_COMMIT_BRANCH == "release/android"
11 | when: manual
12 | - if: '$OPENSOURCE_GO'
13 | when: always
14 | - when: never
15 | allow_failure: true
16 | before_script:
17 | - !reference [ default, before_script ]
18 | - apt update && apt-get install -y connect-proxy
19 | script:
20 | - git clone "$CI_REPOSITORY_URL" --branch "$CI_COMMIT_BRANCH" _APP_CLONE;
21 | - cd _APP_CLONE
22 | - git remote add public git@github.com:ProtonVPN/wireguard-go.git
23 | - git push public "$CI_COMMIT_BRANCH" -f
24 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | PREFIX ?= /usr
2 | DESTDIR ?=
3 | BINDIR ?= $(PREFIX)/bin
4 | export GO111MODULE := on
5 |
6 | all: generate-version-and-build
7 |
8 | MAKEFLAGS += --no-print-directory
9 |
10 | generate-version-and-build:
11 | @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
12 | tag="$$(git describe --dirty 2>/dev/null)" && \
13 | ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
14 | [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
15 | echo "$$ver" > version.go && \
16 | git update-index --assume-unchanged version.go || true
17 | @$(MAKE) wireguard-go
18 |
19 | wireguard-go: $(wildcard *.go) $(wildcard */*.go)
20 | go build -v -o "$@"
21 |
22 | install: wireguard-go
23 | @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
24 |
25 | test:
26 | go test ./...
27 |
28 | clean:
29 | rm -f wireguard-go
30 |
31 | .PHONY: all clean test install generate-version-and-build
32 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module golang.zx2c4.com/wireguard
2 |
3 | go 1.23.4
4 |
5 | require (
6 | github.com/refraction-networking/utls v1.6.7
7 | github.com/stretchr/testify v1.8.0
8 | golang.org/x/crypto v0.21.0
9 | golang.org/x/net v0.23.0
10 | golang.org/x/sys v0.18.0
11 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
12 | gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
13 | )
14 |
15 | require (
16 | github.com/google/btree v1.0.1 // indirect
17 | golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
18 | )
19 |
20 | require (
21 | github.com/andybalholm/brotli v1.0.6 // indirect
22 | github.com/cloudflare/circl v1.3.7 // indirect
23 | github.com/davecgh/go-spew v1.1.1 // indirect
24 | github.com/klauspost/compress v1.17.4 // indirect
25 | github.com/kr/pretty v0.3.1 // indirect
26 | github.com/pmezard/go-difflib v1.0.0 // indirect
27 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
28 | gopkg.in/yaml.v3 v3.0.1 // indirect
29 | )
30 |
--------------------------------------------------------------------------------
/conn/boundif_android.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package conn
7 |
8 | func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
9 | sysconn, err := bind.ipv4.SyscallConn()
10 | if err != nil {
11 | return -1, err
12 | }
13 | err = sysconn.Control(func(f uintptr) {
14 | fd = int(f)
15 | })
16 | if err != nil {
17 | return -1, err
18 | }
19 | return
20 | }
21 |
22 | func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
23 | sysconn, err := bind.ipv6.SyscallConn()
24 | if err != nil {
25 | return -1, err
26 | }
27 | err = sysconn.Control(func(f uintptr) {
28 | fd = int(f)
29 | })
30 | if err != nil {
31 | return -1, err
32 | }
33 | return
34 | }
35 |
36 | func (bind *StdNetBindTcp) PeekLookAtSocketFd4() (fd int, err error) {
37 | return -1, err
38 | }
39 |
40 | func (bind *StdNetBindTcp) PeekLookAtSocketFd6() (fd int, err error) {
41 | return -1, err
42 | }
43 |
--------------------------------------------------------------------------------
/tun/tun.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package tun
7 |
8 | import (
9 | "os"
10 | )
11 |
12 | type Event int
13 |
14 | const (
15 | EventUp = 1 << iota
16 | EventDown
17 | EventMTUUpdate
18 | )
19 |
20 | type Device interface {
21 | File() *os.File // returns the file descriptor of the device
22 | Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
23 | Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
24 | Flush() error // flush all previous writes to the device
25 | MTU() (int, error) // returns the MTU of the device
26 | Name() (string, error) // fetches and returns the current name
27 | Events() <-chan Event // returns a constant channel of events related to the device
28 | Close() error // stops the device and closes the event channel
29 | }
30 |
--------------------------------------------------------------------------------
/tai64n/tai64n.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package tai64n
7 |
8 | import (
9 | "bytes"
10 | "encoding/binary"
11 | "time"
12 | )
13 |
14 | const (
15 | TimestampSize = 12
16 | base = uint64(0x400000000000000a)
17 | whitenerMask = uint32(0x1000000 - 1)
18 | )
19 |
20 | type Timestamp [TimestampSize]byte
21 |
22 | func stamp(t time.Time) Timestamp {
23 | var tai64n Timestamp
24 | secs := base + uint64(t.Unix())
25 | nano := uint32(t.Nanosecond()) &^ whitenerMask
26 | binary.BigEndian.PutUint64(tai64n[:], secs)
27 | binary.BigEndian.PutUint32(tai64n[8:], nano)
28 | return tai64n
29 | }
30 |
31 | func Now() Timestamp {
32 | return stamp(time.Now())
33 | }
34 |
35 | func (t1 Timestamp) After(t2 Timestamp) bool {
36 | return bytes.Compare(t1[:], t2[:]) > 0
37 | }
38 |
39 | func (t Timestamp) String() string {
40 | return time.Unix(int64(binary.BigEndian.Uint64(t[:8])-base), int64(binary.BigEndian.Uint32(t[8:12]))).String()
41 | }
42 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Permission is hereby granted, free of charge, to any person obtaining a copy of
2 | this software and associated documentation files (the "Software"), to deal in
3 | the Software without restriction, including without limitation the rights to
4 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
5 | of the Software, and to permit persons to whom the Software is furnished to do
6 | so, subject to the following conditions:
7 |
8 | The above copyright notice and this permission notice shall be included in all
9 | copies or substantial portions of the Software.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17 | SOFTWARE.
18 |
--------------------------------------------------------------------------------
/device/endpoint_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "math/rand"
10 | "net/netip"
11 | )
12 |
13 | type DummyEndpoint struct {
14 | src, dst netip.Addr
15 | }
16 |
17 | func CreateDummyEndpoint() (*DummyEndpoint, error) {
18 | var src, dst [16]byte
19 | if _, err := rand.Read(src[:]); err != nil {
20 | return nil, err
21 | }
22 | _, err := rand.Read(dst[:])
23 | return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
24 | }
25 |
26 | func (e *DummyEndpoint) ClearSrc() {}
27 |
28 | func (e *DummyEndpoint) SrcToString() string {
29 | return netip.AddrPortFrom(e.SrcIP(), 1000).String()
30 | }
31 |
32 | func (e *DummyEndpoint) DstToString() string {
33 | return netip.AddrPortFrom(e.DstIP(), 1000).String()
34 | }
35 |
36 | func (e *DummyEndpoint) DstToBytes() []byte {
37 | out := e.DstIP().AsSlice()
38 | out = append(out, byte(1000&0xff))
39 | out = append(out, byte((1000>>8)&0xff))
40 | return out
41 | }
42 |
43 | func (e *DummyEndpoint) DstIP() netip.Addr {
44 | return e.dst
45 | }
46 |
47 | func (e *DummyEndpoint) SrcIP() netip.Addr {
48 | return e.src
49 | }
50 |
--------------------------------------------------------------------------------
/device/keypair.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "crypto/cipher"
10 | "sync"
11 | "sync/atomic"
12 | "time"
13 |
14 | "golang.zx2c4.com/wireguard/replay"
15 | )
16 |
17 | /* Due to limitations in Go and /x/crypto there is currently
18 | * no way to ensure that key material is securely ereased in memory.
19 | *
20 | * Since this may harm the forward secrecy property,
21 | * we plan to resolve this issue; whenever Go allows us to do so.
22 | */
23 |
24 | type Keypair struct {
25 | sendNonce atomic.Uint64
26 | send cipher.AEAD
27 | receive cipher.AEAD
28 | replayFilter replay.Filter
29 | isInitiator bool
30 | created time.Time
31 | localIndex uint32
32 | remoteIndex uint32
33 | }
34 |
35 | type Keypairs struct {
36 | sync.RWMutex
37 | current *Keypair
38 | previous *Keypair
39 | next atomic.Pointer[Keypair]
40 | }
41 |
42 | func (kp *Keypairs) Current() *Keypair {
43 | kp.RLock()
44 | defer kp.RUnlock()
45 | return kp.current
46 | }
47 |
48 | func (device *Device) DeleteKeypair(key *Keypair) {
49 | if key != nil {
50 | device.indexTable.Delete(key.localIndex)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/format_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 | package main
6 |
7 | import (
8 | "bytes"
9 | "go/format"
10 | "io/fs"
11 | "os"
12 | "path/filepath"
13 | "runtime"
14 | "sync"
15 | "testing"
16 | )
17 |
18 | func TestFormatting(t *testing.T) {
19 | var wg sync.WaitGroup
20 | filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
21 | if err != nil {
22 | t.Errorf("unable to walk %s: %v", path, err)
23 | return nil
24 | }
25 | if d.IsDir() || filepath.Ext(path) != ".go" {
26 | return nil
27 | }
28 | wg.Add(1)
29 | go func(path string) {
30 | defer wg.Done()
31 | src, err := os.ReadFile(path)
32 | if err != nil {
33 | t.Errorf("unable to read %s: %v", path, err)
34 | return
35 | }
36 | if runtime.GOOS == "windows" {
37 | src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'})
38 | }
39 | formatted, err := format.Source(src)
40 | if err != nil {
41 | t.Errorf("unable to format %s: %v", path, err)
42 | return
43 | }
44 | if !bytes.Equal(src, formatted) {
45 | t.Errorf("unformatted code: %s", path)
46 | }
47 | }(path)
48 | return nil
49 | })
50 | wg.Wait()
51 | }
52 |
--------------------------------------------------------------------------------
/device/bind_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "errors"
10 |
11 | "golang.zx2c4.com/wireguard/conn"
12 | )
13 |
14 | type DummyDatagram struct {
15 | msg []byte
16 | endpoint conn.Endpoint
17 | }
18 |
19 | type DummyBind struct {
20 | in6 chan DummyDatagram
21 | in4 chan DummyDatagram
22 | closed bool
23 | }
24 |
25 | func (b *DummyBind) SetMark(v uint32) error {
26 | return nil
27 | }
28 |
29 | func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
30 | datagram, ok := <-b.in6
31 | if !ok {
32 | return 0, nil, errors.New("closed")
33 | }
34 | copy(buff, datagram.msg)
35 | return len(datagram.msg), datagram.endpoint, nil
36 | }
37 |
38 | func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
39 | datagram, ok := <-b.in4
40 | if !ok {
41 | return 0, nil, errors.New("closed")
42 | }
43 | copy(buff, datagram.msg)
44 | return len(datagram.msg), datagram.endpoint, nil
45 | }
46 |
47 | func (b *DummyBind) Close() error {
48 | close(b.in6)
49 | close(b.in4)
50 | b.closed = true
51 | return nil
52 | }
53 |
54 | func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
55 | return nil
56 | }
57 |
--------------------------------------------------------------------------------
/tai64n/tai64n_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package tai64n
7 |
8 | import (
9 | "testing"
10 | "time"
11 | )
12 |
13 | // Test that timestamps are monotonic as required by Wireguard and that
14 | // nanosecond-level information is whitened to prevent side channel attacks.
15 | func TestMonotonic(t *testing.T) {
16 | startTime := time.Unix(0, 123456789) // a nontrivial bit pattern
17 | // Whitening should reduce timestamp granularity
18 | // to more than 10 but fewer than 20 milliseconds.
19 | tests := []struct {
20 | name string
21 | t1, t2 time.Time
22 | wantAfter bool
23 | }{
24 | {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false},
25 | {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false},
26 | {"after_1_ms", startTime, startTime.Add(time.Millisecond), false},
27 | {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false},
28 | {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true},
29 | }
30 |
31 | for _, tt := range tests {
32 | t.Run(tt.name, func(t *testing.T) {
33 | ts1, ts2 := stamp(tt.t1), stamp(tt.t2)
34 | got := ts2.After(ts1)
35 | if got != tt.wantAfter {
36 | t.Errorf("after = %v; want %v", got, tt.wantAfter)
37 | }
38 | })
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/device/constants.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "time"
10 | )
11 |
12 | /* Specification constants */
13 |
14 | const (
15 | RekeyAfterMessages = (1 << 60)
16 | RejectAfterMessages = (1 << 64) - (1 << 13) - 1
17 | RekeyAfterTime = time.Second * 120
18 | RekeyAttemptTime = time.Second * 90
19 | RekeyTimeout = time.Second * 5
20 | MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */
21 | RekeyTimeoutJitterMaxMs = 334
22 | RejectAfterTime = time.Second * 180
23 | KeepaliveTimeout = time.Second * 10
24 | CookieRefreshTime = time.Second * 120
25 | HandshakeInitationRate = time.Second / 50
26 | PaddingMultiple = 16
27 | )
28 |
29 | const (
30 | MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
31 | MaxMessageSize = MaxSegmentSize // maximum size of transport message
32 | MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content
33 | )
34 |
35 | /* Implementation constants */
36 |
37 | const (
38 | UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
39 | MaxPeers = 1 << 16 // maximum number of configured peers
40 | )
41 |
--------------------------------------------------------------------------------
/device/tun.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "fmt"
10 |
11 | "golang.zx2c4.com/wireguard/tun"
12 | )
13 |
14 | const DefaultMTU = 1420
15 |
16 | func (device *Device) RoutineTUNEventReader() {
17 | device.log.Verbosef("Routine: event worker - started")
18 |
19 | for event := range device.tun.device.Events() {
20 | if event&tun.EventMTUUpdate != 0 {
21 | mtu, err := device.tun.device.MTU()
22 | if err != nil {
23 | device.log.Errorf("Failed to load updated MTU of device: %v", err)
24 | continue
25 | }
26 | if mtu < 0 {
27 | device.log.Errorf("MTU not updated to negative value: %v", mtu)
28 | continue
29 | }
30 | var tooLarge string
31 | if mtu > MaxContentSize {
32 | tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
33 | mtu = MaxContentSize
34 | }
35 | old := device.tun.mtu.Swap(int32(mtu))
36 | if int(old) != mtu {
37 | device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
38 | }
39 | }
40 |
41 | if event&tun.EventUp != 0 {
42 | device.log.Verbosef("Interface up requested")
43 | device.Up()
44 | }
45 |
46 | if event&tun.EventDown != 0 {
47 | device.log.Verbosef("Interface down requested")
48 | device.Down()
49 | }
50 | }
51 |
52 | device.log.Verbosef("Routine: event worker - stopped")
53 | }
54 |
--------------------------------------------------------------------------------
/tun/netstack/examples/http_client.go:
--------------------------------------------------------------------------------
1 | //go:build ignore
2 | // +build ignore
3 |
4 | /* SPDX-License-Identifier: MIT
5 | *
6 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
7 | */
8 |
9 | package main
10 |
11 | import (
12 | "io"
13 | "log"
14 | "net/http"
15 | "net/netip"
16 |
17 | "golang.zx2c4.com/wireguard/conn"
18 | "golang.zx2c4.com/wireguard/device"
19 | "golang.zx2c4.com/wireguard/tun/netstack"
20 | )
21 |
22 | func main() {
23 | tun, tnet, err := netstack.CreateNetTUN(
24 | []netip.Addr{netip.MustParseAddr("192.168.4.28")},
25 | []netip.Addr{netip.MustParseAddr("8.8.8.8")},
26 | 1420)
27 | if err != nil {
28 | log.Panic(err)
29 | }
30 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
31 | err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379
32 | public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28
33 | allowed_ip=0.0.0.0/0
34 | endpoint=127.0.0.1:58120
35 | `)
36 | err = dev.Up()
37 | if err != nil {
38 | log.Panic(err)
39 | }
40 |
41 | client := http.Client{
42 | Transport: &http.Transport{
43 | DialContext: tnet.DialContext,
44 | },
45 | }
46 | resp, err := client.Get("http://192.168.4.29/")
47 | if err != nil {
48 | log.Panic(err)
49 | }
50 | body, err := io.ReadAll(resp.Body)
51 | if err != nil {
52 | log.Panic(err)
53 | }
54 | log.Println(string(body))
55 | }
56 |
--------------------------------------------------------------------------------
/device/logger.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "golang.zx2c4.com/wireguard/conn"
10 | "log"
11 | "os"
12 | )
13 |
14 | // A Logger provides logging for a Device.
15 | // The functions are Printf-style functions.
16 | // They must be safe for concurrent use.
17 | // They do not require a trailing newline in the format.
18 | // If nil, that level of logging will be silent.
19 | type Logger struct {
20 | conn.Logger
21 | }
22 |
23 | // Log levels for use with NewLogger.
24 | const (
25 | LogLevelSilent = iota
26 | LogLevelError
27 | LogLevelVerbose
28 | )
29 |
30 | // Function for use in Logger for discarding logged lines.
31 | func DiscardLogf(format string, args ...any) {}
32 |
33 | // NewLogger constructs a Logger that writes to stdout.
34 | // It logs at the specified log level and above.
35 | // It decorates log lines with the log level, date, time, and prepend.
36 | func NewLogger(level int, prepend string) *Logger {
37 | logger := &Logger{conn.Logger{Verbosef: DiscardLogf, Errorf: DiscardLogf}}
38 | logf := func(prefix string) func(string, ...any) {
39 | return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
40 | }
41 | if level >= LogLevelVerbose {
42 | logger.Verbosef = logf("DEBUG")
43 | }
44 | if level >= LogLevelError {
45 | logger.Errorf = logf("ERROR")
46 | }
47 | return logger
48 | }
49 |
--------------------------------------------------------------------------------
/conn/mark_unix.go:
--------------------------------------------------------------------------------
1 | //go:build linux || openbsd || freebsd
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package conn
9 |
10 | import (
11 | "runtime"
12 |
13 | "golang.org/x/sys/unix"
14 | )
15 |
16 | var fwmarkIoctl int
17 |
18 | func init() {
19 | switch runtime.GOOS {
20 | case "linux", "android":
21 | fwmarkIoctl = 36 /* unix.SO_MARK */
22 | case "freebsd":
23 | fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
24 | case "openbsd":
25 | fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
26 | }
27 | }
28 |
29 | func (bind *StdNetBind) SetMark(mark uint32) error {
30 | var operr error
31 | if fwmarkIoctl == 0 {
32 | return nil
33 | }
34 | if bind.ipv4 != nil {
35 | fd, err := bind.ipv4.SyscallConn()
36 | if err != nil {
37 | return err
38 | }
39 | err = fd.Control(func(fd uintptr) {
40 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
41 | })
42 | if err == nil {
43 | err = operr
44 | }
45 | if err != nil {
46 | return err
47 | }
48 | }
49 | if bind.ipv6 != nil {
50 | fd, err := bind.ipv6.SyscallConn()
51 | if err != nil {
52 | return err
53 | }
54 | err = fd.Control(func(fd uintptr) {
55 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
56 | })
57 | if err == nil {
58 | err = operr
59 | }
60 | if err != nil {
61 | return err
62 | }
63 | }
64 | return nil
65 | }
66 |
--------------------------------------------------------------------------------
/tun/netstack/examples/http_server.go:
--------------------------------------------------------------------------------
1 | //go:build ignore
2 | // +build ignore
3 |
4 | /* SPDX-License-Identifier: MIT
5 | *
6 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
7 | */
8 |
9 | package main
10 |
11 | import (
12 | "io"
13 | "log"
14 | "net"
15 | "net/http"
16 | "net/netip"
17 |
18 | "golang.zx2c4.com/wireguard/conn"
19 | "golang.zx2c4.com/wireguard/device"
20 | "golang.zx2c4.com/wireguard/tun/netstack"
21 | )
22 |
23 | func main() {
24 | tun, tnet, err := netstack.CreateNetTUN(
25 | []netip.Addr{netip.MustParseAddr("192.168.4.29")},
26 | []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
27 | 1420,
28 | )
29 | if err != nil {
30 | log.Panic(err)
31 | }
32 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
33 | dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641
34 | listen_port=58120
35 | public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c
36 | allowed_ip=192.168.4.28/32
37 | persistent_keepalive_interval=25
38 | `)
39 | dev.Up()
40 | listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80})
41 | if err != nil {
42 | log.Panicln(err)
43 | }
44 | http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
45 | log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent())
46 | io.WriteString(writer, "Hello from userspace TCP!")
47 | })
48 | err = http.Serve(listener, nil)
49 | if err != nil {
50 | log.Panicln(err)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/ipc/uapi_unix.go:
--------------------------------------------------------------------------------
1 | //go:build linux || darwin || freebsd || openbsd
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package ipc
9 |
10 | import (
11 | "errors"
12 | "fmt"
13 | "net"
14 | "os"
15 |
16 | "golang.org/x/sys/unix"
17 | )
18 |
19 | const (
20 | IpcErrorIO = -int64(unix.EIO)
21 | IpcErrorProtocol = -int64(unix.EPROTO)
22 | IpcErrorInvalid = -int64(unix.EINVAL)
23 | IpcErrorPortInUse = -int64(unix.EADDRINUSE)
24 | IpcErrorUnknown = -55 // ENOANO
25 | )
26 |
27 | // socketDirectory is variable because it is modified by a linker
28 | // flag in wireguard-android.
29 | var socketDirectory = "/var/run/wireguard"
30 |
31 | func sockPath(iface string) string {
32 | return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
33 | }
34 |
35 | func UAPIOpen(name string) (*os.File, error) {
36 | if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
37 | return nil, err
38 | }
39 |
40 | socketPath := sockPath(name)
41 | addr, err := net.ResolveUnixAddr("unix", socketPath)
42 | if err != nil {
43 | return nil, err
44 | }
45 |
46 | oldUmask := unix.Umask(0o077)
47 | defer unix.Umask(oldUmask)
48 |
49 | listener, err := net.ListenUnix("unix", addr)
50 | if err == nil {
51 | return listener.File()
52 | }
53 |
54 | // Test socket, if not in use cleanup and try again.
55 | if _, err := net.Dial("unix", socketPath); err == nil {
56 | return nil, errors.New("unix socket in use")
57 | }
58 | if err := os.Remove(socketPath); err != nil {
59 | return nil, err
60 | }
61 | listener, err = net.ListenUnix("unix", addr)
62 | if err != nil {
63 | return nil, err
64 | }
65 | return listener.File()
66 | }
67 |
--------------------------------------------------------------------------------
/conn/server_name_utils/consistent_hash_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2025. Proton AG
3 | *
4 | * This file is part of ProtonVPN.
5 | *
6 | * ProtonVPN is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * ProtonVPN is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with ProtonVPN. If not, see .
18 | */
19 |
20 | package server_name_utils
21 |
22 | import (
23 | "math"
24 | "strconv"
25 | "testing"
26 |
27 | "github.com/stretchr/testify/assert"
28 | )
29 |
30 | var maxMinus5 = uint32(math.MaxUint32 - 5)
31 |
32 | func testHash(value string) uint32 {
33 | switch value {
34 | case "max-5":
35 | return maxMinus5
36 | default:
37 | hash, _ := strconv.ParseUint(value, 10, 32)
38 | return uint32(hash)
39 | }
40 | }
41 |
42 | func TestConsistentHash(t *testing.T) {
43 | assert := assert.New(t)
44 |
45 | values := []string{"70", "max-5", "10"}
46 | hashedValues := sortValuesByHash(values, testHash)
47 | assert.Equal([]HashedValue{{"10", 10}, {"70", 70}, {"max-5", maxMinus5}}, hashedValues)
48 | assert.Equal("70", findClosestValue("68", hashedValues, testHash))
49 | assert.Equal("70", findClosestValue("72", hashedValues, testHash))
50 | assert.Equal("max-5", findClosestValue("2", hashedValues, testHash))
51 | assert.Equal("10", findClosestValue("6", hashedValues, testHash))
52 | }
--------------------------------------------------------------------------------
/device/noise-types.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "crypto/subtle"
10 | "encoding/hex"
11 | "errors"
12 | )
13 |
14 | const (
15 | NoisePublicKeySize = 32
16 | NoisePrivateKeySize = 32
17 | NoisePresharedKeySize = 32
18 | )
19 |
20 | type (
21 | NoisePublicKey [NoisePublicKeySize]byte
22 | NoisePrivateKey [NoisePrivateKeySize]byte
23 | NoisePresharedKey [NoisePresharedKeySize]byte
24 | NoiseNonce uint64 // padded to 12-bytes
25 | )
26 |
27 | func loadExactHex(dst []byte, src string) error {
28 | slice, err := hex.DecodeString(src)
29 | if err != nil {
30 | return err
31 | }
32 | if len(slice) != len(dst) {
33 | return errors.New("hex string does not fit the slice")
34 | }
35 | copy(dst, slice)
36 | return nil
37 | }
38 |
39 | func (key NoisePrivateKey) IsZero() bool {
40 | var zero NoisePrivateKey
41 | return key.Equals(zero)
42 | }
43 |
44 | func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
45 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
46 | }
47 |
48 | func (key *NoisePrivateKey) FromHex(src string) (err error) {
49 | err = loadExactHex(key[:], src)
50 | key.clamp()
51 | return
52 | }
53 |
54 | func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
55 | err = loadExactHex(key[:], src)
56 | if key.IsZero() {
57 | return
58 | }
59 | key.clamp()
60 | return
61 | }
62 |
63 | func (key *NoisePublicKey) FromHex(src string) error {
64 | return loadExactHex(key[:], src)
65 | }
66 |
67 | func (key NoisePublicKey) IsZero() bool {
68 | var zero NoisePublicKey
69 | return key.Equals(zero)
70 | }
71 |
72 | func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
73 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
74 | }
75 |
76 | func (key *NoisePresharedKey) FromHex(src string) error {
77 | return loadExactHex(key[:], src)
78 | }
79 |
--------------------------------------------------------------------------------
/replay/replay.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
7 | package replay
8 |
9 | type block uint64
10 |
11 | const (
12 | blockBitLog = 6 // 1<<6 == 64 bits
13 | blockBits = 1 << blockBitLog // must be power of 2
14 | ringBlocks = 1 << 7 // must be power of 2
15 | windowSize = (ringBlocks - 1) * blockBits
16 | blockMask = ringBlocks - 1
17 | bitMask = blockBits - 1
18 | )
19 |
20 | // A Filter rejects replayed messages by checking if message counter value is
21 | // within a sliding window of previously received messages.
22 | // The zero value for Filter is an empty filter ready to use.
23 | // Filters are unsafe for concurrent use.
24 | type Filter struct {
25 | last uint64
26 | ring [ringBlocks]block
27 | }
28 |
29 | // Reset resets the filter to empty state.
30 | func (f *Filter) Reset() {
31 | f.last = 0
32 | f.ring[0] = 0
33 | }
34 |
35 | // ValidateCounter checks if the counter should be accepted.
36 | // Overlimit counters (>= limit) are always rejected.
37 | func (f *Filter) ValidateCounter(counter, limit uint64) bool {
38 | if counter >= limit {
39 | return false
40 | }
41 | indexBlock := counter >> blockBitLog
42 | if counter > f.last { // move window forward
43 | current := f.last >> blockBitLog
44 | diff := indexBlock - current
45 | if diff > ringBlocks {
46 | diff = ringBlocks // cap diff to clear the whole ring
47 | }
48 | for i := current + 1; i <= current+diff; i++ {
49 | f.ring[i&blockMask] = 0
50 | }
51 | f.last = counter
52 | } else if f.last-counter > windowSize { // behind current window
53 | return false
54 | }
55 | // check and set bit
56 | indexBlock &= blockMask
57 | indexBit := counter & bitMask
58 | old := f.ring[indexBlock]
59 | new := old | 1< p.max {
38 | t.Errorf("count (%d) > max (%d)", count, p.max)
39 | }
40 | for {
41 | old := max.Load()
42 | if count <= old {
43 | break
44 | }
45 | if max.CompareAndSwap(old, count) {
46 | break
47 | }
48 | }
49 | }
50 | for i := 0; i < workers; i++ {
51 | go func() {
52 | defer wg.Done()
53 | for trials.Add(-1) > 0 {
54 | updateMax()
55 | x := p.Get()
56 | updateMax()
57 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
58 | updateMax()
59 | p.Put(x)
60 | updateMax()
61 | }
62 | }()
63 | }
64 | wg.Wait()
65 | if max.Load() != p.max {
66 | t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
67 | }
68 | }
69 |
70 | func BenchmarkWaitPool(b *testing.B) {
71 | var wg sync.WaitGroup
72 | var trials atomic.Int32
73 | trials.Store(int32(b.N))
74 | workers := runtime.NumCPU() + 2
75 | if workers-4 <= 0 {
76 | b.Skip("Not enough cores")
77 | }
78 | p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
79 | wg.Add(workers)
80 | b.ResetTimer()
81 | for i := 0; i < workers; i++ {
82 | go func() {
83 | defer wg.Done()
84 | for trials.Add(-1) > 0 {
85 | x := p.Get()
86 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
87 | p.Put(x)
88 | }
89 | }()
90 | }
91 | wg.Wait()
92 | }
93 |
--------------------------------------------------------------------------------
/device/indextable.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "crypto/rand"
10 | "encoding/binary"
11 | "sync"
12 | )
13 |
14 | type IndexTableEntry struct {
15 | peer *Peer
16 | handshake *Handshake
17 | keypair *Keypair
18 | }
19 |
20 | type IndexTable struct {
21 | sync.RWMutex
22 | table map[uint32]IndexTableEntry
23 | }
24 |
25 | func randUint32() (uint32, error) {
26 | var integer [4]byte
27 | _, err := rand.Read(integer[:])
28 | // Arbitrary endianness; both are intrinsified by the Go compiler.
29 | return binary.LittleEndian.Uint32(integer[:]), err
30 | }
31 |
32 | func (table *IndexTable) Init() {
33 | table.Lock()
34 | defer table.Unlock()
35 | table.table = make(map[uint32]IndexTableEntry)
36 | }
37 |
38 | func (table *IndexTable) Delete(index uint32) {
39 | table.Lock()
40 | defer table.Unlock()
41 | delete(table.table, index)
42 | }
43 |
44 | func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
45 | table.Lock()
46 | defer table.Unlock()
47 | entry, ok := table.table[index]
48 | if !ok {
49 | return
50 | }
51 | table.table[index] = IndexTableEntry{
52 | peer: entry.peer,
53 | keypair: keypair,
54 | handshake: nil,
55 | }
56 | }
57 |
58 | func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) {
59 | for {
60 | // generate random index
61 |
62 | index, err := randUint32()
63 | if err != nil {
64 | return index, err
65 | }
66 |
67 | // check if index used
68 |
69 | table.RLock()
70 | _, ok := table.table[index]
71 | table.RUnlock()
72 | if ok {
73 | continue
74 | }
75 |
76 | // check again while locked
77 |
78 | table.Lock()
79 | _, found := table.table[index]
80 | if found {
81 | table.Unlock()
82 | continue
83 | }
84 | table.table[index] = IndexTableEntry{
85 | peer: peer,
86 | handshake: handshake,
87 | keypair: nil,
88 | }
89 | table.Unlock()
90 | return index, nil
91 | }
92 | }
93 |
94 | func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
95 | table.RLock()
96 | defer table.RUnlock()
97 | return table.table[id]
98 | }
99 |
--------------------------------------------------------------------------------
/tun/alignment_windows_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package tun
7 |
8 | import (
9 | "reflect"
10 | "testing"
11 | "unsafe"
12 | )
13 |
14 | func checkAlignment(t *testing.T, name string, offset uintptr) {
15 | t.Helper()
16 | if offset%8 != 0 {
17 | t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
18 | }
19 | }
20 |
21 | // TestRateJugglerAlignment checks that atomically-accessed fields are
22 | // aligned to 64-bit boundaries, as required by the atomic package.
23 | //
24 | // Unfortunately, violating this rule on 32-bit platforms results in a
25 | // hard segfault at runtime.
26 | func TestRateJugglerAlignment(t *testing.T) {
27 | var r rateJuggler
28 |
29 | typ := reflect.TypeOf(&r).Elem()
30 | t.Logf("Peer type size: %d, with fields:", typ.Size())
31 | for i := 0; i < typ.NumField(); i++ {
32 | field := typ.Field(i)
33 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
34 | field.Name,
35 | field.Offset,
36 | field.Type.Size(),
37 | field.Type.Align(),
38 | )
39 | }
40 |
41 | checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current))
42 | checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount))
43 | checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime))
44 | }
45 |
46 | // TestNativeTunAlignment checks that atomically-accessed fields are
47 | // aligned to 64-bit boundaries, as required by the atomic package.
48 | //
49 | // Unfortunately, violating this rule on 32-bit platforms results in a
50 | // hard segfault at runtime.
51 | func TestNativeTunAlignment(t *testing.T) {
52 | var tun NativeTun
53 |
54 | typ := reflect.TypeOf(&tun).Elem()
55 | t.Logf("Peer type size: %d, with fields:", typ.Size())
56 | for i := 0; i < typ.NumField(); i++ {
57 | field := typ.Field(i)
58 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
59 | field.Name,
60 | field.Offset,
61 | field.Type.Size(),
62 | field.Type.Align(),
63 | )
64 | }
65 |
66 | checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate))
67 | }
68 |
--------------------------------------------------------------------------------
/device/pools.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "sync"
10 | )
11 |
12 | type WaitPool struct {
13 | pool sync.Pool
14 | cond sync.Cond
15 | lock sync.Mutex
16 | count uint32 // Get calls not yet Put back
17 | max uint32
18 | }
19 |
20 | func NewWaitPool(max uint32, new func() any) *WaitPool {
21 | p := &WaitPool{pool: sync.Pool{New: new}, max: max}
22 | p.cond = sync.Cond{L: &p.lock}
23 | return p
24 | }
25 |
26 | func (p *WaitPool) Get() any {
27 | if p.max != 0 {
28 | p.lock.Lock()
29 | for p.count >= p.max {
30 | p.cond.Wait()
31 | }
32 | p.count++
33 | p.lock.Unlock()
34 | }
35 | return p.pool.Get()
36 | }
37 |
38 | func (p *WaitPool) Put(x any) {
39 | p.pool.Put(x)
40 | if p.max == 0 {
41 | return
42 | }
43 | p.lock.Lock()
44 | defer p.lock.Unlock()
45 | p.count--
46 | p.cond.Signal()
47 | }
48 |
49 | func (device *Device) PopulatePools() {
50 | device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
51 | return new([MaxMessageSize]byte)
52 | })
53 | device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
54 | return new(QueueInboundElement)
55 | })
56 | device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
57 | return new(QueueOutboundElement)
58 | })
59 | }
60 |
61 | func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
62 | return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
63 | }
64 |
65 | func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
66 | device.pool.messageBuffers.Put(msg)
67 | }
68 |
69 | func (device *Device) GetInboundElement() *QueueInboundElement {
70 | return device.pool.inboundElements.Get().(*QueueInboundElement)
71 | }
72 |
73 | func (device *Device) PutInboundElement(elem *QueueInboundElement) {
74 | elem.clearPointers()
75 | device.pool.inboundElements.Put(elem)
76 | }
77 |
78 | func (device *Device) GetOutboundElement() *QueueOutboundElement {
79 | return device.pool.outboundElements.Get().(*QueueOutboundElement)
80 | }
81 |
82 | func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
83 | elem.clearPointers()
84 | device.pool.outboundElements.Put(elem)
85 | }
86 |
--------------------------------------------------------------------------------
/main_windows.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package main
7 |
8 | import (
9 | "fmt"
10 | "os"
11 | "os/signal"
12 | "syscall"
13 |
14 | "golang.zx2c4.com/wireguard/conn"
15 | "golang.zx2c4.com/wireguard/device"
16 | "golang.zx2c4.com/wireguard/ipc"
17 |
18 | "golang.zx2c4.com/wireguard/tun"
19 | )
20 |
21 | const (
22 | ExitSetupSuccess = 0
23 | ExitSetupFailed = 1
24 | )
25 |
26 | func main() {
27 | if len(os.Args) != 2 {
28 | os.Exit(ExitSetupFailed)
29 | }
30 | interfaceName := os.Args[1]
31 |
32 | fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is , which includes this code as a module.")
33 |
34 | logger := device.NewLogger(
35 | device.LogLevelVerbose,
36 | fmt.Sprintf("(%s) ", interfaceName),
37 | )
38 | logger.Verbosef("Starting wireguard-go version %s", Version)
39 |
40 | tun, err := tun.CreateTUN(interfaceName, 0)
41 | if err == nil {
42 | realInterfaceName, err2 := tun.Name()
43 | if err2 == nil {
44 | interfaceName = realInterfaceName
45 | }
46 | } else {
47 | logger.Errorf("Failed to create TUN device: %v", err)
48 | os.Exit(ExitSetupFailed)
49 | }
50 |
51 | device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
52 | err = device.Up()
53 | if err != nil {
54 | logger.Errorf("Failed to bring up device: %v", err)
55 | os.Exit(ExitSetupFailed)
56 | }
57 | logger.Verbosef("Device started")
58 |
59 | uapi, err := ipc.UAPIListen(interfaceName)
60 | if err != nil {
61 | logger.Errorf("Failed to listen on uapi socket: %v", err)
62 | os.Exit(ExitSetupFailed)
63 | }
64 |
65 | errs := make(chan error)
66 | term := make(chan os.Signal, 1)
67 |
68 | go func() {
69 | for {
70 | conn, err := uapi.Accept()
71 | if err != nil {
72 | errs <- err
73 | return
74 | }
75 | go device.IpcHandle(conn)
76 | }
77 | }()
78 | logger.Verbosef("UAPI listener started")
79 |
80 | // wait for program to terminate
81 |
82 | signal.Notify(term, os.Interrupt)
83 | signal.Notify(term, os.Kill)
84 | signal.Notify(term, syscall.SIGTERM)
85 |
86 | select {
87 | case <-term:
88 | case <-errs:
89 | case <-device.Wait():
90 | }
91 |
92 | // clean up
93 |
94 | uapi.Close()
95 | device.Close()
96 |
97 | logger.Verbosef("Shutting down")
98 | }
99 |
--------------------------------------------------------------------------------
/device/kdf_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "encoding/hex"
10 | "testing"
11 |
12 | "golang.org/x/crypto/blake2s"
13 | )
14 |
15 | type KDFTest struct {
16 | key string
17 | input string
18 | t0 string
19 | t1 string
20 | t2 string
21 | }
22 |
23 | func assertEquals(t *testing.T, a, b string) {
24 | if a != b {
25 | t.Fatal("expected", a, "=", b)
26 | }
27 | }
28 |
29 | func TestKDF(t *testing.T) {
30 | tests := []KDFTest{
31 | {
32 | key: "746573742d6b6579",
33 | input: "746573742d696e707574",
34 | t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633",
35 | t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a",
36 | t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24",
37 | },
38 | {
39 | key: "776972656775617264",
40 | input: "776972656775617264",
41 | t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8",
42 | t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f",
43 | t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160",
44 | },
45 | {
46 | key: "",
47 | input: "",
48 | t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0",
49 | t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e",
50 | t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e",
51 | },
52 | }
53 |
54 | var t0, t1, t2 [blake2s.Size]byte
55 |
56 | for _, test := range tests {
57 | key, _ := hex.DecodeString(test.key)
58 | input, _ := hex.DecodeString(test.input)
59 | KDF3(&t0, &t1, &t2, key, input)
60 | t0s := hex.EncodeToString(t0[:])
61 | t1s := hex.EncodeToString(t1[:])
62 | t2s := hex.EncodeToString(t2[:])
63 | assertEquals(t, t0s, test.t0)
64 | assertEquals(t, t1s, test.t1)
65 | assertEquals(t, t2s, test.t2)
66 | }
67 |
68 | for _, test := range tests {
69 | key, _ := hex.DecodeString(test.key)
70 | input, _ := hex.DecodeString(test.input)
71 | KDF2(&t0, &t1, key, input)
72 | t0s := hex.EncodeToString(t0[:])
73 | t1s := hex.EncodeToString(t1[:])
74 | assertEquals(t, t0s, test.t0)
75 | assertEquals(t, t1s, test.t1)
76 | }
77 |
78 | for _, test := range tests {
79 | key, _ := hex.DecodeString(test.key)
80 | input, _ := hex.DecodeString(test.input)
81 | KDF1(&t0, key, input)
82 | t0s := hex.EncodeToString(t0[:])
83 | assertEquals(t, t0s, test.t0)
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/rwcancel/rwcancel.go:
--------------------------------------------------------------------------------
1 | //go:build !windows && !js
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | // Package rwcancel implements cancelable read/write operations on
9 | // a file descriptor.
10 | package rwcancel
11 |
12 | import (
13 | "errors"
14 | "os"
15 | "syscall"
16 |
17 | "golang.org/x/sys/unix"
18 | )
19 |
20 | type RWCancel struct {
21 | fd int
22 | closingReader *os.File
23 | closingWriter *os.File
24 | }
25 |
26 | func NewRWCancel(fd int) (*RWCancel, error) {
27 | err := unix.SetNonblock(fd, true)
28 | if err != nil {
29 | return nil, err
30 | }
31 | rwcancel := RWCancel{fd: fd}
32 |
33 | rwcancel.closingReader, rwcancel.closingWriter, err = os.Pipe()
34 | if err != nil {
35 | return nil, err
36 | }
37 |
38 | return &rwcancel, nil
39 | }
40 |
41 | func RetryAfterError(err error) bool {
42 | return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR)
43 | }
44 |
45 | func (rw *RWCancel) ReadyRead() bool {
46 | closeFd := int32(rw.closingReader.Fd())
47 |
48 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}}
49 | var err error
50 | for {
51 | _, err = unix.Poll(pollFds, -1)
52 | if err == nil || !RetryAfterError(err) {
53 | break
54 | }
55 | }
56 | if err != nil {
57 | return false
58 | }
59 | if pollFds[1].Revents != 0 {
60 | return false
61 | }
62 | return pollFds[0].Revents != 0
63 | }
64 |
65 | func (rw *RWCancel) ReadyWrite() bool {
66 | closeFd := int32(rw.closingReader.Fd())
67 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
68 | var err error
69 | for {
70 | _, err = unix.Poll(pollFds, -1)
71 | if err == nil || !RetryAfterError(err) {
72 | break
73 | }
74 | }
75 | if err != nil {
76 | return false
77 | }
78 |
79 | if pollFds[1].Revents != 0 {
80 | return false
81 | }
82 | return pollFds[0].Revents != 0
83 | }
84 |
85 | func (rw *RWCancel) Read(p []byte) (n int, err error) {
86 | for {
87 | n, err := unix.Read(rw.fd, p)
88 | if err == nil || !RetryAfterError(err) {
89 | return n, err
90 | }
91 | if !rw.ReadyRead() {
92 | return 0, os.ErrClosed
93 | }
94 | }
95 | }
96 |
97 | func (rw *RWCancel) Write(p []byte) (n int, err error) {
98 | for {
99 | n, err := unix.Write(rw.fd, p)
100 | if err == nil || !RetryAfterError(err) {
101 | return n, err
102 | }
103 | if !rw.ReadyWrite() {
104 | return 0, os.ErrClosed
105 | }
106 | }
107 | }
108 |
109 | func (rw *RWCancel) Cancel() (err error) {
110 | _, err = rw.closingWriter.Write([]byte{0})
111 | return
112 | }
113 |
114 | func (rw *RWCancel) Close() {
115 | rw.closingReader.Close()
116 | rw.closingWriter.Close()
117 | }
118 |
--------------------------------------------------------------------------------
/device/noise-helpers.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "crypto/hmac"
10 | "crypto/rand"
11 | "crypto/subtle"
12 | "errors"
13 | "hash"
14 |
15 | "golang.org/x/crypto/blake2s"
16 | "golang.org/x/crypto/curve25519"
17 | )
18 |
19 | /* KDF related functions.
20 | * HMAC-based Key Derivation Function (HKDF)
21 | * https://tools.ietf.org/html/rfc5869
22 | */
23 |
24 | func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) {
25 | mac := hmac.New(func() hash.Hash {
26 | h, _ := blake2s.New256(nil)
27 | return h
28 | }, key)
29 | mac.Write(in0)
30 | mac.Sum(sum[:0])
31 | }
32 |
33 | func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
34 | mac := hmac.New(func() hash.Hash {
35 | h, _ := blake2s.New256(nil)
36 | return h
37 | }, key)
38 | mac.Write(in0)
39 | mac.Write(in1)
40 | mac.Sum(sum[:0])
41 | }
42 |
43 | func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
44 | HMAC1(t0, key, input)
45 | HMAC1(t0, t0[:], []byte{0x1})
46 | }
47 |
48 | func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
49 | var prk [blake2s.Size]byte
50 | HMAC1(&prk, key, input)
51 | HMAC1(t0, prk[:], []byte{0x1})
52 | HMAC2(t1, prk[:], t0[:], []byte{0x2})
53 | setZero(prk[:])
54 | }
55 |
56 | func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
57 | var prk [blake2s.Size]byte
58 | HMAC1(&prk, key, input)
59 | HMAC1(t0, prk[:], []byte{0x1})
60 | HMAC2(t1, prk[:], t0[:], []byte{0x2})
61 | HMAC2(t2, prk[:], t1[:], []byte{0x3})
62 | setZero(prk[:])
63 | }
64 |
65 | func isZero(val []byte) bool {
66 | acc := 1
67 | for _, b := range val {
68 | acc &= subtle.ConstantTimeByteEq(b, 0)
69 | }
70 | return acc == 1
71 | }
72 |
73 | /* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */
74 | func setZero(arr []byte) {
75 | for i := range arr {
76 | arr[i] = 0
77 | }
78 | }
79 |
80 | func (sk *NoisePrivateKey) clamp() {
81 | sk[0] &= 248
82 | sk[31] = (sk[31] & 127) | 64
83 | }
84 |
85 | func newPrivateKey() (sk NoisePrivateKey, err error) {
86 | _, err = rand.Read(sk[:])
87 | sk.clamp()
88 | return
89 | }
90 |
91 | func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
92 | apk := (*[NoisePublicKeySize]byte)(&pk)
93 | ask := (*[NoisePrivateKeySize]byte)(sk)
94 | curve25519.ScalarBaseMult(apk, ask)
95 | return
96 | }
97 |
98 | var errInvalidPublicKey = errors.New("invalid public key")
99 |
100 | func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
101 | apk := (*[NoisePublicKeySize]byte)(&pk)
102 | ask := (*[NoisePrivateKeySize]byte)(sk)
103 | curve25519.ScalarMult(&ss, ask, apk)
104 | if isZero(ss[:]) {
105 | return ss, errInvalidPublicKey
106 | }
107 | return ss, nil
108 | }
109 |
--------------------------------------------------------------------------------
/ipc/uapi_linux.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package ipc
7 |
8 | import (
9 | "net"
10 | "os"
11 |
12 | "golang.org/x/sys/unix"
13 | "golang.zx2c4.com/wireguard/rwcancel"
14 | )
15 |
16 | type UAPIListener struct {
17 | listener net.Listener // unix socket listener
18 | connNew chan net.Conn
19 | connErr chan error
20 | inotifyFd int
21 | inotifyRWCancel *rwcancel.RWCancel
22 | }
23 |
24 | func (l *UAPIListener) Accept() (net.Conn, error) {
25 | for {
26 | select {
27 | case conn := <-l.connNew:
28 | return conn, nil
29 |
30 | case err := <-l.connErr:
31 | return nil, err
32 | }
33 | }
34 | }
35 |
36 | func (l *UAPIListener) Close() error {
37 | err1 := unix.Close(l.inotifyFd)
38 | err2 := l.inotifyRWCancel.Cancel()
39 | err3 := l.listener.Close()
40 | if err1 != nil {
41 | return err1
42 | }
43 | if err2 != nil {
44 | return err2
45 | }
46 | return err3
47 | }
48 |
49 | func (l *UAPIListener) Addr() net.Addr {
50 | return l.listener.Addr()
51 | }
52 |
53 | func UAPIListen(name string, file *os.File) (net.Listener, error) {
54 | // wrap file in listener
55 |
56 | listener, err := net.FileListener(file)
57 | if err != nil {
58 | return nil, err
59 | }
60 |
61 | if unixListener, ok := listener.(*net.UnixListener); ok {
62 | unixListener.SetUnlinkOnClose(true)
63 | }
64 |
65 | uapi := &UAPIListener{
66 | listener: listener,
67 | connNew: make(chan net.Conn, 1),
68 | connErr: make(chan error, 1),
69 | }
70 |
71 | // watch for deletion of socket
72 |
73 | socketPath := sockPath(name)
74 |
75 | uapi.inotifyFd, err = unix.InotifyInit()
76 | if err != nil {
77 | return nil, err
78 | }
79 |
80 | _, err = unix.InotifyAddWatch(
81 | uapi.inotifyFd,
82 | socketPath,
83 | unix.IN_ATTRIB|
84 | unix.IN_DELETE|
85 | unix.IN_DELETE_SELF,
86 | )
87 |
88 | if err != nil {
89 | return nil, err
90 | }
91 |
92 | uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd)
93 | if err != nil {
94 | unix.Close(uapi.inotifyFd)
95 | return nil, err
96 | }
97 |
98 | go func(l *UAPIListener) {
99 | var buff [0]byte
100 | for {
101 | defer uapi.inotifyRWCancel.Close()
102 | // start with lstat to avoid race condition
103 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
104 | l.connErr <- err
105 | return
106 | }
107 | _, err := uapi.inotifyRWCancel.Read(buff[:])
108 | if err != nil {
109 | l.connErr <- err
110 | return
111 | }
112 | }
113 | }(uapi)
114 |
115 | // watch for new connections
116 |
117 | go func(l *UAPIListener) {
118 | for {
119 | conn, err := l.listener.Accept()
120 | if err != nil {
121 | l.connErr <- err
122 | break
123 | }
124 | l.connNew <- conn
125 | }
126 | }(uapi)
127 |
128 | return uapi, nil
129 | }
130 |
--------------------------------------------------------------------------------
/ratelimiter/ratelimiter_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package ratelimiter
7 |
8 | import (
9 | "net/netip"
10 | "testing"
11 | "time"
12 | )
13 |
14 | type result struct {
15 | allowed bool
16 | text string
17 | wait time.Duration
18 | }
19 |
20 | func TestRatelimiter(t *testing.T) {
21 | var rate Ratelimiter
22 | var expectedResults []result
23 |
24 | nano := func(nano int64) time.Duration {
25 | return time.Nanosecond * time.Duration(nano)
26 | }
27 |
28 | add := func(res result) {
29 | expectedResults = append(
30 | expectedResults,
31 | res,
32 | )
33 | }
34 |
35 | for i := 0; i < packetsBurstable; i++ {
36 | add(result{
37 | allowed: true,
38 | text: "initial burst",
39 | })
40 | }
41 |
42 | add(result{
43 | allowed: false,
44 | text: "after burst",
45 | })
46 |
47 | add(result{
48 | allowed: true,
49 | wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
50 | text: "filling tokens for single packet",
51 | })
52 |
53 | add(result{
54 | allowed: false,
55 | text: "not having refilled enough",
56 | })
57 |
58 | add(result{
59 | allowed: true,
60 | wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
61 | text: "filling tokens for two packet burst",
62 | })
63 |
64 | add(result{
65 | allowed: true,
66 | text: "second packet in 2 packet burst",
67 | })
68 |
69 | add(result{
70 | allowed: false,
71 | text: "packet following 2 packet burst",
72 | })
73 |
74 | ips := []netip.Addr{
75 | netip.MustParseAddr("127.0.0.1"),
76 | netip.MustParseAddr("192.168.1.1"),
77 | netip.MustParseAddr("172.167.2.3"),
78 | netip.MustParseAddr("97.231.252.215"),
79 | netip.MustParseAddr("248.97.91.167"),
80 | netip.MustParseAddr("188.208.233.47"),
81 | netip.MustParseAddr("104.2.183.179"),
82 | netip.MustParseAddr("72.129.46.120"),
83 | netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
84 | netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
85 | netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
86 | netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
87 | netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
88 | netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
89 | }
90 |
91 | now := time.Now()
92 | rate.timeNow = func() time.Time {
93 | return now
94 | }
95 | defer func() {
96 | // Lock to avoid data race with cleanup goroutine from Init.
97 | rate.mu.Lock()
98 | defer rate.mu.Unlock()
99 |
100 | rate.timeNow = time.Now
101 | }()
102 | timeSleep := func(d time.Duration) {
103 | now = now.Add(d + 1)
104 | rate.cleanup()
105 | }
106 |
107 | rate.Init()
108 | defer rate.Close()
109 |
110 | for i, res := range expectedResults {
111 | timeSleep(res.wait)
112 | for _, ip := range ips {
113 | allowed := rate.Allow(ip)
114 | if allowed != res.allowed {
115 | t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
116 | }
117 | }
118 | }
119 | }
120 |
--------------------------------------------------------------------------------
/ipc/uapi_bsd.go:
--------------------------------------------------------------------------------
1 | //go:build darwin || freebsd || openbsd
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package ipc
9 |
10 | import (
11 | "errors"
12 | "net"
13 | "os"
14 | "unsafe"
15 |
16 | "golang.org/x/sys/unix"
17 | )
18 |
19 | type UAPIListener struct {
20 | listener net.Listener // unix socket listener
21 | connNew chan net.Conn
22 | connErr chan error
23 | kqueueFd int
24 | keventFd int
25 | }
26 |
27 | func (l *UAPIListener) Accept() (net.Conn, error) {
28 | for {
29 | select {
30 | case conn := <-l.connNew:
31 | return conn, nil
32 |
33 | case err := <-l.connErr:
34 | return nil, err
35 | }
36 | }
37 | }
38 |
39 | func (l *UAPIListener) Close() error {
40 | err1 := unix.Close(l.kqueueFd)
41 | err2 := unix.Close(l.keventFd)
42 | err3 := l.listener.Close()
43 | if err1 != nil {
44 | return err1
45 | }
46 | if err2 != nil {
47 | return err2
48 | }
49 | return err3
50 | }
51 |
52 | func (l *UAPIListener) Addr() net.Addr {
53 | return l.listener.Addr()
54 | }
55 |
56 | func UAPIListen(name string, file *os.File) (net.Listener, error) {
57 | // wrap file in listener
58 |
59 | listener, err := net.FileListener(file)
60 | if err != nil {
61 | return nil, err
62 | }
63 |
64 | uapi := &UAPIListener{
65 | listener: listener,
66 | connNew: make(chan net.Conn, 1),
67 | connErr: make(chan error, 1),
68 | }
69 |
70 | if unixListener, ok := listener.(*net.UnixListener); ok {
71 | unixListener.SetUnlinkOnClose(true)
72 | }
73 |
74 | socketPath := sockPath(name)
75 |
76 | // watch for deletion of socket
77 |
78 | uapi.kqueueFd, err = unix.Kqueue()
79 | if err != nil {
80 | return nil, err
81 | }
82 | uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0)
83 | if err != nil {
84 | unix.Close(uapi.kqueueFd)
85 | return nil, err
86 | }
87 |
88 | go func(l *UAPIListener) {
89 | event := unix.Kevent_t{
90 | Filter: unix.EVFILT_VNODE,
91 | Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT,
92 | Fflags: unix.NOTE_WRITE,
93 | }
94 | // Allow this assignment to work with both the 32-bit and 64-bit version
95 | // of the above struct. If you know another way, please submit a patch.
96 | *(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd)
97 | events := make([]unix.Kevent_t, 1)
98 | n := 1
99 | var kerr error
100 | for {
101 | // start with lstat to avoid race condition
102 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
103 | l.connErr <- err
104 | return
105 | }
106 | if (kerr != nil || n != 1) && kerr != unix.EINTR {
107 | if kerr != nil {
108 | l.connErr <- kerr
109 | } else {
110 | l.connErr <- errors.New("kqueue returned empty")
111 | }
112 | return
113 | }
114 | n, kerr = unix.Kevent(uapi.kqueueFd, []unix.Kevent_t{event}, events, nil)
115 | }
116 | }(uapi)
117 |
118 | // watch for new connections
119 |
120 | go func(l *UAPIListener) {
121 | for {
122 | conn, err := l.listener.Accept()
123 | if err != nil {
124 | l.connErr <- err
125 | break
126 | }
127 | l.connNew <- conn
128 | }
129 | }(uapi)
130 |
131 | return uapi, nil
132 | }
133 |
--------------------------------------------------------------------------------
/ratelimiter/ratelimiter.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package ratelimiter
7 |
8 | import (
9 | "net/netip"
10 | "sync"
11 | "time"
12 | )
13 |
14 | const (
15 | packetsPerSecond = 20
16 | packetsBurstable = 5
17 | garbageCollectTime = time.Second
18 | packetCost = 1000000000 / packetsPerSecond
19 | maxTokens = packetCost * packetsBurstable
20 | )
21 |
22 | type RatelimiterEntry struct {
23 | mu sync.Mutex
24 | lastTime time.Time
25 | tokens int64
26 | }
27 |
28 | type Ratelimiter struct {
29 | mu sync.RWMutex
30 | timeNow func() time.Time
31 |
32 | stopReset chan struct{} // send to reset, close to stop
33 | table map[netip.Addr]*RatelimiterEntry
34 | }
35 |
36 | func (rate *Ratelimiter) Close() {
37 | rate.mu.Lock()
38 | defer rate.mu.Unlock()
39 |
40 | if rate.stopReset != nil {
41 | close(rate.stopReset)
42 | }
43 | }
44 |
45 | func (rate *Ratelimiter) Init() {
46 | rate.mu.Lock()
47 | defer rate.mu.Unlock()
48 |
49 | if rate.timeNow == nil {
50 | rate.timeNow = time.Now
51 | }
52 |
53 | // stop any ongoing garbage collection routine
54 | if rate.stopReset != nil {
55 | close(rate.stopReset)
56 | }
57 |
58 | rate.stopReset = make(chan struct{})
59 | rate.table = make(map[netip.Addr]*RatelimiterEntry)
60 |
61 | stopReset := rate.stopReset // store in case Init is called again.
62 |
63 | // Start garbage collection routine.
64 | go func() {
65 | ticker := time.NewTicker(time.Second)
66 | ticker.Stop()
67 | for {
68 | select {
69 | case _, ok := <-stopReset:
70 | ticker.Stop()
71 | if !ok {
72 | return
73 | }
74 | ticker = time.NewTicker(time.Second)
75 | case <-ticker.C:
76 | if rate.cleanup() {
77 | ticker.Stop()
78 | }
79 | }
80 | }
81 | }()
82 | }
83 |
84 | func (rate *Ratelimiter) cleanup() (empty bool) {
85 | rate.mu.Lock()
86 | defer rate.mu.Unlock()
87 |
88 | for key, entry := range rate.table {
89 | entry.mu.Lock()
90 | if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
91 | delete(rate.table, key)
92 | }
93 | entry.mu.Unlock()
94 | }
95 |
96 | return len(rate.table) == 0
97 | }
98 |
99 | func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
100 | var entry *RatelimiterEntry
101 | // lookup entry
102 | rate.mu.RLock()
103 | entry = rate.table[ip]
104 | rate.mu.RUnlock()
105 |
106 | // make new entry if not found
107 | if entry == nil {
108 | entry = new(RatelimiterEntry)
109 | entry.tokens = maxTokens - packetCost
110 | entry.lastTime = rate.timeNow()
111 | rate.mu.Lock()
112 | rate.table[ip] = entry
113 | if len(rate.table) == 1 {
114 | rate.stopReset <- struct{}{}
115 | }
116 | rate.mu.Unlock()
117 | return true
118 | }
119 |
120 | // add tokens to entry
121 | entry.mu.Lock()
122 | now := rate.timeNow()
123 | entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
124 | entry.lastTime = now
125 | if entry.tokens > maxTokens {
126 | entry.tokens = maxTokens
127 | }
128 |
129 | // subtract cost of packet
130 | if entry.tokens > packetCost {
131 | entry.tokens -= packetCost
132 | entry.mu.Unlock()
133 | return true
134 | }
135 | entry.mu.Unlock()
136 | return false
137 | }
138 |
--------------------------------------------------------------------------------
/replay/replay_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package replay
7 |
8 | import (
9 | "testing"
10 | )
11 |
12 | /* Ported from the linux kernel implementation
13 | *
14 | *
15 | */
16 |
17 | const RejectAfterMessages = 1<<64 - 1<<13 - 1
18 |
19 | func TestReplay(t *testing.T) {
20 | var filter Filter
21 |
22 | const T_LIM = windowSize + 1
23 |
24 | testNumber := 0
25 | T := func(n uint64, expected bool) {
26 | testNumber++
27 | if filter.ValidateCounter(n, RejectAfterMessages) != expected {
28 | t.Fatal("Test", testNumber, "failed", n, expected)
29 | }
30 | }
31 |
32 | filter.Reset()
33 |
34 | T(0, true) /* 1 */
35 | T(1, true) /* 2 */
36 | T(1, false) /* 3 */
37 | T(9, true) /* 4 */
38 | T(8, true) /* 5 */
39 | T(7, true) /* 6 */
40 | T(7, false) /* 7 */
41 | T(T_LIM, true) /* 8 */
42 | T(T_LIM-1, true) /* 9 */
43 | T(T_LIM-1, false) /* 10 */
44 | T(T_LIM-2, true) /* 11 */
45 | T(2, true) /* 12 */
46 | T(2, false) /* 13 */
47 | T(T_LIM+16, true) /* 14 */
48 | T(3, false) /* 15 */
49 | T(T_LIM+16, false) /* 16 */
50 | T(T_LIM*4, true) /* 17 */
51 | T(T_LIM*4-(T_LIM-1), true) /* 18 */
52 | T(10, false) /* 19 */
53 | T(T_LIM*4-T_LIM, false) /* 20 */
54 | T(T_LIM*4-(T_LIM+1), false) /* 21 */
55 | T(T_LIM*4-(T_LIM-2), true) /* 22 */
56 | T(T_LIM*4+1-T_LIM, false) /* 23 */
57 | T(0, false) /* 24 */
58 | T(RejectAfterMessages, false) /* 25 */
59 | T(RejectAfterMessages-1, true) /* 26 */
60 | T(RejectAfterMessages, false) /* 27 */
61 | T(RejectAfterMessages-1, false) /* 28 */
62 | T(RejectAfterMessages-2, true) /* 29 */
63 | T(RejectAfterMessages+1, false) /* 30 */
64 | T(RejectAfterMessages+2, false) /* 31 */
65 | T(RejectAfterMessages-2, false) /* 32 */
66 | T(RejectAfterMessages-3, true) /* 33 */
67 | T(0, false) /* 34 */
68 |
69 | t.Log("Bulk test 1")
70 | filter.Reset()
71 | testNumber = 0
72 | for i := uint64(1); i <= windowSize; i++ {
73 | T(i, true)
74 | }
75 | T(0, true)
76 | T(0, false)
77 |
78 | t.Log("Bulk test 2")
79 | filter.Reset()
80 | testNumber = 0
81 | for i := uint64(2); i <= windowSize+1; i++ {
82 | T(i, true)
83 | }
84 | T(1, true)
85 | T(0, false)
86 |
87 | t.Log("Bulk test 3")
88 | filter.Reset()
89 | testNumber = 0
90 | for i := uint64(windowSize + 1); i > 0; i-- {
91 | T(i, true)
92 | }
93 |
94 | t.Log("Bulk test 4")
95 | filter.Reset()
96 | testNumber = 0
97 | for i := uint64(windowSize + 2); i > 1; i-- {
98 | T(i, true)
99 | }
100 | T(0, false)
101 |
102 | t.Log("Bulk test 5")
103 | filter.Reset()
104 | testNumber = 0
105 | for i := uint64(windowSize); i > 0; i-- {
106 | T(i, true)
107 | }
108 | T(windowSize+1, true)
109 | T(0, false)
110 |
111 | t.Log("Bulk test 6")
112 | filter.Reset()
113 | testNumber = 0
114 | for i := uint64(windowSize); i > 0; i-- {
115 | T(i, true)
116 | }
117 | T(0, true)
118 | T(windowSize+1, true)
119 | }
120 |
--------------------------------------------------------------------------------
/conn/server_name_utils/consistent_hash.go:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2025. Proton AG
3 | *
4 | * This file is part of ProtonVPN.
5 | *
6 | * ProtonVPN is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * ProtonVPN is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with ProtonVPN. If not, see .
18 | */
19 |
20 | package server_name_utils
21 |
22 | import (
23 | "hash/crc32"
24 | "math"
25 | "sort"
26 | )
27 |
28 | // Set of utilities to implement consistent hashing where map one set of strings (keys) to another (values).
29 | // Set of values will change over time but mapping will remain largely stable.
30 | // Usage:
31 | // hashedValues := sortValuesByHash(values, crc32Hash)
32 | // value := findClosestValue(key, hashedValues, crc32Hash)
33 | // when set of values changes (e.g. new one is added), only mappings that will change are for keys for which
34 | // new value is the closest one (its uint32 hash is closest to key's hash).
35 |
36 | type HashedValue struct {
37 | value string
38 | hash uint32
39 | }
40 |
41 | // Picks value out of sortedValuesWithHashes which hash is closest to hash of value. For distance
42 | // calculation, uint32 is forming a ring where 0 is next to math.MaxUint32. Closer of clockwise and
43 | // counter-clockwise distance is picked.
44 | func findClosestValue(key string, sortedValuesWithHashes []HashedValue, hashFun func(string) uint32) string {
45 | n := len(sortedValuesWithHashes)
46 | if n == 0 {
47 | return ""
48 | } else if n == 1 {
49 | return sortedValuesWithHashes[0].value
50 | }
51 |
52 | keyHash := hashFun(key)
53 | i := sort.Search(n, func(i int) bool {
54 | return sortedValuesWithHashes[i].hash >= keyHash
55 | })
56 |
57 | if i <= 0 || i >= n {
58 | // If it's smaller than first or larger than last, return closest
59 | // between first and last
60 | return closerValue(keyHash, sortedValuesWithHashes[0], sortedValuesWithHashes[n-1])
61 | } else {
62 | return closerValue(keyHash, sortedValuesWithHashes[i-1], sortedValuesWithHashes[i])
63 | }
64 | }
65 |
66 | func sortValuesByHash(values []string, hashFun func(string) uint32) []HashedValue {
67 | hashedValues := make([]HashedValue, len(values))
68 | for i, domain := range values {
69 | hashedValues[i] = HashedValue{domain, hashFun(domain)}
70 | }
71 | sort.Slice(hashedValues, func(i, j int) bool {
72 | return hashedValues[i].hash < hashedValues[j].hash
73 | })
74 | return hashedValues
75 | }
76 |
77 | func crc32Hash(s string) uint32 {
78 | return crc32.ChecksumIEEE([]byte(s))
79 | }
80 |
81 | func ringDistance(a, b uint32) int64 {
82 | var fa = int64(a)
83 | var fb = int64(b)
84 | var large = max(fa, fb)
85 | var small = min(fa, fb)
86 | // Take smaller of clockwise and counter-clockwise distance
87 | return min(large - small, small - large + math.MaxUint32)
88 | }
89 |
90 | func closerValue(hash uint32, a HashedValue, b HashedValue) string {
91 | var da = ringDistance(hash, a.hash)
92 | var db = ringDistance(hash, b.hash)
93 | if da < db {
94 | return a.value
95 | } else {
96 | return b.value
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/device/allowedips_rand_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "math/rand"
10 | "net"
11 | "net/netip"
12 | "sort"
13 | "testing"
14 | )
15 |
16 | const (
17 | NumberOfPeers = 100
18 | NumberOfPeerRemovals = 4
19 | NumberOfAddresses = 250
20 | NumberOfTests = 10000
21 | )
22 |
23 | type SlowNode struct {
24 | peer *Peer
25 | cidr uint8
26 | bits []byte
27 | }
28 |
29 | type SlowRouter []*SlowNode
30 |
31 | func (r SlowRouter) Len() int {
32 | return len(r)
33 | }
34 |
35 | func (r SlowRouter) Less(i, j int) bool {
36 | return r[i].cidr > r[j].cidr
37 | }
38 |
39 | func (r SlowRouter) Swap(i, j int) {
40 | r[i], r[j] = r[j], r[i]
41 | }
42 |
43 | func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
44 | for _, t := range r {
45 | if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
46 | t.peer = peer
47 | t.bits = addr
48 | return r
49 | }
50 | }
51 | r = append(r, &SlowNode{
52 | cidr: cidr,
53 | bits: addr,
54 | peer: peer,
55 | })
56 | sort.Sort(r)
57 | return r
58 | }
59 |
60 | func (r SlowRouter) Lookup(addr []byte) *Peer {
61 | for _, t := range r {
62 | common := commonBits(t.bits, addr)
63 | if common >= t.cidr {
64 | return t.peer
65 | }
66 | }
67 | return nil
68 | }
69 |
70 | func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
71 | n := 0
72 | for _, x := range r {
73 | if x.peer != peer {
74 | r[n] = x
75 | n++
76 | }
77 | }
78 | return r[:n]
79 | }
80 |
81 | func TestTrieRandom(t *testing.T) {
82 | var slow4, slow6 SlowRouter
83 | var peers []*Peer
84 | var allowedIPs AllowedIPs
85 |
86 | rand.Seed(1)
87 |
88 | for n := 0; n < NumberOfPeers; n++ {
89 | peers = append(peers, &Peer{})
90 | }
91 |
92 | for n := 0; n < NumberOfAddresses; n++ {
93 | var addr4 [4]byte
94 | rand.Read(addr4[:])
95 | cidr := uint8(rand.Intn(32) + 1)
96 | index := rand.Intn(NumberOfPeers)
97 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
98 | slow4 = slow4.Insert(addr4[:], cidr, peers[index])
99 |
100 | var addr6 [16]byte
101 | rand.Read(addr6[:])
102 | cidr = uint8(rand.Intn(128) + 1)
103 | index = rand.Intn(NumberOfPeers)
104 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
105 | slow6 = slow6.Insert(addr6[:], cidr, peers[index])
106 | }
107 |
108 | var p int
109 | for p = 0; ; p++ {
110 | for n := 0; n < NumberOfTests; n++ {
111 | var addr4 [4]byte
112 | rand.Read(addr4[:])
113 | peer1 := slow4.Lookup(addr4[:])
114 | peer2 := allowedIPs.Lookup(addr4[:])
115 | if peer1 != peer2 {
116 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
117 | }
118 |
119 | var addr6 [16]byte
120 | rand.Read(addr6[:])
121 | peer1 = slow6.Lookup(addr6[:])
122 | peer2 = allowedIPs.Lookup(addr6[:])
123 | if peer1 != peer2 {
124 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
125 | }
126 | }
127 | if p >= len(peers) || p >= NumberOfPeerRemovals {
128 | break
129 | }
130 | allowedIPs.RemoveByPeer(peers[p])
131 | slow4 = slow4.RemoveByPeer(peers[p])
132 | slow6 = slow6.RemoveByPeer(peers[p])
133 | }
134 | for ; p < len(peers); p++ {
135 | allowedIPs.RemoveByPeer(peers[p])
136 | }
137 |
138 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
139 | t.Error("Failed to remove all nodes from trie by peer")
140 | }
141 | }
142 |
--------------------------------------------------------------------------------
/conn/bindtest/bindtest.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package bindtest
7 |
8 | import (
9 | "fmt"
10 | "math/rand"
11 | "net"
12 | "net/netip"
13 | "os"
14 |
15 | "golang.zx2c4.com/wireguard/conn"
16 | )
17 |
18 | type ChannelBind struct {
19 | rx4, tx4 *chan []byte
20 | rx6, tx6 *chan []byte
21 | closeSignal chan bool
22 | source4, source6 ChannelEndpoint
23 | target4, target6 ChannelEndpoint
24 | }
25 |
26 | type ChannelEndpoint uint16
27 |
28 | var (
29 | _ conn.Bind = (*ChannelBind)(nil)
30 | _ conn.Endpoint = (*ChannelEndpoint)(nil)
31 | )
32 |
33 | func NewChannelBinds() [2]conn.Bind {
34 | arx4 := make(chan []byte, 8192)
35 | brx4 := make(chan []byte, 8192)
36 | arx6 := make(chan []byte, 8192)
37 | brx6 := make(chan []byte, 8192)
38 | var binds [2]ChannelBind
39 | binds[0].rx4 = &arx4
40 | binds[0].tx4 = &brx4
41 | binds[1].rx4 = &brx4
42 | binds[1].tx4 = &arx4
43 | binds[0].rx6 = &arx6
44 | binds[0].tx6 = &brx6
45 | binds[1].rx6 = &brx6
46 | binds[1].tx6 = &arx6
47 | binds[0].target4 = ChannelEndpoint(1)
48 | binds[1].target4 = ChannelEndpoint(2)
49 | binds[0].target6 = ChannelEndpoint(3)
50 | binds[1].target6 = ChannelEndpoint(4)
51 | binds[0].source4 = binds[1].target4
52 | binds[0].source6 = binds[1].target6
53 | binds[1].source4 = binds[0].target4
54 | binds[1].source6 = binds[0].target6
55 | return [2]conn.Bind{&binds[0], &binds[1]}
56 | }
57 |
58 | func (c ChannelEndpoint) ClearSrc() {}
59 |
60 | func (c ChannelEndpoint) SrcToString() string { return "" }
61 |
62 | func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
63 |
64 | func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
65 |
66 | func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
67 |
68 | func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
69 |
70 | func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
71 | c.closeSignal = make(chan bool)
72 | fns = append(fns, c.makeReceiveFunc(*c.rx4))
73 | fns = append(fns, c.makeReceiveFunc(*c.rx6))
74 | if rand.Uint32()&1 == 0 {
75 | return fns, uint16(c.source4), nil
76 | } else {
77 | return fns, uint16(c.source6), nil
78 | }
79 | }
80 |
81 | func (c *ChannelBind) Close() error {
82 | if c.closeSignal != nil {
83 | select {
84 | case <-c.closeSignal:
85 | default:
86 | close(c.closeSignal)
87 | }
88 | }
89 | return nil
90 | }
91 |
92 | func (c *ChannelBind) SetMark(mark uint32) error { return nil }
93 |
94 | func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
95 | return func(b []byte) (n int, ep conn.Endpoint, err error) {
96 | select {
97 | case <-c.closeSignal:
98 | return 0, nil, net.ErrClosed
99 | case rx := <-ch:
100 | return copy(b, rx), c.target6, nil
101 | }
102 | }
103 | }
104 |
105 | func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
106 | select {
107 | case <-c.closeSignal:
108 | return net.ErrClosed
109 | default:
110 | bc := make([]byte, len(b))
111 | copy(bc, b)
112 | if ep.(ChannelEndpoint) == c.target4 {
113 | *c.tx4 <- bc
114 | } else if ep.(ChannelEndpoint) == c.target6 {
115 | *c.tx6 <- bc
116 | } else {
117 | return os.ErrInvalid
118 | }
119 | }
120 | return nil
121 | }
122 |
123 | func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
124 | addr, err := netip.ParseAddrPort(s)
125 | if err != nil {
126 | return nil, err
127 | }
128 | return ChannelEndpoint(addr.Port()), nil
129 | }
130 |
--------------------------------------------------------------------------------
/device/channels.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "runtime"
10 | "sync"
11 | )
12 |
13 | // An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
14 | // An outboundQueue is ref-counted using its wg field.
15 | // An outboundQueue created with newOutboundQueue has one reference.
16 | // Every additional writer must call wg.Add(1).
17 | // Every completed writer must call wg.Done().
18 | // When no further writers will be added,
19 | // call wg.Done to remove the initial reference.
20 | // When the refcount hits 0, the queue's channel is closed.
21 | type outboundQueue struct {
22 | c chan *QueueOutboundElement
23 | wg sync.WaitGroup
24 | }
25 |
26 | func newOutboundQueue() *outboundQueue {
27 | q := &outboundQueue{
28 | c: make(chan *QueueOutboundElement, QueueOutboundSize),
29 | }
30 | q.wg.Add(1)
31 | go func() {
32 | q.wg.Wait()
33 | close(q.c)
34 | }()
35 | return q
36 | }
37 |
38 | // A inboundQueue is similar to an outboundQueue; see those docs.
39 | type inboundQueue struct {
40 | c chan *QueueInboundElement
41 | wg sync.WaitGroup
42 | }
43 |
44 | func newInboundQueue() *inboundQueue {
45 | q := &inboundQueue{
46 | c: make(chan *QueueInboundElement, QueueInboundSize),
47 | }
48 | q.wg.Add(1)
49 | go func() {
50 | q.wg.Wait()
51 | close(q.c)
52 | }()
53 | return q
54 | }
55 |
56 | // A handshakeQueue is similar to an outboundQueue; see those docs.
57 | type handshakeQueue struct {
58 | c chan QueueHandshakeElement
59 | wg sync.WaitGroup
60 | }
61 |
62 | func newHandshakeQueue() *handshakeQueue {
63 | q := &handshakeQueue{
64 | c: make(chan QueueHandshakeElement, QueueHandshakeSize),
65 | }
66 | q.wg.Add(1)
67 | go func() {
68 | q.wg.Wait()
69 | close(q.c)
70 | }()
71 | return q
72 | }
73 |
74 | type autodrainingInboundQueue struct {
75 | c chan *QueueInboundElement
76 | }
77 |
78 | // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
79 | // It is useful in cases in which is it hard to manage the lifetime of the channel.
80 | // The returned channel must not be closed. Senders should signal shutdown using
81 | // some other means, such as sending a sentinel nil values.
82 | func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
83 | q := &autodrainingInboundQueue{
84 | c: make(chan *QueueInboundElement, QueueInboundSize),
85 | }
86 | runtime.SetFinalizer(q, device.flushInboundQueue)
87 | return q
88 | }
89 |
90 | func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
91 | for {
92 | select {
93 | case elem := <-q.c:
94 | elem.Lock()
95 | device.PutMessageBuffer(elem.buffer)
96 | device.PutInboundElement(elem)
97 | default:
98 | return
99 | }
100 | }
101 | }
102 |
103 | type autodrainingOutboundQueue struct {
104 | c chan *QueueOutboundElement
105 | }
106 |
107 | // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
108 | // It is useful in cases in which is it hard to manage the lifetime of the channel.
109 | // The returned channel must not be closed. Senders should signal shutdown using
110 | // some other means, such as sending a sentinel nil values.
111 | // All sends to the channel must be best-effort, because there may be no receivers.
112 | func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
113 | q := &autodrainingOutboundQueue{
114 | c: make(chan *QueueOutboundElement, QueueOutboundSize),
115 | }
116 | runtime.SetFinalizer(q, device.flushOutboundQueue)
117 | return q
118 | }
119 |
120 | func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
121 | for {
122 | select {
123 | case elem := <-q.c:
124 | elem.Lock()
125 | device.PutMessageBuffer(elem.buffer)
126 | device.PutOutboundElement(elem)
127 | default:
128 | return
129 | }
130 | }
131 | }
132 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Go Implementation of [WireGuard](https://www.wireguard.com/)
2 |
3 | This is an implementation of WireGuard in Go.
4 |
5 | ## Usage
6 |
7 | Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run:
8 |
9 | ```
10 | $ wireguard-go wg0
11 | ```
12 |
13 | This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down.
14 |
15 | To run wireguard-go without forking to the background, pass `-f` or `--foreground`:
16 |
17 | ```
18 | $ wireguard-go -f wg0
19 | ```
20 |
21 | When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
22 |
23 | To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
24 |
25 | ## Platforms
26 |
27 | ### Linux
28 |
29 | This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
30 |
31 | ### macOS
32 |
33 | This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
34 |
35 | ### Windows
36 |
37 | This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module.
38 |
39 | ### FreeBSD
40 |
41 | This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`.
42 |
43 | ### OpenBSD
44 |
45 | This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
46 |
47 | ## Building
48 |
49 | This requires an installation of [go](https://golang.org) ≥ 1.18.
50 |
51 | ```
52 | $ git clone https://git.zx2c4.com/wireguard-go
53 | $ cd wireguard-go
54 | $ make
55 | ```
56 |
57 | ## License
58 |
59 | Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
60 |
61 | Permission is hereby granted, free of charge, to any person obtaining a copy of
62 | this software and associated documentation files (the "Software"), to deal in
63 | the Software without restriction, including without limitation the rights to
64 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
65 | of the Software, and to permit persons to whom the Software is furnished to do
66 | so, subject to the following conditions:
67 |
68 | The above copyright notice and this permission notice shall be included in all
69 | copies or substantial portions of the Software.
70 |
71 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
72 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
73 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
74 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
75 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
76 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
77 | SOFTWARE.
78 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
2 | github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
3 | github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU=
4 | github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA=
5 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
6 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
7 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
8 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
9 | github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
10 | github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
11 | github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
12 | github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
13 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
14 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
15 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
16 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
17 | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
18 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
19 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
20 | github.com/refraction-networking/utls v1.6.7 h1:zVJ7sP1dJx/WtVuITug3qYUq034cDq9B2MR1K67ULZM=
21 | github.com/refraction-networking/utls v1.6.7/go.mod h1:BC3O4vQzye5hqpmDTWUqi4P5DDhzJfkV1tdqtawQIH0=
22 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
23 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
24 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
25 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
26 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
27 | github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
28 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
29 | golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
30 | golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
31 | golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
32 | golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
33 | golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
34 | golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
35 | golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
36 | golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
37 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
38 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
39 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
40 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
41 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
42 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
43 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
44 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
45 | gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY=
46 | gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA=
47 |
--------------------------------------------------------------------------------
/tun/tuntest/tuntest.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package tuntest
7 |
8 | import (
9 | "encoding/binary"
10 | "io"
11 | "net/netip"
12 | "os"
13 |
14 | "golang.zx2c4.com/wireguard/tun"
15 | )
16 |
17 | func Ping(dst, src netip.Addr) []byte {
18 | localPort := uint16(1337)
19 | seq := uint16(0)
20 |
21 | payload := make([]byte, 4)
22 | binary.BigEndian.PutUint16(payload[0:], localPort)
23 | binary.BigEndian.PutUint16(payload[2:], seq)
24 |
25 | return genICMPv4(payload, dst, src)
26 | }
27 |
28 | // Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
29 | func checksum(buf []byte, initial uint16) uint16 {
30 | v := uint32(initial)
31 | for i := 0; i < len(buf)-1; i += 2 {
32 | v += uint32(binary.BigEndian.Uint16(buf[i:]))
33 | }
34 | if len(buf)%2 == 1 {
35 | v += uint32(buf[len(buf)-1]) << 8
36 | }
37 | for v > 0xffff {
38 | v = (v >> 16) + (v & 0xffff)
39 | }
40 | return ^uint16(v)
41 | }
42 |
43 | func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
44 | const (
45 | icmpv4ProtocolNumber = 1
46 | icmpv4Echo = 8
47 | icmpv4ChecksumOffset = 2
48 | icmpv4Size = 8
49 | ipv4Size = 20
50 | ipv4TotalLenOffset = 2
51 | ipv4ChecksumOffset = 10
52 | ttl = 65
53 | headerSize = ipv4Size + icmpv4Size
54 | )
55 |
56 | pkt := make([]byte, headerSize+len(payload))
57 |
58 | ip := pkt[0:ipv4Size]
59 | icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size]
60 |
61 | // https://tools.ietf.org/html/rfc792
62 | icmpv4[0] = icmpv4Echo // type
63 | icmpv4[1] = 0 // code
64 | chksum := ^checksum(icmpv4, checksum(payload, 0))
65 | binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
66 |
67 | // https://tools.ietf.org/html/rfc760 section 3.1
68 | length := uint16(len(pkt))
69 | ip[0] = (4 << 4) | (ipv4Size / 4)
70 | binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
71 | ip[8] = ttl
72 | ip[9] = icmpv4ProtocolNumber
73 | copy(ip[12:], src.AsSlice())
74 | copy(ip[16:], dst.AsSlice())
75 | chksum = ^checksum(ip[:], 0)
76 | binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
77 |
78 | copy(pkt[headerSize:], payload)
79 | return pkt
80 | }
81 |
82 | type ChannelTUN struct {
83 | Inbound chan []byte // incoming packets, closed on TUN close
84 | Outbound chan []byte // outbound packets, blocks forever on TUN close
85 |
86 | closed chan struct{}
87 | events chan tun.Event
88 | tun chTun
89 | }
90 |
91 | func NewChannelTUN() *ChannelTUN {
92 | c := &ChannelTUN{
93 | Inbound: make(chan []byte),
94 | Outbound: make(chan []byte),
95 | closed: make(chan struct{}),
96 | events: make(chan tun.Event, 1),
97 | }
98 | c.tun.c = c
99 | c.events <- tun.EventUp
100 | return c
101 | }
102 |
103 | func (c *ChannelTUN) TUN() tun.Device {
104 | return &c.tun
105 | }
106 |
107 | type chTun struct {
108 | c *ChannelTUN
109 | }
110 |
111 | func (t *chTun) File() *os.File { return nil }
112 |
113 | func (t *chTun) Read(data []byte, offset int) (int, error) {
114 | select {
115 | case <-t.c.closed:
116 | return 0, os.ErrClosed
117 | case msg := <-t.c.Outbound:
118 | return copy(data[offset:], msg), nil
119 | }
120 | }
121 |
122 | // Write is called by the wireguard device to deliver a packet for routing.
123 | func (t *chTun) Write(data []byte, offset int) (int, error) {
124 | if offset == -1 {
125 | close(t.c.closed)
126 | close(t.c.events)
127 | return 0, io.EOF
128 | }
129 | msg := make([]byte, len(data)-offset)
130 | copy(msg, data[offset:])
131 | select {
132 | case <-t.c.closed:
133 | return 0, os.ErrClosed
134 | case t.c.Inbound <- msg:
135 | return len(data) - offset, nil
136 | }
137 | }
138 |
139 | const DefaultMTU = 1420
140 |
141 | func (t *chTun) Flush() error { return nil }
142 | func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
143 | func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
144 | func (t *chTun) Events() <-chan tun.Event { return t.c.events }
145 | func (t *chTun) Close() error {
146 | t.Write(nil, -1)
147 | return nil
148 | }
149 |
--------------------------------------------------------------------------------
/conn/conn.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | // Package conn implements WireGuard's network connections.
7 | package conn
8 |
9 | import (
10 | "errors"
11 | "fmt"
12 | "net/netip"
13 | "reflect"
14 | "runtime"
15 | "strings"
16 | )
17 |
18 | // A ReceiveFunc receives a single inbound packet from the network.
19 | // It writes the data into b. n is the length of the packet.
20 | // ep is the remote endpoint.
21 | type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
22 |
23 | // A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
24 | //
25 | // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
26 | // depending on the platform-specific implementation.
27 | type Bind interface {
28 | // Open puts the Bind into a listening state on a given port and reports the actual
29 | // port that it bound to. Passing zero results in a random selection.
30 | // fns is the set of functions that will be called to receive packets.
31 | Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
32 |
33 | // Close closes the Bind listener.
34 | // All fns returned by Open must return net.ErrClosed after a call to Close.
35 | Close() error
36 |
37 | // SetMark sets the mark for each packet sent through this Bind.
38 | // This mark is passed to the kernel as the socket option SO_MARK.
39 | SetMark(mark uint32) error
40 |
41 | // Send writes a packet b to address ep.
42 | Send(b []byte, ep Endpoint) error
43 |
44 | // ParseEndpoint creates a new endpoint from a string.
45 | ParseEndpoint(s string) (Endpoint, error)
46 | }
47 |
48 | type Logger struct {
49 | Verbosef func(format string, args ...any)
50 | Errorf func(format string, args ...any)
51 | }
52 |
53 | // BindSocketToInterface is implemented by Bind objects that support being
54 | // tied to a single network interface. Used by wireguard-windows.
55 | type BindSocketToInterface interface {
56 | BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
57 | BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
58 | }
59 |
60 | // PeekLookAtSocketFd is implemented by Bind objects that support having their
61 | // file descriptor peeked at. Used by wireguard-android.
62 | type PeekLookAtSocketFd interface {
63 | PeekLookAtSocketFd4() (fd int, err error)
64 | PeekLookAtSocketFd6() (fd int, err error)
65 | }
66 |
67 | // An Endpoint maintains the source/destination caching for a peer.
68 | //
69 | // dst: the remote address of a peer ("endpoint" in uapi terminology)
70 | // src: the local address from which datagrams originate going to the peer
71 | type Endpoint interface {
72 | ClearSrc() // clears the source address
73 | SrcToString() string // returns the local source address (ip:port)
74 | DstToString() string // returns the destination address (ip:port)
75 | DstToBytes() []byte // used for mac2 cookie calculations
76 | DstIP() netip.Addr
77 | SrcIP() netip.Addr
78 | }
79 |
80 | var (
81 | ErrBindAlreadyOpen = errors.New("bind is already open")
82 | ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
83 | )
84 |
85 | func (fn ReceiveFunc) PrettyName() string {
86 | name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
87 | // 0. cheese/taco.beansIPv6.func12.func21218-fm
88 | name = strings.TrimSuffix(name, "-fm")
89 | // 1. cheese/taco.beansIPv6.func12.func21218
90 | if idx := strings.LastIndexByte(name, '/'); idx != -1 {
91 | name = name[idx+1:]
92 | // 2. taco.beansIPv6.func12.func21218
93 | }
94 | for {
95 | var idx int
96 | for idx = len(name) - 1; idx >= 0; idx-- {
97 | if name[idx] < '0' || name[idx] > '9' {
98 | break
99 | }
100 | }
101 | if idx == len(name)-1 {
102 | break
103 | }
104 | const dotFunc = ".func"
105 | if !strings.HasSuffix(name[:idx+1], dotFunc) {
106 | break
107 | }
108 | name = name[:idx+1-len(dotFunc)]
109 | // 3. taco.beansIPv6.func12
110 | // 4. taco.beansIPv6
111 | }
112 | if idx := strings.LastIndexByte(name, '.'); idx != -1 {
113 | name = name[idx+1:]
114 | // 5. beansIPv6
115 | }
116 | if name == "" {
117 | return fmt.Sprintf("%p", fn)
118 | }
119 | if strings.HasSuffix(name, "IPv4") {
120 | return "v4"
121 | }
122 | if strings.HasSuffix(name, "IPv6") {
123 | return "v6"
124 | }
125 | return name
126 | }
127 |
--------------------------------------------------------------------------------
/device/noise_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "bytes"
10 | "encoding/binary"
11 | "testing"
12 |
13 | "golang.zx2c4.com/wireguard/conn"
14 | "golang.zx2c4.com/wireguard/tun/tuntest"
15 | )
16 |
17 | func TestCurveWrappers(t *testing.T) {
18 | sk1, err := newPrivateKey()
19 | assertNil(t, err)
20 |
21 | sk2, err := newPrivateKey()
22 | assertNil(t, err)
23 |
24 | pk1 := sk1.publicKey()
25 | pk2 := sk2.publicKey()
26 |
27 | ss1, err1 := sk1.sharedSecret(pk2)
28 | ss2, err2 := sk2.sharedSecret(pk1)
29 |
30 | if ss1 != ss2 || err1 != nil || err2 != nil {
31 | t.Fatal("Failed to compute shared secet")
32 | }
33 | }
34 |
35 | func randDevice(t *testing.T) *Device {
36 | sk, err := newPrivateKey()
37 | if err != nil {
38 | t.Fatal(err)
39 | }
40 | tun := tuntest.NewChannelTUN()
41 | logger := NewLogger(LogLevelError, "")
42 | device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger, make(chan HandshakeState))
43 | device.SetPrivateKey(sk)
44 | return device
45 | }
46 |
47 | func assertNil(t *testing.T, err error) {
48 | if err != nil {
49 | t.Fatal(err)
50 | }
51 | }
52 |
53 | func assertEqual(t *testing.T, a, b []byte) {
54 | if !bytes.Equal(a, b) {
55 | t.Fatal(a, "!=", b)
56 | }
57 | }
58 |
59 | func TestNoiseHandshake(t *testing.T) {
60 | dev1 := randDevice(t)
61 | dev2 := randDevice(t)
62 |
63 | defer dev1.Close()
64 | defer dev2.Close()
65 |
66 | peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
67 | if err != nil {
68 | t.Fatal(err)
69 | }
70 | peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
71 | if err != nil {
72 | t.Fatal(err)
73 | }
74 | peer1.Start()
75 | peer2.Start()
76 |
77 | assertEqual(
78 | t,
79 | peer1.handshake.precomputedStaticStatic[:],
80 | peer2.handshake.precomputedStaticStatic[:],
81 | )
82 |
83 | /* simulate handshake */
84 |
85 | // initiation message
86 |
87 | t.Log("exchange initiation message")
88 |
89 | msg1, err := dev1.CreateMessageInitiation(peer2)
90 | assertNil(t, err)
91 |
92 | packet := make([]byte, 0, 256)
93 | writer := bytes.NewBuffer(packet)
94 | err = binary.Write(writer, binary.LittleEndian, msg1)
95 | assertNil(t, err)
96 | peer := dev2.ConsumeMessageInitiation(msg1)
97 | if peer == nil {
98 | t.Fatal("handshake failed at initiation message")
99 | }
100 |
101 | assertEqual(
102 | t,
103 | peer1.handshake.chainKey[:],
104 | peer2.handshake.chainKey[:],
105 | )
106 |
107 | assertEqual(
108 | t,
109 | peer1.handshake.hash[:],
110 | peer2.handshake.hash[:],
111 | )
112 |
113 | // response message
114 |
115 | t.Log("exchange response message")
116 |
117 | msg2, err := dev2.CreateMessageResponse(peer1)
118 | assertNil(t, err)
119 |
120 | peer = dev1.ConsumeMessageResponse(msg2)
121 | if peer == nil {
122 | t.Fatal("handshake failed at response message")
123 | }
124 |
125 | assertEqual(
126 | t,
127 | peer1.handshake.chainKey[:],
128 | peer2.handshake.chainKey[:],
129 | )
130 |
131 | assertEqual(
132 | t,
133 | peer1.handshake.hash[:],
134 | peer2.handshake.hash[:],
135 | )
136 |
137 | // key pairs
138 |
139 | t.Log("deriving keys")
140 |
141 | err = peer1.BeginSymmetricSession()
142 | if err != nil {
143 | t.Fatal("failed to derive keypair for peer 1", err)
144 | }
145 |
146 | err = peer2.BeginSymmetricSession()
147 | if err != nil {
148 | t.Fatal("failed to derive keypair for peer 2", err)
149 | }
150 |
151 | key1 := peer1.keypairs.next.Load()
152 | key2 := peer2.keypairs.current
153 |
154 | // encrypting / decryption test
155 |
156 | t.Log("test key pairs")
157 |
158 | func() {
159 | testMsg := []byte("wireguard test message 1")
160 | var err error
161 | var out []byte
162 | var nonce [12]byte
163 | out = key1.send.Seal(out, nonce[:], testMsg, nil)
164 | out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
165 | assertNil(t, err)
166 | assertEqual(t, out, testMsg)
167 | }()
168 |
169 | func() {
170 | testMsg := []byte("wireguard test message 2")
171 | var err error
172 | var out []byte
173 | var nonce [12]byte
174 | out = key2.send.Seal(out, nonce[:], testMsg, nil)
175 | out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
176 | assertNil(t, err)
177 | assertEqual(t, out, testMsg)
178 | }()
179 | }
180 |
--------------------------------------------------------------------------------
/conn/tcp_tls_utils.go:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2022. Proton AG
3 | *
4 | * This file is part of ProtonVPN.
5 | *
6 | * ProtonVPN is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * ProtonVPN is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with ProtonVPN. If not, see .
18 | */
19 |
20 | package conn
21 |
22 | import (
23 | "bytes"
24 | "encoding/binary"
25 | "errors"
26 | )
27 |
28 | var wgDataPrefix = []byte{4, 0, 0, 0}
29 | var wgDataHeaderSize = 16
30 | var wgDataPrefixSize = 8 // Wireguard data header without counter
31 |
32 | var tunSafeHeaderSize = 2
33 | var tunSafeNormalType = uint8(0b00)
34 | var tunSafeDataType = uint8(0b10)
35 |
36 | type TunSafeData struct {
37 | wgSendPrefix []byte
38 | wgSendCount uint64
39 | wgRecvPrefix []byte
40 | wgRecvCount uint64
41 | }
42 |
43 |
44 | func NewTunSafeData() *TunSafeData {
45 | return &TunSafeData{
46 | wgRecvPrefix: make([]byte, 8),
47 | wgSendPrefix: make([]byte, 8),
48 | }
49 | }
50 |
51 | // Returns (type, size)
52 | func parseTunSafeHeader(header []byte) (byte, int) {
53 | tunSafeType := header[0] >> 6
54 | size := (int(header[0])&0b00111111)<<8 | int(header[1])
55 | return tunSafeType, size
56 | }
57 |
58 | func (tunSafe *TunSafeData) clear() {
59 | tunSafe.wgSendCount = 0
60 | tunSafe.wgRecvCount = 0
61 | }
62 |
63 | func (tunSafe *TunSafeData) writeWgHeader(wgPacket []byte) {
64 | buffer := new(bytes.Buffer)
65 | buffer.Grow(len(tunSafe.wgRecvPrefix) + binary.Size(tunSafe.wgRecvCount))
66 | buffer.Write(tunSafe.wgRecvPrefix)
67 | _ = binary.Write(buffer, binary.LittleEndian, tunSafe.wgRecvCount)
68 | copy(wgPacket, buffer.Bytes())
69 | }
70 |
71 | func (tunSafe *TunSafeData) prepareWgPacket(tunSafeType byte, payloadSize int) ([]byte, int, error) {
72 | var wgPacket []byte
73 | offset := 0
74 | switch tunSafeType {
75 | case tunSafeNormalType:
76 | wgPacket = make([]byte, payloadSize)
77 | case tunSafeDataType:
78 | offset = wgDataHeaderSize
79 | wgPacket = make([]byte, payloadSize+offset)
80 | tunSafe.writeWgHeader(wgPacket)
81 | default:
82 | return nil, 0, errors.New("StdNetBindTcp: unknown TunSafe type")
83 | }
84 | return wgPacket, offset, nil
85 | }
86 |
87 | func (tunSafe *TunSafeData) onRecvPacket(tunSafeType byte, wgPacket []byte) {
88 | if tunSafeType == tunSafeNormalType {
89 | isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix)
90 | if isWgDataPacket {
91 | copy(tunSafe.wgRecvPrefix, wgPacket[:wgDataPrefixSize])
92 | countBuffer := bytes.NewBuffer(wgPacket[wgDataPrefixSize:wgDataHeaderSize])
93 | _ = binary.Read(countBuffer, binary.LittleEndian, &tunSafe.wgRecvCount)
94 | }
95 | }
96 | tunSafe.wgRecvCount++
97 | }
98 |
99 | func (tunSafe *TunSafeData) wgToTunSafe(wgPacket []byte) []byte {
100 | wgLen := len(wgPacket)
101 | if wgLen < wgDataHeaderSize {
102 | return wgToTunSafeNormal(wgPacket)
103 | }
104 | wgPrefix := wgPacket[:wgDataPrefixSize]
105 | var wgCount uint64
106 | _ = binary.Read(bytes.NewReader(wgPacket[wgDataPrefixSize:wgDataHeaderSize]), binary.LittleEndian, &wgCount)
107 | prefixMatch := bytes.Equal(wgPrefix, tunSafe.wgSendPrefix)
108 | if prefixMatch && wgCount == tunSafe.wgSendCount+1 {
109 | tunSafe.wgSendCount += 1
110 | return wgToTunSafeData(wgPacket)
111 | } else {
112 | isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix)
113 | if isWgDataPacket {
114 | tunSafe.wgSendPrefix = wgPrefix
115 | tunSafe.wgSendCount = wgCount
116 | }
117 | return wgToTunSafeNormal(wgPacket)
118 | }
119 | }
120 |
121 | func wgToTunSafeNormal(wgPacket []byte) []byte {
122 | payloadSize := len(wgPacket)
123 | result := make([]byte, payloadSize+tunSafeHeaderSize)
124 |
125 | // Tunsafe normal header
126 | result[0] = uint8(payloadSize >> 8)
127 | result[1] = uint8(payloadSize & 0xff)
128 |
129 | // Full packet
130 | copy(result[tunSafeHeaderSize:], wgPacket)
131 |
132 | return result
133 | }
134 |
135 | func wgToTunSafeData(wgPacket []byte) []byte {
136 | payloadSize := len(wgPacket) - wgDataHeaderSize
137 | result := make([]byte, payloadSize+tunSafeHeaderSize)
138 |
139 | // TunSafe data header
140 | result[0] = uint8(0b10<<6 | payloadSize>>8)
141 | result[1] = uint8(payloadSize & 0xff)
142 |
143 | // Packet without header
144 | copy(result[tunSafeHeaderSize:], wgPacket[wgDataHeaderSize:])
145 |
146 | return result
147 | }
148 |
--------------------------------------------------------------------------------
/device/statemanager_test.go:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2022. Proton AG
3 | *
4 | * This file is part of ProtonVPN.
5 | *
6 | * ProtonVPN is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * ProtonVPN is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with ProtonVPN. If not, see .
18 | */
19 |
20 | package device
21 |
22 | import (
23 | "errors"
24 | "github.com/stretchr/testify/assert"
25 | "testing"
26 | "time"
27 | )
28 |
29 | var timeMs int64 = 0
30 | var mockDevice MockDevice
31 | var manager *WireGuardStateManager
32 | var lastState WireGuardState
33 |
34 | type MockDevice struct {
35 | isUp bool
36 | upCount int
37 | }
38 |
39 | func (dev *MockDevice) Up() error {
40 | dev.isUp = true
41 | dev.upCount++
42 | return nil
43 | }
44 |
45 | func (dev *MockDevice) Down() error {
46 | dev.isUp = false
47 | return nil
48 | }
49 |
50 | func setup() {
51 | timeMs = 0
52 | timeNow = func() time.Time { return time.UnixMilli(timeMs) }
53 | mockDevice.isUp = false
54 |
55 | manager = NewWireGuardStateManager(NewLogger(LogLevelVerbose, ""), false)
56 | manager.Start(&mockDevice)
57 | lastState = WireGuardDisabled
58 | go func() {
59 | for lastState != -1 {
60 | lastState = manager.GetState()
61 | }
62 | }()
63 | }
64 |
65 | func setdown() {
66 | manager.Close()
67 | }
68 |
69 | func TestWireGuardStateManager_shouldRestart(t *testing.T) {
70 | assert := assert.New(t)
71 | setup()
72 | defer setdown()
73 |
74 | assert.Equal(initialRestartDelay, manager.nextRestartDelay)
75 |
76 | assert.Equal(false, manager.shouldRestart())
77 | timeMs += initialRestartDelay.Milliseconds()
78 | assert.Equal(false, manager.shouldRestart())
79 | timeMs += 1
80 | assert.Equal(true, manager.shouldRestart())
81 |
82 | assert.Equal(2*initialRestartDelay, manager.nextRestartDelay)
83 | assert.Equal(false, manager.shouldRestart())
84 | timeMs += 2 * initialRestartDelay.Milliseconds()
85 | assert.Equal(false, manager.shouldRestart())
86 | timeMs += 1
87 | assert.Equal(true, manager.shouldRestart())
88 |
89 | timeMs += resetRestartDelay.Milliseconds() + 1
90 | assert.Equal(true, manager.shouldRestart())
91 | assert.Equal(initialRestartDelay, manager.nextRestartDelay)
92 | }
93 |
94 | func TestWireGuardStateManager_networkStartsAndStopsDevice(t *testing.T) {
95 | assert := assert.New(t)
96 | setup()
97 | defer setdown()
98 |
99 | assert.Equal(false, mockDevice.isUp)
100 | manager.SetNetworkAvailable(true)
101 | time.Sleep(time.Millisecond) // Poor substitute for advanceUntilIdle, make sure goroutines finish before checking
102 | assert.Equal(true, mockDevice.isUp)
103 | assert.Equal(WireGuardConnecting, lastState)
104 | manager.SetNetworkAvailable(false)
105 | time.Sleep(time.Millisecond)
106 | assert.Equal(WireGuardWaitingForNetwork, lastState)
107 | assert.Equal(false, mockDevice.isUp)
108 | }
109 |
110 | func TestWireGuardStateManager_happyConnectionPath(t *testing.T) {
111 | assert := assert.New(t)
112 | setup()
113 | defer setdown()
114 |
115 | manager.SetNetworkAvailable(true)
116 | time.Sleep(time.Millisecond)
117 | manager.HandshakeStateChan <- HandshakeSuccess
118 | time.Sleep(time.Millisecond)
119 | assert.Equal(WireGuardConnected, lastState)
120 | assert.Equal(true, mockDevice.isUp)
121 | }
122 |
123 | func TestWireGuardStateManager_handshakeFailCausesRestart(t *testing.T) {
124 | assert := assert.New(t)
125 | setup()
126 | defer setdown()
127 |
128 | manager.SetNetworkAvailable(true)
129 | time.Sleep(time.Millisecond)
130 | manager.HandshakeStateChan <- HandshakeFail
131 | time.Sleep(time.Millisecond)
132 | assert.Equal(WireGuardError, lastState)
133 | timeMs += initialRestartDelay.Milliseconds() + 1
134 | manager.HandshakeStateChan <- HandshakeFail
135 | time.Sleep(time.Millisecond)
136 | assert.Equal(WireGuardConnecting, lastState)
137 | assert.Equal(2, mockDevice.upCount)
138 | }
139 |
140 | func TestWireGuardStateManager_brokenPipeCausesRestart(t *testing.T) {
141 | assert := assert.New(t)
142 | setup()
143 | defer setdown()
144 |
145 | manager.SetNetworkAvailable(true)
146 | timeMs += initialRestartDelay.Milliseconds() + 1
147 | time.Sleep(time.Millisecond)
148 | manager.SocketErrChan <- errors.New("broken pipe")
149 | time.Sleep(time.Millisecond)
150 | assert.Equal(WireGuardConnecting, lastState)
151 | assert.Equal(2, mockDevice.upCount)
152 | }
153 |
--------------------------------------------------------------------------------
/conn/server_name_utils/server_name_utils.go:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2025. Proton AG
3 | *
4 | * This file is part of ProtonVPN.
5 | *
6 | * ProtonVPN is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * ProtonVPN is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU General Public License
17 | * along with ProtonVPN. If not, see .
18 | */
19 |
20 | package server_name_utils
21 |
22 | import (
23 | cryptoRand "crypto/rand"
24 | "math/big"
25 | "math/rand"
26 | "time"
27 | )
28 |
29 | type ServerNameStrategy int
30 |
31 | const (
32 | ServerNameRandom ServerNameStrategy = iota
33 | ServerNameTop
34 | )
35 |
36 | var topLevelDomains = []string{"com", "net", "org", "it", "fr", "me", "ru", "cn", "es", "tr", "top", "xyz", "info"}
37 |
38 | var domains = []string{
39 | "accounts.google.com",
40 | "activity.windows.com",
41 | "analytics.apis.mcafee.com",
42 | "android.apis.google.com",
43 | "android.googleapis.com",
44 | "api.account.samsung.com",
45 | "api.accounts.firefox.com",
46 | "api.accuweather.com",
47 | "api.amazon.com",
48 | "api.browser.yandex.net",
49 | "api.ipify.org",
50 | "api.onedrive.com",
51 | "api.reasonsecurity.com",
52 | "api.samsungcloud.com",
53 | "api.sec.intl.miui.com",
54 | "api.vk.com",
55 | "api.weather.com",
56 | "app-site-association.cdn-apple.com",
57 | "apps.mzstatic.com",
58 | "assets.msn.com",
59 | "backup.googleapis.com",
60 | "brave-core-ext.s3.brave.com",
61 | "caldav.calendar.yahoo.com",
62 | "cc-api-data.adobe.io",
63 | "cdn.ampproject.org",
64 | "cdn.cookielaw.org",
65 | "client.wns.windows.com",
66 | "cloudflare.com",
67 | "cloudflare-dns.com",
68 | "cloudflare-ech.com",
69 | "config.extension.grammarly.com",
70 | "connectivitycheck.android.com",
71 | "connectivitycheck.gstatic.com",
72 | "courier.push.apple.com",
73 | "crl.globalsign.com",
74 | "dc1-file.ksn.kaspersky-labs.com",
75 | "dl.google.com",
76 | "dns.google",
77 | "dns.quad9.net",
78 | "doh.cleanbrowsing.org",
79 | "doh.dns.apple.com",
80 | "doh.opendns.com",
81 | "doh.pub",
82 | "ds.kaspersky.com",
83 | "ecs.office.com",
84 | "edge.microsoft.com",
85 | "events.gfe.nvidia.com",
86 | "excess.duolingo.com",
87 | "firefox.settings.services.mozilla.com",
88 | "fonts.googleapis.com",
89 | "fonts.gstatic.com",
90 | "gateway-asset.icloud-content.com",
91 | "gateway.icloud.com",
92 | "gdmf.apple.com",
93 | "github.com",
94 | "go.microsoft.com",
95 | "go-updater.brave.com",
96 | "graph.microsoft.com",
97 | "gs-loc.apple.com",
98 | "gtglobal.intl.miui.com",
99 | "hcaptcha.com",
100 | "imap.gmail.com",
101 | "imap-mail.outlook.com",
102 | "imap.mail.yahoo.com",
103 | "in.appcenter.ms",
104 | "ipmcdn.avast.com",
105 | "itunes.apple.com",
106 | "loc.map.baidu.com",
107 | "login.live.com",
108 | "login.microsoftonline.com",
109 | "m.media-amazon.com",
110 | "mobile.events.data.microsoft.com",
111 | "mozilla.cloudflare-dns.com",
112 | "mtalk.google.com",
113 | "nimbus.bitdefender.net",
114 | "ocsp2.apple.com",
115 | "outlook.office365.com",
116 | "play-fe.googleapis.com",
117 | "play.googleapis.com",
118 | "play.samsungcloud.com",
119 | "raw.githubusercontent.com",
120 | "s3.amazonaws.com",
121 | "safebrowsing.googleapis.com",
122 | "s.alicdn.com",
123 | "self.events.data.microsoft.com",
124 | "settings-win.data.microsoft.com",
125 | "setup.icloud.com",
126 | "sirius.mwbsys.com",
127 | "spoc.norton.com",
128 | "ssl.gstatic.com",
129 | "translate.goo",
130 | "unpkg.com",
131 | "update.googleapis.com",
132 | "weatherapi.intl.xiaomi.com",
133 | "weatherkit.apple.com",
134 | "westus-0.in.applicationinsights.azure.com",
135 | "www.googleapis.com",
136 | "www.gstatic.com",
137 | "www.msftconnecttest.com",
138 | "www.msftncsi.com",
139 | "www.ntppool.org",
140 | "www.pool.ntp.org",
141 | "www.recaptcha.net",
142 | }
143 |
144 | var domainsSortedByHashes []HashedValue = sortValuesByHash(domains, crc32Hash)
145 |
146 | func ServerNameFor(strategy ServerNameStrategy, addr string) string {
147 | switch strategy {
148 | case ServerNameTop:
149 | return serverNameForAddr(addr)
150 | case ServerNameRandom:
151 | return randomServerName()
152 | default:
153 | return randomServerName()
154 | }
155 | }
156 |
157 | func serverNameForAddr(addr string) string {
158 | return findClosestValue(addr, domainsSortedByHashes, crc32Hash)
159 | }
160 |
161 | func randomServerName() string {
162 | charNum := int('z') - int('a') + 1
163 | size := 3 + randInt(10)
164 | name := make([]byte, size)
165 | for i := range name {
166 | name[i] = byte(int('a') + randInt(charNum))
167 | }
168 | return string(name) + "." + randItem(topLevelDomains)
169 | }
170 |
171 | func randItem(list []string) string {
172 | return list[randInt(len(list))]
173 | }
174 |
175 | func randInt(n int) int {
176 | size, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(n)))
177 | if err == nil {
178 | return int(size.Int64())
179 | }
180 | rand.Seed(time.Now().UnixNano())
181 | return rand.Intn(n)
182 | }
183 |
--------------------------------------------------------------------------------
/device/cookie.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "crypto/hmac"
10 | "crypto/rand"
11 | "sync"
12 | "time"
13 |
14 | "golang.org/x/crypto/blake2s"
15 | "golang.org/x/crypto/chacha20poly1305"
16 | )
17 |
18 | type CookieChecker struct {
19 | sync.RWMutex
20 | mac1 struct {
21 | key [blake2s.Size]byte
22 | }
23 | mac2 struct {
24 | secret [blake2s.Size]byte
25 | secretSet time.Time
26 | encryptionKey [chacha20poly1305.KeySize]byte
27 | }
28 | }
29 |
30 | type CookieGenerator struct {
31 | sync.RWMutex
32 | mac1 struct {
33 | key [blake2s.Size]byte
34 | }
35 | mac2 struct {
36 | cookie [blake2s.Size128]byte
37 | cookieSet time.Time
38 | hasLastMAC1 bool
39 | lastMAC1 [blake2s.Size128]byte
40 | encryptionKey [chacha20poly1305.KeySize]byte
41 | }
42 | }
43 |
44 | func (st *CookieChecker) Init(pk NoisePublicKey) {
45 | st.Lock()
46 | defer st.Unlock()
47 |
48 | // mac1 state
49 |
50 | func() {
51 | hash, _ := blake2s.New256(nil)
52 | hash.Write([]byte(WGLabelMAC1))
53 | hash.Write(pk[:])
54 | hash.Sum(st.mac1.key[:0])
55 | }()
56 |
57 | // mac2 state
58 |
59 | func() {
60 | hash, _ := blake2s.New256(nil)
61 | hash.Write([]byte(WGLabelCookie))
62 | hash.Write(pk[:])
63 | hash.Sum(st.mac2.encryptionKey[:0])
64 | }()
65 |
66 | st.mac2.secretSet = time.Time{}
67 | }
68 |
69 | func (st *CookieChecker) CheckMAC1(msg []byte) bool {
70 | st.RLock()
71 | defer st.RUnlock()
72 |
73 | size := len(msg)
74 | smac2 := size - blake2s.Size128
75 | smac1 := smac2 - blake2s.Size128
76 |
77 | var mac1 [blake2s.Size128]byte
78 |
79 | mac, _ := blake2s.New128(st.mac1.key[:])
80 | mac.Write(msg[:smac1])
81 | mac.Sum(mac1[:0])
82 |
83 | return hmac.Equal(mac1[:], msg[smac1:smac2])
84 | }
85 |
86 | func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
87 | st.RLock()
88 | defer st.RUnlock()
89 |
90 | if time.Since(st.mac2.secretSet) > CookieRefreshTime {
91 | return false
92 | }
93 |
94 | // derive cookie key
95 |
96 | var cookie [blake2s.Size128]byte
97 | func() {
98 | mac, _ := blake2s.New128(st.mac2.secret[:])
99 | mac.Write(src)
100 | mac.Sum(cookie[:0])
101 | }()
102 |
103 | // calculate mac of packet (including mac1)
104 |
105 | smac2 := len(msg) - blake2s.Size128
106 |
107 | var mac2 [blake2s.Size128]byte
108 | func() {
109 | mac, _ := blake2s.New128(cookie[:])
110 | mac.Write(msg[:smac2])
111 | mac.Sum(mac2[:0])
112 | }()
113 |
114 | return hmac.Equal(mac2[:], msg[smac2:])
115 | }
116 |
117 | func (st *CookieChecker) CreateReply(
118 | msg []byte,
119 | recv uint32,
120 | src []byte,
121 | ) (*MessageCookieReply, error) {
122 | st.RLock()
123 |
124 | // refresh cookie secret
125 |
126 | if time.Since(st.mac2.secretSet) > CookieRefreshTime {
127 | st.RUnlock()
128 | st.Lock()
129 | _, err := rand.Read(st.mac2.secret[:])
130 | if err != nil {
131 | st.Unlock()
132 | return nil, err
133 | }
134 | st.mac2.secretSet = time.Now()
135 | st.Unlock()
136 | st.RLock()
137 | }
138 |
139 | // derive cookie
140 |
141 | var cookie [blake2s.Size128]byte
142 | func() {
143 | mac, _ := blake2s.New128(st.mac2.secret[:])
144 | mac.Write(src)
145 | mac.Sum(cookie[:0])
146 | }()
147 |
148 | // encrypt cookie
149 |
150 | size := len(msg)
151 |
152 | smac2 := size - blake2s.Size128
153 | smac1 := smac2 - blake2s.Size128
154 |
155 | reply := new(MessageCookieReply)
156 | reply.Type = MessageCookieReplyType
157 | reply.Receiver = recv
158 |
159 | _, err := rand.Read(reply.Nonce[:])
160 | if err != nil {
161 | st.RUnlock()
162 | return nil, err
163 | }
164 |
165 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
166 | xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2])
167 |
168 | st.RUnlock()
169 |
170 | return reply, nil
171 | }
172 |
173 | func (st *CookieGenerator) Init(pk NoisePublicKey) {
174 | st.Lock()
175 | defer st.Unlock()
176 |
177 | func() {
178 | hash, _ := blake2s.New256(nil)
179 | hash.Write([]byte(WGLabelMAC1))
180 | hash.Write(pk[:])
181 | hash.Sum(st.mac1.key[:0])
182 | }()
183 |
184 | func() {
185 | hash, _ := blake2s.New256(nil)
186 | hash.Write([]byte(WGLabelCookie))
187 | hash.Write(pk[:])
188 | hash.Sum(st.mac2.encryptionKey[:0])
189 | }()
190 |
191 | st.mac2.cookieSet = time.Time{}
192 | }
193 |
194 | func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
195 | st.Lock()
196 | defer st.Unlock()
197 |
198 | if !st.mac2.hasLastMAC1 {
199 | return false
200 | }
201 |
202 | var cookie [blake2s.Size128]byte
203 |
204 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
205 | _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
206 | if err != nil {
207 | return false
208 | }
209 |
210 | st.mac2.cookieSet = time.Now()
211 | st.mac2.cookie = cookie
212 | return true
213 | }
214 |
215 | func (st *CookieGenerator) AddMacs(msg []byte) {
216 | size := len(msg)
217 |
218 | smac2 := size - blake2s.Size128
219 | smac1 := smac2 - blake2s.Size128
220 |
221 | mac1 := msg[smac1:smac2]
222 | mac2 := msg[smac2:]
223 |
224 | st.Lock()
225 | defer st.Unlock()
226 |
227 | // set mac1
228 |
229 | func() {
230 | mac, _ := blake2s.New128(st.mac1.key[:])
231 | mac.Write(msg[:smac1])
232 | mac.Sum(mac1[:0])
233 | }()
234 | copy(st.mac2.lastMAC1[:], mac1)
235 | st.mac2.hasLastMAC1 = true
236 |
237 | // set mac2
238 |
239 | if time.Since(st.mac2.cookieSet) > CookieRefreshTime {
240 | return
241 | }
242 |
243 | func() {
244 | mac, _ := blake2s.New128(st.mac2.cookie[:])
245 | mac.Write(msg[:smac2])
246 | mac.Sum(mac2[:0])
247 | }()
248 | }
249 |
--------------------------------------------------------------------------------
/device/sticky_linux.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | *
5 | * This implements userspace semantics of "sticky sockets", modeled after
6 | * WireGuard's kernelspace implementation. This is more or less a straight port
7 | * of the sticky-sockets.c example code:
8 | * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
9 | *
10 | * Currently there is no way to achieve this within the net package:
11 | * See e.g. https://github.com/golang/go/issues/17930
12 | * So this code is remains platform dependent.
13 | */
14 |
15 | package device
16 |
17 | import (
18 | "sync"
19 | "unsafe"
20 |
21 | "golang.org/x/sys/unix"
22 |
23 | "golang.zx2c4.com/wireguard/conn"
24 | "golang.zx2c4.com/wireguard/rwcancel"
25 | )
26 |
27 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
28 | if _, ok := bind.(*conn.LinuxSocketBind); !ok {
29 | return nil, nil
30 | }
31 |
32 | netlinkSock, err := createNetlinkRouteSocket()
33 | if err != nil {
34 | return nil, err
35 | }
36 | netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
37 | if err != nil {
38 | unix.Close(netlinkSock)
39 | return nil, err
40 | }
41 |
42 | go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
43 |
44 | return netlinkCancel, nil
45 | }
46 |
47 | func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
48 | type peerEndpointPtr struct {
49 | peer *Peer
50 | endpoint *conn.Endpoint
51 | }
52 | var reqPeer map[uint32]peerEndpointPtr
53 | var reqPeerLock sync.Mutex
54 |
55 | defer netlinkCancel.Close()
56 | defer unix.Close(netlinkSock)
57 |
58 | for msg := make([]byte, 1<<16); ; {
59 | var err error
60 | var msgn int
61 | for {
62 | msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
63 | if err == nil || !rwcancel.RetryAfterError(err) {
64 | break
65 | }
66 | if !netlinkCancel.ReadyRead() {
67 | return
68 | }
69 | }
70 | if err != nil {
71 | return
72 | }
73 |
74 | for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
75 |
76 | hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
77 |
78 | if uint(hdr.Len) > uint(len(remain)) {
79 | break
80 | }
81 |
82 | switch hdr.Type {
83 | case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
84 | if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
85 | if uint(len(remain)) < uint(hdr.Len) {
86 | break
87 | }
88 | if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
89 | attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
90 | for {
91 | if uint(len(attr)) < uint(unix.SizeofRtAttr) {
92 | break
93 | }
94 | attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
95 | if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
96 | break
97 | }
98 | if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
99 | ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
100 | reqPeerLock.Lock()
101 | if reqPeer == nil {
102 | reqPeerLock.Unlock()
103 | break
104 | }
105 | pePtr, ok := reqPeer[hdr.Seq]
106 | reqPeerLock.Unlock()
107 | if !ok {
108 | break
109 | }
110 | pePtr.peer.Lock()
111 | if &pePtr.peer.endpoint != pePtr.endpoint {
112 | pePtr.peer.Unlock()
113 | break
114 | }
115 | if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
116 | pePtr.peer.Unlock()
117 | break
118 | }
119 | pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
120 | pePtr.peer.Unlock()
121 | }
122 | attr = attr[attrhdr.Len:]
123 | }
124 | }
125 | break
126 | }
127 | reqPeerLock.Lock()
128 | reqPeer = make(map[uint32]peerEndpointPtr)
129 | reqPeerLock.Unlock()
130 | go func() {
131 | device.peers.RLock()
132 | i := uint32(1)
133 | for _, peer := range device.peers.keyMap {
134 | peer.RLock()
135 | if peer.endpoint == nil {
136 | peer.RUnlock()
137 | continue
138 | }
139 | nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
140 | if nativeEP == nil {
141 | peer.RUnlock()
142 | continue
143 | }
144 | if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
145 | peer.RUnlock()
146 | break
147 | }
148 | nlmsg := struct {
149 | hdr unix.NlMsghdr
150 | msg unix.RtMsg
151 | dsthdr unix.RtAttr
152 | dst [4]byte
153 | srchdr unix.RtAttr
154 | src [4]byte
155 | markhdr unix.RtAttr
156 | mark uint32
157 | }{
158 | unix.NlMsghdr{
159 | Type: uint16(unix.RTM_GETROUTE),
160 | Flags: unix.NLM_F_REQUEST,
161 | Seq: i,
162 | },
163 | unix.RtMsg{
164 | Family: unix.AF_INET,
165 | Dst_len: 32,
166 | Src_len: 32,
167 | },
168 | unix.RtAttr{
169 | Len: 8,
170 | Type: unix.RTA_DST,
171 | },
172 | nativeEP.Dst4().Addr,
173 | unix.RtAttr{
174 | Len: 8,
175 | Type: unix.RTA_SRC,
176 | },
177 | nativeEP.Src4().Src,
178 | unix.RtAttr{
179 | Len: 8,
180 | Type: unix.RTA_MARK,
181 | },
182 | device.net.fwmark,
183 | }
184 | nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
185 | reqPeerLock.Lock()
186 | reqPeer[i] = peerEndpointPtr{
187 | peer: peer,
188 | endpoint: &peer.endpoint,
189 | }
190 | reqPeerLock.Unlock()
191 | peer.RUnlock()
192 | i++
193 | _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
194 | if err != nil {
195 | break
196 | }
197 | }
198 | device.peers.RUnlock()
199 | }()
200 | }
201 | remain = remain[hdr.Len:]
202 | }
203 | }
204 | }
205 |
206 | func createNetlinkRouteSocket() (int, error) {
207 | sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
208 | if err != nil {
209 | return -1, err
210 | }
211 | saddr := &unix.SockaddrNetlink{
212 | Family: unix.AF_NETLINK,
213 | Groups: unix.RTMGRP_IPV4_ROUTE,
214 | }
215 | err = unix.Bind(sock, saddr)
216 | if err != nil {
217 | unix.Close(sock)
218 | return -1, err
219 | }
220 | return sock, nil
221 | }
222 |
--------------------------------------------------------------------------------
/conn/bind_std.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package conn
7 |
8 | import (
9 | "errors"
10 | "fmt"
11 | "net"
12 | "net/netip"
13 | "sync"
14 | "syscall"
15 | )
16 |
17 | // StdNetBind is meant to be a temporary solution on platforms for which
18 | // the sticky socket / source caching behavior has not yet been implemented.
19 | // It uses the Go's net package to implement networking.
20 | // See LinuxSocketBind for a proper implementation on the Linux platform.
21 | type StdNetBind struct {
22 | mu sync.Mutex // protects following fields
23 | ipv4 *net.UDPConn
24 | ipv6 *net.UDPConn
25 | blackhole4 bool
26 | blackhole6 bool
27 |
28 | protectSocket func(fd int) int
29 | }
30 |
31 | func NewStdNetBind(protectSocket func(fd int) int) Bind {
32 | return &StdNetBind{protectSocket: protectSocket}
33 | }
34 |
35 | type StdNetEndpoint netip.AddrPort
36 |
37 | var (
38 | _ Bind = (*StdNetBind)(nil)
39 | _ Endpoint = StdNetEndpoint{}
40 | )
41 |
42 | func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
43 | e, err := netip.ParseAddrPort(s)
44 | return asEndpoint(e), err
45 | }
46 |
47 | func (StdNetEndpoint) ClearSrc() {}
48 |
49 | func (e StdNetEndpoint) DstIP() netip.Addr {
50 | return (netip.AddrPort)(e).Addr()
51 | }
52 |
53 | func (e StdNetEndpoint) SrcIP() netip.Addr {
54 | return netip.Addr{} // not supported
55 | }
56 |
57 | func (e StdNetEndpoint) DstToBytes() []byte {
58 | b, _ := (netip.AddrPort)(e).MarshalBinary()
59 | return b
60 | }
61 |
62 | func (e StdNetEndpoint) DstToString() string {
63 | return (netip.AddrPort)(e).String()
64 | }
65 |
66 | func (e StdNetEndpoint) SrcToString() string {
67 | return ""
68 | }
69 |
70 | func listenNet(network string, port int) (*net.UDPConn, int, error) {
71 | conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
72 | if err != nil {
73 | return nil, 0, err
74 | }
75 |
76 | // Retrieve port.
77 | laddr := conn.LocalAddr()
78 | uaddr, err := net.ResolveUDPAddr(
79 | laddr.Network(),
80 | laddr.String(),
81 | )
82 | if err != nil {
83 | return nil, 0, err
84 | }
85 | return conn, uaddr.Port, nil
86 | }
87 |
88 | func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
89 | bind.mu.Lock()
90 | defer bind.mu.Unlock()
91 |
92 | var err error
93 | var tries int
94 |
95 | if bind.ipv4 != nil || bind.ipv6 != nil {
96 | return nil, 0, ErrBindAlreadyOpen
97 | }
98 |
99 | // Attempt to open ipv4 and ipv6 listeners on the same port.
100 | // If uport is 0, we can retry on failure.
101 | again:
102 | port := int(uport)
103 | var ipv4, ipv6 *net.UDPConn
104 |
105 | ipv4, port, err = listenNet("udp4", port)
106 | if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
107 | return nil, 0, err
108 | }
109 |
110 | // Listen on the same port as we're using for ipv4.
111 | ipv6, port, err = listenNet("udp6", port)
112 | if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
113 | ipv4.Close()
114 | tries++
115 | goto again
116 | }
117 | if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
118 | ipv4.Close()
119 | return nil, 0, err
120 | }
121 | var fns []ReceiveFunc
122 | if ipv4 != nil {
123 | fns = append(fns, bind.makeReceiveIPv4(ipv4))
124 | bind.ipv4 = ipv4
125 | }
126 | if ipv6 != nil {
127 | fns = append(fns, bind.makeReceiveIPv6(ipv6))
128 | bind.ipv6 = ipv6
129 | }
130 | if len(fns) == 0 {
131 | return nil, 0, syscall.EAFNOSUPPORT
132 | }
133 |
134 | err = bind.protectSockets()
135 | if err != nil {
136 | return nil, 0, err
137 | }
138 |
139 | return fns, uint16(port), nil
140 | }
141 |
142 | func (bind *StdNetBind) protectSockets() error {
143 | if bind.ipv4 != nil {
144 | fd, err := bind.PeekLookAtSocketFd4()
145 | if err != nil {
146 | return err
147 | }
148 | status := bind.protectSocket(fd)
149 | if status < 0 {
150 | return fmt.Errorf("Failed to protect socket: status=%d", status)
151 | }
152 | }
153 | if bind.ipv6 != nil {
154 | fd, err := bind.PeekLookAtSocketFd6()
155 | if err != nil {
156 | return err
157 | }
158 | status := bind.protectSocket(fd)
159 | if status < 0 {
160 | return fmt.Errorf("Failed to protect socket: status=%d", status)
161 | }
162 | }
163 | return nil
164 | }
165 |
166 | func (bind *StdNetBind) Close() error {
167 | bind.mu.Lock()
168 | defer bind.mu.Unlock()
169 |
170 | var err1, err2 error
171 | if bind.ipv4 != nil {
172 | err1 = bind.ipv4.Close()
173 | bind.ipv4 = nil
174 | }
175 | if bind.ipv6 != nil {
176 | err2 = bind.ipv6.Close()
177 | bind.ipv6 = nil
178 | }
179 | bind.blackhole4 = false
180 | bind.blackhole6 = false
181 | if err1 != nil {
182 | return err1
183 | }
184 | return err2
185 | }
186 |
187 | func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
188 | return func(buff []byte) (int, Endpoint, error) {
189 | n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
190 | return n, asEndpoint(endpoint), err
191 | }
192 | }
193 |
194 | func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
195 | return func(buff []byte) (int, Endpoint, error) {
196 | n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
197 | return n, asEndpoint(endpoint), err
198 | }
199 | }
200 |
201 | func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
202 | var err error
203 | nend, ok := endpoint.(StdNetEndpoint)
204 | if !ok {
205 | return ErrWrongEndpointType
206 | }
207 | addrPort := netip.AddrPort(nend)
208 |
209 | bind.mu.Lock()
210 | blackhole := bind.blackhole4
211 | conn := bind.ipv4
212 | if addrPort.Addr().Is6() {
213 | blackhole = bind.blackhole6
214 | conn = bind.ipv6
215 | }
216 | bind.mu.Unlock()
217 |
218 | if blackhole {
219 | return nil
220 | }
221 | if conn == nil {
222 | return syscall.EAFNOSUPPORT
223 | }
224 | _, err = conn.WriteToUDPAddrPort(buff, addrPort)
225 | return err
226 | }
227 |
228 | // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
229 | // This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
230 | // but Endpoints are immutable, so we can re-use them.
231 | var endpointPool = sync.Pool{
232 | New: func() any {
233 | return make(map[netip.AddrPort]Endpoint)
234 | },
235 | }
236 |
237 | // asEndpoint returns an Endpoint containing ap.
238 | func asEndpoint(ap netip.AddrPort) Endpoint {
239 | m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
240 | defer endpointPool.Put(m)
241 | e, ok := m[ap]
242 | if !ok {
243 | e = Endpoint(StdNetEndpoint(ap))
244 | m[ap] = e
245 | }
246 | return e
247 | }
248 |
--------------------------------------------------------------------------------
/tun/tun_windows.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package tun
7 |
8 | import (
9 | "errors"
10 | "fmt"
11 | "os"
12 | "sync"
13 | "sync/atomic"
14 | "time"
15 | _ "unsafe"
16 |
17 | "golang.org/x/sys/windows"
18 |
19 | "golang.zx2c4.com/wintun"
20 | )
21 |
22 | const (
23 | rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
24 | spinloopRateThreshold = 800000000 / 8 // 800mbps
25 | spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
26 | )
27 |
28 | type rateJuggler struct {
29 | current atomic.Uint64
30 | nextByteCount atomic.Uint64
31 | nextStartTime atomic.Int64
32 | changing atomic.Bool
33 | }
34 |
35 | type NativeTun struct {
36 | wt *wintun.Adapter
37 | name string
38 | handle windows.Handle
39 | rate rateJuggler
40 | session wintun.Session
41 | readWait windows.Handle
42 | events chan Event
43 | running sync.WaitGroup
44 | closeOnce sync.Once
45 | close atomic.Bool
46 | forcedMTU int
47 | }
48 |
49 | var (
50 | WintunTunnelType = "WireGuard"
51 | WintunStaticRequestedGUID *windows.GUID
52 | )
53 |
54 | //go:linkname procyield runtime.procyield
55 | func procyield(cycles uint32)
56 |
57 | //go:linkname nanotime runtime.nanotime
58 | func nanotime() int64
59 |
60 | // CreateTUN creates a Wintun interface with the given name. Should a Wintun
61 | // interface with the same name exist, it is reused.
62 | func CreateTUN(ifname string, mtu int) (Device, error) {
63 | return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
64 | }
65 |
66 | // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
67 | // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
68 | func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
69 | wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
70 | if err != nil {
71 | return nil, fmt.Errorf("Error creating interface: %w", err)
72 | }
73 |
74 | forcedMTU := 1420
75 | if mtu > 0 {
76 | forcedMTU = mtu
77 | }
78 |
79 | tun := &NativeTun{
80 | wt: wt,
81 | name: ifname,
82 | handle: windows.InvalidHandle,
83 | events: make(chan Event, 10),
84 | forcedMTU: forcedMTU,
85 | }
86 |
87 | tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
88 | if err != nil {
89 | tun.wt.Close()
90 | close(tun.events)
91 | return nil, fmt.Errorf("Error starting session: %w", err)
92 | }
93 | tun.readWait = tun.session.ReadWaitEvent()
94 | return tun, nil
95 | }
96 |
97 | func (tun *NativeTun) Name() (string, error) {
98 | return tun.name, nil
99 | }
100 |
101 | func (tun *NativeTun) File() *os.File {
102 | return nil
103 | }
104 |
105 | func (tun *NativeTun) Events() <-chan Event {
106 | return tun.events
107 | }
108 |
109 | func (tun *NativeTun) Close() error {
110 | var err error
111 | tun.closeOnce.Do(func() {
112 | tun.close.Store(true)
113 | windows.SetEvent(tun.readWait)
114 | tun.running.Wait()
115 | tun.session.End()
116 | if tun.wt != nil {
117 | tun.wt.Close()
118 | }
119 | close(tun.events)
120 | })
121 | return err
122 | }
123 |
124 | func (tun *NativeTun) MTU() (int, error) {
125 | return tun.forcedMTU, nil
126 | }
127 |
128 | // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
129 | func (tun *NativeTun) ForceMTU(mtu int) {
130 | update := tun.forcedMTU != mtu
131 | tun.forcedMTU = mtu
132 | if update {
133 | tun.events <- EventMTUUpdate
134 | }
135 | }
136 |
137 | // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
138 |
139 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
140 | tun.running.Add(1)
141 | defer tun.running.Done()
142 | retry:
143 | if tun.close.Load() {
144 | return 0, os.ErrClosed
145 | }
146 | start := nanotime()
147 | shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
148 | for {
149 | if tun.close.Load() {
150 | return 0, os.ErrClosed
151 | }
152 | packet, err := tun.session.ReceivePacket()
153 | switch err {
154 | case nil:
155 | packetSize := len(packet)
156 | copy(buff[offset:], packet)
157 | tun.session.ReleaseReceivePacket(packet)
158 | tun.rate.update(uint64(packetSize))
159 | return packetSize, nil
160 | case windows.ERROR_NO_MORE_ITEMS:
161 | if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
162 | windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
163 | goto retry
164 | }
165 | procyield(1)
166 | continue
167 | case windows.ERROR_HANDLE_EOF:
168 | return 0, os.ErrClosed
169 | case windows.ERROR_INVALID_DATA:
170 | return 0, errors.New("Send ring corrupt")
171 | }
172 | return 0, fmt.Errorf("Read failed: %w", err)
173 | }
174 | }
175 |
176 | func (tun *NativeTun) Flush() error {
177 | return nil
178 | }
179 |
180 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
181 | tun.running.Add(1)
182 | defer tun.running.Done()
183 | if tun.close.Load() {
184 | return 0, os.ErrClosed
185 | }
186 |
187 | packetSize := len(buff) - offset
188 | tun.rate.update(uint64(packetSize))
189 |
190 | packet, err := tun.session.AllocateSendPacket(packetSize)
191 | if err == nil {
192 | copy(packet, buff[offset:])
193 | tun.session.SendPacket(packet)
194 | return packetSize, nil
195 | }
196 | switch err {
197 | case windows.ERROR_HANDLE_EOF:
198 | return 0, os.ErrClosed
199 | case windows.ERROR_BUFFER_OVERFLOW:
200 | return 0, nil // Dropping when ring is full.
201 | }
202 | return 0, fmt.Errorf("Write failed: %w", err)
203 | }
204 |
205 | // LUID returns Windows interface instance ID.
206 | func (tun *NativeTun) LUID() uint64 {
207 | tun.running.Add(1)
208 | defer tun.running.Done()
209 | if tun.close.Load() {
210 | return 0
211 | }
212 | return tun.wt.LUID()
213 | }
214 |
215 | // RunningVersion returns the running version of the Wintun driver.
216 | func (tun *NativeTun) RunningVersion() (version uint32, err error) {
217 | return wintun.RunningVersion()
218 | }
219 |
220 | func (rate *rateJuggler) update(packetLen uint64) {
221 | now := nanotime()
222 | total := rate.nextByteCount.Add(packetLen)
223 | period := uint64(now - rate.nextStartTime.Load())
224 | if period >= rateMeasurementGranularity {
225 | if !rate.changing.CompareAndSwap(false, true) {
226 | return
227 | }
228 | rate.nextStartTime.Store(now)
229 | rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
230 | rate.nextByteCount.Store(0)
231 | rate.changing.Store(false)
232 | }
233 | }
234 |
--------------------------------------------------------------------------------
/device/cookie_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "testing"
10 | )
11 |
12 | func TestCookieMAC1(t *testing.T) {
13 | // setup generator / checker
14 |
15 | var (
16 | generator CookieGenerator
17 | checker CookieChecker
18 | )
19 |
20 | sk, err := newPrivateKey()
21 | if err != nil {
22 | t.Fatal(err)
23 | }
24 | pk := sk.publicKey()
25 |
26 | generator.Init(pk)
27 | checker.Init(pk)
28 |
29 | // check mac1
30 |
31 | src := []byte{192, 168, 13, 37, 10, 10, 10}
32 |
33 | checkMAC1 := func(msg []byte) {
34 | generator.AddMacs(msg)
35 | if !checker.CheckMAC1(msg) {
36 | t.Fatal("MAC1 generation/verification failed")
37 | }
38 | if checker.CheckMAC2(msg, src) {
39 | t.Fatal("MAC2 generation/verification failed")
40 | }
41 | }
42 |
43 | checkMAC1([]byte{
44 | 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd,
45 | 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62,
46 | 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64,
47 | 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91,
48 | 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4,
49 | 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c,
50 | 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b,
51 | 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5,
52 | 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8,
53 | 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8,
54 | 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e,
55 | 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82,
56 | 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53,
57 | 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd,
58 | 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad,
59 | 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f,
60 | })
61 |
62 | checkMAC1([]byte{
63 | 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c,
64 | 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56,
65 | 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a,
66 | 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c,
67 | 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b,
68 | 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc,
69 | 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b,
70 | 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05,
71 | })
72 |
73 | checkMAC1([]byte{
74 | 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b,
75 | 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc,
76 | 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b,
77 | 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05,
78 | })
79 |
80 | // exchange cookie reply
81 |
82 | func() {
83 | msg := []byte{
84 | 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf,
85 | 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8,
86 | 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5,
87 | 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f,
88 | 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2,
89 | 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b,
90 | 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e,
91 | 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9,
92 | 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22,
93 | 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27,
94 | 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f,
95 | 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2,
96 | 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b,
97 | 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb,
98 | 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61,
99 | 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d,
100 | }
101 | generator.AddMacs(msg)
102 | reply, err := checker.CreateReply(msg, 1377, src)
103 | if err != nil {
104 | t.Fatal("Failed to create cookie reply:", err)
105 | }
106 | if !generator.ConsumeReply(reply) {
107 | t.Fatal("Failed to consume cookie reply")
108 | }
109 | }()
110 |
111 | // check mac2
112 |
113 | checkMAC2 := func(msg []byte) {
114 | generator.AddMacs(msg)
115 |
116 | if !checker.CheckMAC1(msg) {
117 | t.Fatal("MAC1 generation/verification failed")
118 | }
119 | if !checker.CheckMAC2(msg, src) {
120 | t.Fatal("MAC2 generation/verification failed")
121 | }
122 |
123 | msg[5] ^= 0x20
124 |
125 | if checker.CheckMAC1(msg) {
126 | t.Fatal("MAC1 generation/verification failed")
127 | }
128 | if checker.CheckMAC2(msg, src) {
129 | t.Fatal("MAC2 generation/verification failed")
130 | }
131 |
132 | msg[5] ^= 0x20
133 |
134 | srcBad1 := []byte{192, 168, 13, 37, 40, 1}
135 | if checker.CheckMAC2(msg, srcBad1) {
136 | t.Fatal("MAC2 generation/verification failed")
137 | }
138 |
139 | srcBad2 := []byte{192, 168, 13, 38, 40, 1}
140 | if checker.CheckMAC2(msg, srcBad2) {
141 | t.Fatal("MAC2 generation/verification failed")
142 | }
143 | }
144 |
145 | checkMAC2([]byte{
146 | 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3,
147 | 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15,
148 | 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f,
149 | 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f,
150 | 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69,
151 | 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb,
152 | 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c,
153 | 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38,
154 | })
155 |
156 | checkMAC2([]byte{
157 | 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3,
158 | 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85,
159 | 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84,
160 | 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5,
161 | 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85,
162 | 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55,
163 | 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a,
164 | 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca,
165 | 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59,
166 | 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca,
167 | 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b,
168 | 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7,
169 | 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55,
170 | 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2,
171 | 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f,
172 | 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d,
173 | 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d,
174 | 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f,
175 | 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2,
176 | 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0,
177 | 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35,
178 | 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a,
179 | 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee,
180 | 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe,
181 | 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e,
182 | 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4,
183 | 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06,
184 | 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93,
185 | 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa,
186 | 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64,
187 | 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b,
188 | 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa,
189 | })
190 | }
191 |
--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
1 | //go:build !windows
2 |
3 | /* SPDX-License-Identifier: MIT
4 | *
5 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
6 | */
7 |
8 | package main
9 |
10 | import (
11 | "fmt"
12 | "os"
13 | "os/signal"
14 | "runtime"
15 | "strconv"
16 | "syscall"
17 |
18 | "golang.zx2c4.com/wireguard/conn"
19 | "golang.zx2c4.com/wireguard/device"
20 | "golang.zx2c4.com/wireguard/ipc"
21 | "golang.zx2c4.com/wireguard/tun"
22 | )
23 |
24 | const (
25 | ExitSetupSuccess = 0
26 | ExitSetupFailed = 1
27 | )
28 |
29 | const (
30 | ENV_WG_TUN_FD = "WG_TUN_FD"
31 | ENV_WG_UAPI_FD = "WG_UAPI_FD"
32 | ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
33 | )
34 |
35 | func printUsage() {
36 | fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
37 | }
38 |
39 | func warning() {
40 | switch runtime.GOOS {
41 | case "linux", "freebsd", "openbsd":
42 | if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
43 | return
44 | }
45 | default:
46 | return
47 | }
48 |
49 | fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
50 | fmt.Fprintln(os.Stderr, "│ │")
51 | fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
52 | fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
53 | fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
54 | fmt.Fprintln(os.Stderr, "│ please visit: │")
55 | fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
56 | fmt.Fprintln(os.Stderr, "│ │")
57 | fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
58 | }
59 |
60 | func main() {
61 | if len(os.Args) == 2 && os.Args[1] == "--version" {
62 | fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld .\n", Version, runtime.GOOS, runtime.GOARCH)
63 | return
64 | }
65 |
66 | warning()
67 |
68 | var foreground bool
69 | var interfaceName string
70 | if len(os.Args) < 2 || len(os.Args) > 3 {
71 | printUsage()
72 | return
73 | }
74 |
75 | switch os.Args[1] {
76 |
77 | case "-f", "--foreground":
78 | foreground = true
79 | if len(os.Args) != 3 {
80 | printUsage()
81 | return
82 | }
83 | interfaceName = os.Args[2]
84 |
85 | default:
86 | foreground = false
87 | if len(os.Args) != 2 {
88 | printUsage()
89 | return
90 | }
91 | interfaceName = os.Args[1]
92 | }
93 |
94 | if !foreground {
95 | foreground = os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1"
96 | }
97 |
98 | // get log level (default: info)
99 |
100 | logLevel := func() int {
101 | switch os.Getenv("LOG_LEVEL") {
102 | case "verbose", "debug":
103 | return device.LogLevelVerbose
104 | case "error":
105 | return device.LogLevelError
106 | case "silent":
107 | return device.LogLevelSilent
108 | }
109 | return device.LogLevelError
110 | }()
111 |
112 | // open TUN device (or use supplied fd)
113 |
114 | tun, err := func() (tun.Device, error) {
115 | tunFdStr := os.Getenv(ENV_WG_TUN_FD)
116 | if tunFdStr == "" {
117 | return tun.CreateTUN(interfaceName, device.DefaultMTU)
118 | }
119 |
120 | // construct tun device from supplied fd
121 |
122 | fd, err := strconv.ParseUint(tunFdStr, 10, 32)
123 | if err != nil {
124 | return nil, err
125 | }
126 |
127 | err = syscall.SetNonblock(int(fd), true)
128 | if err != nil {
129 | return nil, err
130 | }
131 |
132 | file := os.NewFile(uintptr(fd), "")
133 | return tun.CreateTUNFromFile(file, device.DefaultMTU)
134 | }()
135 |
136 | if err == nil {
137 | realInterfaceName, err2 := tun.Name()
138 | if err2 == nil {
139 | interfaceName = realInterfaceName
140 | }
141 | }
142 |
143 | logger := device.NewLogger(
144 | logLevel,
145 | fmt.Sprintf("(%s) ", interfaceName),
146 | )
147 |
148 | logger.Verbosef("Starting wireguard-go version %s", Version)
149 |
150 | if err != nil {
151 | logger.Errorf("Failed to create TUN device: %v", err)
152 | os.Exit(ExitSetupFailed)
153 | }
154 |
155 | // open UAPI file (or use supplied fd)
156 |
157 | fileUAPI, err := func() (*os.File, error) {
158 | uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
159 | if uapiFdStr == "" {
160 | return ipc.UAPIOpen(interfaceName)
161 | }
162 |
163 | // use supplied fd
164 |
165 | fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
166 | if err != nil {
167 | return nil, err
168 | }
169 |
170 | return os.NewFile(uintptr(fd), ""), nil
171 | }()
172 | if err != nil {
173 | logger.Errorf("UAPI listen error: %v", err)
174 | os.Exit(ExitSetupFailed)
175 | return
176 | }
177 | // daemonize the process
178 |
179 | if !foreground {
180 | env := os.Environ()
181 | env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
182 | env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
183 | env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND))
184 | files := [3]*os.File{}
185 | if os.Getenv("LOG_LEVEL") != "" && logLevel != device.LogLevelSilent {
186 | files[0], _ = os.Open(os.DevNull)
187 | files[1] = os.Stdout
188 | files[2] = os.Stderr
189 | } else {
190 | files[0], _ = os.Open(os.DevNull)
191 | files[1], _ = os.Open(os.DevNull)
192 | files[2], _ = os.Open(os.DevNull)
193 | }
194 | attr := &os.ProcAttr{
195 | Files: []*os.File{
196 | files[0], // stdin
197 | files[1], // stdout
198 | files[2], // stderr
199 | tun.File(),
200 | fileUAPI,
201 | },
202 | Dir: ".",
203 | Env: env,
204 | }
205 |
206 | path, err := os.Executable()
207 | if err != nil {
208 | logger.Errorf("Failed to determine executable: %v", err)
209 | os.Exit(ExitSetupFailed)
210 | }
211 |
212 | process, err := os.StartProcess(
213 | path,
214 | os.Args,
215 | attr,
216 | )
217 | if err != nil {
218 | logger.Errorf("Failed to daemonize: %v", err)
219 | os.Exit(ExitSetupFailed)
220 | }
221 | process.Release()
222 | return
223 | }
224 |
225 | device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
226 |
227 | logger.Verbosef("Device started")
228 |
229 | errs := make(chan error)
230 | term := make(chan os.Signal, 1)
231 |
232 | uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)
233 | if err != nil {
234 | logger.Errorf("Failed to listen on uapi socket: %v", err)
235 | os.Exit(ExitSetupFailed)
236 | }
237 |
238 | go func() {
239 | for {
240 | conn, err := uapi.Accept()
241 | if err != nil {
242 | errs <- err
243 | return
244 | }
245 | go device.IpcHandle(conn)
246 | }
247 | }()
248 |
249 | logger.Verbosef("UAPI listener started")
250 |
251 | // wait for program to terminate
252 |
253 | signal.Notify(term, syscall.SIGTERM)
254 | signal.Notify(term, os.Interrupt)
255 |
256 | select {
257 | case <-term:
258 | case <-errs:
259 | case <-device.Wait():
260 | }
261 |
262 | // clean up
263 |
264 | uapi.Close()
265 | device.Close()
266 |
267 | logger.Verbosef("Shutting down")
268 | }
269 |
--------------------------------------------------------------------------------
/device/allowedips_test.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "math/rand"
10 | "net"
11 | "net/netip"
12 | "testing"
13 | )
14 |
15 | type testPairCommonBits struct {
16 | s1 []byte
17 | s2 []byte
18 | match uint8
19 | }
20 |
21 | func TestCommonBits(t *testing.T) {
22 | tests := []testPairCommonBits{
23 | {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
24 | {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
25 | {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
26 | {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
27 | {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
28 | }
29 |
30 | for _, p := range tests {
31 | v := commonBits(p.s1, p.s2)
32 | if v != p.match {
33 | t.Error(
34 | "For slice", p.s1, p.s2,
35 | "expected match", p.match,
36 | ",but got", v,
37 | )
38 | }
39 | }
40 | }
41 |
42 | func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
43 | var trie *trieEntry
44 | var peers []*Peer
45 | root := parentIndirection{&trie, 2}
46 |
47 | rand.Seed(1)
48 |
49 | const AddressLength = 4
50 |
51 | for n := 0; n < peerNumber; n++ {
52 | peers = append(peers, &Peer{})
53 | }
54 |
55 | for n := 0; n < addressNumber; n++ {
56 | var addr [AddressLength]byte
57 | rand.Read(addr[:])
58 | cidr := uint8(rand.Uint32() % (AddressLength * 8))
59 | index := rand.Int() % peerNumber
60 | root.insert(addr[:], cidr, peers[index])
61 | }
62 |
63 | for n := 0; n < b.N; n++ {
64 | var addr [AddressLength]byte
65 | rand.Read(addr[:])
66 | trie.lookup(addr[:])
67 | }
68 | }
69 |
70 | func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
71 | benchmarkTrie(100, 1000, net.IPv4len, b)
72 | }
73 |
74 | func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
75 | benchmarkTrie(10, 10, net.IPv4len, b)
76 | }
77 |
78 | func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
79 | benchmarkTrie(100, 1000, net.IPv6len, b)
80 | }
81 |
82 | func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
83 | benchmarkTrie(10, 10, net.IPv6len, b)
84 | }
85 |
86 | /* Test ported from kernel implementation:
87 | * selftest/allowedips.h
88 | */
89 | func TestTrieIPv4(t *testing.T) {
90 | a := &Peer{}
91 | b := &Peer{}
92 | c := &Peer{}
93 | d := &Peer{}
94 | e := &Peer{}
95 | g := &Peer{}
96 | h := &Peer{}
97 |
98 | var allowedIPs AllowedIPs
99 |
100 | insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
101 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
102 | }
103 |
104 | assertEQ := func(peer *Peer, a, b, c, d byte) {
105 | p := allowedIPs.Lookup([]byte{a, b, c, d})
106 | if p != peer {
107 | t.Error("Assert EQ failed")
108 | }
109 | }
110 |
111 | assertNEQ := func(peer *Peer, a, b, c, d byte) {
112 | p := allowedIPs.Lookup([]byte{a, b, c, d})
113 | if p == peer {
114 | t.Error("Assert NEQ failed")
115 | }
116 | }
117 |
118 | insert(a, 192, 168, 4, 0, 24)
119 | insert(b, 192, 168, 4, 4, 32)
120 | insert(c, 192, 168, 0, 0, 16)
121 | insert(d, 192, 95, 5, 64, 27)
122 | insert(c, 192, 95, 5, 65, 27)
123 | insert(e, 0, 0, 0, 0, 0)
124 | insert(g, 64, 15, 112, 0, 20)
125 | insert(h, 64, 15, 123, 211, 25)
126 | insert(a, 10, 0, 0, 0, 25)
127 | insert(b, 10, 0, 0, 128, 25)
128 | insert(a, 10, 1, 0, 0, 30)
129 | insert(b, 10, 1, 0, 4, 30)
130 | insert(c, 10, 1, 0, 8, 29)
131 | insert(d, 10, 1, 0, 16, 29)
132 |
133 | assertEQ(a, 192, 168, 4, 20)
134 | assertEQ(a, 192, 168, 4, 0)
135 | assertEQ(b, 192, 168, 4, 4)
136 | assertEQ(c, 192, 168, 200, 182)
137 | assertEQ(c, 192, 95, 5, 68)
138 | assertEQ(e, 192, 95, 5, 96)
139 | assertEQ(g, 64, 15, 116, 26)
140 | assertEQ(g, 64, 15, 127, 3)
141 |
142 | insert(a, 1, 0, 0, 0, 32)
143 | insert(a, 64, 0, 0, 0, 32)
144 | insert(a, 128, 0, 0, 0, 32)
145 | insert(a, 192, 0, 0, 0, 32)
146 | insert(a, 255, 0, 0, 0, 32)
147 |
148 | assertEQ(a, 1, 0, 0, 0)
149 | assertEQ(a, 64, 0, 0, 0)
150 | assertEQ(a, 128, 0, 0, 0)
151 | assertEQ(a, 192, 0, 0, 0)
152 | assertEQ(a, 255, 0, 0, 0)
153 |
154 | allowedIPs.RemoveByPeer(a)
155 |
156 | assertNEQ(a, 1, 0, 0, 0)
157 | assertNEQ(a, 64, 0, 0, 0)
158 | assertNEQ(a, 128, 0, 0, 0)
159 | assertNEQ(a, 192, 0, 0, 0)
160 | assertNEQ(a, 255, 0, 0, 0)
161 |
162 | allowedIPs.RemoveByPeer(a)
163 | allowedIPs.RemoveByPeer(b)
164 | allowedIPs.RemoveByPeer(c)
165 | allowedIPs.RemoveByPeer(d)
166 | allowedIPs.RemoveByPeer(e)
167 | allowedIPs.RemoveByPeer(g)
168 | allowedIPs.RemoveByPeer(h)
169 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
170 | t.Error("Expected removing all the peers to empty trie, but it did not")
171 | }
172 |
173 | insert(a, 192, 168, 0, 0, 16)
174 | insert(a, 192, 168, 0, 0, 24)
175 |
176 | allowedIPs.RemoveByPeer(a)
177 |
178 | assertNEQ(a, 192, 168, 0, 1)
179 | }
180 |
181 | /* Test ported from kernel implementation:
182 | * selftest/allowedips.h
183 | */
184 | func TestTrieIPv6(t *testing.T) {
185 | a := &Peer{}
186 | b := &Peer{}
187 | c := &Peer{}
188 | d := &Peer{}
189 | e := &Peer{}
190 | f := &Peer{}
191 | g := &Peer{}
192 | h := &Peer{}
193 |
194 | var allowedIPs AllowedIPs
195 |
196 | expand := func(a uint32) []byte {
197 | var out [4]byte
198 | out[0] = byte(a >> 24 & 0xff)
199 | out[1] = byte(a >> 16 & 0xff)
200 | out[2] = byte(a >> 8 & 0xff)
201 | out[3] = byte(a & 0xff)
202 | return out[:]
203 | }
204 |
205 | insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
206 | var addr []byte
207 | addr = append(addr, expand(a)...)
208 | addr = append(addr, expand(b)...)
209 | addr = append(addr, expand(c)...)
210 | addr = append(addr, expand(d)...)
211 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
212 | }
213 |
214 | assertEQ := func(peer *Peer, a, b, c, d uint32) {
215 | var addr []byte
216 | addr = append(addr, expand(a)...)
217 | addr = append(addr, expand(b)...)
218 | addr = append(addr, expand(c)...)
219 | addr = append(addr, expand(d)...)
220 | p := allowedIPs.Lookup(addr)
221 | if p != peer {
222 | t.Error("Assert EQ failed")
223 | }
224 | }
225 |
226 | insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
227 | insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
228 | insert(e, 0, 0, 0, 0, 0)
229 | insert(f, 0, 0, 0, 0, 0)
230 | insert(g, 0x24046800, 0, 0, 0, 32)
231 | insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64)
232 | insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128)
233 | insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
234 | insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
235 |
236 | assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543)
237 | assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee)
238 | assertEQ(f, 0x26075300, 0x60006b01, 0, 0)
239 | assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006)
240 | assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678)
241 | assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678)
242 | assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678)
243 | assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678)
244 | assertEQ(h, 0x24046800, 0x40040800, 0, 0)
245 | assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
246 | assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
247 | }
248 |
--------------------------------------------------------------------------------
/tun/tun_openbsd.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package tun
7 |
8 | import (
9 | "errors"
10 | "fmt"
11 | "net"
12 | "os"
13 | "sync"
14 | "syscall"
15 | "unsafe"
16 |
17 | "golang.org/x/net/ipv6"
18 | "golang.org/x/sys/unix"
19 | )
20 |
21 | // Structure for iface mtu get/set ioctls
22 | type ifreq_mtu struct {
23 | Name [unix.IFNAMSIZ]byte
24 | MTU uint32
25 | Pad0 [12]byte
26 | }
27 |
28 | const _TUNSIFMODE = 0x8004745d
29 |
30 | type NativeTun struct {
31 | name string
32 | tunFile *os.File
33 | events chan Event
34 | errors chan error
35 | routeSocket int
36 | closeOnce sync.Once
37 | }
38 |
39 | func (tun *NativeTun) routineRouteListener(tunIfindex int) {
40 | var (
41 | statusUp bool
42 | statusMTU int
43 | )
44 |
45 | defer close(tun.events)
46 |
47 | check := func() bool {
48 | iface, err := net.InterfaceByIndex(tunIfindex)
49 | if err != nil {
50 | tun.errors <- err
51 | return true
52 | }
53 |
54 | // Up / Down event
55 | up := (iface.Flags & net.FlagUp) != 0
56 | if up != statusUp && up {
57 | tun.events <- EventUp
58 | }
59 | if up != statusUp && !up {
60 | tun.events <- EventDown
61 | }
62 | statusUp = up
63 |
64 | // MTU changes
65 | if iface.MTU != statusMTU {
66 | tun.events <- EventMTUUpdate
67 | }
68 | statusMTU = iface.MTU
69 | return false
70 | }
71 |
72 | if check() {
73 | return
74 | }
75 |
76 | data := make([]byte, os.Getpagesize())
77 | for {
78 | n, err := unix.Read(tun.routeSocket, data)
79 | if err != nil {
80 | if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
81 | continue
82 | }
83 | tun.errors <- err
84 | return
85 | }
86 |
87 | if n < 8 {
88 | continue
89 | }
90 |
91 | if data[3 /* type */] != unix.RTM_IFINFO {
92 | continue
93 | }
94 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
95 | if ifindex != tunIfindex {
96 | continue
97 | }
98 | if check() {
99 | return
100 | }
101 | }
102 | }
103 |
104 | func CreateTUN(name string, mtu int) (Device, error) {
105 | ifIndex := -1
106 | if name != "tun" {
107 | _, err := fmt.Sscanf(name, "tun%d", &ifIndex)
108 | if err != nil || ifIndex < 0 {
109 | return nil, fmt.Errorf("Interface name must be tun[0-9]*")
110 | }
111 | }
112 |
113 | var tunfile *os.File
114 | var err error
115 |
116 | if ifIndex != -1 {
117 | tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
118 | } else {
119 | for ifIndex = 0; ifIndex < 256; ifIndex++ {
120 | tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
121 | if err == nil || !errors.Is(err, syscall.EBUSY) {
122 | break
123 | }
124 | }
125 | }
126 |
127 | if err != nil {
128 | return nil, err
129 | }
130 |
131 | tun, err := CreateTUNFromFile(tunfile, mtu)
132 |
133 | if err == nil && name == "tun" {
134 | fname := os.Getenv("WG_TUN_NAME_FILE")
135 | if fname != "" {
136 | os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
137 | }
138 | }
139 |
140 | return tun, err
141 | }
142 |
143 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
144 | tun := &NativeTun{
145 | tunFile: file,
146 | events: make(chan Event, 10),
147 | errors: make(chan error, 1),
148 | }
149 |
150 | name, err := tun.Name()
151 | if err != nil {
152 | tun.tunFile.Close()
153 | return nil, err
154 | }
155 |
156 | tunIfindex, err := func() (int, error) {
157 | iface, err := net.InterfaceByName(name)
158 | if err != nil {
159 | return -1, err
160 | }
161 | return iface.Index, nil
162 | }()
163 | if err != nil {
164 | tun.tunFile.Close()
165 | return nil, err
166 | }
167 |
168 | tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
169 | if err != nil {
170 | tun.tunFile.Close()
171 | return nil, err
172 | }
173 |
174 | go tun.routineRouteListener(tunIfindex)
175 |
176 | currentMTU, err := tun.MTU()
177 | if err != nil || currentMTU != mtu {
178 | err = tun.setMTU(mtu)
179 | if err != nil {
180 | tun.Close()
181 | return nil, err
182 | }
183 | }
184 |
185 | return tun, nil
186 | }
187 |
188 | func (tun *NativeTun) Name() (string, error) {
189 | gostat, err := tun.tunFile.Stat()
190 | if err != nil {
191 | tun.name = ""
192 | return "", err
193 | }
194 | stat := gostat.Sys().(*syscall.Stat_t)
195 | tun.name = fmt.Sprintf("tun%d", stat.Rdev%256)
196 | return tun.name, nil
197 | }
198 |
199 | func (tun *NativeTun) File() *os.File {
200 | return tun.tunFile
201 | }
202 |
203 | func (tun *NativeTun) Events() <-chan Event {
204 | return tun.events
205 | }
206 |
207 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
208 | select {
209 | case err := <-tun.errors:
210 | return 0, err
211 | default:
212 | buff := buff[offset-4:]
213 | n, err := tun.tunFile.Read(buff[:])
214 | if n < 4 {
215 | return 0, err
216 | }
217 | return n - 4, err
218 | }
219 | }
220 |
221 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
222 | // reserve space for header
223 |
224 | buff = buff[offset-4:]
225 |
226 | // add packet information header
227 |
228 | buff[0] = 0x00
229 | buff[1] = 0x00
230 | buff[2] = 0x00
231 |
232 | if buff[4]>>4 == ipv6.Version {
233 | buff[3] = unix.AF_INET6
234 | } else {
235 | buff[3] = unix.AF_INET
236 | }
237 |
238 | // write
239 |
240 | return tun.tunFile.Write(buff)
241 | }
242 |
243 | func (tun *NativeTun) Flush() error {
244 | // TODO: can flushing be implemented by buffering and using sendmmsg?
245 | return nil
246 | }
247 |
248 | func (tun *NativeTun) Close() error {
249 | var err1, err2 error
250 | tun.closeOnce.Do(func() {
251 | err1 = tun.tunFile.Close()
252 | if tun.routeSocket != -1 {
253 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
254 | err2 = unix.Close(tun.routeSocket)
255 | tun.routeSocket = -1
256 | } else if tun.events != nil {
257 | close(tun.events)
258 | }
259 | })
260 | if err1 != nil {
261 | return err1
262 | }
263 | return err2
264 | }
265 |
266 | func (tun *NativeTun) setMTU(n int) error {
267 | // open datagram socket
268 |
269 | var fd int
270 |
271 | fd, err := unix.Socket(
272 | unix.AF_INET,
273 | unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
274 | 0,
275 | )
276 | if err != nil {
277 | return err
278 | }
279 |
280 | defer unix.Close(fd)
281 |
282 | // do ioctl call
283 |
284 | var ifr ifreq_mtu
285 | copy(ifr.Name[:], tun.name)
286 | ifr.MTU = uint32(n)
287 |
288 | _, _, errno := unix.Syscall(
289 | unix.SYS_IOCTL,
290 | uintptr(fd),
291 | uintptr(unix.SIOCSIFMTU),
292 | uintptr(unsafe.Pointer(&ifr)),
293 | )
294 |
295 | if errno != 0 {
296 | return fmt.Errorf("failed to set MTU on %s", tun.name)
297 | }
298 |
299 | return nil
300 | }
301 |
302 | func (tun *NativeTun) MTU() (int, error) {
303 | // open datagram socket
304 |
305 | fd, err := unix.Socket(
306 | unix.AF_INET,
307 | unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
308 | 0,
309 | )
310 | if err != nil {
311 | return 0, err
312 | }
313 |
314 | defer unix.Close(fd)
315 |
316 | // do ioctl call
317 | var ifr ifreq_mtu
318 | copy(ifr.Name[:], tun.name)
319 |
320 | _, _, errno := unix.Syscall(
321 | unix.SYS_IOCTL,
322 | uintptr(fd),
323 | uintptr(unix.SIOCGIFMTU),
324 | uintptr(unsafe.Pointer(&ifr)),
325 | )
326 | if errno != 0 {
327 | return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
328 | }
329 |
330 | return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
331 | }
332 |
--------------------------------------------------------------------------------
/tun/tun_darwin.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package tun
7 |
8 | import (
9 | "errors"
10 | "fmt"
11 | "net"
12 | "os"
13 | "sync"
14 | "syscall"
15 | "time"
16 | "unsafe"
17 |
18 | "golang.org/x/net/ipv6"
19 | "golang.org/x/sys/unix"
20 | )
21 |
22 | const utunControlName = "com.apple.net.utun_control"
23 |
24 | type NativeTun struct {
25 | name string
26 | tunFile *os.File
27 | events chan Event
28 | errors chan error
29 | routeSocket int
30 | closeOnce sync.Once
31 | }
32 |
33 | func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
34 | for i := 0; i < 20; i++ {
35 | iface, err = net.InterfaceByIndex(index)
36 | if err != nil && errors.Is(err, syscall.ENOMEM) {
37 | time.Sleep(time.Duration(i) * time.Second / 3)
38 | continue
39 | }
40 | return iface, err
41 | }
42 | return nil, err
43 | }
44 |
45 | func (tun *NativeTun) routineRouteListener(tunIfindex int) {
46 | var (
47 | statusUp bool
48 | statusMTU int
49 | )
50 |
51 | defer close(tun.events)
52 |
53 | data := make([]byte, os.Getpagesize())
54 | for {
55 | retry:
56 | n, err := unix.Read(tun.routeSocket, data)
57 | if err != nil {
58 | if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
59 | goto retry
60 | }
61 | tun.errors <- err
62 | return
63 | }
64 |
65 | if n < 14 {
66 | continue
67 | }
68 |
69 | if data[3 /* type */] != unix.RTM_IFINFO {
70 | continue
71 | }
72 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */])))
73 | if ifindex != tunIfindex {
74 | continue
75 | }
76 |
77 | iface, err := retryInterfaceByIndex(ifindex)
78 | if err != nil {
79 | tun.errors <- err
80 | return
81 | }
82 |
83 | // Up / Down event
84 | up := (iface.Flags & net.FlagUp) != 0
85 | if up != statusUp && up {
86 | tun.events <- EventUp
87 | }
88 | if up != statusUp && !up {
89 | tun.events <- EventDown
90 | }
91 | statusUp = up
92 |
93 | // MTU changes
94 | if iface.MTU != statusMTU {
95 | tun.events <- EventMTUUpdate
96 | }
97 | statusMTU = iface.MTU
98 | }
99 | }
100 |
101 | func CreateTUN(name string, mtu int) (Device, error) {
102 | ifIndex := -1
103 | if name != "utun" {
104 | _, err := fmt.Sscanf(name, "utun%d", &ifIndex)
105 | if err != nil || ifIndex < 0 {
106 | return nil, fmt.Errorf("Interface name must be utun[0-9]*")
107 | }
108 | }
109 |
110 | fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
111 | if err != nil {
112 | return nil, err
113 | }
114 |
115 | ctlInfo := &unix.CtlInfo{}
116 | copy(ctlInfo.Name[:], []byte(utunControlName))
117 | err = unix.IoctlCtlInfo(fd, ctlInfo)
118 | if err != nil {
119 | unix.Close(fd)
120 | return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
121 | }
122 |
123 | sc := &unix.SockaddrCtl{
124 | ID: ctlInfo.Id,
125 | Unit: uint32(ifIndex) + 1,
126 | }
127 |
128 | err = unix.Connect(fd, sc)
129 | if err != nil {
130 | unix.Close(fd)
131 | return nil, err
132 | }
133 |
134 | err = unix.SetNonblock(fd, true)
135 | if err != nil {
136 | unix.Close(fd)
137 | return nil, err
138 | }
139 | tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu)
140 |
141 | if err == nil && name == "utun" {
142 | fname := os.Getenv("WG_TUN_NAME_FILE")
143 | if fname != "" {
144 | os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
145 | }
146 | }
147 |
148 | return tun, err
149 | }
150 |
151 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
152 | tun := &NativeTun{
153 | tunFile: file,
154 | events: make(chan Event, 10),
155 | errors: make(chan error, 5),
156 | }
157 |
158 | name, err := tun.Name()
159 | if err != nil {
160 | tun.tunFile.Close()
161 | return nil, err
162 | }
163 |
164 | tunIfindex, err := func() (int, error) {
165 | iface, err := net.InterfaceByName(name)
166 | if err != nil {
167 | return -1, err
168 | }
169 | return iface.Index, nil
170 | }()
171 | if err != nil {
172 | tun.tunFile.Close()
173 | return nil, err
174 | }
175 |
176 | tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
177 | if err != nil {
178 | tun.tunFile.Close()
179 | return nil, err
180 | }
181 |
182 | go tun.routineRouteListener(tunIfindex)
183 |
184 | if mtu > 0 {
185 | err = tun.setMTU(mtu)
186 | if err != nil {
187 | tun.Close()
188 | return nil, err
189 | }
190 | }
191 |
192 | return tun, nil
193 | }
194 |
195 | func (tun *NativeTun) Name() (string, error) {
196 | var err error
197 | tun.operateOnFd(func(fd uintptr) {
198 | tun.name, err = unix.GetsockoptString(
199 | int(fd),
200 | 2, /* #define SYSPROTO_CONTROL 2 */
201 | 2, /* #define UTUN_OPT_IFNAME 2 */
202 | )
203 | })
204 |
205 | if err != nil {
206 | return "", fmt.Errorf("GetSockoptString: %w", err)
207 | }
208 |
209 | return tun.name, nil
210 | }
211 |
212 | func (tun *NativeTun) File() *os.File {
213 | return tun.tunFile
214 | }
215 |
216 | func (tun *NativeTun) Events() <-chan Event {
217 | return tun.events
218 | }
219 |
220 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
221 | select {
222 | case err := <-tun.errors:
223 | return 0, err
224 | default:
225 | buff := buff[offset-4:]
226 | n, err := tun.tunFile.Read(buff[:])
227 | if n < 4 {
228 | return 0, err
229 | }
230 | return n - 4, err
231 | }
232 | }
233 |
234 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
235 | // reserve space for header
236 |
237 | buff = buff[offset-4:]
238 |
239 | // add packet information header
240 |
241 | buff[0] = 0x00
242 | buff[1] = 0x00
243 | buff[2] = 0x00
244 |
245 | if buff[4]>>4 == ipv6.Version {
246 | buff[3] = unix.AF_INET6
247 | } else {
248 | buff[3] = unix.AF_INET
249 | }
250 |
251 | // write
252 |
253 | return tun.tunFile.Write(buff)
254 | }
255 |
256 | func (tun *NativeTun) Flush() error {
257 | // TODO: can flushing be implemented by buffering and using sendmmsg?
258 | return nil
259 | }
260 |
261 | func (tun *NativeTun) Close() error {
262 | var err1, err2 error
263 | tun.closeOnce.Do(func() {
264 | err1 = tun.tunFile.Close()
265 | if tun.routeSocket != -1 {
266 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
267 | err2 = unix.Close(tun.routeSocket)
268 | } else if tun.events != nil {
269 | close(tun.events)
270 | }
271 | })
272 | if err1 != nil {
273 | return err1
274 | }
275 | return err2
276 | }
277 |
278 | func (tun *NativeTun) setMTU(n int) error {
279 | fd, err := socketCloexec(
280 | unix.AF_INET,
281 | unix.SOCK_DGRAM,
282 | 0,
283 | )
284 | if err != nil {
285 | return err
286 | }
287 |
288 | defer unix.Close(fd)
289 |
290 | var ifr unix.IfreqMTU
291 | copy(ifr.Name[:], tun.name)
292 | ifr.MTU = int32(n)
293 | err = unix.IoctlSetIfreqMTU(fd, &ifr)
294 | if err != nil {
295 | return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err)
296 | }
297 |
298 | return nil
299 | }
300 |
301 | func (tun *NativeTun) MTU() (int, error) {
302 | fd, err := socketCloexec(
303 | unix.AF_INET,
304 | unix.SOCK_DGRAM,
305 | 0,
306 | )
307 | if err != nil {
308 | return 0, err
309 | }
310 |
311 | defer unix.Close(fd)
312 |
313 | ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name)
314 | if err != nil {
315 | return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err)
316 | }
317 |
318 | return int(ifr.MTU), nil
319 | }
320 |
321 | func socketCloexec(family, sotype, proto int) (fd int, err error) {
322 | // See go/src/net/sys_cloexec.go for background.
323 | syscall.ForkLock.RLock()
324 | defer syscall.ForkLock.RUnlock()
325 |
326 | fd, err = unix.Socket(family, sotype, proto)
327 | if err == nil {
328 | unix.CloseOnExec(fd)
329 | }
330 | return
331 | }
332 |
--------------------------------------------------------------------------------
/ipc/namedpipe/file.go:
--------------------------------------------------------------------------------
1 | // Copyright 2021 The Go Authors. All rights reserved.
2 | // Copyright 2015 Microsoft
3 | // Use of this source code is governed by a BSD-style
4 | // license that can be found in the LICENSE file.
5 |
6 | //go:build windows
7 | // +build windows
8 |
9 | package namedpipe
10 |
11 | import (
12 | "io"
13 | "os"
14 | "runtime"
15 | "sync"
16 | "sync/atomic"
17 | "time"
18 | "unsafe"
19 |
20 | "golang.org/x/sys/windows"
21 | )
22 |
23 | type timeoutChan chan struct{}
24 |
25 | var (
26 | ioInitOnce sync.Once
27 | ioCompletionPort windows.Handle
28 | )
29 |
30 | // ioResult contains the result of an asynchronous IO operation
31 | type ioResult struct {
32 | bytes uint32
33 | err error
34 | }
35 |
36 | // ioOperation represents an outstanding asynchronous Win32 IO
37 | type ioOperation struct {
38 | o windows.Overlapped
39 | ch chan ioResult
40 | }
41 |
42 | func initIo() {
43 | h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
44 | if err != nil {
45 | panic(err)
46 | }
47 | ioCompletionPort = h
48 | go ioCompletionProcessor(h)
49 | }
50 |
51 | // file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
52 | // It takes ownership of this handle and will close it if it is garbage collected.
53 | type file struct {
54 | handle windows.Handle
55 | wg sync.WaitGroup
56 | wgLock sync.RWMutex
57 | closing atomic.Bool
58 | socket bool
59 | readDeadline deadlineHandler
60 | writeDeadline deadlineHandler
61 | }
62 |
63 | type deadlineHandler struct {
64 | setLock sync.Mutex
65 | channel timeoutChan
66 | channelLock sync.RWMutex
67 | timer *time.Timer
68 | timedout atomic.Bool
69 | }
70 |
71 | // makeFile makes a new file from an existing file handle
72 | func makeFile(h windows.Handle) (*file, error) {
73 | f := &file{handle: h}
74 | ioInitOnce.Do(initIo)
75 | _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0)
76 | if err != nil {
77 | return nil, err
78 | }
79 | err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
80 | if err != nil {
81 | return nil, err
82 | }
83 | f.readDeadline.channel = make(timeoutChan)
84 | f.writeDeadline.channel = make(timeoutChan)
85 | return f, nil
86 | }
87 |
88 | // closeHandle closes the resources associated with a Win32 handle
89 | func (f *file) closeHandle() {
90 | f.wgLock.Lock()
91 | // Atomically set that we are closing, releasing the resources only once.
92 | if f.closing.Swap(true) == false {
93 | f.wgLock.Unlock()
94 | // cancel all IO and wait for it to complete
95 | windows.CancelIoEx(f.handle, nil)
96 | f.wg.Wait()
97 | // at this point, no new IO can start
98 | windows.Close(f.handle)
99 | f.handle = 0
100 | } else {
101 | f.wgLock.Unlock()
102 | }
103 | }
104 |
105 | // Close closes a file.
106 | func (f *file) Close() error {
107 | f.closeHandle()
108 | return nil
109 | }
110 |
111 | // prepareIo prepares for a new IO operation.
112 | // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
113 | func (f *file) prepareIo() (*ioOperation, error) {
114 | f.wgLock.RLock()
115 | if f.closing.Load() {
116 | f.wgLock.RUnlock()
117 | return nil, os.ErrClosed
118 | }
119 | f.wg.Add(1)
120 | f.wgLock.RUnlock()
121 | c := &ioOperation{}
122 | c.ch = make(chan ioResult)
123 | return c, nil
124 | }
125 |
126 | // ioCompletionProcessor processes completed async IOs forever
127 | func ioCompletionProcessor(h windows.Handle) {
128 | for {
129 | var bytes uint32
130 | var key uintptr
131 | var op *ioOperation
132 | err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE)
133 | if op == nil {
134 | panic(err)
135 | }
136 | op.ch <- ioResult{bytes, err}
137 | }
138 | }
139 |
140 | // asyncIo processes the return value from ReadFile or WriteFile, blocking until
141 | // the operation has actually completed.
142 | func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
143 | if err != windows.ERROR_IO_PENDING {
144 | return int(bytes), err
145 | }
146 |
147 | if f.closing.Load() {
148 | windows.CancelIoEx(f.handle, &c.o)
149 | }
150 |
151 | var timeout timeoutChan
152 | if d != nil {
153 | d.channelLock.Lock()
154 | timeout = d.channel
155 | d.channelLock.Unlock()
156 | }
157 |
158 | var r ioResult
159 | select {
160 | case r = <-c.ch:
161 | err = r.err
162 | if err == windows.ERROR_OPERATION_ABORTED {
163 | if f.closing.Load() {
164 | err = os.ErrClosed
165 | }
166 | } else if err != nil && f.socket {
167 | // err is from Win32. Query the overlapped structure to get the winsock error.
168 | var bytes, flags uint32
169 | err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
170 | }
171 | case <-timeout:
172 | windows.CancelIoEx(f.handle, &c.o)
173 | r = <-c.ch
174 | err = r.err
175 | if err == windows.ERROR_OPERATION_ABORTED {
176 | err = os.ErrDeadlineExceeded
177 | }
178 | }
179 |
180 | // runtime.KeepAlive is needed, as c is passed via native
181 | // code to ioCompletionProcessor, c must remain alive
182 | // until the channel read is complete.
183 | runtime.KeepAlive(c)
184 | return int(r.bytes), err
185 | }
186 |
187 | // Read reads from a file handle.
188 | func (f *file) Read(b []byte) (int, error) {
189 | c, err := f.prepareIo()
190 | if err != nil {
191 | return 0, err
192 | }
193 | defer f.wg.Done()
194 |
195 | if f.readDeadline.timedout.Load() {
196 | return 0, os.ErrDeadlineExceeded
197 | }
198 |
199 | var bytes uint32
200 | err = windows.ReadFile(f.handle, b, &bytes, &c.o)
201 | n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
202 | runtime.KeepAlive(b)
203 |
204 | // Handle EOF conditions.
205 | if err == nil && n == 0 && len(b) != 0 {
206 | return 0, io.EOF
207 | } else if err == windows.ERROR_BROKEN_PIPE {
208 | return 0, io.EOF
209 | } else {
210 | return n, err
211 | }
212 | }
213 |
214 | // Write writes to a file handle.
215 | func (f *file) Write(b []byte) (int, error) {
216 | c, err := f.prepareIo()
217 | if err != nil {
218 | return 0, err
219 | }
220 | defer f.wg.Done()
221 |
222 | if f.writeDeadline.timedout.Load() {
223 | return 0, os.ErrDeadlineExceeded
224 | }
225 |
226 | var bytes uint32
227 | err = windows.WriteFile(f.handle, b, &bytes, &c.o)
228 | n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
229 | runtime.KeepAlive(b)
230 | return n, err
231 | }
232 |
233 | func (f *file) SetReadDeadline(deadline time.Time) error {
234 | return f.readDeadline.set(deadline)
235 | }
236 |
237 | func (f *file) SetWriteDeadline(deadline time.Time) error {
238 | return f.writeDeadline.set(deadline)
239 | }
240 |
241 | func (f *file) Flush() error {
242 | return windows.FlushFileBuffers(f.handle)
243 | }
244 |
245 | func (f *file) Fd() uintptr {
246 | return uintptr(f.handle)
247 | }
248 |
249 | func (d *deadlineHandler) set(deadline time.Time) error {
250 | d.setLock.Lock()
251 | defer d.setLock.Unlock()
252 |
253 | if d.timer != nil {
254 | if !d.timer.Stop() {
255 | <-d.channel
256 | }
257 | d.timer = nil
258 | }
259 | d.timedout.Store(false)
260 |
261 | select {
262 | case <-d.channel:
263 | d.channelLock.Lock()
264 | d.channel = make(chan struct{})
265 | d.channelLock.Unlock()
266 | default:
267 | }
268 |
269 | if deadline.IsZero() {
270 | return nil
271 | }
272 |
273 | timeoutIO := func() {
274 | d.timedout.Store(true)
275 | close(d.channel)
276 | }
277 |
278 | now := time.Now()
279 | duration := deadline.Sub(now)
280 | if deadline.After(now) {
281 | // Deadline is in the future, set a timer to wait
282 | d.timer = time.AfterFunc(duration, timeoutIO)
283 | } else {
284 | // Deadline is in the past. Cancel all pending IO now.
285 | timeoutIO()
286 | }
287 | return nil
288 | }
289 |
--------------------------------------------------------------------------------
/device/peer.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "container/list"
10 | "errors"
11 | "sync"
12 | "sync/atomic"
13 | "time"
14 |
15 | "golang.zx2c4.com/wireguard/conn"
16 | )
17 |
18 | type Peer struct {
19 | isRunning atomic.Bool
20 | sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
21 | keypairs Keypairs
22 | handshake Handshake
23 | device *Device
24 | endpoint conn.Endpoint
25 | stopping sync.WaitGroup // routines pending stop
26 | txBytes atomic.Uint64 // bytes send to peer (endpoint)
27 | rxBytes atomic.Uint64 // bytes received from peer
28 | lastHandshakeNano atomic.Int64 // nano seconds since epoch
29 |
30 | disableRoaming bool
31 |
32 | timers struct {
33 | retransmitHandshake *Timer
34 | sendKeepalive *Timer
35 | newHandshake *Timer
36 | zeroKeyMaterial *Timer
37 | persistentKeepalive *Timer
38 | handshakeAttempts atomic.Uint32
39 | needAnotherKeepalive atomic.Bool
40 | sentLastMinuteHandshake atomic.Bool
41 | }
42 |
43 | state struct {
44 | sync.Mutex // protects against concurrent Start/Stop
45 | }
46 |
47 | queue struct {
48 | staged chan *QueueOutboundElement // staged packets before a handshake is available
49 | outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
50 | inbound *autodrainingInboundQueue // sequential ordering of tun writing
51 | }
52 |
53 | cookieGenerator CookieGenerator
54 | trieEntries list.List
55 | persistentKeepaliveInterval atomic.Uint32
56 | }
57 |
58 | func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
59 | if device.isClosed() {
60 | return nil, errors.New("device closed")
61 | }
62 |
63 | // lock resources
64 | device.staticIdentity.RLock()
65 | defer device.staticIdentity.RUnlock()
66 |
67 | device.peers.Lock()
68 | defer device.peers.Unlock()
69 |
70 | // check if over limit
71 | if len(device.peers.keyMap) >= MaxPeers {
72 | return nil, errors.New("too many peers")
73 | }
74 |
75 | // create peer
76 | peer := new(Peer)
77 | peer.Lock()
78 | defer peer.Unlock()
79 |
80 | peer.cookieGenerator.Init(pk)
81 | peer.device = device
82 | peer.queue.outbound = newAutodrainingOutboundQueue(device)
83 | peer.queue.inbound = newAutodrainingInboundQueue(device)
84 | peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
85 |
86 | // map public key
87 | _, ok := device.peers.keyMap[pk]
88 | if ok {
89 | return nil, errors.New("adding existing peer")
90 | }
91 |
92 | // pre-compute DH
93 | handshake := &peer.handshake
94 | handshake.mutex.Lock()
95 | handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
96 | handshake.remoteStatic = pk
97 | handshake.mutex.Unlock()
98 |
99 | // reset endpoint
100 | peer.endpoint = nil
101 |
102 | // init timers
103 | peer.timersInit()
104 |
105 | // add
106 | device.peers.keyMap[pk] = peer
107 |
108 | return peer, nil
109 | }
110 |
111 | func (peer *Peer) SendBuffer(buffer []byte) error {
112 | peer.device.net.RLock()
113 | defer peer.device.net.RUnlock()
114 |
115 | if peer.device.isClosed() {
116 | return nil
117 | }
118 |
119 | peer.RLock()
120 | defer peer.RUnlock()
121 |
122 | if peer.endpoint == nil {
123 | return errors.New("no known endpoint for peer")
124 | }
125 |
126 | err := peer.device.net.bind.Send(buffer, peer.endpoint)
127 | if err == nil {
128 | peer.txBytes.Add(uint64(len(buffer)))
129 | }
130 | return err
131 | }
132 |
133 | func (peer *Peer) String() string {
134 | // The awful goo that follows is identical to:
135 | //
136 | // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
137 | // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
138 | // return fmt.Sprintf("peer(%s)", abbreviatedKey)
139 | //
140 | // except that it is considerably more efficient.
141 | src := peer.handshake.remoteStatic
142 | b64 := func(input byte) byte {
143 | return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
144 | }
145 | b := []byte("peer(____…____)")
146 | const first = len("peer(")
147 | const second = len("peer(____…")
148 | b[first+0] = b64((src[0] >> 2) & 63)
149 | b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
150 | b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
151 | b[first+3] = b64(src[2] & 63)
152 | b[second+0] = b64(src[29] & 63)
153 | b[second+1] = b64((src[30] >> 2) & 63)
154 | b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
155 | b[second+3] = b64((src[31] << 2) & 63)
156 | return string(b)
157 | }
158 |
159 | func (peer *Peer) Start() {
160 | // should never start a peer on a closed device
161 | if peer.device.isClosed() {
162 | return
163 | }
164 |
165 | // prevent simultaneous start/stop operations
166 | peer.state.Lock()
167 | defer peer.state.Unlock()
168 |
169 | if peer.isRunning.Load() {
170 | return
171 | }
172 |
173 | device := peer.device
174 | device.log.Verbosef("%v - Starting", peer)
175 |
176 | // reset routine state
177 | peer.stopping.Wait()
178 | peer.stopping.Add(2)
179 |
180 | peer.handshake.mutex.Lock()
181 | peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
182 | peer.handshake.mutex.Unlock()
183 |
184 | peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
185 |
186 | peer.timersStart()
187 |
188 | device.flushInboundQueue(peer.queue.inbound)
189 | device.flushOutboundQueue(peer.queue.outbound)
190 | go peer.RoutineSequentialSender()
191 | go peer.RoutineSequentialReceiver()
192 |
193 | peer.isRunning.Store(true)
194 | }
195 |
196 | func (peer *Peer) ZeroAndFlushAll() {
197 | device := peer.device
198 |
199 | // clear key pairs
200 |
201 | keypairs := &peer.keypairs
202 | keypairs.Lock()
203 | device.DeleteKeypair(keypairs.previous)
204 | device.DeleteKeypair(keypairs.current)
205 | device.DeleteKeypair(keypairs.next.Load())
206 | keypairs.previous = nil
207 | keypairs.current = nil
208 | keypairs.next.Store(nil)
209 | keypairs.Unlock()
210 |
211 | // clear handshake state
212 |
213 | handshake := &peer.handshake
214 | handshake.mutex.Lock()
215 | device.indexTable.Delete(handshake.localIndex)
216 | handshake.Clear()
217 | handshake.mutex.Unlock()
218 |
219 | peer.FlushStagedPackets()
220 | }
221 |
222 | func (peer *Peer) ExpireCurrentKeypairs() {
223 | handshake := &peer.handshake
224 | handshake.mutex.Lock()
225 | peer.device.indexTable.Delete(handshake.localIndex)
226 | handshake.Clear()
227 | peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
228 | handshake.mutex.Unlock()
229 |
230 | keypairs := &peer.keypairs
231 | keypairs.Lock()
232 | if keypairs.current != nil {
233 | keypairs.current.sendNonce.Store(RejectAfterMessages)
234 | }
235 | if next := keypairs.next.Load(); next != nil {
236 | next.sendNonce.Store(RejectAfterMessages)
237 | }
238 | keypairs.Unlock()
239 | }
240 |
241 | func (peer *Peer) Stop() {
242 | peer.state.Lock()
243 | defer peer.state.Unlock()
244 |
245 | if !peer.isRunning.Swap(false) {
246 | return
247 | }
248 |
249 | peer.device.log.Verbosef("%v - Stopping", peer)
250 |
251 | peer.timersStop()
252 | // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
253 | peer.queue.inbound.c <- nil
254 | peer.queue.outbound.c <- nil
255 | peer.stopping.Wait()
256 | peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
257 |
258 | peer.ZeroAndFlushAll()
259 | }
260 |
261 | func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
262 | if peer.disableRoaming {
263 | return
264 | }
265 | peer.Lock()
266 | peer.endpoint = endpoint
267 | peer.Unlock()
268 | }
269 |
--------------------------------------------------------------------------------
/device/timers.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | *
5 | * This is based heavily on timers.c from the kernel implementation.
6 | */
7 |
8 | package device
9 |
10 | import (
11 | "sync"
12 | "time"
13 | _ "unsafe"
14 | )
15 |
16 | //go:linkname fastrandn runtime.fastrandn
17 | func fastrandn(n uint32) uint32
18 |
19 | // A Timer manages time-based aspects of the WireGuard protocol.
20 | // Timer roughly copies the interface of the Linux kernel's struct timer_list.
21 | type Timer struct {
22 | *time.Timer
23 | modifyingLock sync.RWMutex
24 | runningLock sync.Mutex
25 | isPending bool
26 | }
27 |
28 | func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
29 | timer := &Timer{}
30 | timer.Timer = time.AfterFunc(time.Hour, func() {
31 | timer.runningLock.Lock()
32 | defer timer.runningLock.Unlock()
33 |
34 | timer.modifyingLock.Lock()
35 | if !timer.isPending {
36 | timer.modifyingLock.Unlock()
37 | return
38 | }
39 | timer.isPending = false
40 | timer.modifyingLock.Unlock()
41 |
42 | expirationFunction(peer)
43 | })
44 | timer.Stop()
45 | return timer
46 | }
47 |
48 | func (timer *Timer) Mod(d time.Duration) {
49 | timer.modifyingLock.Lock()
50 | timer.isPending = true
51 | timer.Reset(d)
52 | timer.modifyingLock.Unlock()
53 | }
54 |
55 | func (timer *Timer) Del() {
56 | timer.modifyingLock.Lock()
57 | timer.isPending = false
58 | timer.Stop()
59 | timer.modifyingLock.Unlock()
60 | }
61 |
62 | func (timer *Timer) DelSync() {
63 | timer.Del()
64 | timer.runningLock.Lock()
65 | timer.Del()
66 | timer.runningLock.Unlock()
67 | }
68 |
69 | func (timer *Timer) IsPending() bool {
70 | timer.modifyingLock.RLock()
71 | defer timer.modifyingLock.RUnlock()
72 | return timer.isPending
73 | }
74 |
75 | func (peer *Peer) timersActive() bool {
76 | return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
77 | }
78 |
79 | func expiredRetransmitHandshake(peer *Peer) {
80 | if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
81 | peer.device.UpdateHandshakeState(HandshakeFail)
82 | peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
83 |
84 | if peer.timersActive() {
85 | peer.timers.sendKeepalive.Del()
86 | }
87 |
88 | /* We drop all packets without a keypair and don't try again,
89 | * if we try unsuccessfully for too long to make a handshake.
90 | */
91 | peer.FlushStagedPackets()
92 |
93 | /* We set a timer for destroying any residue that might be left
94 | * of a partial exchange.
95 | */
96 | if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() {
97 | peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
98 | }
99 | } else {
100 | peer.timers.handshakeAttempts.Add(1)
101 | peer.device.UpdateHandshakeState(HandshakeFail)
102 | peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
103 |
104 | /* We clear the endpoint address src address, in case this is the cause of trouble. */
105 | peer.Lock()
106 | if peer.endpoint != nil {
107 | peer.endpoint.ClearSrc()
108 | }
109 | peer.Unlock()
110 |
111 | peer.SendHandshakeInitiation(true)
112 | }
113 | }
114 |
115 | func expiredSendKeepalive(peer *Peer) {
116 | peer.SendKeepalive()
117 | if peer.timers.needAnotherKeepalive.Load() {
118 | peer.timers.needAnotherKeepalive.Store(false)
119 | if peer.timersActive() {
120 | peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
121 | }
122 | }
123 | }
124 |
125 | func expiredNewHandshake(peer *Peer) {
126 | peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
127 | /* We clear the endpoint address src address, in case this is the cause of trouble. */
128 | peer.Lock()
129 | if peer.endpoint != nil {
130 | peer.endpoint.ClearSrc()
131 | }
132 | peer.Unlock()
133 | peer.SendHandshakeInitiation(false)
134 | }
135 |
136 | func expiredZeroKeyMaterial(peer *Peer) {
137 | peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds()))
138 | peer.ZeroAndFlushAll()
139 | }
140 |
141 | func expiredPersistentKeepalive(peer *Peer) {
142 | if peer.persistentKeepaliveInterval.Load() > 0 {
143 | peer.SendKeepalive()
144 | }
145 | }
146 |
147 | /* Should be called after an authenticated data packet is sent. */
148 | func (peer *Peer) timersDataSent() {
149 | if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
150 | peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
151 | }
152 | }
153 |
154 | /* Should be called after an authenticated data packet is received. */
155 | func (peer *Peer) timersDataReceived() {
156 | if peer.timersActive() {
157 | if !peer.timers.sendKeepalive.IsPending() {
158 | peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
159 | } else {
160 | peer.timers.needAnotherKeepalive.Store(true)
161 | }
162 | }
163 | }
164 |
165 | /* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */
166 | func (peer *Peer) timersAnyAuthenticatedPacketSent() {
167 | if peer.timersActive() {
168 | peer.timers.sendKeepalive.Del()
169 | }
170 | }
171 |
172 | /* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */
173 | func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
174 | if peer.timersActive() {
175 | peer.timers.newHandshake.Del()
176 | }
177 | }
178 |
179 | /* Should be called after a handshake initiation message is sent. */
180 | func (peer *Peer) timersHandshakeInitiated() {
181 | if peer.timersActive() {
182 | peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
183 | }
184 | }
185 |
186 | /* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
187 | func (peer *Peer) timersHandshakeComplete() {
188 | if peer.timersActive() {
189 | peer.timers.retransmitHandshake.Del()
190 | }
191 | peer.timers.handshakeAttempts.Store(0)
192 | peer.timers.sentLastMinuteHandshake.Store(false)
193 | peer.lastHandshakeNano.Store(time.Now().UnixNano())
194 | }
195 |
196 | /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
197 | func (peer *Peer) timersSessionDerived() {
198 | if peer.timersActive() {
199 | peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
200 | }
201 | }
202 |
203 | /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
204 | func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
205 | keepalive := peer.persistentKeepaliveInterval.Load()
206 | if keepalive > 0 && peer.timersActive() {
207 | peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
208 | }
209 | }
210 |
211 | func (peer *Peer) timersInit() {
212 | peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake)
213 | peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive)
214 | peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
215 | peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
216 | peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
217 | }
218 |
219 | func (peer *Peer) timersStart() {
220 | peer.timers.handshakeAttempts.Store(0)
221 | peer.timers.sentLastMinuteHandshake.Store(false)
222 | peer.timers.needAnotherKeepalive.Store(false)
223 | }
224 |
225 | func (peer *Peer) timersStop() {
226 | peer.timers.retransmitHandshake.DelSync()
227 | peer.timers.sendKeepalive.DelSync()
228 | peer.timers.newHandshake.DelSync()
229 | peer.timers.zeroKeyMaterial.DelSync()
230 | peer.timers.persistentKeepalive.DelSync()
231 | }
232 |
--------------------------------------------------------------------------------
/device/allowedips.go:
--------------------------------------------------------------------------------
1 | /* SPDX-License-Identifier: MIT
2 | *
3 | * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4 | */
5 |
6 | package device
7 |
8 | import (
9 | "container/list"
10 | "encoding/binary"
11 | "errors"
12 | "math/bits"
13 | "net"
14 | "net/netip"
15 | "sync"
16 | "unsafe"
17 | )
18 |
19 | type parentIndirection struct {
20 | parentBit **trieEntry
21 | parentBitType uint8
22 | }
23 |
24 | type trieEntry struct {
25 | peer *Peer
26 | child [2]*trieEntry
27 | parent parentIndirection
28 | cidr uint8
29 | bitAtByte uint8
30 | bitAtShift uint8
31 | bits []byte
32 | perPeerElem *list.Element
33 | }
34 |
35 | func commonBits(ip1, ip2 []byte) uint8 {
36 | size := len(ip1)
37 | if size == net.IPv4len {
38 | a := binary.BigEndian.Uint32(ip1)
39 | b := binary.BigEndian.Uint32(ip2)
40 | x := a ^ b
41 | return uint8(bits.LeadingZeros32(x))
42 | } else if size == net.IPv6len {
43 | a := binary.BigEndian.Uint64(ip1)
44 | b := binary.BigEndian.Uint64(ip2)
45 | x := a ^ b
46 | if x != 0 {
47 | return uint8(bits.LeadingZeros64(x))
48 | }
49 | a = binary.BigEndian.Uint64(ip1[8:])
50 | b = binary.BigEndian.Uint64(ip2[8:])
51 | x = a ^ b
52 | return 64 + uint8(bits.LeadingZeros64(x))
53 | } else {
54 | panic("Wrong size bit string")
55 | }
56 | }
57 |
58 | func (node *trieEntry) addToPeerEntries() {
59 | node.perPeerElem = node.peer.trieEntries.PushBack(node)
60 | }
61 |
62 | func (node *trieEntry) removeFromPeerEntries() {
63 | if node.perPeerElem != nil {
64 | node.peer.trieEntries.Remove(node.perPeerElem)
65 | node.perPeerElem = nil
66 | }
67 | }
68 |
69 | func (node *trieEntry) choose(ip []byte) byte {
70 | return (ip[node.bitAtByte] >> node.bitAtShift) & 1
71 | }
72 |
73 | func (node *trieEntry) maskSelf() {
74 | mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
75 | for i := 0; i < len(mask); i++ {
76 | node.bits[i] &= mask[i]
77 | }
78 | }
79 |
80 | func (node *trieEntry) zeroizePointers() {
81 | // Make the garbage collector's life slightly easier
82 | node.peer = nil
83 | node.child[0] = nil
84 | node.child[1] = nil
85 | node.parent.parentBit = nil
86 | }
87 |
88 | func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
89 | for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
90 | parent = node
91 | if parent.cidr == cidr {
92 | exact = true
93 | return
94 | }
95 | bit := node.choose(ip)
96 | node = node.child[bit]
97 | }
98 | return
99 | }
100 |
101 | func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
102 | if *trie.parentBit == nil {
103 | node := &trieEntry{
104 | peer: peer,
105 | parent: trie,
106 | bits: ip,
107 | cidr: cidr,
108 | bitAtByte: cidr / 8,
109 | bitAtShift: 7 - (cidr % 8),
110 | }
111 | node.maskSelf()
112 | node.addToPeerEntries()
113 | *trie.parentBit = node
114 | return
115 | }
116 | node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
117 | if exact {
118 | node.removeFromPeerEntries()
119 | node.peer = peer
120 | node.addToPeerEntries()
121 | return
122 | }
123 |
124 | newNode := &trieEntry{
125 | peer: peer,
126 | bits: ip,
127 | cidr: cidr,
128 | bitAtByte: cidr / 8,
129 | bitAtShift: 7 - (cidr % 8),
130 | }
131 | newNode.maskSelf()
132 | newNode.addToPeerEntries()
133 |
134 | var down *trieEntry
135 | if node == nil {
136 | down = *trie.parentBit
137 | } else {
138 | bit := node.choose(ip)
139 | down = node.child[bit]
140 | if down == nil {
141 | newNode.parent = parentIndirection{&node.child[bit], bit}
142 | node.child[bit] = newNode
143 | return
144 | }
145 | }
146 | common := commonBits(down.bits, ip)
147 | if common < cidr {
148 | cidr = common
149 | }
150 | parent := node
151 |
152 | if newNode.cidr == cidr {
153 | bit := newNode.choose(down.bits)
154 | down.parent = parentIndirection{&newNode.child[bit], bit}
155 | newNode.child[bit] = down
156 | if parent == nil {
157 | newNode.parent = trie
158 | *trie.parentBit = newNode
159 | } else {
160 | bit := parent.choose(newNode.bits)
161 | newNode.parent = parentIndirection{&parent.child[bit], bit}
162 | parent.child[bit] = newNode
163 | }
164 | return
165 | }
166 |
167 | node = &trieEntry{
168 | bits: append([]byte{}, newNode.bits...),
169 | cidr: cidr,
170 | bitAtByte: cidr / 8,
171 | bitAtShift: 7 - (cidr % 8),
172 | }
173 | node.maskSelf()
174 |
175 | bit := node.choose(down.bits)
176 | down.parent = parentIndirection{&node.child[bit], bit}
177 | node.child[bit] = down
178 | bit = node.choose(newNode.bits)
179 | newNode.parent = parentIndirection{&node.child[bit], bit}
180 | node.child[bit] = newNode
181 | if parent == nil {
182 | node.parent = trie
183 | *trie.parentBit = node
184 | } else {
185 | bit := parent.choose(node.bits)
186 | node.parent = parentIndirection{&parent.child[bit], bit}
187 | parent.child[bit] = node
188 | }
189 | }
190 |
191 | func (node *trieEntry) lookup(ip []byte) *Peer {
192 | var found *Peer
193 | size := uint8(len(ip))
194 | for node != nil && commonBits(node.bits, ip) >= node.cidr {
195 | if node.peer != nil {
196 | found = node.peer
197 | }
198 | if node.bitAtByte == size {
199 | break
200 | }
201 | bit := node.choose(ip)
202 | node = node.child[bit]
203 | }
204 | return found
205 | }
206 |
207 | type AllowedIPs struct {
208 | IPv4 *trieEntry
209 | IPv6 *trieEntry
210 | mutex sync.RWMutex
211 | }
212 |
213 | func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
214 | table.mutex.RLock()
215 | defer table.mutex.RUnlock()
216 |
217 | for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
218 | node := elem.Value.(*trieEntry)
219 | a, _ := netip.AddrFromSlice(node.bits)
220 | if !cb(netip.PrefixFrom(a, int(node.cidr))) {
221 | return
222 | }
223 | }
224 | }
225 |
226 | func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
227 | table.mutex.Lock()
228 | defer table.mutex.Unlock()
229 |
230 | var next *list.Element
231 | for elem := peer.trieEntries.Front(); elem != nil; elem = next {
232 | next = elem.Next()
233 | node := elem.Value.(*trieEntry)
234 |
235 | node.removeFromPeerEntries()
236 | node.peer = nil
237 | if node.child[0] != nil && node.child[1] != nil {
238 | continue
239 | }
240 | bit := 0
241 | if node.child[0] == nil {
242 | bit = 1
243 | }
244 | child := node.child[bit]
245 | if child != nil {
246 | child.parent = node.parent
247 | }
248 | *node.parent.parentBit = child
249 | if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
250 | node.zeroizePointers()
251 | continue
252 | }
253 | parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
254 | if parent.peer != nil {
255 | node.zeroizePointers()
256 | continue
257 | }
258 | child = parent.child[node.parent.parentBitType^1]
259 | if child != nil {
260 | child.parent = parent.parent
261 | }
262 | *parent.parent.parentBit = child
263 | node.zeroizePointers()
264 | parent.zeroizePointers()
265 | }
266 | }
267 |
268 | func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
269 | table.mutex.Lock()
270 | defer table.mutex.Unlock()
271 |
272 | if prefix.Addr().Is6() {
273 | ip := prefix.Addr().As16()
274 | parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
275 | } else if prefix.Addr().Is4() {
276 | ip := prefix.Addr().As4()
277 | parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
278 | } else {
279 | panic(errors.New("inserting unknown address type"))
280 | }
281 | }
282 |
283 | func (table *AllowedIPs) Lookup(ip []byte) *Peer {
284 | table.mutex.RLock()
285 | defer table.mutex.RUnlock()
286 | switch len(ip) {
287 | case net.IPv6len:
288 | return table.IPv6.lookup(ip)
289 | case net.IPv4len:
290 | return table.IPv4.lookup(ip)
291 | default:
292 | panic(errors.New("looking up unknown address type"))
293 | }
294 | }
295 |
--------------------------------------------------------------------------------