├── .gitignore ├── .vscode ├── launch.json └── settings.json ├── LICENSE ├── Makefile ├── README.md ├── conn ├── bind_linux.go ├── bind_std.go ├── bind_windows.go ├── bindtest │ └── bindtest.go ├── boundif_android.go ├── conn.go ├── default.go ├── mark_default.go ├── mark_unix.go └── winrio │ └── rio_windows.go ├── device ├── alignment_test.go ├── allowedips.go ├── allowedips_rand_test.go ├── allowedips_test.go ├── bind_test.go ├── channels.go ├── constants.go ├── cookie.go ├── cookie_test.go ├── device.go ├── device_test.go ├── devicestate_string.go ├── endpoint_test.go ├── indextable.go ├── ip.go ├── kdf_test.go ├── keypair.go ├── logger.go ├── misc.go ├── mobilequirks.go ├── noise-helpers.go ├── noise-protocol.go ├── noise-types.go ├── noise_test.go ├── peer.go ├── pools.go ├── pools_test.go ├── queueconstants_android.go ├── queueconstants_default.go ├── queueconstants_ios.go ├── queueconstants_windows.go ├── race_disabled_test.go ├── race_enabled_test.go ├── receive.go ├── send.go ├── sticky_default.go ├── sticky_linux.go ├── timers.go ├── tun.go └── uapi.go ├── format_test.go ├── go.mod ├── govpp_crcstring_dump.txt ├── govpp_remove_crcstring_check.patch ├── ipc ├── uapi_bsd.go ├── uapi_linux.go ├── uapi_unix.go ├── uapi_windows.go └── winpipe │ ├── file.go │ ├── winpipe.go │ └── winpipe_test.go ├── main.go ├── main_windows.go ├── ratelimiter ├── ratelimiter.go └── ratelimiter_test.go ├── replay ├── replay.go └── replay_test.go ├── rwcancel ├── rwcancel.go └── rwcancel_windows.go ├── tai64n ├── tai64n.go └── tai64n_test.go ├── tests └── netns.sh ├── tun ├── alignment_windows_test.go ├── netstack │ ├── examples │ │ ├── http_client.go │ │ └── http_server.go │ ├── go.mod │ ├── go.sum │ └── tun.go ├── operateonfd.go ├── tun.go ├── tun_darwin.go ├── tun_freebsd.go ├── tun_linux_vpp.go ├── tun_openbsd.go ├── tun_windows.go ├── tuntest │ └── tuntest.go └── wintun │ ├── dll_fromfile_windows.go │ ├── dll_fromrsrc_windows.go │ ├── dll_windows.go │ ├── memmod │ ├── memmod_windows.go │ ├── memmod_windows_32.go │ ├── memmod_windows_386.go │ ├── memmod_windows_64.go │ ├── memmod_windows_amd64.go │ ├── memmod_windows_arm.go │ ├── memmod_windows_arm64.go │ ├── syscall_windows.go │ ├── syscall_windows_32.go │ └── syscall_windows_64.go │ ├── session_windows.go │ └── wintun_windows.go └── version.go /.gitignore: -------------------------------------------------------------------------------- 1 | wireguard-go-vpp 2 | vendor/* 3 | go.sum 4 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Launch Package", 9 | "type": "go", 10 | "request": "launch", 11 | "mode": "auto", 12 | "program": "${workspaceFolder}", 13 | "args":["-f","home"], 14 | "env": {"LOG_LEVEL":"debug"}, 15 | 16 | } 17 | ] 18 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cSpell.words": [ 3 | "memif" 4 | ] 5 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy of 2 | this software and associated documentation files (the "Software"), to deal in 3 | the Software without restriction, including without limitation the rights to 4 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 5 | of the Software, and to permit persons to whom the Software is furnished to do 6 | so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all 9 | copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PREFIX ?= /usr 2 | DESTDIR ?= 3 | BINDIR ?= $(PREFIX)/bin 4 | export GO111MODULE := on 5 | 6 | all: generate-version-and-build 7 | 8 | MAKEFLAGS += --no-print-directory 9 | 10 | generate-version-and-build: 11 | @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ 12 | tag="$$(git describe --dirty 2>/dev/null)" && \ 13 | ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \ 14 | [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ 15 | echo "$$ver" > version.go && \ 16 | git update-index --assume-unchanged version.go || true 17 | @$(MAKE) wireguard-go-vpp 18 | 19 | wireguard-go-vpp: export CGO_CFLAGS ?= -I/usr/include/memif 20 | wireguard-go-vpp: $(wildcard *.go) $(wildcard */*.go) 21 | go mod vendor && \ 22 | patch -p0 -i govpp_remove_crcstring_check.patch && \ 23 | go build -v -o "$@" 24 | 25 | install: wireguard-go-vpp 26 | @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go-vpp" 27 | 28 | test: 29 | go test -v ./... 30 | 31 | clean: 32 | rm -f wireguard-go-vpp 33 | 34 | .PHONY: all clean test install generate-version-and-build 35 | -------------------------------------------------------------------------------- /conn/bind_std.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "errors" 10 | "net" 11 | "sync" 12 | "syscall" 13 | ) 14 | 15 | // StdNetBind is meant to be a temporary solution on platforms for which 16 | // the sticky socket / source caching behavior has not yet been implemented. 17 | // It uses the Go's net package to implement networking. 18 | // See LinuxSocketBind for a proper implementation on the Linux platform. 19 | type StdNetBind struct { 20 | mu sync.Mutex // protects following fields 21 | ipv4 *net.UDPConn 22 | ipv6 *net.UDPConn 23 | blackhole4 bool 24 | blackhole6 bool 25 | } 26 | 27 | func NewStdNetBind() Bind { return &StdNetBind{} } 28 | 29 | type StdNetEndpoint net.UDPAddr 30 | 31 | var _ Bind = (*StdNetBind)(nil) 32 | var _ Endpoint = (*StdNetEndpoint)(nil) 33 | 34 | func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 35 | addr, err := parseEndpoint(s) 36 | return (*StdNetEndpoint)(addr), err 37 | } 38 | 39 | func (*StdNetEndpoint) ClearSrc() {} 40 | 41 | func (e *StdNetEndpoint) DstIP() net.IP { 42 | return (*net.UDPAddr)(e).IP 43 | } 44 | 45 | func (e *StdNetEndpoint) SrcIP() net.IP { 46 | return nil // not supported 47 | } 48 | 49 | func (e *StdNetEndpoint) DstToBytes() []byte { 50 | addr := (*net.UDPAddr)(e) 51 | out := addr.IP.To4() 52 | if out == nil { 53 | out = addr.IP 54 | } 55 | out = append(out, byte(addr.Port&0xff)) 56 | out = append(out, byte((addr.Port>>8)&0xff)) 57 | return out 58 | } 59 | 60 | func (e *StdNetEndpoint) DstToString() string { 61 | return (*net.UDPAddr)(e).String() 62 | } 63 | 64 | func (e *StdNetEndpoint) SrcToString() string { 65 | return "" 66 | } 67 | 68 | func listenNet(network string, port int) (*net.UDPConn, int, error) { 69 | conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) 70 | if err != nil { 71 | return nil, 0, err 72 | } 73 | 74 | // Retrieve port. 75 | laddr := conn.LocalAddr() 76 | uaddr, err := net.ResolveUDPAddr( 77 | laddr.Network(), 78 | laddr.String(), 79 | ) 80 | if err != nil { 81 | return nil, 0, err 82 | } 83 | return conn, uaddr.Port, nil 84 | } 85 | 86 | func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 87 | bind.mu.Lock() 88 | defer bind.mu.Unlock() 89 | 90 | var err error 91 | var tries int 92 | 93 | if bind.ipv4 != nil || bind.ipv6 != nil { 94 | return nil, 0, ErrBindAlreadyOpen 95 | } 96 | 97 | // Attempt to open ipv4 and ipv6 listeners on the same port. 98 | // If uport is 0, we can retry on failure. 99 | again: 100 | port := int(uport) 101 | var ipv4, ipv6 *net.UDPConn 102 | 103 | ipv4, port, err = listenNet("udp4", port) 104 | if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 105 | return nil, 0, err 106 | } 107 | 108 | // Listen on the same port as we're using for ipv4. 109 | ipv6, port, err = listenNet("udp6", port) 110 | if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 111 | ipv4.Close() 112 | tries++ 113 | goto again 114 | } 115 | if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 116 | ipv4.Close() 117 | return nil, 0, err 118 | } 119 | var fns []ReceiveFunc 120 | if ipv4 != nil { 121 | fns = append(fns, bind.makeReceiveIPv4(ipv4)) 122 | bind.ipv4 = ipv4 123 | } 124 | if ipv6 != nil { 125 | fns = append(fns, bind.makeReceiveIPv6(ipv6)) 126 | bind.ipv6 = ipv6 127 | } 128 | if len(fns) == 0 { 129 | return nil, 0, syscall.EAFNOSUPPORT 130 | } 131 | return fns, uint16(port), nil 132 | } 133 | 134 | func (bind *StdNetBind) Close() error { 135 | bind.mu.Lock() 136 | defer bind.mu.Unlock() 137 | 138 | var err1, err2 error 139 | if bind.ipv4 != nil { 140 | err1 = bind.ipv4.Close() 141 | bind.ipv4 = nil 142 | } 143 | if bind.ipv6 != nil { 144 | err2 = bind.ipv6.Close() 145 | bind.ipv6 = nil 146 | } 147 | bind.blackhole4 = false 148 | bind.blackhole6 = false 149 | if err1 != nil { 150 | return err1 151 | } 152 | return err2 153 | } 154 | 155 | func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { 156 | return func(buff []byte) (int, Endpoint, error) { 157 | n, endpoint, err := conn.ReadFromUDP(buff) 158 | if endpoint != nil { 159 | endpoint.IP = endpoint.IP.To4() 160 | } 161 | return n, (*StdNetEndpoint)(endpoint), err 162 | } 163 | } 164 | 165 | func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { 166 | return func(buff []byte) (int, Endpoint, error) { 167 | n, endpoint, err := conn.ReadFromUDP(buff) 168 | return n, (*StdNetEndpoint)(endpoint), err 169 | } 170 | } 171 | 172 | func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { 173 | var err error 174 | nend, ok := endpoint.(*StdNetEndpoint) 175 | if !ok { 176 | return ErrWrongEndpointType 177 | } 178 | 179 | bind.mu.Lock() 180 | blackhole := bind.blackhole4 181 | conn := bind.ipv4 182 | if nend.IP.To4() == nil { 183 | blackhole = bind.blackhole6 184 | conn = bind.ipv6 185 | } 186 | bind.mu.Unlock() 187 | 188 | if blackhole { 189 | return nil 190 | } 191 | if conn == nil { 192 | return syscall.EAFNOSUPPORT 193 | } 194 | _, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend)) 195 | return err 196 | } 197 | -------------------------------------------------------------------------------- /conn/bindtest/bindtest.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package bindtest 7 | 8 | import ( 9 | "fmt" 10 | "math/rand" 11 | "net" 12 | "os" 13 | "strconv" 14 | 15 | "golang.zx2c4.com/wireguard/conn" 16 | ) 17 | 18 | type ChannelBind struct { 19 | rx4, tx4 *chan []byte 20 | rx6, tx6 *chan []byte 21 | closeSignal chan bool 22 | source4, source6 ChannelEndpoint 23 | target4, target6 ChannelEndpoint 24 | } 25 | 26 | type ChannelEndpoint uint16 27 | 28 | var _ conn.Bind = (*ChannelBind)(nil) 29 | var _ conn.Endpoint = (*ChannelEndpoint)(nil) 30 | 31 | func NewChannelBinds() [2]conn.Bind { 32 | arx4 := make(chan []byte, 8192) 33 | brx4 := make(chan []byte, 8192) 34 | arx6 := make(chan []byte, 8192) 35 | brx6 := make(chan []byte, 8192) 36 | var binds [2]ChannelBind 37 | binds[0].rx4 = &arx4 38 | binds[0].tx4 = &brx4 39 | binds[1].rx4 = &brx4 40 | binds[1].tx4 = &arx4 41 | binds[0].rx6 = &arx6 42 | binds[0].tx6 = &brx6 43 | binds[1].rx6 = &brx6 44 | binds[1].tx6 = &arx6 45 | binds[0].target4 = ChannelEndpoint(1) 46 | binds[1].target4 = ChannelEndpoint(2) 47 | binds[0].target6 = ChannelEndpoint(3) 48 | binds[1].target6 = ChannelEndpoint(4) 49 | binds[0].source4 = binds[1].target4 50 | binds[0].source6 = binds[1].target6 51 | binds[1].source4 = binds[0].target4 52 | binds[1].source6 = binds[0].target6 53 | return [2]conn.Bind{&binds[0], &binds[1]} 54 | } 55 | 56 | func (c ChannelEndpoint) ClearSrc() {} 57 | 58 | func (c ChannelEndpoint) SrcToString() string { return "" } 59 | 60 | func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } 61 | 62 | func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } 63 | 64 | func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } 65 | 66 | func (c ChannelEndpoint) SrcIP() net.IP { return nil } 67 | 68 | func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 69 | c.closeSignal = make(chan bool) 70 | fns = append(fns, c.makeReceiveFunc(*c.rx4)) 71 | fns = append(fns, c.makeReceiveFunc(*c.rx6)) 72 | if rand.Uint32()&1 == 0 { 73 | return fns, uint16(c.source4), nil 74 | } else { 75 | return fns, uint16(c.source6), nil 76 | } 77 | } 78 | 79 | func (c *ChannelBind) Close() error { 80 | if c.closeSignal != nil { 81 | select { 82 | case <-c.closeSignal: 83 | default: 84 | close(c.closeSignal) 85 | } 86 | } 87 | return nil 88 | } 89 | 90 | func (c *ChannelBind) SetMark(mark uint32) error { return nil } 91 | 92 | func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { 93 | return func(b []byte) (n int, ep conn.Endpoint, err error) { 94 | select { 95 | case <-c.closeSignal: 96 | return 0, nil, net.ErrClosed 97 | case rx := <-ch: 98 | return copy(b, rx), c.target6, nil 99 | } 100 | } 101 | } 102 | 103 | func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { 104 | select { 105 | case <-c.closeSignal: 106 | return net.ErrClosed 107 | default: 108 | bc := make([]byte, len(b)) 109 | copy(bc, b) 110 | if ep.(ChannelEndpoint) == c.target4 { 111 | *c.tx4 <- bc 112 | } else if ep.(ChannelEndpoint) == c.target6 { 113 | *c.tx6 <- bc 114 | } else { 115 | return os.ErrInvalid 116 | } 117 | } 118 | return nil 119 | } 120 | 121 | func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { 122 | _, port, err := net.SplitHostPort(s) 123 | if err != nil { 124 | return nil, err 125 | } 126 | i, err := strconv.ParseUint(port, 10, 16) 127 | if err != nil { 128 | return nil, err 129 | } 130 | return ChannelEndpoint(i), nil 131 | } 132 | -------------------------------------------------------------------------------- /conn/boundif_android.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { 9 | sysconn, err := bind.ipv4.SyscallConn() 10 | if err != nil { 11 | return -1, err 12 | } 13 | err = sysconn.Control(func(f uintptr) { 14 | fd = int(f) 15 | }) 16 | if err != nil { 17 | return -1, err 18 | } 19 | return 20 | } 21 | 22 | func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { 23 | sysconn, err := bind.ipv6.SyscallConn() 24 | if err != nil { 25 | return -1, err 26 | } 27 | err = sysconn.Control(func(f uintptr) { 28 | fd = int(f) 29 | }) 30 | if err != nil { 31 | return -1, err 32 | } 33 | return 34 | } 35 | -------------------------------------------------------------------------------- /conn/conn.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | // Package conn implements WireGuard's network connections. 7 | package conn 8 | 9 | import ( 10 | "errors" 11 | "fmt" 12 | "net" 13 | "reflect" 14 | "runtime" 15 | "strings" 16 | ) 17 | 18 | // A ReceiveFunc receives a single inbound packet from the network. 19 | // It writes the data into b. n is the length of the packet. 20 | // ep is the remote endpoint. 21 | type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error) 22 | 23 | // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. 24 | // 25 | // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, 26 | // depending on the platform-specific implementation. 27 | type Bind interface { 28 | // Open puts the Bind into a listening state on a given port and reports the actual 29 | // port that it bound to. Passing zero results in a random selection. 30 | // fns is the set of functions that will be called to receive packets. 31 | Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) 32 | 33 | // Close closes the Bind listener. 34 | // All fns returned by Open must return net.ErrClosed after a call to Close. 35 | Close() error 36 | 37 | // SetMark sets the mark for each packet sent through this Bind. 38 | // This mark is passed to the kernel as the socket option SO_MARK. 39 | SetMark(mark uint32) error 40 | 41 | // Send writes a packet b to address ep. 42 | Send(b []byte, ep Endpoint) error 43 | 44 | // ParseEndpoint creates a new endpoint from a string. 45 | ParseEndpoint(s string) (Endpoint, error) 46 | } 47 | 48 | // BindSocketToInterface is implemented by Bind objects that support being 49 | // tied to a single network interface. Used by wireguard-windows. 50 | type BindSocketToInterface interface { 51 | BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error 52 | BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error 53 | } 54 | 55 | // PeekLookAtSocketFd is implemented by Bind objects that support having their 56 | // file descriptor peeked at. Used by wireguard-android. 57 | type PeekLookAtSocketFd interface { 58 | PeekLookAtSocketFd4() (fd int, err error) 59 | PeekLookAtSocketFd6() (fd int, err error) 60 | } 61 | 62 | // An Endpoint maintains the source/destination caching for a peer. 63 | // 64 | // dst: the remote address of a peer ("endpoint" in uapi terminology) 65 | // src: the local address from which datagrams originate going to the peer 66 | type Endpoint interface { 67 | ClearSrc() // clears the source address 68 | SrcToString() string // returns the local source address (ip:port) 69 | DstToString() string // returns the destination address (ip:port) 70 | DstToBytes() []byte // used for mac2 cookie calculations 71 | DstIP() net.IP 72 | SrcIP() net.IP 73 | } 74 | 75 | var ( 76 | ErrBindAlreadyOpen = errors.New("bind is already open") 77 | ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") 78 | ) 79 | 80 | func (fn ReceiveFunc) PrettyName() string { 81 | name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() 82 | // 0. cheese/taco.beansIPv6.func12.func21218-fm 83 | name = strings.TrimSuffix(name, "-fm") 84 | // 1. cheese/taco.beansIPv6.func12.func21218 85 | if idx := strings.LastIndexByte(name, '/'); idx != -1 { 86 | name = name[idx+1:] 87 | // 2. taco.beansIPv6.func12.func21218 88 | } 89 | for { 90 | var idx int 91 | for idx = len(name) - 1; idx >= 0; idx-- { 92 | if name[idx] < '0' || name[idx] > '9' { 93 | break 94 | } 95 | } 96 | if idx == len(name)-1 { 97 | break 98 | } 99 | const dotFunc = ".func" 100 | if !strings.HasSuffix(name[:idx+1], dotFunc) { 101 | break 102 | } 103 | name = name[:idx+1-len(dotFunc)] 104 | // 3. taco.beansIPv6.func12 105 | // 4. taco.beansIPv6 106 | } 107 | if idx := strings.LastIndexByte(name, '.'); idx != -1 { 108 | name = name[idx+1:] 109 | // 5. beansIPv6 110 | } 111 | if name == "" { 112 | return fmt.Sprintf("%p", fn) 113 | } 114 | if strings.HasSuffix(name, "IPv4") { 115 | return "v4" 116 | } 117 | if strings.HasSuffix(name, "IPv6") { 118 | return "v6" 119 | } 120 | return name 121 | } 122 | 123 | func parseEndpoint(s string) (*net.UDPAddr, error) { 124 | // ensure that the host is an IP address 125 | 126 | host, _, err := net.SplitHostPort(s) 127 | if err != nil { 128 | return nil, err 129 | } 130 | if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { 131 | // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just 132 | // trying to make sure with a small sanity test that this is a real IP address and 133 | // not something that's likely to incur DNS lookups. 134 | host = host[:i] 135 | } 136 | if ip := net.ParseIP(host); ip == nil { 137 | return nil, errors.New("Failed to parse IP address: " + host) 138 | } 139 | 140 | // parse address and port 141 | 142 | addr, err := net.ResolveUDPAddr("udp", s) 143 | if err != nil { 144 | return nil, err 145 | } 146 | ip4 := addr.IP.To4() 147 | if ip4 != nil { 148 | addr.IP = ip4 149 | } 150 | return addr, err 151 | } 152 | -------------------------------------------------------------------------------- /conn/default.go: -------------------------------------------------------------------------------- 1 | // +build !linux,!windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func NewDefaultBind() Bind { return NewStdNetBind() } 11 | -------------------------------------------------------------------------------- /conn/mark_default.go: -------------------------------------------------------------------------------- 1 | // +build !linux,!openbsd,!freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func (bind *StdNetBind) SetMark(mark uint32) error { 11 | return nil 12 | } 13 | -------------------------------------------------------------------------------- /conn/mark_unix.go: -------------------------------------------------------------------------------- 1 | // +build linux openbsd freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "runtime" 12 | 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | var fwmarkIoctl int 17 | 18 | func init() { 19 | switch runtime.GOOS { 20 | case "linux", "android": 21 | fwmarkIoctl = 36 /* unix.SO_MARK */ 22 | case "freebsd": 23 | fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ 24 | case "openbsd": 25 | fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ 26 | } 27 | } 28 | 29 | func (bind *StdNetBind) SetMark(mark uint32) error { 30 | var operr error 31 | if fwmarkIoctl == 0 { 32 | return nil 33 | } 34 | if bind.ipv4 != nil { 35 | fd, err := bind.ipv4.SyscallConn() 36 | if err != nil { 37 | return err 38 | } 39 | err = fd.Control(func(fd uintptr) { 40 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) 41 | }) 42 | if err == nil { 43 | err = operr 44 | } 45 | if err != nil { 46 | return err 47 | } 48 | } 49 | if bind.ipv6 != nil { 50 | fd, err := bind.ipv6.SyscallConn() 51 | if err != nil { 52 | return err 53 | } 54 | err = fd.Control(func(fd uintptr) { 55 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) 56 | }) 57 | if err == nil { 58 | err = operr 59 | } 60 | if err != nil { 61 | return err 62 | } 63 | } 64 | return nil 65 | } 66 | -------------------------------------------------------------------------------- /device/alignment_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "reflect" 10 | "testing" 11 | "unsafe" 12 | ) 13 | 14 | func checkAlignment(t *testing.T, name string, offset uintptr) { 15 | t.Helper() 16 | if offset%8 != 0 { 17 | t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) 18 | } 19 | } 20 | 21 | // TestPeerAlignment checks that atomically-accessed fields are 22 | // aligned to 64-bit boundaries, as required by the atomic package. 23 | // 24 | // Unfortunately, violating this rule on 32-bit platforms results in a 25 | // hard segfault at runtime. 26 | func TestPeerAlignment(t *testing.T) { 27 | var p Peer 28 | 29 | typ := reflect.TypeOf(&p).Elem() 30 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 31 | for i := 0; i < typ.NumField(); i++ { 32 | field := typ.Field(i) 33 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 34 | field.Name, 35 | field.Offset, 36 | field.Type.Size(), 37 | field.Type.Align(), 38 | ) 39 | } 40 | 41 | checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats)) 42 | checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) 43 | } 44 | 45 | // TestDeviceAlignment checks that atomically-accessed fields are 46 | // aligned to 64-bit boundaries, as required by the atomic package. 47 | // 48 | // Unfortunately, violating this rule on 32-bit platforms results in a 49 | // hard segfault at runtime. 50 | func TestDeviceAlignment(t *testing.T) { 51 | var d Device 52 | 53 | typ := reflect.TypeOf(&d).Elem() 54 | t.Logf("Device type size: %d, with fields:", typ.Size()) 55 | for i := 0; i < typ.NumField(); i++ { 56 | field := typ.Field(i) 57 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 58 | field.Name, 59 | field.Offset, 60 | field.Type.Size(), 61 | field.Type.Align(), 62 | ) 63 | } 64 | checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil)) 65 | } 66 | -------------------------------------------------------------------------------- /device/allowedips.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "container/list" 10 | "errors" 11 | "math/bits" 12 | "net" 13 | "sync" 14 | "unsafe" 15 | ) 16 | 17 | type parentIndirection struct { 18 | parentBit **trieEntry 19 | parentBitType uint8 20 | } 21 | 22 | type trieEntry struct { 23 | peer *Peer 24 | child [2]*trieEntry 25 | parent parentIndirection 26 | cidr uint8 27 | bitAtByte uint8 28 | bitAtShift uint8 29 | bits net.IP 30 | perPeerElem *list.Element 31 | } 32 | 33 | func isLittleEndian() bool { 34 | one := uint32(1) 35 | return *(*byte)(unsafe.Pointer(&one)) != 0 36 | } 37 | 38 | func swapU32(i uint32) uint32 { 39 | if !isLittleEndian() { 40 | return i 41 | } 42 | 43 | return bits.ReverseBytes32(i) 44 | } 45 | 46 | func swapU64(i uint64) uint64 { 47 | if !isLittleEndian() { 48 | return i 49 | } 50 | 51 | return bits.ReverseBytes64(i) 52 | } 53 | 54 | func commonBits(ip1 net.IP, ip2 net.IP) uint8 { 55 | size := len(ip1) 56 | if size == net.IPv4len { 57 | a := (*uint32)(unsafe.Pointer(&ip1[0])) 58 | b := (*uint32)(unsafe.Pointer(&ip2[0])) 59 | x := *a ^ *b 60 | return uint8(bits.LeadingZeros32(swapU32(x))) 61 | } else if size == net.IPv6len { 62 | a := (*uint64)(unsafe.Pointer(&ip1[0])) 63 | b := (*uint64)(unsafe.Pointer(&ip2[0])) 64 | x := *a ^ *b 65 | if x != 0 { 66 | return uint8(bits.LeadingZeros64(swapU64(x))) 67 | } 68 | a = (*uint64)(unsafe.Pointer(&ip1[8])) 69 | b = (*uint64)(unsafe.Pointer(&ip2[8])) 70 | x = *a ^ *b 71 | return 64 + uint8(bits.LeadingZeros64(swapU64(x))) 72 | } else { 73 | panic("Wrong size bit string") 74 | } 75 | } 76 | 77 | func (node *trieEntry) addToPeerEntries() { 78 | node.perPeerElem = node.peer.trieEntries.PushBack(node) 79 | } 80 | 81 | func (node *trieEntry) removeFromPeerEntries() { 82 | if node.perPeerElem != nil { 83 | node.peer.trieEntries.Remove(node.perPeerElem) 84 | node.perPeerElem = nil 85 | } 86 | } 87 | 88 | func (node *trieEntry) choose(ip net.IP) byte { 89 | return (ip[node.bitAtByte] >> node.bitAtShift) & 1 90 | } 91 | 92 | func (node *trieEntry) maskSelf() { 93 | mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) 94 | for i := 0; i < len(mask); i++ { 95 | node.bits[i] &= mask[i] 96 | } 97 | } 98 | 99 | func (node *trieEntry) zeroizePointers() { 100 | // Make the garbage collector's life slightly easier 101 | node.peer = nil 102 | node.child[0] = nil 103 | node.child[1] = nil 104 | node.parent.parentBit = nil 105 | } 106 | 107 | func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { 108 | for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { 109 | parent = node 110 | if parent.cidr == cidr { 111 | exact = true 112 | return 113 | } 114 | bit := node.choose(ip) 115 | node = node.child[bit] 116 | } 117 | return 118 | } 119 | 120 | func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { 121 | if *trie.parentBit == nil { 122 | node := &trieEntry{ 123 | peer: peer, 124 | parent: trie, 125 | bits: ip, 126 | cidr: cidr, 127 | bitAtByte: cidr / 8, 128 | bitAtShift: 7 - (cidr % 8), 129 | } 130 | node.maskSelf() 131 | node.addToPeerEntries() 132 | *trie.parentBit = node 133 | return 134 | } 135 | node, exact := (*trie.parentBit).nodePlacement(ip, cidr) 136 | if exact { 137 | node.removeFromPeerEntries() 138 | node.peer = peer 139 | node.addToPeerEntries() 140 | return 141 | } 142 | 143 | newNode := &trieEntry{ 144 | peer: peer, 145 | bits: ip, 146 | cidr: cidr, 147 | bitAtByte: cidr / 8, 148 | bitAtShift: 7 - (cidr % 8), 149 | } 150 | newNode.maskSelf() 151 | newNode.addToPeerEntries() 152 | 153 | var down *trieEntry 154 | if node == nil { 155 | down = *trie.parentBit 156 | } else { 157 | bit := node.choose(ip) 158 | down = node.child[bit] 159 | if down == nil { 160 | newNode.parent = parentIndirection{&node.child[bit], bit} 161 | node.child[bit] = newNode 162 | return 163 | } 164 | } 165 | common := commonBits(down.bits, ip) 166 | if common < cidr { 167 | cidr = common 168 | } 169 | parent := node 170 | 171 | if newNode.cidr == cidr { 172 | bit := newNode.choose(down.bits) 173 | down.parent = parentIndirection{&newNode.child[bit], bit} 174 | newNode.child[bit] = down 175 | if parent == nil { 176 | newNode.parent = trie 177 | *trie.parentBit = newNode 178 | } else { 179 | bit := parent.choose(newNode.bits) 180 | newNode.parent = parentIndirection{&parent.child[bit], bit} 181 | parent.child[bit] = newNode 182 | } 183 | return 184 | } 185 | 186 | node = &trieEntry{ 187 | bits: append([]byte{}, newNode.bits...), 188 | cidr: cidr, 189 | bitAtByte: cidr / 8, 190 | bitAtShift: 7 - (cidr % 8), 191 | } 192 | node.maskSelf() 193 | 194 | bit := node.choose(down.bits) 195 | down.parent = parentIndirection{&node.child[bit], bit} 196 | node.child[bit] = down 197 | bit = node.choose(newNode.bits) 198 | newNode.parent = parentIndirection{&node.child[bit], bit} 199 | node.child[bit] = newNode 200 | if parent == nil { 201 | node.parent = trie 202 | *trie.parentBit = node 203 | } else { 204 | bit := parent.choose(node.bits) 205 | node.parent = parentIndirection{&parent.child[bit], bit} 206 | parent.child[bit] = node 207 | } 208 | } 209 | 210 | func (node *trieEntry) lookup(ip net.IP) *Peer { 211 | var found *Peer 212 | size := uint8(len(ip)) 213 | for node != nil && commonBits(node.bits, ip) >= node.cidr { 214 | if node.peer != nil { 215 | found = node.peer 216 | } 217 | if node.bitAtByte == size { 218 | break 219 | } 220 | bit := node.choose(ip) 221 | node = node.child[bit] 222 | } 223 | return found 224 | } 225 | 226 | type AllowedIPs struct { 227 | IPv4 *trieEntry 228 | IPv6 *trieEntry 229 | mutex sync.RWMutex 230 | } 231 | 232 | func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { 233 | table.mutex.RLock() 234 | defer table.mutex.RUnlock() 235 | 236 | for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { 237 | node := elem.Value.(*trieEntry) 238 | if !cb(node.bits, node.cidr) { 239 | return 240 | } 241 | } 242 | } 243 | 244 | func (table *AllowedIPs) RemoveByPeer(peer *Peer) { 245 | table.mutex.Lock() 246 | defer table.mutex.Unlock() 247 | 248 | var next *list.Element 249 | for elem := peer.trieEntries.Front(); elem != nil; elem = next { 250 | next = elem.Next() 251 | node := elem.Value.(*trieEntry) 252 | 253 | node.removeFromPeerEntries() 254 | node.peer = nil 255 | if node.child[0] != nil && node.child[1] != nil { 256 | continue 257 | } 258 | bit := 0 259 | if node.child[0] == nil { 260 | bit = 1 261 | } 262 | child := node.child[bit] 263 | if child != nil { 264 | child.parent = node.parent 265 | } 266 | *node.parent.parentBit = child 267 | if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { 268 | node.zeroizePointers() 269 | continue 270 | } 271 | parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) 272 | if parent.peer != nil { 273 | node.zeroizePointers() 274 | continue 275 | } 276 | child = parent.child[node.parent.parentBitType^1] 277 | if child != nil { 278 | child.parent = parent.parent 279 | } 280 | *parent.parent.parentBit = child 281 | node.zeroizePointers() 282 | parent.zeroizePointers() 283 | } 284 | } 285 | 286 | func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { 287 | table.mutex.Lock() 288 | defer table.mutex.Unlock() 289 | 290 | switch len(ip) { 291 | case net.IPv6len: 292 | parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer) 293 | case net.IPv4len: 294 | parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer) 295 | default: 296 | panic(errors.New("inserting unknown address type")) 297 | } 298 | } 299 | 300 | func (table *AllowedIPs) Lookup(address []byte) *Peer { 301 | table.mutex.RLock() 302 | defer table.mutex.RUnlock() 303 | switch len(address) { 304 | case net.IPv6len: 305 | return table.IPv6.lookup(address) 306 | case net.IPv4len: 307 | return table.IPv4.lookup(address) 308 | default: 309 | panic(errors.New("looking up unknown address type")) 310 | } 311 | } 312 | -------------------------------------------------------------------------------- /device/allowedips_rand_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net" 11 | "sort" 12 | "testing" 13 | ) 14 | 15 | const ( 16 | NumberOfPeers = 100 17 | NumberOfPeerRemovals = 4 18 | NumberOfAddresses = 250 19 | NumberOfTests = 10000 20 | ) 21 | 22 | type SlowNode struct { 23 | peer *Peer 24 | cidr uint8 25 | bits []byte 26 | } 27 | 28 | type SlowRouter []*SlowNode 29 | 30 | func (r SlowRouter) Len() int { 31 | return len(r) 32 | } 33 | 34 | func (r SlowRouter) Less(i, j int) bool { 35 | return r[i].cidr > r[j].cidr 36 | } 37 | 38 | func (r SlowRouter) Swap(i, j int) { 39 | r[i], r[j] = r[j], r[i] 40 | } 41 | 42 | func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { 43 | for _, t := range r { 44 | if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { 45 | t.peer = peer 46 | t.bits = addr 47 | return r 48 | } 49 | } 50 | r = append(r, &SlowNode{ 51 | cidr: cidr, 52 | bits: addr, 53 | peer: peer, 54 | }) 55 | sort.Sort(r) 56 | return r 57 | } 58 | 59 | func (r SlowRouter) Lookup(addr []byte) *Peer { 60 | for _, t := range r { 61 | common := commonBits(t.bits, addr) 62 | if common >= t.cidr { 63 | return t.peer 64 | } 65 | } 66 | return nil 67 | } 68 | 69 | func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { 70 | n := 0 71 | for _, x := range r { 72 | if x.peer != peer { 73 | r[n] = x 74 | n++ 75 | } 76 | } 77 | return r[:n] 78 | } 79 | 80 | func TestTrieRandom(t *testing.T) { 81 | var slow4, slow6 SlowRouter 82 | var peers []*Peer 83 | var allowedIPs AllowedIPs 84 | 85 | rand.Seed(1) 86 | 87 | for n := 0; n < NumberOfPeers; n++ { 88 | peers = append(peers, &Peer{}) 89 | } 90 | 91 | for n := 0; n < NumberOfAddresses; n++ { 92 | var addr4 [4]byte 93 | rand.Read(addr4[:]) 94 | cidr := uint8(rand.Intn(32) + 1) 95 | index := rand.Intn(NumberOfPeers) 96 | allowedIPs.Insert(addr4[:], cidr, peers[index]) 97 | slow4 = slow4.Insert(addr4[:], cidr, peers[index]) 98 | 99 | var addr6 [16]byte 100 | rand.Read(addr6[:]) 101 | cidr = uint8(rand.Intn(128) + 1) 102 | index = rand.Intn(NumberOfPeers) 103 | allowedIPs.Insert(addr6[:], cidr, peers[index]) 104 | slow6 = slow6.Insert(addr6[:], cidr, peers[index]) 105 | } 106 | 107 | var p int 108 | for p = 0; ; p++ { 109 | for n := 0; n < NumberOfTests; n++ { 110 | var addr4 [4]byte 111 | rand.Read(addr4[:]) 112 | peer1 := slow4.Lookup(addr4[:]) 113 | peer2 := allowedIPs.Lookup(addr4[:]) 114 | if peer1 != peer2 { 115 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) 116 | } 117 | 118 | var addr6 [16]byte 119 | rand.Read(addr6[:]) 120 | peer1 = slow6.Lookup(addr6[:]) 121 | peer2 = allowedIPs.Lookup(addr6[:]) 122 | if peer1 != peer2 { 123 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) 124 | } 125 | } 126 | if p >= len(peers) || p >= NumberOfPeerRemovals { 127 | break 128 | } 129 | allowedIPs.RemoveByPeer(peers[p]) 130 | slow4 = slow4.RemoveByPeer(peers[p]) 131 | slow6 = slow6.RemoveByPeer(peers[p]) 132 | } 133 | for ; p < len(peers); p++ { 134 | allowedIPs.RemoveByPeer(peers[p]) 135 | } 136 | 137 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 138 | t.Error("Failed to remove all nodes from trie by peer") 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /device/allowedips_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net" 11 | "testing" 12 | ) 13 | 14 | type testPairCommonBits struct { 15 | s1 []byte 16 | s2 []byte 17 | match uint8 18 | } 19 | 20 | func TestCommonBits(t *testing.T) { 21 | 22 | tests := []testPairCommonBits{ 23 | {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, 24 | {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, 25 | {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, 26 | {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, 27 | {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, 28 | } 29 | 30 | for _, p := range tests { 31 | v := commonBits(p.s1, p.s2) 32 | if v != p.match { 33 | t.Error( 34 | "For slice", p.s1, p.s2, 35 | "expected match", p.match, 36 | ",but got", v, 37 | ) 38 | } 39 | } 40 | } 41 | 42 | func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { 43 | var trie *trieEntry 44 | var peers []*Peer 45 | root := parentIndirection{&trie, 2} 46 | 47 | rand.Seed(1) 48 | 49 | const AddressLength = 4 50 | 51 | for n := 0; n < peerNumber; n++ { 52 | peers = append(peers, &Peer{}) 53 | } 54 | 55 | for n := 0; n < addressNumber; n++ { 56 | var addr [AddressLength]byte 57 | rand.Read(addr[:]) 58 | cidr := uint8(rand.Uint32() % (AddressLength * 8)) 59 | index := rand.Int() % peerNumber 60 | root.insert(addr[:], cidr, peers[index]) 61 | } 62 | 63 | for n := 0; n < b.N; n++ { 64 | var addr [AddressLength]byte 65 | rand.Read(addr[:]) 66 | trie.lookup(addr[:]) 67 | } 68 | } 69 | 70 | func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { 71 | benchmarkTrie(100, 1000, net.IPv4len, b) 72 | } 73 | 74 | func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { 75 | benchmarkTrie(10, 10, net.IPv4len, b) 76 | } 77 | 78 | func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { 79 | benchmarkTrie(100, 1000, net.IPv6len, b) 80 | } 81 | 82 | func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { 83 | benchmarkTrie(10, 10, net.IPv6len, b) 84 | } 85 | 86 | /* Test ported from kernel implementation: 87 | * selftest/allowedips.h 88 | */ 89 | func TestTrieIPv4(t *testing.T) { 90 | a := &Peer{} 91 | b := &Peer{} 92 | c := &Peer{} 93 | d := &Peer{} 94 | e := &Peer{} 95 | g := &Peer{} 96 | h := &Peer{} 97 | 98 | var allowedIPs AllowedIPs 99 | 100 | insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { 101 | allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer) 102 | } 103 | 104 | assertEQ := func(peer *Peer, a, b, c, d byte) { 105 | p := allowedIPs.Lookup([]byte{a, b, c, d}) 106 | if p != peer { 107 | t.Error("Assert EQ failed") 108 | } 109 | } 110 | 111 | assertNEQ := func(peer *Peer, a, b, c, d byte) { 112 | p := allowedIPs.Lookup([]byte{a, b, c, d}) 113 | if p == peer { 114 | t.Error("Assert NEQ failed") 115 | } 116 | } 117 | 118 | insert(a, 192, 168, 4, 0, 24) 119 | insert(b, 192, 168, 4, 4, 32) 120 | insert(c, 192, 168, 0, 0, 16) 121 | insert(d, 192, 95, 5, 64, 27) 122 | insert(c, 192, 95, 5, 65, 27) 123 | insert(e, 0, 0, 0, 0, 0) 124 | insert(g, 64, 15, 112, 0, 20) 125 | insert(h, 64, 15, 123, 211, 25) 126 | insert(a, 10, 0, 0, 0, 25) 127 | insert(b, 10, 0, 0, 128, 25) 128 | insert(a, 10, 1, 0, 0, 30) 129 | insert(b, 10, 1, 0, 4, 30) 130 | insert(c, 10, 1, 0, 8, 29) 131 | insert(d, 10, 1, 0, 16, 29) 132 | 133 | assertEQ(a, 192, 168, 4, 20) 134 | assertEQ(a, 192, 168, 4, 0) 135 | assertEQ(b, 192, 168, 4, 4) 136 | assertEQ(c, 192, 168, 200, 182) 137 | assertEQ(c, 192, 95, 5, 68) 138 | assertEQ(e, 192, 95, 5, 96) 139 | assertEQ(g, 64, 15, 116, 26) 140 | assertEQ(g, 64, 15, 127, 3) 141 | 142 | insert(a, 1, 0, 0, 0, 32) 143 | insert(a, 64, 0, 0, 0, 32) 144 | insert(a, 128, 0, 0, 0, 32) 145 | insert(a, 192, 0, 0, 0, 32) 146 | insert(a, 255, 0, 0, 0, 32) 147 | 148 | assertEQ(a, 1, 0, 0, 0) 149 | assertEQ(a, 64, 0, 0, 0) 150 | assertEQ(a, 128, 0, 0, 0) 151 | assertEQ(a, 192, 0, 0, 0) 152 | assertEQ(a, 255, 0, 0, 0) 153 | 154 | allowedIPs.RemoveByPeer(a) 155 | 156 | assertNEQ(a, 1, 0, 0, 0) 157 | assertNEQ(a, 64, 0, 0, 0) 158 | assertNEQ(a, 128, 0, 0, 0) 159 | assertNEQ(a, 192, 0, 0, 0) 160 | assertNEQ(a, 255, 0, 0, 0) 161 | 162 | allowedIPs.RemoveByPeer(a) 163 | allowedIPs.RemoveByPeer(b) 164 | allowedIPs.RemoveByPeer(c) 165 | allowedIPs.RemoveByPeer(d) 166 | allowedIPs.RemoveByPeer(e) 167 | allowedIPs.RemoveByPeer(g) 168 | allowedIPs.RemoveByPeer(h) 169 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 170 | t.Error("Expected removing all the peers to empty trie, but it did not") 171 | } 172 | 173 | insert(a, 192, 168, 0, 0, 16) 174 | insert(a, 192, 168, 0, 0, 24) 175 | 176 | allowedIPs.RemoveByPeer(a) 177 | 178 | assertNEQ(a, 192, 168, 0, 1) 179 | } 180 | 181 | /* Test ported from kernel implementation: 182 | * selftest/allowedips.h 183 | */ 184 | func TestTrieIPv6(t *testing.T) { 185 | a := &Peer{} 186 | b := &Peer{} 187 | c := &Peer{} 188 | d := &Peer{} 189 | e := &Peer{} 190 | f := &Peer{} 191 | g := &Peer{} 192 | h := &Peer{} 193 | 194 | var allowedIPs AllowedIPs 195 | 196 | expand := func(a uint32) []byte { 197 | var out [4]byte 198 | out[0] = byte(a >> 24 & 0xff) 199 | out[1] = byte(a >> 16 & 0xff) 200 | out[2] = byte(a >> 8 & 0xff) 201 | out[3] = byte(a & 0xff) 202 | return out[:] 203 | } 204 | 205 | insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { 206 | var addr []byte 207 | addr = append(addr, expand(a)...) 208 | addr = append(addr, expand(b)...) 209 | addr = append(addr, expand(c)...) 210 | addr = append(addr, expand(d)...) 211 | allowedIPs.Insert(addr, cidr, peer) 212 | } 213 | 214 | assertEQ := func(peer *Peer, a, b, c, d uint32) { 215 | var addr []byte 216 | addr = append(addr, expand(a)...) 217 | addr = append(addr, expand(b)...) 218 | addr = append(addr, expand(c)...) 219 | addr = append(addr, expand(d)...) 220 | p := allowedIPs.Lookup(addr) 221 | if p != peer { 222 | t.Error("Assert EQ failed") 223 | } 224 | } 225 | 226 | insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) 227 | insert(c, 0x26075300, 0x60006b00, 0, 0, 64) 228 | insert(e, 0, 0, 0, 0, 0) 229 | insert(f, 0, 0, 0, 0, 0) 230 | insert(g, 0x24046800, 0, 0, 0, 32) 231 | insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) 232 | insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) 233 | insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) 234 | insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) 235 | 236 | assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) 237 | assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) 238 | assertEQ(f, 0x26075300, 0x60006b01, 0, 0) 239 | assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) 240 | assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) 241 | assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) 242 | assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) 243 | assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) 244 | assertEQ(h, 0x24046800, 0x40040800, 0, 0) 245 | assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) 246 | assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) 247 | } 248 | -------------------------------------------------------------------------------- /device/bind_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "errors" 10 | 11 | "golang.zx2c4.com/wireguard/conn" 12 | ) 13 | 14 | type DummyDatagram struct { 15 | msg []byte 16 | endpoint conn.Endpoint 17 | } 18 | 19 | type DummyBind struct { 20 | in6 chan DummyDatagram 21 | in4 chan DummyDatagram 22 | closed bool 23 | } 24 | 25 | func (b *DummyBind) SetMark(v uint32) error { 26 | return nil 27 | } 28 | 29 | func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) { 30 | datagram, ok := <-b.in6 31 | if !ok { 32 | return 0, nil, errors.New("closed") 33 | } 34 | copy(buff, datagram.msg) 35 | return len(datagram.msg), datagram.endpoint, nil 36 | } 37 | 38 | func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) { 39 | datagram, ok := <-b.in4 40 | if !ok { 41 | return 0, nil, errors.New("closed") 42 | } 43 | copy(buff, datagram.msg) 44 | return len(datagram.msg), datagram.endpoint, nil 45 | } 46 | 47 | func (b *DummyBind) Close() error { 48 | close(b.in6) 49 | close(b.in4) 50 | b.closed = true 51 | return nil 52 | } 53 | 54 | func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error { 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /device/channels.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "runtime" 10 | "sync" 11 | ) 12 | 13 | // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. 14 | // An outboundQueue is ref-counted using its wg field. 15 | // An outboundQueue created with newOutboundQueue has one reference. 16 | // Every additional writer must call wg.Add(1). 17 | // Every completed writer must call wg.Done(). 18 | // When no further writers will be added, 19 | // call wg.Done to remove the initial reference. 20 | // When the refcount hits 0, the queue's channel is closed. 21 | type outboundQueue struct { 22 | c chan *QueueOutboundElement 23 | wg sync.WaitGroup 24 | } 25 | 26 | func newOutboundQueue() *outboundQueue { 27 | q := &outboundQueue{ 28 | c: make(chan *QueueOutboundElement, QueueOutboundSize), 29 | } 30 | q.wg.Add(1) 31 | go func() { 32 | q.wg.Wait() 33 | close(q.c) 34 | }() 35 | return q 36 | } 37 | 38 | // A inboundQueue is similar to an outboundQueue; see those docs. 39 | type inboundQueue struct { 40 | c chan *QueueInboundElement 41 | wg sync.WaitGroup 42 | } 43 | 44 | func newInboundQueue() *inboundQueue { 45 | q := &inboundQueue{ 46 | c: make(chan *QueueInboundElement, QueueInboundSize), 47 | } 48 | q.wg.Add(1) 49 | go func() { 50 | q.wg.Wait() 51 | close(q.c) 52 | }() 53 | return q 54 | } 55 | 56 | // A handshakeQueue is similar to an outboundQueue; see those docs. 57 | type handshakeQueue struct { 58 | c chan QueueHandshakeElement 59 | wg sync.WaitGroup 60 | } 61 | 62 | func newHandshakeQueue() *handshakeQueue { 63 | q := &handshakeQueue{ 64 | c: make(chan QueueHandshakeElement, QueueHandshakeSize), 65 | } 66 | q.wg.Add(1) 67 | go func() { 68 | q.wg.Wait() 69 | close(q.c) 70 | }() 71 | return q 72 | } 73 | 74 | type autodrainingInboundQueue struct { 75 | c chan *QueueInboundElement 76 | } 77 | 78 | // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. 79 | // It is useful in cases in which is it hard to manage the lifetime of the channel. 80 | // The returned channel must not be closed. Senders should signal shutdown using 81 | // some other means, such as sending a sentinel nil values. 82 | func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { 83 | q := &autodrainingInboundQueue{ 84 | c: make(chan *QueueInboundElement, QueueInboundSize), 85 | } 86 | runtime.SetFinalizer(q, device.flushInboundQueue) 87 | return q 88 | } 89 | 90 | func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { 91 | for { 92 | select { 93 | case elem := <-q.c: 94 | elem.Lock() 95 | device.PutMessageBuffer(elem.buffer) 96 | device.PutInboundElement(elem) 97 | default: 98 | return 99 | } 100 | } 101 | } 102 | 103 | type autodrainingOutboundQueue struct { 104 | c chan *QueueOutboundElement 105 | } 106 | 107 | // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. 108 | // It is useful in cases in which is it hard to manage the lifetime of the channel. 109 | // The returned channel must not be closed. Senders should signal shutdown using 110 | // some other means, such as sending a sentinel nil values. 111 | // All sends to the channel must be best-effort, because there may be no receivers. 112 | func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { 113 | q := &autodrainingOutboundQueue{ 114 | c: make(chan *QueueOutboundElement, QueueOutboundSize), 115 | } 116 | runtime.SetFinalizer(q, device.flushOutboundQueue) 117 | return q 118 | } 119 | 120 | func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { 121 | for { 122 | select { 123 | case elem := <-q.c: 124 | elem.Lock() 125 | device.PutMessageBuffer(elem.buffer) 126 | device.PutOutboundElement(elem) 127 | default: 128 | return 129 | } 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /device/constants.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "time" 10 | ) 11 | 12 | /* Specification constants */ 13 | 14 | const ( 15 | RekeyAfterMessages = (1 << 60) 16 | RejectAfterMessages = (1 << 64) - (1 << 13) - 1 17 | RekeyAfterTime = time.Second * 120 18 | RekeyAttemptTime = time.Second * 90 19 | RekeyTimeout = time.Second * 5 20 | MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */ 21 | RekeyTimeoutJitterMaxMs = 334 22 | RejectAfterTime = time.Second * 180 23 | KeepaliveTimeout = time.Second * 10 24 | CookieRefreshTime = time.Second * 120 25 | HandshakeInitationRate = time.Second / 50 26 | PaddingMultiple = 16 27 | ) 28 | 29 | const ( 30 | MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) 31 | MaxMessageSize = MaxSegmentSize // maximum size of transport message 32 | MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content 33 | ) 34 | 35 | /* Implementation constants */ 36 | 37 | const ( 38 | UnderLoadAfterTime = time.Second // how long does the device remain under load after detected 39 | MaxPeers = 1 << 16 // maximum number of configured peers 40 | ) 41 | -------------------------------------------------------------------------------- /device/cookie.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "sync" 12 | "time" 13 | 14 | "golang.org/x/crypto/blake2s" 15 | "golang.org/x/crypto/chacha20poly1305" 16 | ) 17 | 18 | type CookieChecker struct { 19 | sync.RWMutex 20 | mac1 struct { 21 | key [blake2s.Size]byte 22 | } 23 | mac2 struct { 24 | secret [blake2s.Size]byte 25 | secretSet time.Time 26 | encryptionKey [chacha20poly1305.KeySize]byte 27 | } 28 | } 29 | 30 | type CookieGenerator struct { 31 | sync.RWMutex 32 | mac1 struct { 33 | key [blake2s.Size]byte 34 | } 35 | mac2 struct { 36 | cookie [blake2s.Size128]byte 37 | cookieSet time.Time 38 | hasLastMAC1 bool 39 | lastMAC1 [blake2s.Size128]byte 40 | encryptionKey [chacha20poly1305.KeySize]byte 41 | } 42 | } 43 | 44 | func (st *CookieChecker) Init(pk NoisePublicKey) { 45 | st.Lock() 46 | defer st.Unlock() 47 | 48 | // mac1 state 49 | 50 | func() { 51 | hash, _ := blake2s.New256(nil) 52 | hash.Write([]byte(WGLabelMAC1)) 53 | hash.Write(pk[:]) 54 | hash.Sum(st.mac1.key[:0]) 55 | }() 56 | 57 | // mac2 state 58 | 59 | func() { 60 | hash, _ := blake2s.New256(nil) 61 | hash.Write([]byte(WGLabelCookie)) 62 | hash.Write(pk[:]) 63 | hash.Sum(st.mac2.encryptionKey[:0]) 64 | }() 65 | 66 | st.mac2.secretSet = time.Time{} 67 | } 68 | 69 | func (st *CookieChecker) CheckMAC1(msg []byte) bool { 70 | st.RLock() 71 | defer st.RUnlock() 72 | 73 | size := len(msg) 74 | smac2 := size - blake2s.Size128 75 | smac1 := smac2 - blake2s.Size128 76 | 77 | var mac1 [blake2s.Size128]byte 78 | 79 | mac, _ := blake2s.New128(st.mac1.key[:]) 80 | mac.Write(msg[:smac1]) 81 | mac.Sum(mac1[:0]) 82 | 83 | return hmac.Equal(mac1[:], msg[smac1:smac2]) 84 | } 85 | 86 | func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { 87 | st.RLock() 88 | defer st.RUnlock() 89 | 90 | if time.Since(st.mac2.secretSet) > CookieRefreshTime { 91 | return false 92 | } 93 | 94 | // derive cookie key 95 | 96 | var cookie [blake2s.Size128]byte 97 | func() { 98 | mac, _ := blake2s.New128(st.mac2.secret[:]) 99 | mac.Write(src) 100 | mac.Sum(cookie[:0]) 101 | }() 102 | 103 | // calculate mac of packet (including mac1) 104 | 105 | smac2 := len(msg) - blake2s.Size128 106 | 107 | var mac2 [blake2s.Size128]byte 108 | func() { 109 | mac, _ := blake2s.New128(cookie[:]) 110 | mac.Write(msg[:smac2]) 111 | mac.Sum(mac2[:0]) 112 | }() 113 | 114 | return hmac.Equal(mac2[:], msg[smac2:]) 115 | } 116 | 117 | func (st *CookieChecker) CreateReply( 118 | msg []byte, 119 | recv uint32, 120 | src []byte, 121 | ) (*MessageCookieReply, error) { 122 | 123 | st.RLock() 124 | 125 | // refresh cookie secret 126 | 127 | if time.Since(st.mac2.secretSet) > CookieRefreshTime { 128 | st.RUnlock() 129 | st.Lock() 130 | _, err := rand.Read(st.mac2.secret[:]) 131 | if err != nil { 132 | st.Unlock() 133 | return nil, err 134 | } 135 | st.mac2.secretSet = time.Now() 136 | st.Unlock() 137 | st.RLock() 138 | } 139 | 140 | // derive cookie 141 | 142 | var cookie [blake2s.Size128]byte 143 | func() { 144 | mac, _ := blake2s.New128(st.mac2.secret[:]) 145 | mac.Write(src) 146 | mac.Sum(cookie[:0]) 147 | }() 148 | 149 | // encrypt cookie 150 | 151 | size := len(msg) 152 | 153 | smac2 := size - blake2s.Size128 154 | smac1 := smac2 - blake2s.Size128 155 | 156 | reply := new(MessageCookieReply) 157 | reply.Type = MessageCookieReplyType 158 | reply.Receiver = recv 159 | 160 | _, err := rand.Read(reply.Nonce[:]) 161 | if err != nil { 162 | st.RUnlock() 163 | return nil, err 164 | } 165 | 166 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) 167 | xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2]) 168 | 169 | st.RUnlock() 170 | 171 | return reply, nil 172 | } 173 | 174 | func (st *CookieGenerator) Init(pk NoisePublicKey) { 175 | st.Lock() 176 | defer st.Unlock() 177 | 178 | func() { 179 | hash, _ := blake2s.New256(nil) 180 | hash.Write([]byte(WGLabelMAC1)) 181 | hash.Write(pk[:]) 182 | hash.Sum(st.mac1.key[:0]) 183 | }() 184 | 185 | func() { 186 | hash, _ := blake2s.New256(nil) 187 | hash.Write([]byte(WGLabelCookie)) 188 | hash.Write(pk[:]) 189 | hash.Sum(st.mac2.encryptionKey[:0]) 190 | }() 191 | 192 | st.mac2.cookieSet = time.Time{} 193 | } 194 | 195 | func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { 196 | st.Lock() 197 | defer st.Unlock() 198 | 199 | if !st.mac2.hasLastMAC1 { 200 | return false 201 | } 202 | 203 | var cookie [blake2s.Size128]byte 204 | 205 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) 206 | _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) 207 | 208 | if err != nil { 209 | return false 210 | } 211 | 212 | st.mac2.cookieSet = time.Now() 213 | st.mac2.cookie = cookie 214 | return true 215 | } 216 | 217 | func (st *CookieGenerator) AddMacs(msg []byte) { 218 | 219 | size := len(msg) 220 | 221 | smac2 := size - blake2s.Size128 222 | smac1 := smac2 - blake2s.Size128 223 | 224 | mac1 := msg[smac1:smac2] 225 | mac2 := msg[smac2:] 226 | 227 | st.Lock() 228 | defer st.Unlock() 229 | 230 | // set mac1 231 | 232 | func() { 233 | mac, _ := blake2s.New128(st.mac1.key[:]) 234 | mac.Write(msg[:smac1]) 235 | mac.Sum(mac1[:0]) 236 | }() 237 | copy(st.mac2.lastMAC1[:], mac1) 238 | st.mac2.hasLastMAC1 = true 239 | 240 | // set mac2 241 | 242 | if time.Since(st.mac2.cookieSet) > CookieRefreshTime { 243 | return 244 | } 245 | 246 | func() { 247 | mac, _ := blake2s.New128(st.mac2.cookie[:]) 248 | mac.Write(msg[:smac2]) 249 | mac.Sum(mac2[:0]) 250 | }() 251 | } 252 | -------------------------------------------------------------------------------- /device/cookie_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "testing" 10 | ) 11 | 12 | func TestCookieMAC1(t *testing.T) { 13 | 14 | // setup generator / checker 15 | 16 | var ( 17 | generator CookieGenerator 18 | checker CookieChecker 19 | ) 20 | 21 | sk, err := newPrivateKey() 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | pk := sk.publicKey() 26 | 27 | generator.Init(pk) 28 | checker.Init(pk) 29 | 30 | // check mac1 31 | 32 | src := []byte{192, 168, 13, 37, 10, 10, 10} 33 | 34 | checkMAC1 := func(msg []byte) { 35 | generator.AddMacs(msg) 36 | if !checker.CheckMAC1(msg) { 37 | t.Fatal("MAC1 generation/verification failed") 38 | } 39 | if checker.CheckMAC2(msg, src) { 40 | t.Fatal("MAC2 generation/verification failed") 41 | } 42 | } 43 | 44 | checkMAC1([]byte{ 45 | 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, 46 | 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, 47 | 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, 48 | 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, 49 | 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, 50 | 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, 51 | 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, 52 | 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, 53 | 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, 54 | 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, 55 | 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, 56 | 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, 57 | 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, 58 | 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, 59 | 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, 60 | 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, 61 | }) 62 | 63 | checkMAC1([]byte{ 64 | 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, 65 | 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, 66 | 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, 67 | 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, 68 | 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, 69 | 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, 70 | 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, 71 | 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, 72 | }) 73 | 74 | checkMAC1([]byte{ 75 | 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, 76 | 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, 77 | 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, 78 | 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, 79 | }) 80 | 81 | // exchange cookie reply 82 | 83 | func() { 84 | msg := []byte{ 85 | 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, 86 | 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, 87 | 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, 88 | 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, 89 | 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, 90 | 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, 91 | 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, 92 | 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, 93 | 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, 94 | 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, 95 | 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, 96 | 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, 97 | 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, 98 | 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, 99 | 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, 100 | 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, 101 | } 102 | generator.AddMacs(msg) 103 | reply, err := checker.CreateReply(msg, 1377, src) 104 | if err != nil { 105 | t.Fatal("Failed to create cookie reply:", err) 106 | } 107 | if !generator.ConsumeReply(reply) { 108 | t.Fatal("Failed to consume cookie reply") 109 | } 110 | }() 111 | 112 | // check mac2 113 | 114 | checkMAC2 := func(msg []byte) { 115 | generator.AddMacs(msg) 116 | 117 | if !checker.CheckMAC1(msg) { 118 | t.Fatal("MAC1 generation/verification failed") 119 | } 120 | if !checker.CheckMAC2(msg, src) { 121 | t.Fatal("MAC2 generation/verification failed") 122 | } 123 | 124 | msg[5] ^= 0x20 125 | 126 | if checker.CheckMAC1(msg) { 127 | t.Fatal("MAC1 generation/verification failed") 128 | } 129 | if checker.CheckMAC2(msg, src) { 130 | t.Fatal("MAC2 generation/verification failed") 131 | } 132 | 133 | msg[5] ^= 0x20 134 | 135 | srcBad1 := []byte{192, 168, 13, 37, 40, 01} 136 | if checker.CheckMAC2(msg, srcBad1) { 137 | t.Fatal("MAC2 generation/verification failed") 138 | } 139 | 140 | srcBad2 := []byte{192, 168, 13, 38, 40, 01} 141 | if checker.CheckMAC2(msg, srcBad2) { 142 | t.Fatal("MAC2 generation/verification failed") 143 | } 144 | } 145 | 146 | checkMAC2([]byte{ 147 | 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, 148 | 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, 149 | 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, 150 | 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, 151 | 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, 152 | 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, 153 | 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, 154 | 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, 155 | }) 156 | 157 | checkMAC2([]byte{ 158 | 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, 159 | 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, 160 | 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, 161 | 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, 162 | 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, 163 | 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, 164 | 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, 165 | 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, 166 | 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, 167 | 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, 168 | 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, 169 | 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, 170 | 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, 171 | 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, 172 | 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, 173 | 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, 174 | 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, 175 | 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, 176 | 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, 177 | 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, 178 | 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, 179 | 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, 180 | 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, 181 | 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, 182 | 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, 183 | 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, 184 | 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, 185 | 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, 186 | 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, 187 | 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, 188 | 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, 189 | 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, 190 | }) 191 | } 192 | -------------------------------------------------------------------------------- /device/devicestate_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT. 2 | 3 | package device 4 | 5 | import "strconv" 6 | 7 | const _deviceState_name = "DownUpClosed" 8 | 9 | var _deviceState_index = [...]uint8{0, 4, 6, 12} 10 | 11 | func (i deviceState) String() string { 12 | if i >= deviceState(len(_deviceState_index)-1) { 13 | return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")" 14 | } 15 | return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]] 16 | } 17 | -------------------------------------------------------------------------------- /device/endpoint_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net" 11 | ) 12 | 13 | type DummyEndpoint struct { 14 | src [16]byte 15 | dst [16]byte 16 | } 17 | 18 | func CreateDummyEndpoint() (*DummyEndpoint, error) { 19 | var end DummyEndpoint 20 | if _, err := rand.Read(end.src[:]); err != nil { 21 | return nil, err 22 | } 23 | _, err := rand.Read(end.dst[:]) 24 | return &end, err 25 | } 26 | 27 | func (e *DummyEndpoint) ClearSrc() {} 28 | 29 | func (e *DummyEndpoint) SrcToString() string { 30 | var addr net.UDPAddr 31 | addr.IP = e.SrcIP() 32 | addr.Port = 1000 33 | return addr.String() 34 | } 35 | 36 | func (e *DummyEndpoint) DstToString() string { 37 | var addr net.UDPAddr 38 | addr.IP = e.DstIP() 39 | addr.Port = 1000 40 | return addr.String() 41 | } 42 | 43 | func (e *DummyEndpoint) SrcToBytes() []byte { 44 | return e.src[:] 45 | } 46 | 47 | func (e *DummyEndpoint) DstIP() net.IP { 48 | return e.dst[:] 49 | } 50 | 51 | func (e *DummyEndpoint) SrcIP() net.IP { 52 | return e.src[:] 53 | } 54 | -------------------------------------------------------------------------------- /device/indextable.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/rand" 10 | "encoding/binary" 11 | "sync" 12 | ) 13 | 14 | type IndexTableEntry struct { 15 | peer *Peer 16 | handshake *Handshake 17 | keypair *Keypair 18 | } 19 | 20 | type IndexTable struct { 21 | sync.RWMutex 22 | table map[uint32]IndexTableEntry 23 | } 24 | 25 | func randUint32() (uint32, error) { 26 | var integer [4]byte 27 | _, err := rand.Read(integer[:]) 28 | // Arbitrary endianness; both are intrinsified by the Go compiler. 29 | return binary.LittleEndian.Uint32(integer[:]), err 30 | } 31 | 32 | func (table *IndexTable) Init() { 33 | table.Lock() 34 | defer table.Unlock() 35 | table.table = make(map[uint32]IndexTableEntry) 36 | } 37 | 38 | func (table *IndexTable) Delete(index uint32) { 39 | table.Lock() 40 | defer table.Unlock() 41 | delete(table.table, index) 42 | } 43 | 44 | func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { 45 | table.Lock() 46 | defer table.Unlock() 47 | entry, ok := table.table[index] 48 | if !ok { 49 | return 50 | } 51 | table.table[index] = IndexTableEntry{ 52 | peer: entry.peer, 53 | keypair: keypair, 54 | handshake: nil, 55 | } 56 | } 57 | 58 | func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) { 59 | for { 60 | // generate random index 61 | 62 | index, err := randUint32() 63 | if err != nil { 64 | return index, err 65 | } 66 | 67 | // check if index used 68 | 69 | table.RLock() 70 | _, ok := table.table[index] 71 | table.RUnlock() 72 | if ok { 73 | continue 74 | } 75 | 76 | // check again while locked 77 | 78 | table.Lock() 79 | _, found := table.table[index] 80 | if found { 81 | table.Unlock() 82 | continue 83 | } 84 | table.table[index] = IndexTableEntry{ 85 | peer: peer, 86 | handshake: handshake, 87 | keypair: nil, 88 | } 89 | table.Unlock() 90 | return index, nil 91 | } 92 | } 93 | 94 | func (table *IndexTable) Lookup(id uint32) IndexTableEntry { 95 | table.RLock() 96 | defer table.RUnlock() 97 | return table.table[id] 98 | } 99 | -------------------------------------------------------------------------------- /device/ip.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "net" 10 | ) 11 | 12 | const ( 13 | IPv4offsetTotalLength = 2 14 | IPv4offsetSrc = 12 15 | IPv4offsetDst = IPv4offsetSrc + net.IPv4len 16 | ) 17 | 18 | const ( 19 | IPv6offsetPayloadLength = 4 20 | IPv6offsetSrc = 8 21 | IPv6offsetDst = IPv6offsetSrc + net.IPv6len 22 | ) 23 | -------------------------------------------------------------------------------- /device/kdf_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "encoding/hex" 10 | "testing" 11 | 12 | "golang.org/x/crypto/blake2s" 13 | ) 14 | 15 | type KDFTest struct { 16 | key string 17 | input string 18 | t0 string 19 | t1 string 20 | t2 string 21 | } 22 | 23 | func assertEquals(t *testing.T, a string, b string) { 24 | if a != b { 25 | t.Fatal("expected", a, "=", b) 26 | } 27 | } 28 | 29 | func TestKDF(t *testing.T) { 30 | tests := []KDFTest{ 31 | { 32 | key: "746573742d6b6579", 33 | input: "746573742d696e707574", 34 | t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", 35 | t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", 36 | t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", 37 | }, 38 | { 39 | key: "776972656775617264", 40 | input: "776972656775617264", 41 | t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", 42 | t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", 43 | t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", 44 | }, 45 | { 46 | key: "", 47 | input: "", 48 | t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", 49 | t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", 50 | t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", 51 | }, 52 | } 53 | 54 | var t0, t1, t2 [blake2s.Size]byte 55 | 56 | for _, test := range tests { 57 | key, _ := hex.DecodeString(test.key) 58 | input, _ := hex.DecodeString(test.input) 59 | KDF3(&t0, &t1, &t2, key, input) 60 | t0s := hex.EncodeToString(t0[:]) 61 | t1s := hex.EncodeToString(t1[:]) 62 | t2s := hex.EncodeToString(t2[:]) 63 | assertEquals(t, t0s, test.t0) 64 | assertEquals(t, t1s, test.t1) 65 | assertEquals(t, t2s, test.t2) 66 | } 67 | 68 | for _, test := range tests { 69 | key, _ := hex.DecodeString(test.key) 70 | input, _ := hex.DecodeString(test.input) 71 | KDF2(&t0, &t1, key, input) 72 | t0s := hex.EncodeToString(t0[:]) 73 | t1s := hex.EncodeToString(t1[:]) 74 | assertEquals(t, t0s, test.t0) 75 | assertEquals(t, t1s, test.t1) 76 | } 77 | 78 | for _, test := range tests { 79 | key, _ := hex.DecodeString(test.key) 80 | input, _ := hex.DecodeString(test.input) 81 | KDF1(&t0, key, input) 82 | t0s := hex.EncodeToString(t0[:]) 83 | assertEquals(t, t0s, test.t0) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /device/keypair.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/cipher" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | "unsafe" 14 | 15 | "golang.zx2c4.com/wireguard/replay" 16 | ) 17 | 18 | /* Due to limitations in Go and /x/crypto there is currently 19 | * no way to ensure that key material is securely ereased in memory. 20 | * 21 | * Since this may harm the forward secrecy property, 22 | * we plan to resolve this issue; whenever Go allows us to do so. 23 | */ 24 | 25 | type Keypair struct { 26 | sendNonce uint64 // accessed atomically 27 | send cipher.AEAD 28 | receive cipher.AEAD 29 | replayFilter replay.Filter 30 | isInitiator bool 31 | created time.Time 32 | localIndex uint32 33 | remoteIndex uint32 34 | } 35 | 36 | type Keypairs struct { 37 | sync.RWMutex 38 | current *Keypair 39 | previous *Keypair 40 | next *Keypair 41 | } 42 | 43 | func (kp *Keypairs) storeNext(next *Keypair) { 44 | atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next)) 45 | } 46 | 47 | func (kp *Keypairs) loadNext() *Keypair { 48 | return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)))) 49 | } 50 | 51 | func (kp *Keypairs) Current() *Keypair { 52 | kp.RLock() 53 | defer kp.RUnlock() 54 | return kp.current 55 | } 56 | 57 | func (device *Device) DeleteKeypair(key *Keypair) { 58 | if key != nil { 59 | device.indexTable.Delete(key.localIndex) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /device/logger.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "log" 10 | "os" 11 | ) 12 | 13 | // A Logger provides logging for a Device. 14 | // The functions are Printf-style functions. 15 | // They must be safe for concurrent use. 16 | // They do not require a trailing newline in the format. 17 | // If nil, that level of logging will be silent. 18 | type Logger struct { 19 | Verbosef func(format string, args ...interface{}) 20 | Errorf func(format string, args ...interface{}) 21 | } 22 | 23 | // Log levels for use with NewLogger. 24 | const ( 25 | LogLevelSilent = iota 26 | LogLevelError 27 | LogLevelVerbose 28 | ) 29 | 30 | // Function for use in Logger for discarding logged lines. 31 | func DiscardLogf(format string, args ...interface{}) {} 32 | 33 | // NewLogger constructs a Logger that writes to stdout. 34 | // It logs at the specified log level and above. 35 | // It decorates log lines with the log level, date, time, and prepend. 36 | func NewLogger(level int, prepend string) *Logger { 37 | logger := &Logger{DiscardLogf, DiscardLogf} 38 | logf := func(prefix string) func(string, ...interface{}) { 39 | return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf 40 | } 41 | if level >= LogLevelVerbose { 42 | logger.Verbosef = logf("DEBUG") 43 | } 44 | if level >= LogLevelError { 45 | logger.Errorf = logf("ERROR") 46 | } 47 | return logger 48 | } 49 | -------------------------------------------------------------------------------- /device/misc.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "sync/atomic" 10 | ) 11 | 12 | /* Atomic Boolean */ 13 | 14 | const ( 15 | AtomicFalse = int32(iota) 16 | AtomicTrue 17 | ) 18 | 19 | type AtomicBool struct { 20 | int32 21 | } 22 | 23 | func (a *AtomicBool) Get() bool { 24 | return atomic.LoadInt32(&a.int32) == AtomicTrue 25 | } 26 | 27 | func (a *AtomicBool) Swap(val bool) bool { 28 | flag := AtomicFalse 29 | if val { 30 | flag = AtomicTrue 31 | } 32 | return atomic.SwapInt32(&a.int32, flag) == AtomicTrue 33 | } 34 | 35 | func (a *AtomicBool) Set(val bool) { 36 | flag := AtomicFalse 37 | if val { 38 | flag = AtomicTrue 39 | } 40 | atomic.StoreInt32(&a.int32, flag) 41 | } 42 | -------------------------------------------------------------------------------- /device/mobilequirks.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { 9 | device.peers.RLock() 10 | for _, peer := range device.peers.keyMap { 11 | peer.Lock() 12 | peer.disableRoaming = peer.endpoint != nil 13 | peer.Unlock() 14 | } 15 | device.peers.RUnlock() 16 | } 17 | -------------------------------------------------------------------------------- /device/noise-helpers.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "crypto/subtle" 12 | "hash" 13 | 14 | "golang.org/x/crypto/blake2s" 15 | "golang.org/x/crypto/curve25519" 16 | ) 17 | 18 | /* KDF related functions. 19 | * HMAC-based Key Derivation Function (HKDF) 20 | * https://tools.ietf.org/html/rfc5869 21 | */ 22 | 23 | func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { 24 | mac := hmac.New(func() hash.Hash { 25 | h, _ := blake2s.New256(nil) 26 | return h 27 | }, key) 28 | mac.Write(in0) 29 | mac.Sum(sum[:0]) 30 | } 31 | 32 | func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { 33 | mac := hmac.New(func() hash.Hash { 34 | h, _ := blake2s.New256(nil) 35 | return h 36 | }, key) 37 | mac.Write(in0) 38 | mac.Write(in1) 39 | mac.Sum(sum[:0]) 40 | } 41 | 42 | func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { 43 | HMAC1(t0, key, input) 44 | HMAC1(t0, t0[:], []byte{0x1}) 45 | } 46 | 47 | func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { 48 | var prk [blake2s.Size]byte 49 | HMAC1(&prk, key, input) 50 | HMAC1(t0, prk[:], []byte{0x1}) 51 | HMAC2(t1, prk[:], t0[:], []byte{0x2}) 52 | setZero(prk[:]) 53 | } 54 | 55 | func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { 56 | var prk [blake2s.Size]byte 57 | HMAC1(&prk, key, input) 58 | HMAC1(t0, prk[:], []byte{0x1}) 59 | HMAC2(t1, prk[:], t0[:], []byte{0x2}) 60 | HMAC2(t2, prk[:], t1[:], []byte{0x3}) 61 | setZero(prk[:]) 62 | } 63 | 64 | func isZero(val []byte) bool { 65 | acc := 1 66 | for _, b := range val { 67 | acc &= subtle.ConstantTimeByteEq(b, 0) 68 | } 69 | return acc == 1 70 | } 71 | 72 | /* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ 73 | func setZero(arr []byte) { 74 | for i := range arr { 75 | arr[i] = 0 76 | } 77 | } 78 | 79 | func (sk *NoisePrivateKey) clamp() { 80 | sk[0] &= 248 81 | sk[31] = (sk[31] & 127) | 64 82 | } 83 | 84 | func newPrivateKey() (sk NoisePrivateKey, err error) { 85 | _, err = rand.Read(sk[:]) 86 | sk.clamp() 87 | return 88 | } 89 | 90 | func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { 91 | apk := (*[NoisePublicKeySize]byte)(&pk) 92 | ask := (*[NoisePrivateKeySize]byte)(sk) 93 | curve25519.ScalarBaseMult(apk, ask) 94 | return 95 | } 96 | 97 | func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { 98 | apk := (*[NoisePublicKeySize]byte)(&pk) 99 | ask := (*[NoisePrivateKeySize]byte)(sk) 100 | curve25519.ScalarMult(&ss, ask, apk) 101 | return ss 102 | } 103 | -------------------------------------------------------------------------------- /device/noise-types.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/subtle" 10 | "encoding/hex" 11 | "errors" 12 | ) 13 | 14 | const ( 15 | NoisePublicKeySize = 32 16 | NoisePrivateKeySize = 32 17 | NoisePresharedKeySize = 32 18 | ) 19 | 20 | type ( 21 | NoisePublicKey [NoisePublicKeySize]byte 22 | NoisePrivateKey [NoisePrivateKeySize]byte 23 | NoisePresharedKey [NoisePresharedKeySize]byte 24 | NoiseNonce uint64 // padded to 12-bytes 25 | ) 26 | 27 | func loadExactHex(dst []byte, src string) error { 28 | slice, err := hex.DecodeString(src) 29 | if err != nil { 30 | return err 31 | } 32 | if len(slice) != len(dst) { 33 | return errors.New("hex string does not fit the slice") 34 | } 35 | copy(dst, slice) 36 | return nil 37 | } 38 | 39 | func (key NoisePrivateKey) IsZero() bool { 40 | var zero NoisePrivateKey 41 | return key.Equals(zero) 42 | } 43 | 44 | func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { 45 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 46 | } 47 | 48 | func (key *NoisePrivateKey) FromHex(src string) (err error) { 49 | err = loadExactHex(key[:], src) 50 | key.clamp() 51 | return 52 | } 53 | 54 | func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { 55 | err = loadExactHex(key[:], src) 56 | if key.IsZero() { 57 | return 58 | } 59 | key.clamp() 60 | return 61 | } 62 | 63 | func (key *NoisePublicKey) FromHex(src string) error { 64 | return loadExactHex(key[:], src) 65 | } 66 | 67 | func (key NoisePublicKey) IsZero() bool { 68 | var zero NoisePublicKey 69 | return key.Equals(zero) 70 | } 71 | 72 | func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { 73 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 74 | } 75 | 76 | func (key *NoisePresharedKey) FromHex(src string) error { 77 | return loadExactHex(key[:], src) 78 | } 79 | -------------------------------------------------------------------------------- /device/noise_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "bytes" 10 | "encoding/binary" 11 | "testing" 12 | 13 | "golang.zx2c4.com/wireguard/conn" 14 | "golang.zx2c4.com/wireguard/tun/tuntest" 15 | ) 16 | 17 | func TestCurveWrappers(t *testing.T) { 18 | sk1, err := newPrivateKey() 19 | assertNil(t, err) 20 | 21 | sk2, err := newPrivateKey() 22 | assertNil(t, err) 23 | 24 | pk1 := sk1.publicKey() 25 | pk2 := sk2.publicKey() 26 | 27 | ss1 := sk1.sharedSecret(pk2) 28 | ss2 := sk2.sharedSecret(pk1) 29 | 30 | if ss1 != ss2 { 31 | t.Fatal("Failed to compute shared secet") 32 | } 33 | } 34 | 35 | func randDevice(t *testing.T) *Device { 36 | sk, err := newPrivateKey() 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | tun := tuntest.NewChannelTUN() 41 | logger := NewLogger(LogLevelError, "") 42 | device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) 43 | device.SetPrivateKey(sk) 44 | return device 45 | } 46 | 47 | func assertNil(t *testing.T, err error) { 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | } 52 | 53 | func assertEqual(t *testing.T, a, b []byte) { 54 | if !bytes.Equal(a, b) { 55 | t.Fatal(a, "!=", b) 56 | } 57 | } 58 | 59 | func TestNoiseHandshake(t *testing.T) { 60 | dev1 := randDevice(t) 61 | dev2 := randDevice(t) 62 | 63 | defer dev1.Close() 64 | defer dev2.Close() 65 | 66 | peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) 67 | if err != nil { 68 | t.Fatal(err) 69 | } 70 | peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | 75 | assertEqual( 76 | t, 77 | peer1.handshake.precomputedStaticStatic[:], 78 | peer2.handshake.precomputedStaticStatic[:], 79 | ) 80 | 81 | /* simulate handshake */ 82 | 83 | // initiation message 84 | 85 | t.Log("exchange initiation message") 86 | 87 | msg1, err := dev1.CreateMessageInitiation(peer2) 88 | assertNil(t, err) 89 | 90 | packet := make([]byte, 0, 256) 91 | writer := bytes.NewBuffer(packet) 92 | err = binary.Write(writer, binary.LittleEndian, msg1) 93 | assertNil(t, err) 94 | peer := dev2.ConsumeMessageInitiation(msg1) 95 | if peer == nil { 96 | t.Fatal("handshake failed at initiation message") 97 | } 98 | 99 | assertEqual( 100 | t, 101 | peer1.handshake.chainKey[:], 102 | peer2.handshake.chainKey[:], 103 | ) 104 | 105 | assertEqual( 106 | t, 107 | peer1.handshake.hash[:], 108 | peer2.handshake.hash[:], 109 | ) 110 | 111 | // response message 112 | 113 | t.Log("exchange response message") 114 | 115 | msg2, err := dev2.CreateMessageResponse(peer1) 116 | assertNil(t, err) 117 | 118 | peer = dev1.ConsumeMessageResponse(msg2) 119 | if peer == nil { 120 | t.Fatal("handshake failed at response message") 121 | } 122 | 123 | assertEqual( 124 | t, 125 | peer1.handshake.chainKey[:], 126 | peer2.handshake.chainKey[:], 127 | ) 128 | 129 | assertEqual( 130 | t, 131 | peer1.handshake.hash[:], 132 | peer2.handshake.hash[:], 133 | ) 134 | 135 | // key pairs 136 | 137 | t.Log("deriving keys") 138 | 139 | err = peer1.BeginSymmetricSession() 140 | if err != nil { 141 | t.Fatal("failed to derive keypair for peer 1", err) 142 | } 143 | 144 | err = peer2.BeginSymmetricSession() 145 | if err != nil { 146 | t.Fatal("failed to derive keypair for peer 2", err) 147 | } 148 | 149 | key1 := peer1.keypairs.loadNext() 150 | key2 := peer2.keypairs.current 151 | 152 | // encrypting / decryption test 153 | 154 | t.Log("test key pairs") 155 | 156 | func() { 157 | testMsg := []byte("wireguard test message 1") 158 | var err error 159 | var out []byte 160 | var nonce [12]byte 161 | out = key1.send.Seal(out, nonce[:], testMsg, nil) 162 | out, err = key2.receive.Open(out[:0], nonce[:], out, nil) 163 | assertNil(t, err) 164 | assertEqual(t, out, testMsg) 165 | }() 166 | 167 | func() { 168 | testMsg := []byte("wireguard test message 2") 169 | var err error 170 | var out []byte 171 | var nonce [12]byte 172 | out = key2.send.Seal(out, nonce[:], testMsg, nil) 173 | out, err = key1.receive.Open(out[:0], nonce[:], out, nil) 174 | assertNil(t, err) 175 | assertEqual(t, out, testMsg) 176 | }() 177 | } 178 | -------------------------------------------------------------------------------- /device/peer.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "container/list" 10 | "errors" 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | 15 | "golang.zx2c4.com/wireguard/conn" 16 | ) 17 | 18 | type Peer struct { 19 | isRunning AtomicBool 20 | sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer 21 | keypairs Keypairs 22 | handshake Handshake 23 | device *Device 24 | endpoint conn.Endpoint 25 | stopping sync.WaitGroup // routines pending stop 26 | 27 | // These fields are accessed with atomic operations, which must be 28 | // 64-bit aligned even on 32-bit platforms. Go guarantees that an 29 | // allocated struct will be 64-bit aligned. So we place 30 | // atomically-accessed fields up front, so that they can share in 31 | // this alignment before smaller fields throw it off. 32 | stats struct { 33 | txBytes uint64 // bytes send to peer (endpoint) 34 | rxBytes uint64 // bytes received from peer 35 | lastHandshakeNano int64 // nano seconds since epoch 36 | } 37 | 38 | disableRoaming bool 39 | 40 | timers struct { 41 | retransmitHandshake *Timer 42 | sendKeepalive *Timer 43 | newHandshake *Timer 44 | zeroKeyMaterial *Timer 45 | persistentKeepalive *Timer 46 | handshakeAttempts uint32 47 | needAnotherKeepalive AtomicBool 48 | sentLastMinuteHandshake AtomicBool 49 | } 50 | 51 | state struct { 52 | sync.Mutex // protects against concurrent Start/Stop 53 | } 54 | 55 | queue struct { 56 | staged chan *QueueOutboundElement // staged packets before a handshake is available 57 | outbound *autodrainingOutboundQueue // sequential ordering of udp transmission 58 | inbound *autodrainingInboundQueue // sequential ordering of tun writing 59 | } 60 | 61 | cookieGenerator CookieGenerator 62 | trieEntries list.List 63 | persistentKeepaliveInterval uint32 // accessed atomically 64 | } 65 | 66 | func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { 67 | if device.isClosed() { 68 | return nil, errors.New("device closed") 69 | } 70 | 71 | // lock resources 72 | device.staticIdentity.RLock() 73 | defer device.staticIdentity.RUnlock() 74 | 75 | device.peers.Lock() 76 | defer device.peers.Unlock() 77 | 78 | // check if over limit 79 | if len(device.peers.keyMap) >= MaxPeers { 80 | return nil, errors.New("too many peers") 81 | } 82 | 83 | // create peer 84 | peer := new(Peer) 85 | peer.Lock() 86 | defer peer.Unlock() 87 | 88 | peer.cookieGenerator.Init(pk) 89 | peer.device = device 90 | peer.queue.outbound = newAutodrainingOutboundQueue(device) 91 | peer.queue.inbound = newAutodrainingInboundQueue(device) 92 | peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) 93 | 94 | // map public key 95 | _, ok := device.peers.keyMap[pk] 96 | if ok { 97 | return nil, errors.New("adding existing peer") 98 | } 99 | 100 | // pre-compute DH 101 | handshake := &peer.handshake 102 | handshake.mutex.Lock() 103 | handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) 104 | handshake.remoteStatic = pk 105 | handshake.mutex.Unlock() 106 | 107 | // reset endpoint 108 | peer.endpoint = nil 109 | 110 | // add 111 | device.peers.keyMap[pk] = peer 112 | 113 | // start peer 114 | peer.timersInit() 115 | if peer.device.isUp() { 116 | peer.Start() 117 | } 118 | 119 | return peer, nil 120 | } 121 | 122 | func (peer *Peer) SendBuffer(buffer []byte) error { 123 | peer.device.net.RLock() 124 | defer peer.device.net.RUnlock() 125 | 126 | if peer.device.isClosed() { 127 | return nil 128 | } 129 | 130 | peer.RLock() 131 | defer peer.RUnlock() 132 | 133 | if peer.endpoint == nil { 134 | return errors.New("no known endpoint for peer") 135 | } 136 | 137 | err := peer.device.net.bind.Send(buffer, peer.endpoint) 138 | if err == nil { 139 | atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) 140 | } 141 | return err 142 | } 143 | 144 | func (peer *Peer) String() string { 145 | // The awful goo that follows is identical to: 146 | // 147 | // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) 148 | // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43] 149 | // return fmt.Sprintf("peer(%s)", abbreviatedKey) 150 | // 151 | // except that it is considerably more efficient. 152 | src := peer.handshake.remoteStatic 153 | b64 := func(input byte) byte { 154 | return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3) 155 | } 156 | b := []byte("peer(____…____)") 157 | const first = len("peer(") 158 | const second = len("peer(____…") 159 | b[first+0] = b64((src[0] >> 2) & 63) 160 | b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63) 161 | b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63) 162 | b[first+3] = b64(src[2] & 63) 163 | b[second+0] = b64(src[29] & 63) 164 | b[second+1] = b64((src[30] >> 2) & 63) 165 | b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63) 166 | b[second+3] = b64((src[31] << 2) & 63) 167 | return string(b) 168 | } 169 | 170 | func (peer *Peer) Start() { 171 | // should never start a peer on a closed device 172 | if peer.device.isClosed() { 173 | return 174 | } 175 | 176 | // prevent simultaneous start/stop operations 177 | peer.state.Lock() 178 | defer peer.state.Unlock() 179 | 180 | if peer.isRunning.Get() { 181 | return 182 | } 183 | 184 | device := peer.device 185 | device.log.Verbosef("%v - Starting", peer) 186 | 187 | // reset routine state 188 | peer.stopping.Wait() 189 | peer.stopping.Add(2) 190 | 191 | peer.handshake.mutex.Lock() 192 | peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) 193 | peer.handshake.mutex.Unlock() 194 | 195 | peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes 196 | 197 | peer.timersStart() 198 | 199 | device.flushInboundQueue(peer.queue.inbound) 200 | device.flushOutboundQueue(peer.queue.outbound) 201 | go peer.RoutineSequentialSender() 202 | go peer.RoutineSequentialReceiver() 203 | 204 | peer.isRunning.Set(true) 205 | } 206 | 207 | func (peer *Peer) ZeroAndFlushAll() { 208 | device := peer.device 209 | 210 | // clear key pairs 211 | 212 | keypairs := &peer.keypairs 213 | keypairs.Lock() 214 | device.DeleteKeypair(keypairs.previous) 215 | device.DeleteKeypair(keypairs.current) 216 | device.DeleteKeypair(keypairs.loadNext()) 217 | keypairs.previous = nil 218 | keypairs.current = nil 219 | keypairs.storeNext(nil) 220 | keypairs.Unlock() 221 | 222 | // clear handshake state 223 | 224 | handshake := &peer.handshake 225 | handshake.mutex.Lock() 226 | device.indexTable.Delete(handshake.localIndex) 227 | handshake.Clear() 228 | handshake.mutex.Unlock() 229 | 230 | peer.FlushStagedPackets() 231 | } 232 | 233 | func (peer *Peer) ExpireCurrentKeypairs() { 234 | handshake := &peer.handshake 235 | handshake.mutex.Lock() 236 | peer.device.indexTable.Delete(handshake.localIndex) 237 | handshake.Clear() 238 | peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) 239 | handshake.mutex.Unlock() 240 | 241 | keypairs := &peer.keypairs 242 | keypairs.Lock() 243 | if keypairs.current != nil { 244 | atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages) 245 | } 246 | if keypairs.next != nil { 247 | next := keypairs.loadNext() 248 | atomic.StoreUint64(&next.sendNonce, RejectAfterMessages) 249 | } 250 | keypairs.Unlock() 251 | } 252 | 253 | func (peer *Peer) Stop() { 254 | peer.state.Lock() 255 | defer peer.state.Unlock() 256 | 257 | if !peer.isRunning.Swap(false) { 258 | return 259 | } 260 | 261 | peer.device.log.Verbosef("%v - Stopping", peer) 262 | 263 | peer.timersStop() 264 | // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. 265 | peer.queue.inbound.c <- nil 266 | peer.queue.outbound.c <- nil 267 | peer.stopping.Wait() 268 | peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us 269 | 270 | peer.ZeroAndFlushAll() 271 | } 272 | 273 | func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { 274 | if peer.disableRoaming { 275 | return 276 | } 277 | peer.Lock() 278 | peer.endpoint = endpoint 279 | peer.Unlock() 280 | } 281 | -------------------------------------------------------------------------------- /device/pools.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "sync" 10 | "sync/atomic" 11 | ) 12 | 13 | type WaitPool struct { 14 | pool sync.Pool 15 | cond sync.Cond 16 | lock sync.Mutex 17 | count uint32 18 | max uint32 19 | } 20 | 21 | func NewWaitPool(max uint32, new func() interface{}) *WaitPool { 22 | p := &WaitPool{pool: sync.Pool{New: new}, max: max} 23 | p.cond = sync.Cond{L: &p.lock} 24 | return p 25 | } 26 | 27 | func (p *WaitPool) Get() interface{} { 28 | if p.max != 0 { 29 | p.lock.Lock() 30 | for atomic.LoadUint32(&p.count) >= p.max { 31 | p.cond.Wait() 32 | } 33 | atomic.AddUint32(&p.count, 1) 34 | p.lock.Unlock() 35 | } 36 | return p.pool.Get() 37 | } 38 | 39 | func (p *WaitPool) Put(x interface{}) { 40 | p.pool.Put(x) 41 | if p.max == 0 { 42 | return 43 | } 44 | atomic.AddUint32(&p.count, ^uint32(0)) 45 | p.cond.Signal() 46 | } 47 | 48 | func (device *Device) PopulatePools() { 49 | device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { 50 | return new([MaxMessageSize]byte) 51 | }) 52 | device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { 53 | return new(QueueInboundElement) 54 | }) 55 | device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { 56 | return new(QueueOutboundElement) 57 | }) 58 | } 59 | 60 | func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { 61 | return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) 62 | } 63 | 64 | func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { 65 | device.pool.messageBuffers.Put(msg) 66 | } 67 | 68 | func (device *Device) GetInboundElement() *QueueInboundElement { 69 | return device.pool.inboundElements.Get().(*QueueInboundElement) 70 | } 71 | 72 | func (device *Device) PutInboundElement(elem *QueueInboundElement) { 73 | elem.clearPointers() 74 | device.pool.inboundElements.Put(elem) 75 | } 76 | 77 | func (device *Device) GetOutboundElement() *QueueOutboundElement { 78 | return device.pool.outboundElements.Get().(*QueueOutboundElement) 79 | } 80 | 81 | func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { 82 | elem.clearPointers() 83 | device.pool.outboundElements.Put(elem) 84 | } 85 | -------------------------------------------------------------------------------- /device/pools_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "runtime" 11 | "sync" 12 | "sync/atomic" 13 | "testing" 14 | "time" 15 | ) 16 | 17 | func TestWaitPool(t *testing.T) { 18 | t.Skip("Currently disabled") 19 | var wg sync.WaitGroup 20 | trials := int32(100000) 21 | if raceEnabled { 22 | // This test can be very slow with -race. 23 | trials /= 10 24 | } 25 | workers := runtime.NumCPU() + 2 26 | if workers-4 <= 0 { 27 | t.Skip("Not enough cores") 28 | } 29 | p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) }) 30 | wg.Add(workers) 31 | max := uint32(0) 32 | updateMax := func() { 33 | count := atomic.LoadUint32(&p.count) 34 | if count > p.max { 35 | t.Errorf("count (%d) > max (%d)", count, p.max) 36 | } 37 | for { 38 | old := atomic.LoadUint32(&max) 39 | if count <= old { 40 | break 41 | } 42 | if atomic.CompareAndSwapUint32(&max, old, count) { 43 | break 44 | } 45 | } 46 | } 47 | for i := 0; i < workers; i++ { 48 | go func() { 49 | defer wg.Done() 50 | for atomic.AddInt32(&trials, -1) > 0 { 51 | updateMax() 52 | x := p.Get() 53 | updateMax() 54 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 55 | updateMax() 56 | p.Put(x) 57 | updateMax() 58 | } 59 | }() 60 | } 61 | wg.Wait() 62 | if max != p.max { 63 | t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) 64 | } 65 | } 66 | 67 | func BenchmarkWaitPool(b *testing.B) { 68 | var wg sync.WaitGroup 69 | trials := int32(b.N) 70 | workers := runtime.NumCPU() + 2 71 | if workers-4 <= 0 { 72 | b.Skip("Not enough cores") 73 | } 74 | p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) }) 75 | wg.Add(workers) 76 | b.ResetTimer() 77 | for i := 0; i < workers; i++ { 78 | go func() { 79 | defer wg.Done() 80 | for atomic.AddInt32(&trials, -1) > 0 { 81 | x := p.Get() 82 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 83 | p.Put(x) 84 | } 85 | }() 86 | } 87 | wg.Wait() 88 | } 89 | -------------------------------------------------------------------------------- /device/queueconstants_android.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | /* Reduce memory consumption for Android */ 9 | 10 | const ( 11 | QueueStagedSize = 128 12 | QueueOutboundSize = 1024 13 | QueueInboundSize = 1024 14 | QueueHandshakeSize = 1024 15 | MaxSegmentSize = 2200 16 | PreallocatedBuffersPerPool = 4096 17 | ) 18 | -------------------------------------------------------------------------------- /device/queueconstants_default.go: -------------------------------------------------------------------------------- 1 | // +build !android,!ios,!windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const ( 11 | QueueStagedSize = 128 12 | QueueOutboundSize = 1024 13 | QueueInboundSize = 1024 14 | QueueHandshakeSize = 1024 15 | MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram 16 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth 17 | ) 18 | -------------------------------------------------------------------------------- /device/queueconstants_ios.go: -------------------------------------------------------------------------------- 1 | // +build ios 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | // Fit within memory limits for iOS's Network Extension API, which has stricter requirements. 11 | // These are vars instead of consts, because heavier network extensions might want to reduce 12 | // them further. 13 | var ( 14 | QueueStagedSize = 128 15 | QueueOutboundSize = 1024 16 | QueueInboundSize = 1024 17 | QueueHandshakeSize = 1024 18 | PreallocatedBuffersPerPool uint32 = 1024 19 | ) 20 | 21 | const MaxSegmentSize = 1700 22 | -------------------------------------------------------------------------------- /device/queueconstants_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | const ( 9 | QueueStagedSize = 128 10 | QueueOutboundSize = 1024 11 | QueueInboundSize = 1024 12 | QueueHandshakeSize = 1024 13 | MaxSegmentSize = 2048 - 32 // largest possible UDP datagram 14 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth 15 | ) 16 | -------------------------------------------------------------------------------- /device/race_disabled_test.go: -------------------------------------------------------------------------------- 1 | //+build !race 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const raceEnabled = false 11 | -------------------------------------------------------------------------------- /device/race_enabled_test.go: -------------------------------------------------------------------------------- 1 | //+build race 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const raceEnabled = true 11 | -------------------------------------------------------------------------------- /device/sticky_default.go: -------------------------------------------------------------------------------- 1 | // +build !linux 2 | 3 | package device 4 | 5 | import ( 6 | "golang.zx2c4.com/wireguard/conn" 7 | "golang.zx2c4.com/wireguard/rwcancel" 8 | ) 9 | 10 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 11 | return nil, nil 12 | } 13 | -------------------------------------------------------------------------------- /device/sticky_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | * 5 | * This implements userspace semantics of "sticky sockets", modeled after 6 | * WireGuard's kernelspace implementation. This is more or less a straight port 7 | * of the sticky-sockets.c example code: 8 | * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c 9 | * 10 | * Currently there is no way to achieve this within the net package: 11 | * See e.g. https://github.com/golang/go/issues/17930 12 | * So this code is remains platform dependent. 13 | */ 14 | 15 | package device 16 | 17 | import ( 18 | "sync" 19 | "unsafe" 20 | 21 | "golang.org/x/sys/unix" 22 | 23 | "golang.zx2c4.com/wireguard/conn" 24 | "golang.zx2c4.com/wireguard/rwcancel" 25 | ) 26 | 27 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 28 | if _, ok := bind.(*conn.LinuxSocketBind); !ok { 29 | return nil, nil 30 | } 31 | 32 | netlinkSock, err := createNetlinkRouteSocket() 33 | if err != nil { 34 | return nil, err 35 | } 36 | netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) 37 | if err != nil { 38 | unix.Close(netlinkSock) 39 | return nil, err 40 | } 41 | 42 | go device.routineRouteListener(bind, netlinkSock, netlinkCancel) 43 | 44 | return netlinkCancel, nil 45 | } 46 | 47 | func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { 48 | type peerEndpointPtr struct { 49 | peer *Peer 50 | endpoint *conn.Endpoint 51 | } 52 | var reqPeer map[uint32]peerEndpointPtr 53 | var reqPeerLock sync.Mutex 54 | 55 | defer netlinkCancel.Close() 56 | defer unix.Close(netlinkSock) 57 | 58 | for msg := make([]byte, 1<<16); ; { 59 | var err error 60 | var msgn int 61 | for { 62 | msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) 63 | if err == nil || !rwcancel.RetryAfterError(err) { 64 | break 65 | } 66 | if !netlinkCancel.ReadyRead() { 67 | return 68 | } 69 | } 70 | if err != nil { 71 | return 72 | } 73 | 74 | for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { 75 | 76 | hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) 77 | 78 | if uint(hdr.Len) > uint(len(remain)) { 79 | break 80 | } 81 | 82 | switch hdr.Type { 83 | case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: 84 | if hdr.Seq <= MaxPeers && hdr.Seq > 0 { 85 | if uint(len(remain)) < uint(hdr.Len) { 86 | break 87 | } 88 | if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { 89 | attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] 90 | for { 91 | if uint(len(attr)) < uint(unix.SizeofRtAttr) { 92 | break 93 | } 94 | attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) 95 | if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { 96 | break 97 | } 98 | if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { 99 | ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) 100 | reqPeerLock.Lock() 101 | if reqPeer == nil { 102 | reqPeerLock.Unlock() 103 | break 104 | } 105 | pePtr, ok := reqPeer[hdr.Seq] 106 | reqPeerLock.Unlock() 107 | if !ok { 108 | break 109 | } 110 | pePtr.peer.Lock() 111 | if &pePtr.peer.endpoint != pePtr.endpoint { 112 | pePtr.peer.Unlock() 113 | break 114 | } 115 | if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx { 116 | pePtr.peer.Unlock() 117 | break 118 | } 119 | pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc() 120 | pePtr.peer.Unlock() 121 | } 122 | attr = attr[attrhdr.Len:] 123 | } 124 | } 125 | break 126 | } 127 | reqPeerLock.Lock() 128 | reqPeer = make(map[uint32]peerEndpointPtr) 129 | reqPeerLock.Unlock() 130 | go func() { 131 | device.peers.RLock() 132 | i := uint32(1) 133 | for _, peer := range device.peers.keyMap { 134 | peer.RLock() 135 | if peer.endpoint == nil { 136 | peer.RUnlock() 137 | continue 138 | } 139 | nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint) 140 | if nativeEP == nil { 141 | peer.RUnlock() 142 | continue 143 | } 144 | if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { 145 | peer.RUnlock() 146 | break 147 | } 148 | nlmsg := struct { 149 | hdr unix.NlMsghdr 150 | msg unix.RtMsg 151 | dsthdr unix.RtAttr 152 | dst [4]byte 153 | srchdr unix.RtAttr 154 | src [4]byte 155 | markhdr unix.RtAttr 156 | mark uint32 157 | }{ 158 | unix.NlMsghdr{ 159 | Type: uint16(unix.RTM_GETROUTE), 160 | Flags: unix.NLM_F_REQUEST, 161 | Seq: i, 162 | }, 163 | unix.RtMsg{ 164 | Family: unix.AF_INET, 165 | Dst_len: 32, 166 | Src_len: 32, 167 | }, 168 | unix.RtAttr{ 169 | Len: 8, 170 | Type: unix.RTA_DST, 171 | }, 172 | nativeEP.Dst4().Addr, 173 | unix.RtAttr{ 174 | Len: 8, 175 | Type: unix.RTA_SRC, 176 | }, 177 | nativeEP.Src4().Src, 178 | unix.RtAttr{ 179 | Len: 8, 180 | Type: unix.RTA_MARK, 181 | }, 182 | device.net.fwmark, 183 | } 184 | nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) 185 | reqPeerLock.Lock() 186 | reqPeer[i] = peerEndpointPtr{ 187 | peer: peer, 188 | endpoint: &peer.endpoint, 189 | } 190 | reqPeerLock.Unlock() 191 | peer.RUnlock() 192 | i++ 193 | _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) 194 | if err != nil { 195 | break 196 | } 197 | } 198 | device.peers.RUnlock() 199 | }() 200 | } 201 | remain = remain[hdr.Len:] 202 | } 203 | } 204 | } 205 | 206 | func createNetlinkRouteSocket() (int, error) { 207 | sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) 208 | if err != nil { 209 | return -1, err 210 | } 211 | saddr := &unix.SockaddrNetlink{ 212 | Family: unix.AF_NETLINK, 213 | Groups: unix.RTMGRP_IPV4_ROUTE, 214 | } 215 | err = unix.Bind(sock, saddr) 216 | if err != nil { 217 | unix.Close(sock) 218 | return -1, err 219 | } 220 | return sock, nil 221 | } 222 | -------------------------------------------------------------------------------- /device/timers.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | * 5 | * This is based heavily on timers.c from the kernel implementation. 6 | */ 7 | 8 | package device 9 | 10 | import ( 11 | "math/rand" 12 | "sync" 13 | "sync/atomic" 14 | "time" 15 | ) 16 | 17 | // A Timer manages time-based aspects of the WireGuard protocol. 18 | // Timer roughly copies the interface of the Linux kernel's struct timer_list. 19 | type Timer struct { 20 | *time.Timer 21 | modifyingLock sync.RWMutex 22 | runningLock sync.Mutex 23 | isPending bool 24 | } 25 | 26 | func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { 27 | timer := &Timer{} 28 | timer.Timer = time.AfterFunc(time.Hour, func() { 29 | timer.runningLock.Lock() 30 | defer timer.runningLock.Unlock() 31 | 32 | timer.modifyingLock.Lock() 33 | if !timer.isPending { 34 | timer.modifyingLock.Unlock() 35 | return 36 | } 37 | timer.isPending = false 38 | timer.modifyingLock.Unlock() 39 | 40 | expirationFunction(peer) 41 | }) 42 | timer.Stop() 43 | return timer 44 | } 45 | 46 | func (timer *Timer) Mod(d time.Duration) { 47 | timer.modifyingLock.Lock() 48 | timer.isPending = true 49 | timer.Reset(d) 50 | timer.modifyingLock.Unlock() 51 | } 52 | 53 | func (timer *Timer) Del() { 54 | timer.modifyingLock.Lock() 55 | timer.isPending = false 56 | timer.Stop() 57 | timer.modifyingLock.Unlock() 58 | } 59 | 60 | func (timer *Timer) DelSync() { 61 | timer.Del() 62 | timer.runningLock.Lock() 63 | timer.Del() 64 | timer.runningLock.Unlock() 65 | } 66 | 67 | func (timer *Timer) IsPending() bool { 68 | timer.modifyingLock.RLock() 69 | defer timer.modifyingLock.RUnlock() 70 | return timer.isPending 71 | } 72 | 73 | func (peer *Peer) timersActive() bool { 74 | return peer.isRunning.Get() && peer.device != nil && peer.device.isUp() 75 | } 76 | 77 | func expiredRetransmitHandshake(peer *Peer) { 78 | if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { 79 | peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) 80 | 81 | if peer.timersActive() { 82 | peer.timers.sendKeepalive.Del() 83 | } 84 | 85 | /* We drop all packets without a keypair and don't try again, 86 | * if we try unsuccessfully for too long to make a handshake. 87 | */ 88 | peer.FlushStagedPackets() 89 | 90 | /* We set a timer for destroying any residue that might be left 91 | * of a partial exchange. 92 | */ 93 | if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() { 94 | peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) 95 | } 96 | } else { 97 | atomic.AddUint32(&peer.timers.handshakeAttempts, 1) 98 | peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) 99 | 100 | /* We clear the endpoint address src address, in case this is the cause of trouble. */ 101 | peer.Lock() 102 | if peer.endpoint != nil { 103 | peer.endpoint.ClearSrc() 104 | } 105 | peer.Unlock() 106 | 107 | peer.SendHandshakeInitiation(true) 108 | } 109 | } 110 | 111 | func expiredSendKeepalive(peer *Peer) { 112 | peer.SendKeepalive() 113 | if peer.timers.needAnotherKeepalive.Get() { 114 | peer.timers.needAnotherKeepalive.Set(false) 115 | if peer.timersActive() { 116 | peer.timers.sendKeepalive.Mod(KeepaliveTimeout) 117 | } 118 | } 119 | } 120 | 121 | func expiredNewHandshake(peer *Peer) { 122 | peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) 123 | /* We clear the endpoint address src address, in case this is the cause of trouble. */ 124 | peer.Lock() 125 | if peer.endpoint != nil { 126 | peer.endpoint.ClearSrc() 127 | } 128 | peer.Unlock() 129 | peer.SendHandshakeInitiation(false) 130 | 131 | } 132 | 133 | func expiredZeroKeyMaterial(peer *Peer) { 134 | peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds())) 135 | peer.ZeroAndFlushAll() 136 | } 137 | 138 | func expiredPersistentKeepalive(peer *Peer) { 139 | if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { 140 | peer.SendKeepalive() 141 | } 142 | } 143 | 144 | /* Should be called after an authenticated data packet is sent. */ 145 | func (peer *Peer) timersDataSent() { 146 | if peer.timersActive() && !peer.timers.newHandshake.IsPending() { 147 | peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) 148 | } 149 | } 150 | 151 | /* Should be called after an authenticated data packet is received. */ 152 | func (peer *Peer) timersDataReceived() { 153 | if peer.timersActive() { 154 | if !peer.timers.sendKeepalive.IsPending() { 155 | peer.timers.sendKeepalive.Mod(KeepaliveTimeout) 156 | } else { 157 | peer.timers.needAnotherKeepalive.Set(true) 158 | } 159 | } 160 | } 161 | 162 | /* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ 163 | func (peer *Peer) timersAnyAuthenticatedPacketSent() { 164 | if peer.timersActive() { 165 | peer.timers.sendKeepalive.Del() 166 | } 167 | } 168 | 169 | /* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ 170 | func (peer *Peer) timersAnyAuthenticatedPacketReceived() { 171 | if peer.timersActive() { 172 | peer.timers.newHandshake.Del() 173 | } 174 | } 175 | 176 | /* Should be called after a handshake initiation message is sent. */ 177 | func (peer *Peer) timersHandshakeInitiated() { 178 | if peer.timersActive() { 179 | peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) 180 | } 181 | } 182 | 183 | /* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ 184 | func (peer *Peer) timersHandshakeComplete() { 185 | if peer.timersActive() { 186 | peer.timers.retransmitHandshake.Del() 187 | } 188 | atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) 189 | peer.timers.sentLastMinuteHandshake.Set(false) 190 | atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) 191 | } 192 | 193 | /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ 194 | func (peer *Peer) timersSessionDerived() { 195 | if peer.timersActive() { 196 | peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) 197 | } 198 | } 199 | 200 | /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ 201 | func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { 202 | keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval) 203 | if keepalive > 0 && peer.timersActive() { 204 | peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) 205 | } 206 | } 207 | 208 | func (peer *Peer) timersInit() { 209 | peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake) 210 | peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive) 211 | peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) 212 | peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) 213 | peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) 214 | } 215 | 216 | func (peer *Peer) timersStart() { 217 | atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) 218 | peer.timers.sentLastMinuteHandshake.Set(false) 219 | peer.timers.needAnotherKeepalive.Set(false) 220 | } 221 | 222 | func (peer *Peer) timersStop() { 223 | peer.timers.retransmitHandshake.DelSync() 224 | peer.timers.sendKeepalive.DelSync() 225 | peer.timers.newHandshake.DelSync() 226 | peer.timers.zeroKeyMaterial.DelSync() 227 | peer.timers.persistentKeepalive.DelSync() 228 | } 229 | -------------------------------------------------------------------------------- /device/tun.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "fmt" 10 | "sync/atomic" 11 | 12 | "golang.zx2c4.com/wireguard/tun" 13 | ) 14 | 15 | const DefaultMTU = 1420 16 | 17 | func (device *Device) RoutineTUNEventReader() { 18 | device.log.Verbosef("Routine: event worker - started") 19 | 20 | for event := range device.tun.device.Events() { 21 | if event&tun.EventMTUUpdate != 0 { 22 | mtu, err := device.tun.device.MTU() 23 | if err != nil { 24 | device.log.Errorf("Failed to load updated MTU of device: %v", err) 25 | continue 26 | } 27 | if mtu < 0 { 28 | device.log.Errorf("MTU not updated to negative value: %v", mtu) 29 | continue 30 | } 31 | var tooLarge string 32 | if mtu > MaxContentSize { 33 | tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) 34 | mtu = MaxContentSize 35 | } 36 | old := atomic.SwapInt32(&device.tun.mtu, int32(mtu)) 37 | if int(old) != mtu { 38 | device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) 39 | } 40 | } 41 | 42 | if event&tun.EventUp != 0 { 43 | device.log.Verbosef("Interface up requested") 44 | device.Up() 45 | } 46 | 47 | if event&tun.EventDown != 0 { 48 | device.log.Verbosef("Interface down requested") 49 | device.Down() 50 | } 51 | } 52 | 53 | device.log.Verbosef("Routine: event worker - stopped") 54 | } 55 | -------------------------------------------------------------------------------- /format_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | package main 6 | 7 | import ( 8 | "bytes" 9 | "go/format" 10 | "io/fs" 11 | "os" 12 | "path/filepath" 13 | "runtime" 14 | "sync" 15 | "testing" 16 | ) 17 | 18 | func TestFormatting(t *testing.T) { 19 | var wg sync.WaitGroup 20 | filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { 21 | if err != nil { 22 | t.Errorf("unable to walk %s: %v", path, err) 23 | return nil 24 | } 25 | if d.IsDir() || filepath.Ext(path) != ".go" { 26 | return nil 27 | } 28 | wg.Add(1) 29 | go func(path string) { 30 | defer wg.Done() 31 | src, err := os.ReadFile(path) 32 | if err != nil { 33 | t.Errorf("unable to read %s: %v", path, err) 34 | return 35 | } 36 | if runtime.GOOS == "windows" { 37 | src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'}) 38 | } 39 | formatted, err := format.Source(src) 40 | if err != nil { 41 | t.Errorf("unable to format %s: %v", path, err) 42 | return 43 | } 44 | if !bytes.Equal(src, formatted) { 45 | t.Errorf("unformatted code: %s", path) 46 | } 47 | }(path) 48 | return nil 49 | }) 50 | wg.Wait() 51 | } 52 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module golang.zx2c4.com/wireguard 2 | 3 | go 1.16 4 | 5 | require ( 6 | git.fd.io/govpp.git v0.3.6-0.20210727130229-24f179dbb953 7 | git.fd.io/govpp.git/extras v0.0.0-20210727130229-24f179dbb953 8 | github.com/fsnotify/fsnotify v1.4.9 // indirect 9 | github.com/google/gopacket v1.1.19 10 | github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 // indirect 11 | github.com/sirupsen/logrus v1.8.1 12 | golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 13 | golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 14 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c 15 | gopkg.in/yaml.v2 v2.4.0 // indirect 16 | ) 17 | -------------------------------------------------------------------------------- /govpp_remove_crcstring_check.patch: -------------------------------------------------------------------------------- 1 | --- vendor/git.fd.io/govpp.git/adapter/socketclient/socketclient.go 2021-07-29 22:21:35.560204223 +0800 2 | +++ vendor/git.fd.io/govpp.git/adapter/socketclient/socketclient.go 2021-07-30 04:53:13.528822544 +0800 3 | @@ -314,6 +314,8 @@ 4 | for _, x := range reply.MessageTable { 5 | msgName := strings.Split(x.Name, "\x00")[0] 6 | name := strings.TrimSuffix(msgName, "\x13") 7 | + nameslice := strings.Split(name, "_") 8 | + name = strings.Join(nameslice[:len(nameslice)-1], "_") 9 | c.msgTable[name] = x.Index 10 | if strings.HasPrefix(name, "sockclnt_delete_") { 11 | c.sockDelMsgId = x.Index 12 | @@ -327,7 +329,7 @@ 13 | } 14 | 15 | func (c *Client) GetMsgID(msgName string, msgCrc string) (uint16, error) { 16 | - if msgID, ok := c.msgTable[msgName+"_"+msgCrc]; ok { 17 | + if msgID, ok := c.msgTable[msgName]; ok { 18 | return msgID, nil 19 | } 20 | return 0, &adapter.UnknownMsgError{ 21 | -------------------------------------------------------------------------------- /ipc/uapi_bsd.go: -------------------------------------------------------------------------------- 1 | // +build darwin freebsd openbsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package ipc 9 | 10 | import ( 11 | "errors" 12 | "net" 13 | "os" 14 | "unsafe" 15 | 16 | "golang.org/x/sys/unix" 17 | ) 18 | 19 | type UAPIListener struct { 20 | listener net.Listener // unix socket listener 21 | connNew chan net.Conn 22 | connErr chan error 23 | kqueueFd int 24 | keventFd int 25 | } 26 | 27 | func (l *UAPIListener) Accept() (net.Conn, error) { 28 | for { 29 | select { 30 | case conn := <-l.connNew: 31 | return conn, nil 32 | 33 | case err := <-l.connErr: 34 | return nil, err 35 | } 36 | } 37 | } 38 | 39 | func (l *UAPIListener) Close() error { 40 | err1 := unix.Close(l.kqueueFd) 41 | err2 := unix.Close(l.keventFd) 42 | err3 := l.listener.Close() 43 | if err1 != nil { 44 | return err1 45 | } 46 | if err2 != nil { 47 | return err2 48 | } 49 | return err3 50 | } 51 | 52 | func (l *UAPIListener) Addr() net.Addr { 53 | return l.listener.Addr() 54 | } 55 | 56 | func UAPIListen(name string, file *os.File) (net.Listener, error) { 57 | 58 | // wrap file in listener 59 | 60 | listener, err := net.FileListener(file) 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | uapi := &UAPIListener{ 66 | listener: listener, 67 | connNew: make(chan net.Conn, 1), 68 | connErr: make(chan error, 1), 69 | } 70 | 71 | if unixListener, ok := listener.(*net.UnixListener); ok { 72 | unixListener.SetUnlinkOnClose(true) 73 | } 74 | 75 | socketPath := sockPath(name) 76 | 77 | // watch for deletion of socket 78 | 79 | uapi.kqueueFd, err = unix.Kqueue() 80 | if err != nil { 81 | return nil, err 82 | } 83 | uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0) 84 | if err != nil { 85 | unix.Close(uapi.kqueueFd) 86 | return nil, err 87 | } 88 | 89 | go func(l *UAPIListener) { 90 | event := unix.Kevent_t{ 91 | Filter: unix.EVFILT_VNODE, 92 | Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT, 93 | Fflags: unix.NOTE_WRITE, 94 | } 95 | // Allow this assignment to work with both the 32-bit and 64-bit version 96 | // of the above struct. If you know another way, please submit a patch. 97 | *(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd) 98 | events := make([]unix.Kevent_t, 1) 99 | n := 1 100 | var kerr error 101 | for { 102 | // start with lstat to avoid race condition 103 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) { 104 | l.connErr <- err 105 | return 106 | } 107 | if kerr != nil || n != 1 { 108 | if kerr != nil { 109 | l.connErr <- kerr 110 | } else { 111 | l.connErr <- errors.New("kqueue returned empty") 112 | } 113 | return 114 | } 115 | n, kerr = unix.Kevent(uapi.kqueueFd, []unix.Kevent_t{event}, events, nil) 116 | } 117 | }(uapi) 118 | 119 | // watch for new connections 120 | 121 | go func(l *UAPIListener) { 122 | for { 123 | conn, err := l.listener.Accept() 124 | if err != nil { 125 | l.connErr <- err 126 | break 127 | } 128 | l.connNew <- conn 129 | } 130 | }(uapi) 131 | 132 | return uapi, nil 133 | } 134 | -------------------------------------------------------------------------------- /ipc/uapi_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | import ( 9 | "net" 10 | "os" 11 | 12 | "golang.org/x/sys/unix" 13 | "golang.zx2c4.com/wireguard/rwcancel" 14 | ) 15 | 16 | type UAPIListener struct { 17 | listener net.Listener // unix socket listener 18 | connNew chan net.Conn 19 | connErr chan error 20 | inotifyFd int 21 | inotifyRWCancel *rwcancel.RWCancel 22 | } 23 | 24 | func (l *UAPIListener) Accept() (net.Conn, error) { 25 | for { 26 | select { 27 | case conn := <-l.connNew: 28 | return conn, nil 29 | 30 | case err := <-l.connErr: 31 | return nil, err 32 | } 33 | } 34 | } 35 | 36 | func (l *UAPIListener) Close() error { 37 | err1 := unix.Close(l.inotifyFd) 38 | err2 := l.inotifyRWCancel.Cancel() 39 | err3 := l.listener.Close() 40 | if err1 != nil { 41 | return err1 42 | } 43 | if err2 != nil { 44 | return err2 45 | } 46 | return err3 47 | } 48 | 49 | func (l *UAPIListener) Addr() net.Addr { 50 | return l.listener.Addr() 51 | } 52 | 53 | func UAPIListen(name string, file *os.File) (net.Listener, error) { 54 | 55 | // wrap file in listener 56 | 57 | listener, err := net.FileListener(file) 58 | if err != nil { 59 | return nil, err 60 | } 61 | 62 | if unixListener, ok := listener.(*net.UnixListener); ok { 63 | unixListener.SetUnlinkOnClose(true) 64 | } 65 | 66 | uapi := &UAPIListener{ 67 | listener: listener, 68 | connNew: make(chan net.Conn, 1), 69 | connErr: make(chan error, 1), 70 | } 71 | 72 | // watch for deletion of socket 73 | 74 | socketPath := sockPath(name) 75 | 76 | uapi.inotifyFd, err = unix.InotifyInit() 77 | if err != nil { 78 | return nil, err 79 | } 80 | 81 | _, err = unix.InotifyAddWatch( 82 | uapi.inotifyFd, 83 | socketPath, 84 | unix.IN_ATTRIB| 85 | unix.IN_DELETE| 86 | unix.IN_DELETE_SELF, 87 | ) 88 | 89 | if err != nil { 90 | return nil, err 91 | } 92 | 93 | uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd) 94 | if err != nil { 95 | unix.Close(uapi.inotifyFd) 96 | return nil, err 97 | } 98 | 99 | go func(l *UAPIListener) { 100 | var buff [0]byte 101 | for { 102 | defer uapi.inotifyRWCancel.Close() 103 | // start with lstat to avoid race condition 104 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) { 105 | l.connErr <- err 106 | return 107 | } 108 | _, err := uapi.inotifyRWCancel.Read(buff[:]) 109 | if err != nil { 110 | l.connErr <- err 111 | return 112 | } 113 | } 114 | }(uapi) 115 | 116 | // watch for new connections 117 | 118 | go func(l *UAPIListener) { 119 | for { 120 | conn, err := l.listener.Accept() 121 | if err != nil { 122 | l.connErr <- err 123 | break 124 | } 125 | l.connNew <- conn 126 | } 127 | }(uapi) 128 | 129 | return uapi, nil 130 | } 131 | -------------------------------------------------------------------------------- /ipc/uapi_unix.go: -------------------------------------------------------------------------------- 1 | // +build linux darwin freebsd openbsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package ipc 9 | 10 | import ( 11 | "errors" 12 | "fmt" 13 | "net" 14 | "os" 15 | 16 | "golang.org/x/sys/unix" 17 | ) 18 | 19 | const ( 20 | IpcErrorIO = -int64(unix.EIO) 21 | IpcErrorProtocol = -int64(unix.EPROTO) 22 | IpcErrorInvalid = -int64(unix.EINVAL) 23 | IpcErrorPortInUse = -int64(unix.EADDRINUSE) 24 | IpcErrorUnknown = -55 // ENOANO 25 | ) 26 | 27 | // socketDirectory is variable because it is modified by a linker 28 | // flag in wireguard-android. 29 | var socketDirectory = "/var/run/wireguard" 30 | 31 | func sockPath(iface string) string { 32 | return fmt.Sprintf("%s/%s.sock", socketDirectory, iface) 33 | } 34 | 35 | func UAPIOpen(name string) (*os.File, error) { 36 | if err := os.MkdirAll(socketDirectory, 0755); err != nil { 37 | return nil, err 38 | } 39 | 40 | socketPath := sockPath(name) 41 | addr, err := net.ResolveUnixAddr("unix", socketPath) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | oldUmask := unix.Umask(0077) 47 | defer unix.Umask(oldUmask) 48 | 49 | listener, err := net.ListenUnix("unix", addr) 50 | if err == nil { 51 | return listener.File() 52 | } 53 | 54 | // Test socket, if not in use cleanup and try again. 55 | if _, err := net.Dial("unix", socketPath); err == nil { 56 | return nil, errors.New("unix socket in use") 57 | } 58 | if err := os.Remove(socketPath); err != nil { 59 | return nil, err 60 | } 61 | listener, err = net.ListenUnix("unix", addr) 62 | if err != nil { 63 | return nil, err 64 | } 65 | return listener.File() 66 | } 67 | -------------------------------------------------------------------------------- /ipc/uapi_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | import ( 9 | "net" 10 | 11 | "golang.org/x/sys/windows" 12 | 13 | "golang.zx2c4.com/wireguard/ipc/winpipe" 14 | ) 15 | 16 | // TODO: replace these with actual standard windows error numbers from the win package 17 | const ( 18 | IpcErrorIO = -int64(5) 19 | IpcErrorProtocol = -int64(71) 20 | IpcErrorInvalid = -int64(22) 21 | IpcErrorPortInUse = -int64(98) 22 | IpcErrorUnknown = -int64(55) 23 | ) 24 | 25 | type UAPIListener struct { 26 | listener net.Listener // unix socket listener 27 | connNew chan net.Conn 28 | connErr chan error 29 | kqueueFd int 30 | keventFd int 31 | } 32 | 33 | func (l *UAPIListener) Accept() (net.Conn, error) { 34 | for { 35 | select { 36 | case conn := <-l.connNew: 37 | return conn, nil 38 | 39 | case err := <-l.connErr: 40 | return nil, err 41 | } 42 | } 43 | } 44 | 45 | func (l *UAPIListener) Close() error { 46 | return l.listener.Close() 47 | } 48 | 49 | func (l *UAPIListener) Addr() net.Addr { 50 | return l.listener.Addr() 51 | } 52 | 53 | var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR 54 | 55 | func init() { 56 | var err error 57 | /* SDDL_DEVOBJ_SYS_ALL from the WDK */ 58 | UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)") 59 | if err != nil { 60 | panic(err) 61 | } 62 | } 63 | 64 | func UAPIListen(name string) (net.Listener, error) { 65 | config := winpipe.ListenConfig{ 66 | SecurityDescriptor: UAPISecurityDescriptor, 67 | } 68 | listener, err := winpipe.Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) 69 | if err != nil { 70 | return nil, err 71 | } 72 | 73 | uapi := &UAPIListener{ 74 | listener: listener, 75 | connNew: make(chan net.Conn, 1), 76 | connErr: make(chan error, 1), 77 | } 78 | 79 | go func(l *UAPIListener) { 80 | for { 81 | conn, err := l.listener.Accept() 82 | if err != nil { 83 | l.connErr <- err 84 | break 85 | } 86 | l.connNew <- conn 87 | } 88 | }(uapi) 89 | 90 | return uapi, nil 91 | } 92 | -------------------------------------------------------------------------------- /ipc/winpipe/file.go: -------------------------------------------------------------------------------- 1 | // +build windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2005 Microsoft 6 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 7 | */ 8 | 9 | package winpipe 10 | 11 | import ( 12 | "io" 13 | "os" 14 | "runtime" 15 | "sync" 16 | "sync/atomic" 17 | "time" 18 | "unsafe" 19 | 20 | "golang.org/x/sys/windows" 21 | ) 22 | 23 | type timeoutChan chan struct{} 24 | 25 | var ioInitOnce sync.Once 26 | var ioCompletionPort windows.Handle 27 | 28 | // ioResult contains the result of an asynchronous IO operation 29 | type ioResult struct { 30 | bytes uint32 31 | err error 32 | } 33 | 34 | // ioOperation represents an outstanding asynchronous Win32 IO 35 | type ioOperation struct { 36 | o windows.Overlapped 37 | ch chan ioResult 38 | } 39 | 40 | func initIo() { 41 | h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 42 | if err != nil { 43 | panic(err) 44 | } 45 | ioCompletionPort = h 46 | go ioCompletionProcessor(h) 47 | } 48 | 49 | // file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. 50 | // It takes ownership of this handle and will close it if it is garbage collected. 51 | type file struct { 52 | handle windows.Handle 53 | wg sync.WaitGroup 54 | wgLock sync.RWMutex 55 | closing uint32 // used as atomic boolean 56 | socket bool 57 | readDeadline deadlineHandler 58 | writeDeadline deadlineHandler 59 | } 60 | 61 | type deadlineHandler struct { 62 | setLock sync.Mutex 63 | channel timeoutChan 64 | channelLock sync.RWMutex 65 | timer *time.Timer 66 | timedout uint32 // used as atomic boolean 67 | } 68 | 69 | // makeFile makes a new file from an existing file handle 70 | func makeFile(h windows.Handle) (*file, error) { 71 | f := &file{handle: h} 72 | ioInitOnce.Do(initIo) 73 | _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) 74 | if err != nil { 75 | return nil, err 76 | } 77 | err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) 78 | if err != nil { 79 | return nil, err 80 | } 81 | f.readDeadline.channel = make(timeoutChan) 82 | f.writeDeadline.channel = make(timeoutChan) 83 | return f, nil 84 | } 85 | 86 | // closeHandle closes the resources associated with a Win32 handle 87 | func (f *file) closeHandle() { 88 | f.wgLock.Lock() 89 | // Atomically set that we are closing, releasing the resources only once. 90 | if atomic.SwapUint32(&f.closing, 1) == 0 { 91 | f.wgLock.Unlock() 92 | // cancel all IO and wait for it to complete 93 | windows.CancelIoEx(f.handle, nil) 94 | f.wg.Wait() 95 | // at this point, no new IO can start 96 | windows.Close(f.handle) 97 | f.handle = 0 98 | } else { 99 | f.wgLock.Unlock() 100 | } 101 | } 102 | 103 | // Close closes a file. 104 | func (f *file) Close() error { 105 | f.closeHandle() 106 | return nil 107 | } 108 | 109 | // prepareIo prepares for a new IO operation. 110 | // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. 111 | func (f *file) prepareIo() (*ioOperation, error) { 112 | f.wgLock.RLock() 113 | if atomic.LoadUint32(&f.closing) == 1 { 114 | f.wgLock.RUnlock() 115 | return nil, os.ErrClosed 116 | } 117 | f.wg.Add(1) 118 | f.wgLock.RUnlock() 119 | c := &ioOperation{} 120 | c.ch = make(chan ioResult) 121 | return c, nil 122 | } 123 | 124 | // ioCompletionProcessor processes completed async IOs forever 125 | func ioCompletionProcessor(h windows.Handle) { 126 | for { 127 | var bytes uint32 128 | var key uintptr 129 | var op *ioOperation 130 | err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) 131 | if op == nil { 132 | panic(err) 133 | } 134 | op.ch <- ioResult{bytes, err} 135 | } 136 | } 137 | 138 | // asyncIo processes the return value from ReadFile or WriteFile, blocking until 139 | // the operation has actually completed. 140 | func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { 141 | if err != windows.ERROR_IO_PENDING { 142 | return int(bytes), err 143 | } 144 | 145 | if atomic.LoadUint32(&f.closing) == 1 { 146 | windows.CancelIoEx(f.handle, &c.o) 147 | } 148 | 149 | var timeout timeoutChan 150 | if d != nil { 151 | d.channelLock.Lock() 152 | timeout = d.channel 153 | d.channelLock.Unlock() 154 | } 155 | 156 | var r ioResult 157 | select { 158 | case r = <-c.ch: 159 | err = r.err 160 | if err == windows.ERROR_OPERATION_ABORTED { 161 | if atomic.LoadUint32(&f.closing) == 1 { 162 | err = os.ErrClosed 163 | } 164 | } else if err != nil && f.socket { 165 | // err is from Win32. Query the overlapped structure to get the winsock error. 166 | var bytes, flags uint32 167 | err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) 168 | } 169 | case <-timeout: 170 | windows.CancelIoEx(f.handle, &c.o) 171 | r = <-c.ch 172 | err = r.err 173 | if err == windows.ERROR_OPERATION_ABORTED { 174 | err = os.ErrDeadlineExceeded 175 | } 176 | } 177 | 178 | // runtime.KeepAlive is needed, as c is passed via native 179 | // code to ioCompletionProcessor, c must remain alive 180 | // until the channel read is complete. 181 | runtime.KeepAlive(c) 182 | return int(r.bytes), err 183 | } 184 | 185 | // Read reads from a file handle. 186 | func (f *file) Read(b []byte) (int, error) { 187 | c, err := f.prepareIo() 188 | if err != nil { 189 | return 0, err 190 | } 191 | defer f.wg.Done() 192 | 193 | if atomic.LoadUint32(&f.readDeadline.timedout) == 1 { 194 | return 0, os.ErrDeadlineExceeded 195 | } 196 | 197 | var bytes uint32 198 | err = windows.ReadFile(f.handle, b, &bytes, &c.o) 199 | n, err := f.asyncIo(c, &f.readDeadline, bytes, err) 200 | runtime.KeepAlive(b) 201 | 202 | // Handle EOF conditions. 203 | if err == nil && n == 0 && len(b) != 0 { 204 | return 0, io.EOF 205 | } else if err == windows.ERROR_BROKEN_PIPE { 206 | return 0, io.EOF 207 | } else { 208 | return n, err 209 | } 210 | } 211 | 212 | // Write writes to a file handle. 213 | func (f *file) Write(b []byte) (int, error) { 214 | c, err := f.prepareIo() 215 | if err != nil { 216 | return 0, err 217 | } 218 | defer f.wg.Done() 219 | 220 | if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 { 221 | return 0, os.ErrDeadlineExceeded 222 | } 223 | 224 | var bytes uint32 225 | err = windows.WriteFile(f.handle, b, &bytes, &c.o) 226 | n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) 227 | runtime.KeepAlive(b) 228 | return n, err 229 | } 230 | 231 | func (f *file) SetReadDeadline(deadline time.Time) error { 232 | return f.readDeadline.set(deadline) 233 | } 234 | 235 | func (f *file) SetWriteDeadline(deadline time.Time) error { 236 | return f.writeDeadline.set(deadline) 237 | } 238 | 239 | func (f *file) Flush() error { 240 | return windows.FlushFileBuffers(f.handle) 241 | } 242 | 243 | func (f *file) Fd() uintptr { 244 | return uintptr(f.handle) 245 | } 246 | 247 | func (d *deadlineHandler) set(deadline time.Time) error { 248 | d.setLock.Lock() 249 | defer d.setLock.Unlock() 250 | 251 | if d.timer != nil { 252 | if !d.timer.Stop() { 253 | <-d.channel 254 | } 255 | d.timer = nil 256 | } 257 | atomic.StoreUint32(&d.timedout, 0) 258 | 259 | select { 260 | case <-d.channel: 261 | d.channelLock.Lock() 262 | d.channel = make(chan struct{}) 263 | d.channelLock.Unlock() 264 | default: 265 | } 266 | 267 | if deadline.IsZero() { 268 | return nil 269 | } 270 | 271 | timeoutIO := func() { 272 | atomic.StoreUint32(&d.timedout, 1) 273 | close(d.channel) 274 | } 275 | 276 | now := time.Now() 277 | duration := deadline.Sub(now) 278 | if deadline.After(now) { 279 | // Deadline is in the future, set a timer to wait 280 | d.timer = time.AfterFunc(duration, timeoutIO) 281 | } else { 282 | // Deadline is in the past. Cancel all pending IO now. 283 | timeoutIO() 284 | } 285 | return nil 286 | } 287 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // +build !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "fmt" 12 | "os" 13 | "os/signal" 14 | "runtime" 15 | "strconv" 16 | "syscall" 17 | 18 | "golang.zx2c4.com/wireguard/conn" 19 | "golang.zx2c4.com/wireguard/device" 20 | "golang.zx2c4.com/wireguard/ipc" 21 | "golang.zx2c4.com/wireguard/tun" 22 | ) 23 | 24 | const ( 25 | ExitSetupSuccess = 0 26 | ExitSetupFailed = 1 27 | ) 28 | 29 | const ( 30 | ENV_WG_TUN_FD = "WG_TUN_FD" 31 | ENV_WG_UAPI_FD = "WG_UAPI_FD" 32 | ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" 33 | ) 34 | 35 | func printUsage() { 36 | fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) 37 | } 38 | 39 | func warning() { 40 | switch runtime.GOOS { 41 | case "linux", "freebsd", "openbsd": 42 | if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { 43 | return 44 | } 45 | default: 46 | return 47 | } 48 | 49 | fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐") 50 | fmt.Fprintln(os.Stderr, "│ │") 51 | fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │") 52 | fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │") 53 | fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │") 54 | fmt.Fprintln(os.Stderr, "│ please visit: │") 55 | fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │") 56 | fmt.Fprintln(os.Stderr, "│ │") 57 | fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘") 58 | } 59 | 60 | func main() { 61 | if len(os.Args) == 2 && os.Args[1] == "--version" { 62 | fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld .\n", Version, runtime.GOOS, runtime.GOARCH) 63 | return 64 | } 65 | 66 | warning() 67 | 68 | var foreground bool 69 | var interfaceName string 70 | if len(os.Args) < 2 || len(os.Args) > 3 { 71 | printUsage() 72 | return 73 | } 74 | 75 | switch os.Args[1] { 76 | 77 | case "-f", "--foreground": 78 | foreground = true 79 | if len(os.Args) != 3 { 80 | printUsage() 81 | return 82 | } 83 | interfaceName = os.Args[2] 84 | 85 | default: 86 | foreground = false 87 | if len(os.Args) != 2 { 88 | printUsage() 89 | return 90 | } 91 | interfaceName = os.Args[1] 92 | } 93 | 94 | if !foreground { 95 | foreground = os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" 96 | } 97 | 98 | // get log level (default: info) 99 | 100 | logLevel := func() int { 101 | switch os.Getenv("LOG_LEVEL") { 102 | case "verbose", "debug": 103 | return device.LogLevelVerbose 104 | case "error": 105 | return device.LogLevelError 106 | case "silent": 107 | return device.LogLevelSilent 108 | } 109 | return device.LogLevelError 110 | }() 111 | 112 | // open TUN device (or use supplied fd) 113 | 114 | tun, err := func() (tun.Device, error) { 115 | tunFdStr := os.Getenv(ENV_WG_TUN_FD) 116 | if tunFdStr == "" { 117 | return tun.CreateTUN(interfaceName, device.DefaultMTU) 118 | } 119 | 120 | // construct tun device from supplied fd 121 | 122 | fd, err := strconv.ParseUint(tunFdStr, 10, 32) 123 | if err != nil { 124 | return nil, err 125 | } 126 | 127 | err = syscall.SetNonblock(int(fd), true) 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | file := os.NewFile(uintptr(fd), "") 133 | return tun.CreateTUNFromFile(file, device.DefaultMTU) 134 | }() 135 | 136 | if err == nil { 137 | realInterfaceName, err2 := tun.Name() 138 | if err2 == nil { 139 | interfaceName = realInterfaceName 140 | } 141 | } 142 | 143 | logger := device.NewLogger( 144 | logLevel, 145 | fmt.Sprintf("(%s) ", interfaceName), 146 | ) 147 | 148 | logger.Verbosef("Starting wireguard-go version %s", Version) 149 | 150 | if err != nil { 151 | logger.Errorf("Failed to create TUN device: %v", err) 152 | os.Exit(ExitSetupFailed) 153 | } 154 | 155 | // open UAPI file (or use supplied fd) 156 | 157 | fileUAPI, err := func() (*os.File, error) { 158 | uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) 159 | if uapiFdStr == "" { 160 | return ipc.UAPIOpen(interfaceName) 161 | } 162 | 163 | // use supplied fd 164 | 165 | fd, err := strconv.ParseUint(uapiFdStr, 10, 32) 166 | if err != nil { 167 | return nil, err 168 | } 169 | 170 | return os.NewFile(uintptr(fd), ""), nil 171 | }() 172 | 173 | if err != nil { 174 | logger.Errorf("UAPI listen error: %v", err) 175 | os.Exit(ExitSetupFailed) 176 | return 177 | } 178 | // daemonize the process 179 | 180 | if !foreground { 181 | env := os.Environ() 182 | env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) 183 | env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) 184 | env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND)) 185 | files := [3]*os.File{} 186 | if os.Getenv("LOG_LEVEL") != "" && logLevel != device.LogLevelSilent { 187 | files[0], _ = os.Open(os.DevNull) 188 | files[1] = os.Stdout 189 | files[2] = os.Stderr 190 | } else { 191 | files[0], _ = os.Open(os.DevNull) 192 | files[1], _ = os.Open(os.DevNull) 193 | files[2], _ = os.Open(os.DevNull) 194 | } 195 | attr := &os.ProcAttr{ 196 | Files: []*os.File{ 197 | files[0], // stdin 198 | files[1], // stdout 199 | files[2], // stderr 200 | tun.File(), 201 | fileUAPI, 202 | }, 203 | Dir: ".", 204 | Env: env, 205 | } 206 | 207 | path, err := os.Executable() 208 | if err != nil { 209 | logger.Errorf("Failed to determine executable: %v", err) 210 | os.Exit(ExitSetupFailed) 211 | } 212 | 213 | process, err := os.StartProcess( 214 | path, 215 | os.Args, 216 | attr, 217 | ) 218 | if err != nil { 219 | logger.Errorf("Failed to daemonize: %v", err) 220 | os.Exit(ExitSetupFailed) 221 | } 222 | process.Release() 223 | return 224 | } 225 | 226 | device := device.NewDevice(tun, conn.NewDefaultBind(), logger) 227 | 228 | logger.Verbosef("Device started") 229 | 230 | errs := make(chan error) 231 | term := make(chan os.Signal, 1) 232 | 233 | uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) 234 | if err != nil { 235 | logger.Errorf("Failed to listen on uapi socket: %v", err) 236 | os.Exit(ExitSetupFailed) 237 | } 238 | 239 | go func() { 240 | for { 241 | conn, err := uapi.Accept() 242 | if err != nil { 243 | errs <- err 244 | return 245 | } 246 | go device.IpcHandle(conn) 247 | } 248 | }() 249 | 250 | logger.Verbosef("UAPI listener started") 251 | 252 | // wait for program to terminate 253 | 254 | signal.Notify(term, syscall.SIGTERM) 255 | signal.Notify(term, os.Interrupt) 256 | 257 | select { 258 | case <-term: 259 | case <-errs: 260 | case <-device.Wait(): 261 | } 262 | 263 | // clean up 264 | 265 | uapi.Close() 266 | device.Close() 267 | 268 | logger.Verbosef("Shutting down") 269 | } 270 | -------------------------------------------------------------------------------- /main_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package main 7 | 8 | import ( 9 | "fmt" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | 14 | "golang.zx2c4.com/wireguard/conn" 15 | "golang.zx2c4.com/wireguard/device" 16 | "golang.zx2c4.com/wireguard/ipc" 17 | 18 | "golang.zx2c4.com/wireguard/tun" 19 | ) 20 | 21 | const ( 22 | ExitSetupSuccess = 0 23 | ExitSetupFailed = 1 24 | ) 25 | 26 | func main() { 27 | if len(os.Args) != 2 { 28 | os.Exit(ExitSetupFailed) 29 | } 30 | interfaceName := os.Args[1] 31 | 32 | fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is , which includes this code as a module.") 33 | 34 | logger := device.NewLogger( 35 | device.LogLevelVerbose, 36 | fmt.Sprintf("(%s) ", interfaceName), 37 | ) 38 | logger.Verbosef("Starting wireguard-go version %s", Version) 39 | 40 | tun, err := tun.CreateTUN(interfaceName, 0) 41 | if err == nil { 42 | realInterfaceName, err2 := tun.Name() 43 | if err2 == nil { 44 | interfaceName = realInterfaceName 45 | } 46 | } else { 47 | logger.Errorf("Failed to create TUN device: %v", err) 48 | os.Exit(ExitSetupFailed) 49 | } 50 | 51 | device := device.NewDevice(tun, conn.NewDefaultBind(), logger) 52 | err = device.Up() 53 | if err != nil { 54 | logger.Errorf("Failed to bring up device: %v", err) 55 | os.Exit(ExitSetupFailed) 56 | } 57 | logger.Verbosef("Device started") 58 | 59 | uapi, err := ipc.UAPIListen(interfaceName) 60 | if err != nil { 61 | logger.Errorf("Failed to listen on uapi socket: %v", err) 62 | os.Exit(ExitSetupFailed) 63 | } 64 | 65 | errs := make(chan error) 66 | term := make(chan os.Signal, 1) 67 | 68 | go func() { 69 | for { 70 | conn, err := uapi.Accept() 71 | if err != nil { 72 | errs <- err 73 | return 74 | } 75 | go device.IpcHandle(conn) 76 | } 77 | }() 78 | logger.Verbosef("UAPI listener started") 79 | 80 | // wait for program to terminate 81 | 82 | signal.Notify(term, os.Interrupt) 83 | signal.Notify(term, os.Kill) 84 | signal.Notify(term, syscall.SIGTERM) 85 | 86 | select { 87 | case <-term: 88 | case <-errs: 89 | case <-device.Wait(): 90 | } 91 | 92 | // clean up 93 | 94 | uapi.Close() 95 | device.Close() 96 | 97 | logger.Verbosef("Shutting down") 98 | } 99 | -------------------------------------------------------------------------------- /ratelimiter/ratelimiter.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ratelimiter 7 | 8 | import ( 9 | "net" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | const ( 15 | packetsPerSecond = 20 16 | packetsBurstable = 5 17 | garbageCollectTime = time.Second 18 | packetCost = 1000000000 / packetsPerSecond 19 | maxTokens = packetCost * packetsBurstable 20 | ) 21 | 22 | type RatelimiterEntry struct { 23 | mu sync.Mutex 24 | lastTime time.Time 25 | tokens int64 26 | } 27 | 28 | type Ratelimiter struct { 29 | mu sync.RWMutex 30 | timeNow func() time.Time 31 | 32 | stopReset chan struct{} // send to reset, close to stop 33 | tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry 34 | tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry 35 | } 36 | 37 | func (rate *Ratelimiter) Close() { 38 | rate.mu.Lock() 39 | defer rate.mu.Unlock() 40 | 41 | if rate.stopReset != nil { 42 | close(rate.stopReset) 43 | } 44 | } 45 | 46 | func (rate *Ratelimiter) Init() { 47 | rate.mu.Lock() 48 | defer rate.mu.Unlock() 49 | 50 | if rate.timeNow == nil { 51 | rate.timeNow = time.Now 52 | } 53 | 54 | // stop any ongoing garbage collection routine 55 | if rate.stopReset != nil { 56 | close(rate.stopReset) 57 | } 58 | 59 | rate.stopReset = make(chan struct{}) 60 | rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) 61 | rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) 62 | 63 | stopReset := rate.stopReset // store in case Init is called again. 64 | 65 | // Start garbage collection routine. 66 | go func() { 67 | ticker := time.NewTicker(time.Second) 68 | ticker.Stop() 69 | for { 70 | select { 71 | case _, ok := <-stopReset: 72 | ticker.Stop() 73 | if !ok { 74 | return 75 | } 76 | ticker = time.NewTicker(time.Second) 77 | case <-ticker.C: 78 | if rate.cleanup() { 79 | ticker.Stop() 80 | } 81 | } 82 | } 83 | }() 84 | } 85 | 86 | func (rate *Ratelimiter) cleanup() (empty bool) { 87 | rate.mu.Lock() 88 | defer rate.mu.Unlock() 89 | 90 | for key, entry := range rate.tableIPv4 { 91 | entry.mu.Lock() 92 | if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 93 | delete(rate.tableIPv4, key) 94 | } 95 | entry.mu.Unlock() 96 | } 97 | 98 | for key, entry := range rate.tableIPv6 { 99 | entry.mu.Lock() 100 | if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 101 | delete(rate.tableIPv6, key) 102 | } 103 | entry.mu.Unlock() 104 | } 105 | 106 | return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 107 | } 108 | 109 | func (rate *Ratelimiter) Allow(ip net.IP) bool { 110 | var entry *RatelimiterEntry 111 | var keyIPv4 [net.IPv4len]byte 112 | var keyIPv6 [net.IPv6len]byte 113 | 114 | // lookup entry 115 | 116 | IPv4 := ip.To4() 117 | IPv6 := ip.To16() 118 | 119 | rate.mu.RLock() 120 | 121 | if IPv4 != nil { 122 | copy(keyIPv4[:], IPv4) 123 | entry = rate.tableIPv4[keyIPv4] 124 | } else { 125 | copy(keyIPv6[:], IPv6) 126 | entry = rate.tableIPv6[keyIPv6] 127 | } 128 | 129 | rate.mu.RUnlock() 130 | 131 | // make new entry if not found 132 | 133 | if entry == nil { 134 | entry = new(RatelimiterEntry) 135 | entry.tokens = maxTokens - packetCost 136 | entry.lastTime = rate.timeNow() 137 | rate.mu.Lock() 138 | if IPv4 != nil { 139 | rate.tableIPv4[keyIPv4] = entry 140 | if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { 141 | rate.stopReset <- struct{}{} 142 | } 143 | } else { 144 | rate.tableIPv6[keyIPv6] = entry 145 | if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 { 146 | rate.stopReset <- struct{}{} 147 | } 148 | } 149 | rate.mu.Unlock() 150 | return true 151 | } 152 | 153 | // add tokens to entry 154 | 155 | entry.mu.Lock() 156 | now := rate.timeNow() 157 | entry.tokens += now.Sub(entry.lastTime).Nanoseconds() 158 | entry.lastTime = now 159 | if entry.tokens > maxTokens { 160 | entry.tokens = maxTokens 161 | } 162 | 163 | // subtract cost of packet 164 | 165 | if entry.tokens > packetCost { 166 | entry.tokens -= packetCost 167 | entry.mu.Unlock() 168 | return true 169 | } 170 | entry.mu.Unlock() 171 | return false 172 | } 173 | -------------------------------------------------------------------------------- /ratelimiter/ratelimiter_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ratelimiter 7 | 8 | import ( 9 | "net" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | type result struct { 15 | allowed bool 16 | text string 17 | wait time.Duration 18 | } 19 | 20 | func TestRatelimiter(t *testing.T) { 21 | var rate Ratelimiter 22 | var expectedResults []result 23 | 24 | nano := func(nano int64) time.Duration { 25 | return time.Nanosecond * time.Duration(nano) 26 | } 27 | 28 | add := func(res result) { 29 | expectedResults = append( 30 | expectedResults, 31 | res, 32 | ) 33 | } 34 | 35 | for i := 0; i < packetsBurstable; i++ { 36 | add(result{ 37 | allowed: true, 38 | text: "initial burst", 39 | }) 40 | } 41 | 42 | add(result{ 43 | allowed: false, 44 | text: "after burst", 45 | }) 46 | 47 | add(result{ 48 | allowed: true, 49 | wait: nano(time.Second.Nanoseconds() / packetsPerSecond), 50 | text: "filling tokens for single packet", 51 | }) 52 | 53 | add(result{ 54 | allowed: false, 55 | text: "not having refilled enough", 56 | }) 57 | 58 | add(result{ 59 | allowed: true, 60 | wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)), 61 | text: "filling tokens for two packet burst", 62 | }) 63 | 64 | add(result{ 65 | allowed: true, 66 | text: "second packet in 2 packet burst", 67 | }) 68 | 69 | add(result{ 70 | allowed: false, 71 | text: "packet following 2 packet burst", 72 | }) 73 | 74 | ips := []net.IP{ 75 | net.ParseIP("127.0.0.1"), 76 | net.ParseIP("192.168.1.1"), 77 | net.ParseIP("172.167.2.3"), 78 | net.ParseIP("97.231.252.215"), 79 | net.ParseIP("248.97.91.167"), 80 | net.ParseIP("188.208.233.47"), 81 | net.ParseIP("104.2.183.179"), 82 | net.ParseIP("72.129.46.120"), 83 | net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), 84 | net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), 85 | net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), 86 | net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), 87 | net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), 88 | net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), 89 | } 90 | 91 | now := time.Now() 92 | rate.timeNow = func() time.Time { 93 | return now 94 | } 95 | defer func() { 96 | // Lock to avoid data race with cleanup goroutine from Init. 97 | rate.mu.Lock() 98 | defer rate.mu.Unlock() 99 | 100 | rate.timeNow = time.Now 101 | }() 102 | timeSleep := func(d time.Duration) { 103 | now = now.Add(d + 1) 104 | rate.cleanup() 105 | } 106 | 107 | rate.Init() 108 | defer rate.Close() 109 | 110 | for i, res := range expectedResults { 111 | timeSleep(res.wait) 112 | for _, ip := range ips { 113 | allowed := rate.Allow(ip) 114 | if allowed != res.allowed { 115 | t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed) 116 | } 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /replay/replay.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. 7 | package replay 8 | 9 | type block uint64 10 | 11 | const ( 12 | blockBitLog = 6 // 1<<6 == 64 bits 13 | blockBits = 1 << blockBitLog // must be power of 2 14 | ringBlocks = 1 << 7 // must be power of 2 15 | windowSize = (ringBlocks - 1) * blockBits 16 | blockMask = ringBlocks - 1 17 | bitMask = blockBits - 1 18 | ) 19 | 20 | // A Filter rejects replayed messages by checking if message counter value is 21 | // within a sliding window of previously received messages. 22 | // The zero value for Filter is an empty filter ready to use. 23 | // Filters are unsafe for concurrent use. 24 | type Filter struct { 25 | last uint64 26 | ring [ringBlocks]block 27 | } 28 | 29 | // Reset resets the filter to empty state. 30 | func (f *Filter) Reset() { 31 | f.last = 0 32 | f.ring[0] = 0 33 | } 34 | 35 | // ValidateCounter checks if the counter should be accepted. 36 | // Overlimit counters (>= limit) are always rejected. 37 | func (f *Filter) ValidateCounter(counter uint64, limit uint64) bool { 38 | if counter >= limit { 39 | return false 40 | } 41 | indexBlock := counter >> blockBitLog 42 | if counter > f.last { // move window forward 43 | current := f.last >> blockBitLog 44 | diff := indexBlock - current 45 | if diff > ringBlocks { 46 | diff = ringBlocks // cap diff to clear the whole ring 47 | } 48 | for i := current + 1; i <= current+diff; i++ { 49 | f.ring[i&blockMask] = 0 50 | } 51 | f.last = counter 52 | } else if f.last-counter > windowSize { // behind current window 53 | return false 54 | } 55 | // check and set bit 56 | indexBlock &= blockMask 57 | indexBit := counter & bitMask 58 | old := f.ring[indexBlock] 59 | new := old | 1< 0; i-- { 91 | T(i, true) 92 | } 93 | 94 | t.Log("Bulk test 4") 95 | filter.Reset() 96 | testNumber = 0 97 | for i := uint64(windowSize + 2); i > 1; i-- { 98 | T(i, true) 99 | } 100 | T(0, false) 101 | 102 | t.Log("Bulk test 5") 103 | filter.Reset() 104 | testNumber = 0 105 | for i := uint64(windowSize); i > 0; i-- { 106 | T(i, true) 107 | } 108 | T(windowSize+1, true) 109 | T(0, false) 110 | 111 | t.Log("Bulk test 6") 112 | filter.Reset() 113 | testNumber = 0 114 | for i := uint64(windowSize); i > 0; i-- { 115 | T(i, true) 116 | } 117 | T(0, true) 118 | T(windowSize+1, true) 119 | } 120 | -------------------------------------------------------------------------------- /rwcancel/rwcancel.go: -------------------------------------------------------------------------------- 1 | // +build !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | // Package rwcancel implements cancelable read/write operations on 9 | // a file descriptor. 10 | package rwcancel 11 | 12 | import ( 13 | "errors" 14 | "os" 15 | "syscall" 16 | 17 | "golang.org/x/sys/unix" 18 | ) 19 | 20 | type RWCancel struct { 21 | fd int 22 | closingReader *os.File 23 | closingWriter *os.File 24 | } 25 | 26 | func NewRWCancel(fd int) (*RWCancel, error) { 27 | err := unix.SetNonblock(fd, true) 28 | if err != nil { 29 | return nil, err 30 | } 31 | rwcancel := RWCancel{fd: fd} 32 | 33 | rwcancel.closingReader, rwcancel.closingWriter, err = os.Pipe() 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | return &rwcancel, nil 39 | } 40 | 41 | func RetryAfterError(err error) bool { 42 | return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) 43 | } 44 | 45 | func (rw *RWCancel) ReadyRead() bool { 46 | closeFd := int32(rw.closingReader.Fd()) 47 | 48 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}} 49 | var err error 50 | for { 51 | _, err = unix.Poll(pollFds, -1) 52 | if err == nil || !RetryAfterError(err) { 53 | break 54 | } 55 | } 56 | if err != nil { 57 | return false 58 | } 59 | if pollFds[1].Revents != 0 { 60 | return false 61 | } 62 | return pollFds[0].Revents != 0 63 | } 64 | 65 | func (rw *RWCancel) ReadyWrite() bool { 66 | closeFd := int32(rw.closingReader.Fd()) 67 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}} 68 | var err error 69 | for { 70 | _, err = unix.Poll(pollFds, -1) 71 | if err == nil || !RetryAfterError(err) { 72 | break 73 | } 74 | } 75 | if err != nil { 76 | return false 77 | } 78 | 79 | if pollFds[1].Revents != 0 { 80 | return false 81 | } 82 | return pollFds[0].Revents != 0 83 | } 84 | 85 | func (rw *RWCancel) Read(p []byte) (n int, err error) { 86 | for { 87 | n, err := unix.Read(rw.fd, p) 88 | if err == nil || !RetryAfterError(err) { 89 | return n, err 90 | } 91 | if !rw.ReadyRead() { 92 | return 0, os.ErrClosed 93 | } 94 | } 95 | } 96 | 97 | func (rw *RWCancel) Write(p []byte) (n int, err error) { 98 | for { 99 | n, err := unix.Write(rw.fd, p) 100 | if err == nil || !RetryAfterError(err) { 101 | return n, err 102 | } 103 | if !rw.ReadyWrite() { 104 | return 0, os.ErrClosed 105 | } 106 | } 107 | } 108 | 109 | func (rw *RWCancel) Cancel() (err error) { 110 | _, err = rw.closingWriter.Write([]byte{0}) 111 | return 112 | } 113 | 114 | func (rw *RWCancel) Close() { 115 | rw.closingReader.Close() 116 | rw.closingWriter.Close() 117 | } 118 | -------------------------------------------------------------------------------- /rwcancel/rwcancel_windows.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | 3 | package rwcancel 4 | 5 | type RWCancel struct { 6 | } 7 | 8 | func (*RWCancel) Cancel() {} 9 | -------------------------------------------------------------------------------- /tai64n/tai64n.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tai64n 7 | 8 | import ( 9 | "bytes" 10 | "encoding/binary" 11 | "time" 12 | ) 13 | 14 | const TimestampSize = 12 15 | const base = uint64(0x400000000000000a) 16 | const whitenerMask = uint32(0x1000000 - 1) 17 | 18 | type Timestamp [TimestampSize]byte 19 | 20 | func stamp(t time.Time) Timestamp { 21 | var tai64n Timestamp 22 | secs := base + uint64(t.Unix()) 23 | nano := uint32(t.Nanosecond()) &^ whitenerMask 24 | binary.BigEndian.PutUint64(tai64n[:], secs) 25 | binary.BigEndian.PutUint32(tai64n[8:], nano) 26 | return tai64n 27 | } 28 | 29 | func Now() Timestamp { 30 | return stamp(time.Now()) 31 | } 32 | 33 | func (t1 Timestamp) After(t2 Timestamp) bool { 34 | return bytes.Compare(t1[:], t2[:]) > 0 35 | } 36 | 37 | func (t Timestamp) String() string { 38 | return time.Unix(int64(binary.BigEndian.Uint64(t[:8])-base), int64(binary.BigEndian.Uint32(t[8:12]))).String() 39 | } 40 | -------------------------------------------------------------------------------- /tai64n/tai64n_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tai64n 7 | 8 | import ( 9 | "testing" 10 | "time" 11 | ) 12 | 13 | // Test that timestamps are monotonic as required by Wireguard and that 14 | // nanosecond-level information is whitened to prevent side channel attacks. 15 | func TestMonotonic(t *testing.T) { 16 | startTime := time.Unix(0, 123456789) // a nontrivial bit pattern 17 | // Whitening should reduce timestamp granularity 18 | // to more than 10 but fewer than 20 milliseconds. 19 | tests := []struct { 20 | name string 21 | t1, t2 time.Time 22 | wantAfter bool 23 | }{ 24 | {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false}, 25 | {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false}, 26 | {"after_1_ms", startTime, startTime.Add(time.Millisecond), false}, 27 | {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false}, 28 | {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true}, 29 | } 30 | 31 | for _, tt := range tests { 32 | t.Run(tt.name, func(t *testing.T) { 33 | ts1, ts2 := stamp(tt.t1), stamp(tt.t2) 34 | got := ts2.After(ts1) 35 | if got != tt.wantAfter { 36 | t.Errorf("after = %v; want %v", got, tt.wantAfter) 37 | } 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /tun/alignment_windows_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "reflect" 10 | "testing" 11 | "unsafe" 12 | ) 13 | 14 | func checkAlignment(t *testing.T, name string, offset uintptr) { 15 | t.Helper() 16 | if offset%8 != 0 { 17 | t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) 18 | } 19 | } 20 | 21 | // TestRateJugglerAlignment checks that atomically-accessed fields are 22 | // aligned to 64-bit boundaries, as required by the atomic package. 23 | // 24 | // Unfortunately, violating this rule on 32-bit platforms results in a 25 | // hard segfault at runtime. 26 | func TestRateJugglerAlignment(t *testing.T) { 27 | var r rateJuggler 28 | 29 | typ := reflect.TypeOf(&r).Elem() 30 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 31 | for i := 0; i < typ.NumField(); i++ { 32 | field := typ.Field(i) 33 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 34 | field.Name, 35 | field.Offset, 36 | field.Type.Size(), 37 | field.Type.Align(), 38 | ) 39 | } 40 | 41 | checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current)) 42 | checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount)) 43 | checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime)) 44 | } 45 | 46 | // TestNativeTunAlignment checks that atomically-accessed fields are 47 | // aligned to 64-bit boundaries, as required by the atomic package. 48 | // 49 | // Unfortunately, violating this rule on 32-bit platforms results in a 50 | // hard segfault at runtime. 51 | func TestNativeTunAlignment(t *testing.T) { 52 | var tun NativeTun 53 | 54 | typ := reflect.TypeOf(&tun).Elem() 55 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 56 | for i := 0; i < typ.NumField(); i++ { 57 | field := typ.Field(i) 58 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 59 | field.Name, 60 | field.Offset, 61 | field.Type.Size(), 62 | field.Type.Align(), 63 | ) 64 | } 65 | 66 | checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate)) 67 | } 68 | -------------------------------------------------------------------------------- /tun/netstack/examples/http_client.go: -------------------------------------------------------------------------------- 1 | // +build ignore 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "io" 12 | "log" 13 | "net" 14 | "net/http" 15 | 16 | "golang.zx2c4.com/wireguard/conn" 17 | "golang.zx2c4.com/wireguard/device" 18 | "golang.zx2c4.com/wireguard/tun/netstack" 19 | ) 20 | 21 | func main() { 22 | tun, tnet, err := netstack.CreateNetTUN( 23 | []net.IP{net.ParseIP("192.168.4.29")}, 24 | []net.IP{net.ParseIP("8.8.8.8")}, 25 | 1420) 26 | if err != nil { 27 | log.Panic(err) 28 | } 29 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 30 | dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f 31 | public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b 32 | endpoint=163.172.161.0:12912 33 | allowed_ip=0.0.0.0/0 34 | `) 35 | err = dev.Up() 36 | if err != nil { 37 | log.Panic(err) 38 | } 39 | 40 | client := http.Client{ 41 | Transport: &http.Transport{ 42 | DialContext: tnet.DialContext, 43 | }, 44 | } 45 | resp, err := client.Get("https://www.zx2c4.com/ip") 46 | if err != nil { 47 | log.Panic(err) 48 | } 49 | body, err := io.ReadAll(resp.Body) 50 | if err != nil { 51 | log.Panic(err) 52 | } 53 | log.Println(string(body)) 54 | } 55 | -------------------------------------------------------------------------------- /tun/netstack/examples/http_server.go: -------------------------------------------------------------------------------- 1 | // +build ignore 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "io" 12 | "log" 13 | "net" 14 | "net/http" 15 | 16 | "golang.zx2c4.com/wireguard/conn" 17 | "golang.zx2c4.com/wireguard/device" 18 | "golang.zx2c4.com/wireguard/tun/netstack" 19 | ) 20 | 21 | func main() { 22 | tun, tnet, err := netstack.CreateNetTUN( 23 | []net.IP{net.ParseIP("192.168.4.29")}, 24 | []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, 25 | 1420, 26 | ) 27 | if err != nil { 28 | log.Panic(err) 29 | } 30 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 31 | dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f 32 | public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b 33 | endpoint=163.172.161.0:12912 34 | allowed_ip=0.0.0.0/0 35 | persistent_keepalive_interval=25 36 | `) 37 | dev.Up() 38 | listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80}) 39 | if err != nil { 40 | log.Panicln(err) 41 | } 42 | http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { 43 | log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent()) 44 | io.WriteString(writer, "Hello from userspace TCP!") 45 | }) 46 | err = http.Serve(listener, nil) 47 | if err != nil { 48 | log.Panicln(err) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /tun/netstack/go.mod: -------------------------------------------------------------------------------- 1 | module golang.zx2c4.com/wireguard/tun/netstack 2 | 3 | go 1.16 4 | 5 | require ( 6 | golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 7 | golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect 8 | golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect 9 | golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 10 | gvisor.dev/gvisor v0.0.0-20210506004418-fbfeba3024f0 11 | ) 12 | -------------------------------------------------------------------------------- /tun/operateonfd.go: -------------------------------------------------------------------------------- 1 | // +build !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package tun 9 | 10 | import ( 11 | "fmt" 12 | ) 13 | 14 | func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) { 15 | sysconn, err := tun.tunFile.SyscallConn() 16 | if err != nil { 17 | tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error()) 18 | return 19 | } 20 | err = sysconn.Control(fn) 21 | if err != nil { 22 | tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error()) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /tun/tun.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "os" 10 | ) 11 | 12 | type Event int 13 | 14 | const ( 15 | EventUp = 1 << iota 16 | EventDown 17 | EventMTUUpdate 18 | ) 19 | 20 | type Device interface { 21 | File() *os.File // returns the file descriptor of the device 22 | Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) 23 | Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) 24 | Flush() error // flush all previous writes to the device 25 | MTU() (int, error) // returns the MTU of the device 26 | Name() (string, error) // fetches and returns the current name 27 | Events() chan Event // returns a constant channel of events related to the device 28 | Close() error // stops the device and closes the event channel 29 | } 30 | -------------------------------------------------------------------------------- /tun/tun_darwin.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "net" 12 | "os" 13 | "sync" 14 | "syscall" 15 | "time" 16 | "unsafe" 17 | 18 | "golang.org/x/net/ipv6" 19 | "golang.org/x/sys/unix" 20 | ) 21 | 22 | const utunControlName = "com.apple.net.utun_control" 23 | 24 | type NativeTun struct { 25 | name string 26 | tunFile *os.File 27 | events chan Event 28 | errors chan error 29 | routeSocket int 30 | closeOnce sync.Once 31 | } 32 | 33 | func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { 34 | for i := 0; i < 20; i++ { 35 | iface, err = net.InterfaceByIndex(index) 36 | if err != nil && errors.Is(err, syscall.ENOMEM) { 37 | time.Sleep(time.Duration(i) * time.Second / 3) 38 | continue 39 | } 40 | return iface, err 41 | } 42 | return nil, err 43 | } 44 | 45 | func (tun *NativeTun) routineRouteListener(tunIfindex int) { 46 | var ( 47 | statusUp bool 48 | statusMTU int 49 | ) 50 | 51 | defer close(tun.events) 52 | 53 | data := make([]byte, os.Getpagesize()) 54 | for { 55 | retry: 56 | n, err := unix.Read(tun.routeSocket, data) 57 | if err != nil { 58 | if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { 59 | goto retry 60 | } 61 | tun.errors <- err 62 | return 63 | } 64 | 65 | if n < 14 { 66 | continue 67 | } 68 | 69 | if data[3 /* type */] != unix.RTM_IFINFO { 70 | continue 71 | } 72 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */]))) 73 | if ifindex != tunIfindex { 74 | continue 75 | } 76 | 77 | iface, err := retryInterfaceByIndex(ifindex) 78 | if err != nil { 79 | tun.errors <- err 80 | return 81 | } 82 | 83 | // Up / Down event 84 | up := (iface.Flags & net.FlagUp) != 0 85 | if up != statusUp && up { 86 | tun.events <- EventUp 87 | } 88 | if up != statusUp && !up { 89 | tun.events <- EventDown 90 | } 91 | statusUp = up 92 | 93 | // MTU changes 94 | if iface.MTU != statusMTU { 95 | tun.events <- EventMTUUpdate 96 | } 97 | statusMTU = iface.MTU 98 | } 99 | } 100 | 101 | func CreateTUN(name string, mtu int) (Device, error) { 102 | ifIndex := -1 103 | if name != "utun" { 104 | _, err := fmt.Sscanf(name, "utun%d", &ifIndex) 105 | if err != nil || ifIndex < 0 { 106 | return nil, fmt.Errorf("Interface name must be utun[0-9]*") 107 | } 108 | } 109 | 110 | fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) 111 | 112 | if err != nil { 113 | return nil, err 114 | } 115 | 116 | ctlInfo := &unix.CtlInfo{} 117 | copy(ctlInfo.Name[:], []byte(utunControlName)) 118 | err = unix.IoctlCtlInfo(fd, ctlInfo) 119 | if err != nil { 120 | return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err) 121 | } 122 | 123 | sc := &unix.SockaddrCtl{ 124 | ID: ctlInfo.Id, 125 | Unit: uint32(ifIndex) + 1, 126 | } 127 | 128 | err = unix.Connect(fd, sc) 129 | if err != nil { 130 | return nil, err 131 | } 132 | 133 | err = syscall.SetNonblock(fd, true) 134 | if err != nil { 135 | return nil, err 136 | } 137 | tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu) 138 | 139 | if err == nil && name == "utun" { 140 | fname := os.Getenv("WG_TUN_NAME_FILE") 141 | if fname != "" { 142 | os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400) 143 | } 144 | } 145 | 146 | return tun, err 147 | } 148 | 149 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 150 | tun := &NativeTun{ 151 | tunFile: file, 152 | events: make(chan Event, 10), 153 | errors: make(chan error, 5), 154 | } 155 | 156 | name, err := tun.Name() 157 | if err != nil { 158 | tun.tunFile.Close() 159 | return nil, err 160 | } 161 | 162 | tunIfindex, err := func() (int, error) { 163 | iface, err := net.InterfaceByName(name) 164 | if err != nil { 165 | return -1, err 166 | } 167 | return iface.Index, nil 168 | }() 169 | if err != nil { 170 | tun.tunFile.Close() 171 | return nil, err 172 | } 173 | 174 | tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) 175 | if err != nil { 176 | tun.tunFile.Close() 177 | return nil, err 178 | } 179 | 180 | go tun.routineRouteListener(tunIfindex) 181 | 182 | if mtu > 0 { 183 | err = tun.setMTU(mtu) 184 | if err != nil { 185 | tun.Close() 186 | return nil, err 187 | } 188 | } 189 | 190 | return tun, nil 191 | } 192 | 193 | func (tun *NativeTun) Name() (string, error) { 194 | var err error 195 | tun.operateOnFd(func(fd uintptr) { 196 | tun.name, err = unix.GetsockoptString( 197 | int(fd), 198 | 2, /* #define SYSPROTO_CONTROL 2 */ 199 | 2, /* #define UTUN_OPT_IFNAME 2 */ 200 | ) 201 | }) 202 | 203 | if err != nil { 204 | return "", fmt.Errorf("GetSockoptString: %w", err) 205 | } 206 | 207 | return tun.name, nil 208 | } 209 | 210 | func (tun *NativeTun) File() *os.File { 211 | return tun.tunFile 212 | } 213 | 214 | func (tun *NativeTun) Events() chan Event { 215 | return tun.events 216 | } 217 | 218 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 219 | select { 220 | case err := <-tun.errors: 221 | return 0, err 222 | default: 223 | buff := buff[offset-4:] 224 | n, err := tun.tunFile.Read(buff[:]) 225 | if n < 4 { 226 | return 0, err 227 | } 228 | return n - 4, err 229 | } 230 | } 231 | 232 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 233 | 234 | // reserve space for header 235 | 236 | buff = buff[offset-4:] 237 | 238 | // add packet information header 239 | 240 | buff[0] = 0x00 241 | buff[1] = 0x00 242 | buff[2] = 0x00 243 | 244 | if buff[4]>>4 == ipv6.Version { 245 | buff[3] = unix.AF_INET6 246 | } else { 247 | buff[3] = unix.AF_INET 248 | } 249 | 250 | // write 251 | 252 | return tun.tunFile.Write(buff) 253 | } 254 | 255 | func (tun *NativeTun) Flush() error { 256 | // TODO: can flushing be implemented by buffering and using sendmmsg? 257 | return nil 258 | } 259 | 260 | func (tun *NativeTun) Close() error { 261 | var err1, err2 error 262 | tun.closeOnce.Do(func() { 263 | err1 = tun.tunFile.Close() 264 | if tun.routeSocket != -1 { 265 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) 266 | err2 = unix.Close(tun.routeSocket) 267 | } else if tun.events != nil { 268 | close(tun.events) 269 | } 270 | }) 271 | if err1 != nil { 272 | return err1 273 | } 274 | return err2 275 | } 276 | 277 | func (tun *NativeTun) setMTU(n int) error { 278 | fd, err := unix.Socket( 279 | unix.AF_INET, 280 | unix.SOCK_DGRAM, 281 | 0, 282 | ) 283 | 284 | if err != nil { 285 | return err 286 | } 287 | 288 | defer unix.Close(fd) 289 | 290 | var ifr unix.IfreqMTU 291 | copy(ifr.Name[:], tun.name) 292 | ifr.MTU = int32(n) 293 | err = unix.IoctlSetIfreqMTU(fd, &ifr) 294 | if err != nil { 295 | return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err) 296 | } 297 | 298 | return nil 299 | } 300 | 301 | func (tun *NativeTun) MTU() (int, error) { 302 | fd, err := unix.Socket( 303 | unix.AF_INET, 304 | unix.SOCK_DGRAM, 305 | 0, 306 | ) 307 | 308 | if err != nil { 309 | return 0, err 310 | } 311 | 312 | defer unix.Close(fd) 313 | 314 | ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name) 315 | if err != nil { 316 | return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err) 317 | } 318 | 319 | return int(ifr.MTU), nil 320 | } 321 | -------------------------------------------------------------------------------- /tun/tun_openbsd.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "net" 12 | "os" 13 | "sync" 14 | "syscall" 15 | "unsafe" 16 | 17 | "golang.org/x/net/ipv6" 18 | "golang.org/x/sys/unix" 19 | ) 20 | 21 | // Structure for iface mtu get/set ioctls 22 | type ifreq_mtu struct { 23 | Name [unix.IFNAMSIZ]byte 24 | MTU uint32 25 | Pad0 [12]byte 26 | } 27 | 28 | const _TUNSIFMODE = 0x8004745d 29 | 30 | type NativeTun struct { 31 | name string 32 | tunFile *os.File 33 | events chan Event 34 | errors chan error 35 | routeSocket int 36 | closeOnce sync.Once 37 | } 38 | 39 | func (tun *NativeTun) routineRouteListener(tunIfindex int) { 40 | var ( 41 | statusUp bool 42 | statusMTU int 43 | ) 44 | 45 | defer close(tun.events) 46 | 47 | check := func() bool { 48 | iface, err := net.InterfaceByIndex(tunIfindex) 49 | if err != nil { 50 | tun.errors <- err 51 | return true 52 | } 53 | 54 | // Up / Down event 55 | up := (iface.Flags & net.FlagUp) != 0 56 | if up != statusUp && up { 57 | tun.events <- EventUp 58 | } 59 | if up != statusUp && !up { 60 | tun.events <- EventDown 61 | } 62 | statusUp = up 63 | 64 | // MTU changes 65 | if iface.MTU != statusMTU { 66 | tun.events <- EventMTUUpdate 67 | } 68 | statusMTU = iface.MTU 69 | return false 70 | } 71 | 72 | if check() { 73 | return 74 | } 75 | 76 | data := make([]byte, os.Getpagesize()) 77 | for { 78 | n, err := unix.Read(tun.routeSocket, data) 79 | if err != nil { 80 | if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { 81 | continue 82 | } 83 | tun.errors <- err 84 | return 85 | } 86 | 87 | if n < 8 { 88 | continue 89 | } 90 | 91 | if data[3 /* type */] != unix.RTM_IFINFO { 92 | continue 93 | } 94 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */]))) 95 | if ifindex != tunIfindex { 96 | continue 97 | } 98 | if check() { 99 | return 100 | } 101 | } 102 | } 103 | 104 | func CreateTUN(name string, mtu int) (Device, error) { 105 | ifIndex := -1 106 | if name != "tun" { 107 | _, err := fmt.Sscanf(name, "tun%d", &ifIndex) 108 | if err != nil || ifIndex < 0 { 109 | return nil, fmt.Errorf("Interface name must be tun[0-9]*") 110 | } 111 | } 112 | 113 | var tunfile *os.File 114 | var err error 115 | 116 | if ifIndex != -1 { 117 | tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0) 118 | } else { 119 | for ifIndex = 0; ifIndex < 256; ifIndex++ { 120 | tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0) 121 | if err == nil || !errors.Is(err, syscall.EBUSY) { 122 | break 123 | } 124 | } 125 | } 126 | 127 | if err != nil { 128 | return nil, err 129 | } 130 | 131 | tun, err := CreateTUNFromFile(tunfile, mtu) 132 | 133 | if err == nil && name == "tun" { 134 | fname := os.Getenv("WG_TUN_NAME_FILE") 135 | if fname != "" { 136 | os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400) 137 | } 138 | } 139 | 140 | return tun, err 141 | } 142 | 143 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 144 | tun := &NativeTun{ 145 | tunFile: file, 146 | events: make(chan Event, 10), 147 | errors: make(chan error, 1), 148 | } 149 | 150 | name, err := tun.Name() 151 | if err != nil { 152 | tun.tunFile.Close() 153 | return nil, err 154 | } 155 | 156 | tunIfindex, err := func() (int, error) { 157 | iface, err := net.InterfaceByName(name) 158 | if err != nil { 159 | return -1, err 160 | } 161 | return iface.Index, nil 162 | }() 163 | if err != nil { 164 | tun.tunFile.Close() 165 | return nil, err 166 | } 167 | 168 | tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) 169 | if err != nil { 170 | tun.tunFile.Close() 171 | return nil, err 172 | } 173 | 174 | go tun.routineRouteListener(tunIfindex) 175 | 176 | currentMTU, err := tun.MTU() 177 | if err != nil || currentMTU != mtu { 178 | err = tun.setMTU(mtu) 179 | if err != nil { 180 | tun.Close() 181 | return nil, err 182 | } 183 | } 184 | 185 | return tun, nil 186 | } 187 | 188 | func (tun *NativeTun) Name() (string, error) { 189 | gostat, err := tun.tunFile.Stat() 190 | if err != nil { 191 | tun.name = "" 192 | return "", err 193 | } 194 | stat := gostat.Sys().(*syscall.Stat_t) 195 | tun.name = fmt.Sprintf("tun%d", stat.Rdev%256) 196 | return tun.name, nil 197 | } 198 | 199 | func (tun *NativeTun) File() *os.File { 200 | return tun.tunFile 201 | } 202 | 203 | func (tun *NativeTun) Events() chan Event { 204 | return tun.events 205 | } 206 | 207 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 208 | select { 209 | case err := <-tun.errors: 210 | return 0, err 211 | default: 212 | buff := buff[offset-4:] 213 | n, err := tun.tunFile.Read(buff[:]) 214 | if n < 4 { 215 | return 0, err 216 | } 217 | return n - 4, err 218 | } 219 | } 220 | 221 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 222 | 223 | // reserve space for header 224 | 225 | buff = buff[offset-4:] 226 | 227 | // add packet information header 228 | 229 | buff[0] = 0x00 230 | buff[1] = 0x00 231 | buff[2] = 0x00 232 | 233 | if buff[4]>>4 == ipv6.Version { 234 | buff[3] = unix.AF_INET6 235 | } else { 236 | buff[3] = unix.AF_INET 237 | } 238 | 239 | // write 240 | 241 | return tun.tunFile.Write(buff) 242 | } 243 | 244 | func (tun *NativeTun) Flush() error { 245 | // TODO: can flushing be implemented by buffering and using sendmmsg? 246 | return nil 247 | } 248 | 249 | func (tun *NativeTun) Close() error { 250 | var err1, err2 error 251 | tun.closeOnce.Do(func() { 252 | err1 = tun.tunFile.Close() 253 | if tun.routeSocket != -1 { 254 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) 255 | err2 = unix.Close(tun.routeSocket) 256 | tun.routeSocket = -1 257 | } else if tun.events != nil { 258 | close(tun.events) 259 | } 260 | }) 261 | if err1 != nil { 262 | return err1 263 | } 264 | return err2 265 | } 266 | 267 | func (tun *NativeTun) setMTU(n int) error { 268 | // open datagram socket 269 | 270 | var fd int 271 | 272 | fd, err := unix.Socket( 273 | unix.AF_INET, 274 | unix.SOCK_DGRAM, 275 | 0, 276 | ) 277 | 278 | if err != nil { 279 | return err 280 | } 281 | 282 | defer unix.Close(fd) 283 | 284 | // do ioctl call 285 | 286 | var ifr ifreq_mtu 287 | copy(ifr.Name[:], tun.name) 288 | ifr.MTU = uint32(n) 289 | 290 | _, _, errno := unix.Syscall( 291 | unix.SYS_IOCTL, 292 | uintptr(fd), 293 | uintptr(unix.SIOCSIFMTU), 294 | uintptr(unsafe.Pointer(&ifr)), 295 | ) 296 | 297 | if errno != 0 { 298 | return fmt.Errorf("failed to set MTU on %s", tun.name) 299 | } 300 | 301 | return nil 302 | } 303 | 304 | func (tun *NativeTun) MTU() (int, error) { 305 | // open datagram socket 306 | 307 | fd, err := unix.Socket( 308 | unix.AF_INET, 309 | unix.SOCK_DGRAM, 310 | 0, 311 | ) 312 | 313 | if err != nil { 314 | return 0, err 315 | } 316 | 317 | defer unix.Close(fd) 318 | 319 | // do ioctl call 320 | var ifr ifreq_mtu 321 | copy(ifr.Name[:], tun.name) 322 | 323 | _, _, errno := unix.Syscall( 324 | unix.SYS_IOCTL, 325 | uintptr(fd), 326 | uintptr(unix.SIOCGIFMTU), 327 | uintptr(unsafe.Pointer(&ifr)), 328 | ) 329 | if errno != 0 { 330 | return 0, fmt.Errorf("failed to get MTU on %s", tun.name) 331 | } 332 | 333 | return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil 334 | } 335 | -------------------------------------------------------------------------------- /tun/tun_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "log" 12 | "os" 13 | "sync" 14 | "sync/atomic" 15 | "time" 16 | _ "unsafe" 17 | 18 | "golang.org/x/sys/windows" 19 | 20 | "golang.zx2c4.com/wireguard/tun/wintun" 21 | ) 22 | 23 | const ( 24 | rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 25 | spinloopRateThreshold = 800000000 / 8 // 800mbps 26 | spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 27 | ) 28 | 29 | type rateJuggler struct { 30 | current uint64 31 | nextByteCount uint64 32 | nextStartTime int64 33 | changing int32 34 | } 35 | 36 | type NativeTun struct { 37 | wt *wintun.Adapter 38 | handle windows.Handle 39 | rate rateJuggler 40 | session wintun.Session 41 | readWait windows.Handle 42 | events chan Event 43 | running sync.WaitGroup 44 | closeOnce sync.Once 45 | close int32 46 | forcedMTU int 47 | } 48 | 49 | var WintunPool, _ = wintun.MakePool("WireGuard") 50 | var WintunStaticRequestedGUID *windows.GUID 51 | 52 | //go:linkname procyield runtime.procyield 53 | func procyield(cycles uint32) 54 | 55 | //go:linkname nanotime runtime.nanotime 56 | func nanotime() int64 57 | 58 | // 59 | // CreateTUN creates a Wintun interface with the given name. Should a Wintun 60 | // interface with the same name exist, it is reused. 61 | // 62 | func CreateTUN(ifname string, mtu int) (Device, error) { 63 | return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) 64 | } 65 | 66 | // 67 | // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and 68 | // a requested GUID. Should a Wintun interface with the same name exist, it is reused. 69 | // 70 | func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { 71 | var err error 72 | var wt *wintun.Adapter 73 | 74 | // Does an interface with this name already exist? 75 | wt, err = WintunPool.OpenAdapter(ifname) 76 | if err == nil { 77 | // If so, we delete it, in case it has weird residual configuration. 78 | _, err = wt.Delete(true) 79 | if err != nil { 80 | return nil, fmt.Errorf("Error deleting already existing interface: %w", err) 81 | } 82 | } 83 | wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID) 84 | if err != nil { 85 | return nil, fmt.Errorf("Error creating interface: %w", err) 86 | } 87 | if rebootRequired { 88 | log.Println("Windows indicated a reboot is required.") 89 | } 90 | 91 | forcedMTU := 1420 92 | if mtu > 0 { 93 | forcedMTU = mtu 94 | } 95 | 96 | tun := &NativeTun{ 97 | wt: wt, 98 | handle: windows.InvalidHandle, 99 | events: make(chan Event, 10), 100 | forcedMTU: forcedMTU, 101 | } 102 | 103 | tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB 104 | if err != nil { 105 | tun.wt.Delete(false) 106 | close(tun.events) 107 | return nil, fmt.Errorf("Error starting session: %w", err) 108 | } 109 | tun.readWait = tun.session.ReadWaitEvent() 110 | return tun, nil 111 | } 112 | 113 | func (tun *NativeTun) Name() (string, error) { 114 | tun.running.Add(1) 115 | defer tun.running.Done() 116 | if atomic.LoadInt32(&tun.close) == 1 { 117 | return "", os.ErrClosed 118 | } 119 | return tun.wt.Name() 120 | } 121 | 122 | func (tun *NativeTun) File() *os.File { 123 | return nil 124 | } 125 | 126 | func (tun *NativeTun) Events() chan Event { 127 | return tun.events 128 | } 129 | 130 | func (tun *NativeTun) Close() error { 131 | var err error 132 | tun.closeOnce.Do(func() { 133 | atomic.StoreInt32(&tun.close, 1) 134 | windows.SetEvent(tun.readWait) 135 | tun.running.Wait() 136 | tun.session.End() 137 | if tun.wt != nil { 138 | _, err = tun.wt.Delete(false) 139 | } 140 | close(tun.events) 141 | }) 142 | return err 143 | } 144 | 145 | func (tun *NativeTun) MTU() (int, error) { 146 | return tun.forcedMTU, nil 147 | } 148 | 149 | // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. 150 | func (tun *NativeTun) ForceMTU(mtu int) { 151 | update := tun.forcedMTU != mtu 152 | tun.forcedMTU = mtu 153 | if update { 154 | tun.events <- EventMTUUpdate 155 | } 156 | } 157 | 158 | // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. 159 | 160 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 161 | tun.running.Add(1) 162 | defer tun.running.Done() 163 | retry: 164 | if atomic.LoadInt32(&tun.close) == 1 { 165 | return 0, os.ErrClosed 166 | } 167 | start := nanotime() 168 | shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 169 | for { 170 | if atomic.LoadInt32(&tun.close) == 1 { 171 | return 0, os.ErrClosed 172 | } 173 | packet, err := tun.session.ReceivePacket() 174 | switch err { 175 | case nil: 176 | packetSize := len(packet) 177 | copy(buff[offset:], packet) 178 | tun.session.ReleaseReceivePacket(packet) 179 | tun.rate.update(uint64(packetSize)) 180 | return packetSize, nil 181 | case windows.ERROR_NO_MORE_ITEMS: 182 | if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 183 | windows.WaitForSingleObject(tun.readWait, windows.INFINITE) 184 | goto retry 185 | } 186 | procyield(1) 187 | continue 188 | case windows.ERROR_HANDLE_EOF: 189 | return 0, os.ErrClosed 190 | case windows.ERROR_INVALID_DATA: 191 | return 0, errors.New("Send ring corrupt") 192 | } 193 | return 0, fmt.Errorf("Read failed: %w", err) 194 | } 195 | } 196 | 197 | func (tun *NativeTun) Flush() error { 198 | return nil 199 | } 200 | 201 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 202 | tun.running.Add(1) 203 | defer tun.running.Done() 204 | if atomic.LoadInt32(&tun.close) == 1 { 205 | return 0, os.ErrClosed 206 | } 207 | 208 | packetSize := len(buff) - offset 209 | tun.rate.update(uint64(packetSize)) 210 | 211 | packet, err := tun.session.AllocateSendPacket(packetSize) 212 | if err == nil { 213 | copy(packet, buff[offset:]) 214 | tun.session.SendPacket(packet) 215 | return packetSize, nil 216 | } 217 | switch err { 218 | case windows.ERROR_HANDLE_EOF: 219 | return 0, os.ErrClosed 220 | case windows.ERROR_BUFFER_OVERFLOW: 221 | return 0, nil // Dropping when ring is full. 222 | } 223 | return 0, fmt.Errorf("Write failed: %w", err) 224 | } 225 | 226 | // LUID returns Windows interface instance ID. 227 | func (tun *NativeTun) LUID() uint64 { 228 | tun.running.Add(1) 229 | defer tun.running.Done() 230 | if atomic.LoadInt32(&tun.close) == 1 { 231 | return 0 232 | } 233 | return tun.wt.LUID() 234 | } 235 | 236 | // RunningVersion returns the running version of the Wintun driver. 237 | func (tun *NativeTun) RunningVersion() (version uint32, err error) { 238 | return wintun.RunningVersion() 239 | } 240 | 241 | func (rate *rateJuggler) update(packetLen uint64) { 242 | now := nanotime() 243 | total := atomic.AddUint64(&rate.nextByteCount, packetLen) 244 | period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) 245 | if period >= rateMeasurementGranularity { 246 | if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { 247 | return 248 | } 249 | atomic.StoreInt64(&rate.nextStartTime, now) 250 | atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) 251 | atomic.StoreUint64(&rate.nextByteCount, 0) 252 | atomic.StoreInt32(&rate.changing, 0) 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /tun/tuntest/tuntest.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tuntest 7 | 8 | import ( 9 | "encoding/binary" 10 | "io" 11 | "net" 12 | "os" 13 | 14 | "golang.zx2c4.com/wireguard/tun" 15 | ) 16 | 17 | func Ping(dst, src net.IP) []byte { 18 | localPort := uint16(1337) 19 | seq := uint16(0) 20 | 21 | payload := make([]byte, 4) 22 | binary.BigEndian.PutUint16(payload[0:], localPort) 23 | binary.BigEndian.PutUint16(payload[2:], seq) 24 | 25 | return genICMPv4(payload, dst, src) 26 | } 27 | 28 | // Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. 29 | func checksum(buf []byte, initial uint16) uint16 { 30 | v := uint32(initial) 31 | for i := 0; i < len(buf)-1; i += 2 { 32 | v += uint32(binary.BigEndian.Uint16(buf[i:])) 33 | } 34 | if len(buf)%2 == 1 { 35 | v += uint32(buf[len(buf)-1]) << 8 36 | } 37 | for v > 0xffff { 38 | v = (v >> 16) + (v & 0xffff) 39 | } 40 | return ^uint16(v) 41 | } 42 | 43 | func genICMPv4(payload []byte, dst, src net.IP) []byte { 44 | const ( 45 | icmpv4ProtocolNumber = 1 46 | icmpv4Echo = 8 47 | icmpv4ChecksumOffset = 2 48 | icmpv4Size = 8 49 | ipv4Size = 20 50 | ipv4TotalLenOffset = 2 51 | ipv4ChecksumOffset = 10 52 | ttl = 65 53 | headerSize = ipv4Size + icmpv4Size 54 | ) 55 | 56 | pkt := make([]byte, headerSize+len(payload)) 57 | 58 | ip := pkt[0:ipv4Size] 59 | icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size] 60 | 61 | // https://tools.ietf.org/html/rfc792 62 | icmpv4[0] = icmpv4Echo // type 63 | icmpv4[1] = 0 // code 64 | chksum := ^checksum(icmpv4, checksum(payload, 0)) 65 | binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) 66 | 67 | // https://tools.ietf.org/html/rfc760 section 3.1 68 | length := uint16(len(pkt)) 69 | ip[0] = (4 << 4) | (ipv4Size / 4) 70 | binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) 71 | ip[8] = ttl 72 | ip[9] = icmpv4ProtocolNumber 73 | copy(ip[12:], src.To4()) 74 | copy(ip[16:], dst.To4()) 75 | chksum = ^checksum(ip[:], 0) 76 | binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) 77 | 78 | copy(pkt[headerSize:], payload) 79 | return pkt 80 | } 81 | 82 | type ChannelTUN struct { 83 | Inbound chan []byte // incoming packets, closed on TUN close 84 | Outbound chan []byte // outbound packets, blocks forever on TUN close 85 | 86 | closed chan struct{} 87 | events chan tun.Event 88 | tun chTun 89 | } 90 | 91 | func NewChannelTUN() *ChannelTUN { 92 | c := &ChannelTUN{ 93 | Inbound: make(chan []byte), 94 | Outbound: make(chan []byte), 95 | closed: make(chan struct{}), 96 | events: make(chan tun.Event, 1), 97 | } 98 | c.tun.c = c 99 | c.events <- tun.EventUp 100 | return c 101 | } 102 | 103 | func (c *ChannelTUN) TUN() tun.Device { 104 | return &c.tun 105 | } 106 | 107 | type chTun struct { 108 | c *ChannelTUN 109 | } 110 | 111 | func (t *chTun) File() *os.File { return nil } 112 | 113 | func (t *chTun) Read(data []byte, offset int) (int, error) { 114 | select { 115 | case <-t.c.closed: 116 | return 0, os.ErrClosed 117 | case msg := <-t.c.Outbound: 118 | return copy(data[offset:], msg), nil 119 | } 120 | } 121 | 122 | // Write is called by the wireguard device to deliver a packet for routing. 123 | func (t *chTun) Write(data []byte, offset int) (int, error) { 124 | if offset == -1 { 125 | close(t.c.closed) 126 | close(t.c.events) 127 | return 0, io.EOF 128 | } 129 | msg := make([]byte, len(data)-offset) 130 | copy(msg, data[offset:]) 131 | select { 132 | case <-t.c.closed: 133 | return 0, os.ErrClosed 134 | case t.c.Inbound <- msg: 135 | return len(data) - offset, nil 136 | } 137 | } 138 | 139 | const DefaultMTU = 1420 140 | 141 | func (t *chTun) Flush() error { return nil } 142 | func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } 143 | func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } 144 | func (t *chTun) Events() chan tun.Event { return t.c.events } 145 | func (t *chTun) Close() error { 146 | t.Write(nil, -1) 147 | return nil 148 | } 149 | -------------------------------------------------------------------------------- /tun/wintun/dll_fromfile_windows.go: -------------------------------------------------------------------------------- 1 | // +build !load_wintun_from_rsrc 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package wintun 9 | 10 | import ( 11 | "fmt" 12 | "sync" 13 | "sync/atomic" 14 | "unsafe" 15 | 16 | "golang.org/x/sys/windows" 17 | ) 18 | 19 | type lazyDLL struct { 20 | Name string 21 | mu sync.Mutex 22 | module windows.Handle 23 | onLoad func(d *lazyDLL) 24 | } 25 | 26 | func (d *lazyDLL) Load() error { 27 | if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil { 28 | return nil 29 | } 30 | d.mu.Lock() 31 | defer d.mu.Unlock() 32 | if d.module != 0 { 33 | return nil 34 | } 35 | 36 | const ( 37 | LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200 38 | LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800 39 | ) 40 | module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32) 41 | if err != nil { 42 | return fmt.Errorf("Unable to load library: %w", err) 43 | } 44 | 45 | atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module)) 46 | if d.onLoad != nil { 47 | d.onLoad(d) 48 | } 49 | return nil 50 | } 51 | 52 | func (p *lazyProc) nameToAddr() (uintptr, error) { 53 | return windows.GetProcAddress(p.dll.module, p.Name) 54 | } 55 | -------------------------------------------------------------------------------- /tun/wintun/dll_fromrsrc_windows.go: -------------------------------------------------------------------------------- 1 | // +build load_wintun_from_rsrc 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package wintun 9 | 10 | import ( 11 | "fmt" 12 | "sync" 13 | "sync/atomic" 14 | "unsafe" 15 | 16 | "golang.org/x/sys/windows" 17 | 18 | "golang.zx2c4.com/wireguard/tun/wintun/memmod" 19 | ) 20 | 21 | type lazyDLL struct { 22 | Name string 23 | mu sync.Mutex 24 | module *memmod.Module 25 | onLoad func(d *lazyDLL) 26 | } 27 | 28 | func (d *lazyDLL) Load() error { 29 | if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil { 30 | return nil 31 | } 32 | d.mu.Lock() 33 | defer d.mu.Unlock() 34 | if d.module != nil { 35 | return nil 36 | } 37 | 38 | const ourModule windows.Handle = 0 39 | resInfo, err := windows.FindResource(ourModule, d.Name, windows.RT_RCDATA) 40 | if err != nil { 41 | return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err) 42 | } 43 | data, err := windows.LoadResourceData(ourModule, resInfo) 44 | if err != nil { 45 | return fmt.Errorf("Unable to load resource: %w", err) 46 | } 47 | module, err := memmod.LoadLibrary(data) 48 | if err != nil { 49 | return fmt.Errorf("Unable to load library: %w", err) 50 | } 51 | 52 | atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module)) 53 | if d.onLoad != nil { 54 | d.onLoad(d) 55 | } 56 | return nil 57 | } 58 | 59 | func (p *lazyProc) nameToAddr() (uintptr, error) { 60 | return p.dll.module.ProcAddressByName(p.Name) 61 | } 62 | -------------------------------------------------------------------------------- /tun/wintun/dll_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package wintun 7 | 8 | import ( 9 | "fmt" 10 | "sync" 11 | "sync/atomic" 12 | "unsafe" 13 | ) 14 | 15 | func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL { 16 | return &lazyDLL{Name: name, onLoad: onLoad} 17 | } 18 | 19 | func (d *lazyDLL) NewProc(name string) *lazyProc { 20 | return &lazyProc{dll: d, Name: name} 21 | } 22 | 23 | type lazyProc struct { 24 | Name string 25 | mu sync.Mutex 26 | dll *lazyDLL 27 | addr uintptr 28 | } 29 | 30 | func (p *lazyProc) Find() error { 31 | if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil { 32 | return nil 33 | } 34 | p.mu.Lock() 35 | defer p.mu.Unlock() 36 | if p.addr != 0 { 37 | return nil 38 | } 39 | 40 | err := p.dll.Load() 41 | if err != nil { 42 | return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err) 43 | } 44 | addr, err := p.nameToAddr() 45 | if err != nil { 46 | return fmt.Errorf("Error getting %v address: %w", p.Name, err) 47 | } 48 | 49 | atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr)) 50 | return nil 51 | } 52 | 53 | func (p *lazyProc) Addr() uintptr { 54 | err := p.Find() 55 | if err != nil { 56 | panic(err) 57 | } 58 | return p.addr 59 | } 60 | -------------------------------------------------------------------------------- /tun/wintun/memmod/memmod_windows_32.go: -------------------------------------------------------------------------------- 1 | // +build windows,386 windows,arm 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package memmod 9 | 10 | func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr { 11 | return 0 12 | } 13 | 14 | func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) { 15 | return 16 | } 17 | -------------------------------------------------------------------------------- /tun/wintun/memmod/memmod_windows_386.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package memmod 7 | 8 | const imageFileProcess = IMAGE_FILE_MACHINE_I386 9 | -------------------------------------------------------------------------------- /tun/wintun/memmod/memmod_windows_64.go: -------------------------------------------------------------------------------- 1 | // +build windows,amd64 windows,arm64 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package memmod 9 | 10 | import ( 11 | "fmt" 12 | 13 | "golang.org/x/sys/windows" 14 | ) 15 | 16 | func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr { 17 | return uintptr(opthdr.ImageBase & 0xffffffff00000000) 18 | } 19 | 20 | func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) { 21 | for (module.codeBase >> 32) < ((module.codeBase + alignedImageSize) >> 32) { 22 | node := &addressList{ 23 | next: module.blockedMemory, 24 | address: module.codeBase, 25 | } 26 | module.blockedMemory = node 27 | module.codeBase, err = windows.VirtualAlloc(0, 28 | alignedImageSize, 29 | windows.MEM_RESERVE|windows.MEM_COMMIT, 30 | windows.PAGE_READWRITE) 31 | if err != nil { 32 | return fmt.Errorf("Error allocating memory block: %w", err) 33 | } 34 | } 35 | return 36 | } 37 | -------------------------------------------------------------------------------- /tun/wintun/memmod/memmod_windows_amd64.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package memmod 7 | 8 | const imageFileProcess = IMAGE_FILE_MACHINE_AMD64 9 | -------------------------------------------------------------------------------- /tun/wintun/memmod/memmod_windows_arm.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package memmod 7 | 8 | const imageFileProcess = IMAGE_FILE_MACHINE_ARMNT 9 | -------------------------------------------------------------------------------- /tun/wintun/memmod/memmod_windows_arm64.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package memmod 7 | 8 | const imageFileProcess = IMAGE_FILE_MACHINE_ARM64 9 | -------------------------------------------------------------------------------- /tun/wintun/memmod/syscall_windows_32.go: -------------------------------------------------------------------------------- 1 | // +build windows,386 windows,arm 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package memmod 9 | 10 | // Optional header format 11 | type IMAGE_OPTIONAL_HEADER struct { 12 | Magic uint16 13 | MajorLinkerVersion uint8 14 | MinorLinkerVersion uint8 15 | SizeOfCode uint32 16 | SizeOfInitializedData uint32 17 | SizeOfUninitializedData uint32 18 | AddressOfEntryPoint uint32 19 | BaseOfCode uint32 20 | BaseOfData uint32 21 | ImageBase uintptr 22 | SectionAlignment uint32 23 | FileAlignment uint32 24 | MajorOperatingSystemVersion uint16 25 | MinorOperatingSystemVersion uint16 26 | MajorImageVersion uint16 27 | MinorImageVersion uint16 28 | MajorSubsystemVersion uint16 29 | MinorSubsystemVersion uint16 30 | Win32VersionValue uint32 31 | SizeOfImage uint32 32 | SizeOfHeaders uint32 33 | CheckSum uint32 34 | Subsystem uint16 35 | DllCharacteristics uint16 36 | SizeOfStackReserve uintptr 37 | SizeOfStackCommit uintptr 38 | SizeOfHeapReserve uintptr 39 | SizeOfHeapCommit uintptr 40 | LoaderFlags uint32 41 | NumberOfRvaAndSizes uint32 42 | DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY 43 | } 44 | 45 | const IMAGE_ORDINAL_FLAG uintptr = 0x80000000 46 | -------------------------------------------------------------------------------- /tun/wintun/memmod/syscall_windows_64.go: -------------------------------------------------------------------------------- 1 | // +build windows,amd64 windows,arm64 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package memmod 9 | 10 | // Optional header format 11 | type IMAGE_OPTIONAL_HEADER struct { 12 | Magic uint16 13 | MajorLinkerVersion uint8 14 | MinorLinkerVersion uint8 15 | SizeOfCode uint32 16 | SizeOfInitializedData uint32 17 | SizeOfUninitializedData uint32 18 | AddressOfEntryPoint uint32 19 | BaseOfCode uint32 20 | ImageBase uintptr 21 | SectionAlignment uint32 22 | FileAlignment uint32 23 | MajorOperatingSystemVersion uint16 24 | MinorOperatingSystemVersion uint16 25 | MajorImageVersion uint16 26 | MinorImageVersion uint16 27 | MajorSubsystemVersion uint16 28 | MinorSubsystemVersion uint16 29 | Win32VersionValue uint32 30 | SizeOfImage uint32 31 | SizeOfHeaders uint32 32 | CheckSum uint32 33 | Subsystem uint16 34 | DllCharacteristics uint16 35 | SizeOfStackReserve uintptr 36 | SizeOfStackCommit uintptr 37 | SizeOfHeapReserve uintptr 38 | SizeOfHeapCommit uintptr 39 | LoaderFlags uint32 40 | NumberOfRvaAndSizes uint32 41 | DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY 42 | } 43 | 44 | const IMAGE_ORDINAL_FLAG uintptr = 0x8000000000000000 45 | -------------------------------------------------------------------------------- /tun/wintun/session_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package wintun 7 | 8 | import ( 9 | "syscall" 10 | "unsafe" 11 | 12 | "golang.org/x/sys/windows" 13 | ) 14 | 15 | type Session struct { 16 | handle uintptr 17 | } 18 | 19 | const ( 20 | PacketSizeMax = 0xffff // Maximum packet size 21 | RingCapacityMin = 0x20000 // Minimum ring capacity (128 kiB) 22 | RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB) 23 | ) 24 | 25 | // Packet with data 26 | type Packet struct { 27 | Next *Packet // Pointer to next packet in queue 28 | Size uint32 // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE) 29 | Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet 30 | } 31 | 32 | var ( 33 | procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket") 34 | procWintunEndSession = modwintun.NewProc("WintunEndSession") 35 | procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent") 36 | procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket") 37 | procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket") 38 | procWintunSendPacket = modwintun.NewProc("WintunSendPacket") 39 | procWintunStartSession = modwintun.NewProc("WintunStartSession") 40 | ) 41 | 42 | func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) { 43 | r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0) 44 | if r0 == 0 { 45 | err = e1 46 | } else { 47 | session = Session{r0} 48 | } 49 | return 50 | } 51 | 52 | func (session Session) End() { 53 | syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0) 54 | session.handle = 0 55 | } 56 | 57 | func (session Session) ReadWaitEvent() (handle windows.Handle) { 58 | r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0) 59 | handle = windows.Handle(r0) 60 | return 61 | } 62 | 63 | func (session Session) ReceivePacket() (packet []byte, err error) { 64 | var packetSize uint32 65 | r0, _, e1 := syscall.Syscall(procWintunReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0) 66 | if r0 == 0 { 67 | err = e1 68 | return 69 | } 70 | unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize)) 71 | return 72 | } 73 | 74 | func (session Session) ReleaseReceivePacket(packet []byte) { 75 | syscall.Syscall(procWintunReleaseReceivePacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) 76 | } 77 | 78 | func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) { 79 | r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket.Addr(), 2, session.handle, uintptr(packetSize), 0) 80 | if r0 == 0 { 81 | err = e1 82 | return 83 | } 84 | unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize)) 85 | return 86 | } 87 | 88 | func (session Session) SendPacket(packet []byte) { 89 | syscall.Syscall(procWintunSendPacket.Addr(), 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) 90 | } 91 | 92 | // unsafeSlice updates the slice slicePtr to be a slice 93 | // referencing the provided data with its length & capacity set to 94 | // lenCap. 95 | // 96 | // TODO: when Go 1.16 or Go 1.17 is the minimum supported version, 97 | // update callers to use unsafe.Slice instead of this. 98 | func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) { 99 | type sliceHeader struct { 100 | Data unsafe.Pointer 101 | Len int 102 | Cap int 103 | } 104 | h := (*sliceHeader)(slicePtr) 105 | h.Data = data 106 | h.Len = lenCap 107 | h.Cap = lenCap 108 | } 109 | -------------------------------------------------------------------------------- /tun/wintun/wintun_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package wintun 7 | 8 | import ( 9 | "errors" 10 | "log" 11 | "runtime" 12 | "syscall" 13 | "unsafe" 14 | 15 | "golang.org/x/sys/windows" 16 | ) 17 | 18 | type loggerLevel int 19 | 20 | const ( 21 | logInfo loggerLevel = iota 22 | logWarn 23 | logErr 24 | ) 25 | 26 | const ( 27 | PoolNameMax = 256 28 | AdapterNameMax = 128 29 | ) 30 | 31 | type Pool [PoolNameMax]uint16 32 | type Adapter struct { 33 | handle uintptr 34 | } 35 | 36 | var ( 37 | modwintun = newLazyDLL("wintun.dll", setupLogger) 38 | 39 | procWintunCreateAdapter = modwintun.NewProc("WintunCreateAdapter") 40 | procWintunDeleteAdapter = modwintun.NewProc("WintunDeleteAdapter") 41 | procWintunDeletePoolDriver = modwintun.NewProc("WintunDeletePoolDriver") 42 | procWintunEnumAdapters = modwintun.NewProc("WintunEnumAdapters") 43 | procWintunFreeAdapter = modwintun.NewProc("WintunFreeAdapter") 44 | procWintunOpenAdapter = modwintun.NewProc("WintunOpenAdapter") 45 | procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID") 46 | procWintunGetAdapterName = modwintun.NewProc("WintunGetAdapterName") 47 | procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion") 48 | procWintunSetAdapterName = modwintun.NewProc("WintunSetAdapterName") 49 | ) 50 | 51 | func setupLogger(dll *lazyDLL) { 52 | syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, windows.NewCallback(func(level loggerLevel, msg *uint16) int { 53 | log.Println("[Wintun]", windows.UTF16PtrToString(msg)) 54 | return 0 55 | }), 0, 0) 56 | } 57 | 58 | func MakePool(poolName string) (pool *Pool, err error) { 59 | poolName16, err := windows.UTF16FromString(poolName) 60 | if err != nil { 61 | return 62 | } 63 | if len(poolName16) > PoolNameMax { 64 | err = errors.New("Pool name too long") 65 | return 66 | } 67 | pool = &Pool{} 68 | copy(pool[:], poolName16) 69 | return 70 | } 71 | 72 | func (pool *Pool) String() string { 73 | return windows.UTF16ToString(pool[:]) 74 | } 75 | 76 | func freeAdapter(wintun *Adapter) { 77 | syscall.Syscall(procWintunFreeAdapter.Addr(), 1, uintptr(wintun.handle), 0, 0) 78 | } 79 | 80 | // OpenAdapter finds a Wintun adapter by its name. This function returns the adapter if found, or 81 | // windows.ERROR_FILE_NOT_FOUND otherwise. If the adapter is found but not a Wintun-class or a 82 | // member of the pool, this function returns windows.ERROR_ALREADY_EXISTS. The adapter must be 83 | // released after use. 84 | func (pool *Pool) OpenAdapter(ifname string) (wintun *Adapter, err error) { 85 | ifname16, err := windows.UTF16PtrFromString(ifname) 86 | if err != nil { 87 | return nil, err 88 | } 89 | r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), 0) 90 | if r0 == 0 { 91 | err = e1 92 | return 93 | } 94 | wintun = &Adapter{r0} 95 | runtime.SetFinalizer(wintun, freeAdapter) 96 | return 97 | } 98 | 99 | // CreateAdapter creates a Wintun adapter. ifname is the requested name of the adapter, while 100 | // requestedGUID is the GUID of the created network adapter, which then influences NLA generation 101 | // deterministically. If it is set to nil, the GUID is chosen by the system at random, and hence a 102 | // new NLA entry is created for each new adapter. It is called "requested" GUID because the API it 103 | // uses is completely undocumented, and so there could be minor interesting complications with its 104 | // usage. This function returns the network adapter ID and a flag if reboot is required. 105 | func (pool *Pool) CreateAdapter(ifname string, requestedGUID *windows.GUID) (wintun *Adapter, rebootRequired bool, err error) { 106 | var ifname16 *uint16 107 | ifname16, err = windows.UTF16PtrFromString(ifname) 108 | if err != nil { 109 | return 110 | } 111 | var _p0 uint32 112 | r0, _, e1 := syscall.Syscall6(procWintunCreateAdapter.Addr(), 4, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(ifname16)), uintptr(unsafe.Pointer(requestedGUID)), uintptr(unsafe.Pointer(&_p0)), 0, 0) 113 | rebootRequired = _p0 != 0 114 | if r0 == 0 { 115 | err = e1 116 | return 117 | } 118 | wintun = &Adapter{r0} 119 | runtime.SetFinalizer(wintun, freeAdapter) 120 | return 121 | } 122 | 123 | // Delete deletes a Wintun adapter. This function succeeds if the adapter was not found. It returns 124 | // a bool indicating whether a reboot is required. 125 | func (wintun *Adapter) Delete(forceCloseSessions bool) (rebootRequired bool, err error) { 126 | var _p0 uint32 127 | if forceCloseSessions { 128 | _p0 = 1 129 | } 130 | var _p1 uint32 131 | r1, _, e1 := syscall.Syscall(procWintunDeleteAdapter.Addr(), 3, uintptr(wintun.handle), uintptr(_p0), uintptr(unsafe.Pointer(&_p1))) 132 | rebootRequired = _p1 != 0 133 | if r1 == 0 { 134 | err = e1 135 | } 136 | return 137 | } 138 | 139 | // DeleteMatchingAdapters deletes all Wintun adapters, which match 140 | // given criteria, and returns which ones it deleted, whether a reboot 141 | // is required after, and which errors occurred during the process. 142 | func (pool *Pool) DeleteMatchingAdapters(matches func(adapter *Adapter) bool, forceCloseSessions bool) (rebootRequired bool, errors []error) { 143 | cb := func(handle uintptr, _ uintptr) int { 144 | adapter := &Adapter{handle} 145 | if !matches(adapter) { 146 | return 1 147 | } 148 | rebootRequired2, err := adapter.Delete(forceCloseSessions) 149 | if err != nil { 150 | errors = append(errors, err) 151 | return 1 152 | } 153 | rebootRequired = rebootRequired || rebootRequired2 154 | return 1 155 | } 156 | r1, _, e1 := syscall.Syscall(procWintunEnumAdapters.Addr(), 3, uintptr(unsafe.Pointer(pool)), uintptr(windows.NewCallback(cb)), 0) 157 | if r1 == 0 { 158 | errors = append(errors, e1) 159 | } 160 | return 161 | } 162 | 163 | // Name returns the name of the Wintun adapter. 164 | func (wintun *Adapter) Name() (ifname string, err error) { 165 | var ifname16 [AdapterNameMax]uint16 166 | r1, _, e1 := syscall.Syscall(procWintunGetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0) 167 | if r1 == 0 { 168 | err = e1 169 | return 170 | } 171 | ifname = windows.UTF16ToString(ifname16[:]) 172 | return 173 | } 174 | 175 | // DeleteDriver deletes all Wintun adapters in a pool and if there are no more adapters in any other 176 | // pools, also removes Wintun from the driver store, usually called by uninstallers. 177 | func (pool *Pool) DeleteDriver() (rebootRequired bool, err error) { 178 | var _p0 uint32 179 | r1, _, e1 := syscall.Syscall(procWintunDeletePoolDriver.Addr(), 2, uintptr(unsafe.Pointer(pool)), uintptr(unsafe.Pointer(&_p0)), 0) 180 | rebootRequired = _p0 != 0 181 | if r1 == 0 { 182 | err = e1 183 | } 184 | return 185 | 186 | } 187 | 188 | // SetName sets name of the Wintun adapter. 189 | func (wintun *Adapter) SetName(ifname string) (err error) { 190 | ifname16, err := windows.UTF16FromString(ifname) 191 | if err != nil { 192 | return err 193 | } 194 | r1, _, e1 := syscall.Syscall(procWintunSetAdapterName.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&ifname16[0])), 0) 195 | if r1 == 0 { 196 | err = e1 197 | } 198 | return 199 | } 200 | 201 | // RunningVersion returns the version of the running Wintun driver. 202 | func RunningVersion() (version uint32, err error) { 203 | r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0) 204 | version = uint32(r0) 205 | if version == 0 { 206 | err = e1 207 | } 208 | return 209 | } 210 | 211 | // LUID returns the LUID of the adapter. 212 | func (wintun *Adapter) LUID() (luid uint64) { 213 | syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0) 214 | return 215 | } 216 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | const Version = "0.0.20210424" 4 | --------------------------------------------------------------------------------