├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── conn ├── bind_linux.go ├── bind_std.go ├── bind_windows.go ├── bindtest │ └── bindtest.go ├── boundif_android.go ├── conn.go ├── default.go ├── mark_default.go ├── mark_unix.go └── winrio │ └── rio_windows.go ├── device ├── alignment_test.go ├── allowedips.go ├── allowedips_rand_test.go ├── allowedips_test.go ├── bind_test.go ├── channels.go ├── constants.go ├── cookie.go ├── cookie_test.go ├── device.go ├── device_test.go ├── devicestate_string.go ├── endpoint_test.go ├── indextable.go ├── ip.go ├── kdf_test.go ├── keypair.go ├── logger.go ├── misc.go ├── mobilequirks.go ├── noise-helpers.go ├── noise-protocol.go ├── noise-types.go ├── noise_test.go ├── peer.go ├── pools.go ├── pools_test.go ├── queueconstants_android.go ├── queueconstants_default.go ├── queueconstants_ios.go ├── queueconstants_windows.go ├── race_disabled_test.go ├── race_enabled_test.go ├── receive.go ├── send.go ├── sticky_default.go ├── sticky_linux.go ├── timers.go ├── tun.go └── uapi.go ├── format_test.go ├── go.mod ├── go.sum ├── ipc ├── namedpipe │ ├── file.go │ ├── namedpipe.go │ └── namedpipe_test.go ├── uapi_bsd.go ├── uapi_js.go ├── uapi_linux.go ├── uapi_unix.go └── uapi_windows.go ├── main.go ├── main_windows.go ├── ratelimiter ├── ratelimiter.go └── ratelimiter_test.go ├── replay ├── replay.go └── replay_test.go ├── rwcancel ├── rwcancel.go └── rwcancel_stub.go ├── tai64n ├── tai64n.go └── tai64n_test.go ├── tests └── netns.sh ├── tun ├── alignment_windows_test.go ├── netstack │ ├── examples │ │ ├── http_client.go │ │ ├── http_server.go │ │ └── ping_client.go │ ├── go.mod │ ├── go.sum │ └── tun.go ├── operateonfd.go ├── tun.go ├── tun_darwin.go ├── tun_freebsd.go ├── tun_linux.go ├── tun_openbsd.go ├── tun_windows.go └── tuntest │ └── tuntest.go └── version.go /.gitignore: -------------------------------------------------------------------------------- 1 | wireguard-go 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy of 2 | this software and associated documentation files (the "Software"), to deal in 3 | the Software without restriction, including without limitation the rights to 4 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 5 | of the Software, and to permit persons to whom the Software is furnished to do 6 | so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all 9 | copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. 18 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PREFIX ?= /usr 2 | DESTDIR ?= 3 | BINDIR ?= $(PREFIX)/bin 4 | export GO111MODULE := on 5 | 6 | all: generate-version-and-build 7 | 8 | MAKEFLAGS += --no-print-directory 9 | 10 | generate-version-and-build: 11 | @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ 12 | tag="$$(git describe --dirty 2>/dev/null)" && \ 13 | ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \ 14 | [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ 15 | echo "$$ver" > version.go && \ 16 | git update-index --assume-unchanged version.go || true 17 | @$(MAKE) wireguard-go 18 | 19 | wireguard-go: $(wildcard *.go) $(wildcard */*.go) 20 | go build -v -o "$@" 21 | 22 | install: wireguard-go 23 | @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go" 24 | 25 | test: 26 | go test ./... 27 | 28 | clean: 29 | rm -f wireguard-go 30 | 31 | .PHONY: all clean test install generate-version-and-build 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Go Implementation of [WireGuard](https://www.wireguard.com/) 2 | 3 | This is an implementation of WireGuard in Go. 4 | 5 | ## Usage 6 | 7 | Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run: 8 | 9 | ``` 10 | $ wireguard-go wg0 11 | ``` 12 | 13 | This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-go shutting down. 14 | 15 | To run wireguard-go without forking to the background, pass `-f` or `--foreground`: 16 | 17 | ``` 18 | $ wireguard-go -f wg0 19 | ``` 20 | 21 | When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands. 22 | 23 | To run with more logging you may set the environment variable `LOG_LEVEL=debug`. 24 | 25 | ## Platforms 26 | 27 | ### Linux 28 | 29 | This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions. 30 | 31 | ### macOS 32 | 33 | This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. 34 | 35 | ### Windows 36 | 37 | This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module. 38 | 39 | ### FreeBSD 40 | 41 | This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`. 42 | 43 | ### OpenBSD 44 | 45 | This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. 46 | 47 | ## Building 48 | 49 | This requires an installation of [go](https://golang.org) ≥ 1.18. 50 | 51 | ``` 52 | $ git clone https://git.zx2c4.com/wireguard-go 53 | $ cd wireguard-go 54 | $ make 55 | ``` 56 | 57 | ## License 58 | 59 | Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 60 | 61 | Permission is hereby granted, free of charge, to any person obtaining a copy of 62 | this software and associated documentation files (the "Software"), to deal in 63 | the Software without restriction, including without limitation the rights to 64 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 65 | of the Software, and to permit persons to whom the Software is furnished to do 66 | so, subject to the following conditions: 67 | 68 | The above copyright notice and this permission notice shall be included in all 69 | copies or substantial portions of the Software. 70 | 71 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 72 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 73 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 74 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 75 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 76 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 77 | SOFTWARE. 78 | -------------------------------------------------------------------------------- /conn/bind_std.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | import ( 9 | "errors" 10 | "net" 11 | "net/netip" 12 | "sync" 13 | "syscall" 14 | ) 15 | 16 | // StdNetBind is meant to be a temporary solution on platforms for which 17 | // the sticky socket / source caching behavior has not yet been implemented. 18 | // It uses the Go's net package to implement networking. 19 | // See LinuxSocketBind for a proper implementation on the Linux platform. 20 | type StdNetBind struct { 21 | mu sync.Mutex // protects following fields 22 | ipv4 *net.UDPConn 23 | ipv6 *net.UDPConn 24 | blackhole4 bool 25 | blackhole6 bool 26 | } 27 | 28 | func NewStdNetBind() Bind { return &StdNetBind{} } 29 | 30 | type StdNetEndpoint netip.AddrPort 31 | 32 | var ( 33 | _ Bind = (*StdNetBind)(nil) 34 | _ Endpoint = StdNetEndpoint{} 35 | ) 36 | 37 | func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 38 | e, err := netip.ParseAddrPort(s) 39 | return asEndpoint(e), err 40 | } 41 | 42 | func (StdNetEndpoint) ClearSrc() {} 43 | 44 | func (e StdNetEndpoint) DstIP() netip.Addr { 45 | return (netip.AddrPort)(e).Addr() 46 | } 47 | 48 | func (e StdNetEndpoint) SrcIP() netip.Addr { 49 | return netip.Addr{} // not supported 50 | } 51 | 52 | func (e StdNetEndpoint) DstToBytes() []byte { 53 | b, _ := (netip.AddrPort)(e).MarshalBinary() 54 | return b 55 | } 56 | 57 | func (e StdNetEndpoint) DstToString() string { 58 | return (netip.AddrPort)(e).String() 59 | } 60 | 61 | func (e StdNetEndpoint) SrcToString() string { 62 | return "" 63 | } 64 | 65 | func listenNet(network string, port int) (*net.UDPConn, int, error) { 66 | conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) 67 | if err != nil { 68 | return nil, 0, err 69 | } 70 | 71 | // Retrieve port. 72 | laddr := conn.LocalAddr() 73 | uaddr, err := net.ResolveUDPAddr( 74 | laddr.Network(), 75 | laddr.String(), 76 | ) 77 | if err != nil { 78 | return nil, 0, err 79 | } 80 | return conn, uaddr.Port, nil 81 | } 82 | 83 | func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 84 | bind.mu.Lock() 85 | defer bind.mu.Unlock() 86 | 87 | var err error 88 | var tries int 89 | 90 | if bind.ipv4 != nil || bind.ipv6 != nil { 91 | return nil, 0, ErrBindAlreadyOpen 92 | } 93 | 94 | // Attempt to open ipv4 and ipv6 listeners on the same port. 95 | // If uport is 0, we can retry on failure. 96 | again: 97 | port := int(uport) 98 | var ipv4, ipv6 *net.UDPConn 99 | 100 | ipv4, port, err = listenNet("udp4", port) 101 | if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 102 | return nil, 0, err 103 | } 104 | 105 | // Listen on the same port as we're using for ipv4. 106 | ipv6, port, err = listenNet("udp6", port) 107 | if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 108 | ipv4.Close() 109 | tries++ 110 | goto again 111 | } 112 | if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 113 | ipv4.Close() 114 | return nil, 0, err 115 | } 116 | var fns []ReceiveFunc 117 | if ipv4 != nil { 118 | fns = append(fns, bind.makeReceiveIPv4(ipv4)) 119 | bind.ipv4 = ipv4 120 | } 121 | if ipv6 != nil { 122 | fns = append(fns, bind.makeReceiveIPv6(ipv6)) 123 | bind.ipv6 = ipv6 124 | } 125 | if len(fns) == 0 { 126 | return nil, 0, syscall.EAFNOSUPPORT 127 | } 128 | return fns, uint16(port), nil 129 | } 130 | 131 | func (bind *StdNetBind) Close() error { 132 | bind.mu.Lock() 133 | defer bind.mu.Unlock() 134 | 135 | var err1, err2 error 136 | if bind.ipv4 != nil { 137 | err1 = bind.ipv4.Close() 138 | bind.ipv4 = nil 139 | } 140 | if bind.ipv6 != nil { 141 | err2 = bind.ipv6.Close() 142 | bind.ipv6 = nil 143 | } 144 | bind.blackhole4 = false 145 | bind.blackhole6 = false 146 | if err1 != nil { 147 | return err1 148 | } 149 | return err2 150 | } 151 | 152 | func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { 153 | return func(buff []byte) (int, Endpoint, error) { 154 | n, endpoint, err := conn.ReadFromUDPAddrPort(buff) 155 | return n, asEndpoint(endpoint), err 156 | } 157 | } 158 | 159 | func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { 160 | return func(buff []byte) (int, Endpoint, error) { 161 | n, endpoint, err := conn.ReadFromUDPAddrPort(buff) 162 | return n, asEndpoint(endpoint), err 163 | } 164 | } 165 | 166 | func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { 167 | var err error 168 | nend, ok := endpoint.(StdNetEndpoint) 169 | if !ok { 170 | return ErrWrongEndpointType 171 | } 172 | addrPort := netip.AddrPort(nend) 173 | 174 | bind.mu.Lock() 175 | blackhole := bind.blackhole4 176 | conn := bind.ipv4 177 | if addrPort.Addr().Is6() { 178 | blackhole = bind.blackhole6 179 | conn = bind.ipv6 180 | } 181 | bind.mu.Unlock() 182 | 183 | if blackhole { 184 | return nil 185 | } 186 | if conn == nil { 187 | return syscall.EAFNOSUPPORT 188 | } 189 | _, err = conn.WriteToUDPAddrPort(buff, addrPort) 190 | return err 191 | } 192 | 193 | // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. 194 | // This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates, 195 | // but Endpoints are immutable, so we can re-use them. 196 | var endpointPool = sync.Pool{ 197 | New: func() any { 198 | return make(map[netip.AddrPort]Endpoint) 199 | }, 200 | } 201 | 202 | // asEndpoint returns an Endpoint containing ap. 203 | func asEndpoint(ap netip.AddrPort) Endpoint { 204 | m := endpointPool.Get().(map[netip.AddrPort]Endpoint) 205 | defer endpointPool.Put(m) 206 | e, ok := m[ap] 207 | if !ok { 208 | e = Endpoint(StdNetEndpoint(ap)) 209 | m[ap] = e 210 | } 211 | return e 212 | } 213 | -------------------------------------------------------------------------------- /conn/bindtest/bindtest.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package bindtest 7 | 8 | import ( 9 | "fmt" 10 | "math/rand" 11 | "net" 12 | "net/netip" 13 | "os" 14 | 15 | "golang.zx2c4.com/wireguard/conn" 16 | ) 17 | 18 | type ChannelBind struct { 19 | rx4, tx4 *chan []byte 20 | rx6, tx6 *chan []byte 21 | closeSignal chan bool 22 | source4, source6 ChannelEndpoint 23 | target4, target6 ChannelEndpoint 24 | } 25 | 26 | type ChannelEndpoint uint16 27 | 28 | var ( 29 | _ conn.Bind = (*ChannelBind)(nil) 30 | _ conn.Endpoint = (*ChannelEndpoint)(nil) 31 | ) 32 | 33 | func NewChannelBinds() [2]conn.Bind { 34 | arx4 := make(chan []byte, 8192) 35 | brx4 := make(chan []byte, 8192) 36 | arx6 := make(chan []byte, 8192) 37 | brx6 := make(chan []byte, 8192) 38 | var binds [2]ChannelBind 39 | binds[0].rx4 = &arx4 40 | binds[0].tx4 = &brx4 41 | binds[1].rx4 = &brx4 42 | binds[1].tx4 = &arx4 43 | binds[0].rx6 = &arx6 44 | binds[0].tx6 = &brx6 45 | binds[1].rx6 = &brx6 46 | binds[1].tx6 = &arx6 47 | binds[0].target4 = ChannelEndpoint(1) 48 | binds[1].target4 = ChannelEndpoint(2) 49 | binds[0].target6 = ChannelEndpoint(3) 50 | binds[1].target6 = ChannelEndpoint(4) 51 | binds[0].source4 = binds[1].target4 52 | binds[0].source6 = binds[1].target6 53 | binds[1].source4 = binds[0].target4 54 | binds[1].source6 = binds[0].target6 55 | return [2]conn.Bind{&binds[0], &binds[1]} 56 | } 57 | 58 | func (c ChannelEndpoint) ClearSrc() {} 59 | 60 | func (c ChannelEndpoint) SrcToString() string { return "" } 61 | 62 | func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } 63 | 64 | func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } 65 | 66 | func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } 67 | 68 | func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } 69 | 70 | func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 71 | c.closeSignal = make(chan bool) 72 | fns = append(fns, c.makeReceiveFunc(*c.rx4)) 73 | fns = append(fns, c.makeReceiveFunc(*c.rx6)) 74 | if rand.Uint32()&1 == 0 { 75 | return fns, uint16(c.source4), nil 76 | } else { 77 | return fns, uint16(c.source6), nil 78 | } 79 | } 80 | 81 | func (c *ChannelBind) Close() error { 82 | if c.closeSignal != nil { 83 | select { 84 | case <-c.closeSignal: 85 | default: 86 | close(c.closeSignal) 87 | } 88 | } 89 | return nil 90 | } 91 | 92 | func (c *ChannelBind) SetMark(mark uint32) error { return nil } 93 | 94 | func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { 95 | return func(b []byte) (n int, ep conn.Endpoint, err error) { 96 | select { 97 | case <-c.closeSignal: 98 | return 0, nil, net.ErrClosed 99 | case rx := <-ch: 100 | return copy(b, rx), c.target6, nil 101 | } 102 | } 103 | } 104 | 105 | func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { 106 | select { 107 | case <-c.closeSignal: 108 | return net.ErrClosed 109 | default: 110 | bc := make([]byte, len(b)) 111 | copy(bc, b) 112 | if ep.(ChannelEndpoint) == c.target4 { 113 | *c.tx4 <- bc 114 | } else if ep.(ChannelEndpoint) == c.target6 { 115 | *c.tx6 <- bc 116 | } else { 117 | return os.ErrInvalid 118 | } 119 | } 120 | return nil 121 | } 122 | 123 | func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { 124 | addr, err := netip.ParseAddrPort(s) 125 | if err != nil { 126 | return nil, err 127 | } 128 | return ChannelEndpoint(addr.Port()), nil 129 | } 130 | -------------------------------------------------------------------------------- /conn/boundif_android.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package conn 7 | 8 | func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { 9 | sysconn, err := bind.ipv4.SyscallConn() 10 | if err != nil { 11 | return -1, err 12 | } 13 | err = sysconn.Control(func(f uintptr) { 14 | fd = int(f) 15 | }) 16 | if err != nil { 17 | return -1, err 18 | } 19 | return 20 | } 21 | 22 | func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { 23 | sysconn, err := bind.ipv6.SyscallConn() 24 | if err != nil { 25 | return -1, err 26 | } 27 | err = sysconn.Control(func(f uintptr) { 28 | fd = int(f) 29 | }) 30 | if err != nil { 31 | return -1, err 32 | } 33 | return 34 | } 35 | -------------------------------------------------------------------------------- /conn/conn.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | // Package conn implements WireGuard's network connections. 7 | package conn 8 | 9 | import ( 10 | "errors" 11 | "fmt" 12 | "net/netip" 13 | "reflect" 14 | "runtime" 15 | "strings" 16 | ) 17 | 18 | // A ReceiveFunc receives a single inbound packet from the network. 19 | // It writes the data into b. n is the length of the packet. 20 | // ep is the remote endpoint. 21 | type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error) 22 | 23 | // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. 24 | // 25 | // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, 26 | // depending on the platform-specific implementation. 27 | type Bind interface { 28 | // Open puts the Bind into a listening state on a given port and reports the actual 29 | // port that it bound to. Passing zero results in a random selection. 30 | // fns is the set of functions that will be called to receive packets. 31 | Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) 32 | 33 | // Close closes the Bind listener. 34 | // All fns returned by Open must return net.ErrClosed after a call to Close. 35 | Close() error 36 | 37 | // SetMark sets the mark for each packet sent through this Bind. 38 | // This mark is passed to the kernel as the socket option SO_MARK. 39 | SetMark(mark uint32) error 40 | 41 | // Send writes a packet b to address ep. 42 | Send(b []byte, ep Endpoint) error 43 | 44 | // ParseEndpoint creates a new endpoint from a string. 45 | ParseEndpoint(s string) (Endpoint, error) 46 | } 47 | 48 | // BindSocketToInterface is implemented by Bind objects that support being 49 | // tied to a single network interface. Used by wireguard-windows. 50 | type BindSocketToInterface interface { 51 | BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error 52 | BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error 53 | } 54 | 55 | // PeekLookAtSocketFd is implemented by Bind objects that support having their 56 | // file descriptor peeked at. Used by wireguard-android. 57 | type PeekLookAtSocketFd interface { 58 | PeekLookAtSocketFd4() (fd int, err error) 59 | PeekLookAtSocketFd6() (fd int, err error) 60 | } 61 | 62 | // An Endpoint maintains the source/destination caching for a peer. 63 | // 64 | // dst: the remote address of a peer ("endpoint" in uapi terminology) 65 | // src: the local address from which datagrams originate going to the peer 66 | type Endpoint interface { 67 | ClearSrc() // clears the source address 68 | SrcToString() string // returns the local source address (ip:port) 69 | DstToString() string // returns the destination address (ip:port) 70 | DstToBytes() []byte // used for mac2 cookie calculations 71 | DstIP() netip.Addr 72 | SrcIP() netip.Addr 73 | } 74 | 75 | var ( 76 | ErrBindAlreadyOpen = errors.New("bind is already open") 77 | ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") 78 | ) 79 | 80 | func (fn ReceiveFunc) PrettyName() string { 81 | name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() 82 | // 0. cheese/taco.beansIPv6.func12.func21218-fm 83 | name = strings.TrimSuffix(name, "-fm") 84 | // 1. cheese/taco.beansIPv6.func12.func21218 85 | if idx := strings.LastIndexByte(name, '/'); idx != -1 { 86 | name = name[idx+1:] 87 | // 2. taco.beansIPv6.func12.func21218 88 | } 89 | for { 90 | var idx int 91 | for idx = len(name) - 1; idx >= 0; idx-- { 92 | if name[idx] < '0' || name[idx] > '9' { 93 | break 94 | } 95 | } 96 | if idx == len(name)-1 { 97 | break 98 | } 99 | const dotFunc = ".func" 100 | if !strings.HasSuffix(name[:idx+1], dotFunc) { 101 | break 102 | } 103 | name = name[:idx+1-len(dotFunc)] 104 | // 3. taco.beansIPv6.func12 105 | // 4. taco.beansIPv6 106 | } 107 | if idx := strings.LastIndexByte(name, '.'); idx != -1 { 108 | name = name[idx+1:] 109 | // 5. beansIPv6 110 | } 111 | if name == "" { 112 | return fmt.Sprintf("%p", fn) 113 | } 114 | if strings.HasSuffix(name, "IPv4") { 115 | return "v4" 116 | } 117 | if strings.HasSuffix(name, "IPv6") { 118 | return "v6" 119 | } 120 | return name 121 | } 122 | -------------------------------------------------------------------------------- /conn/default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux && !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func NewDefaultBind() Bind { return NewStdNetBind() } 11 | -------------------------------------------------------------------------------- /conn/mark_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux && !openbsd && !freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | func (bind *StdNetBind) SetMark(mark uint32) error { 11 | return nil 12 | } 13 | -------------------------------------------------------------------------------- /conn/mark_unix.go: -------------------------------------------------------------------------------- 1 | //go:build linux || openbsd || freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package conn 9 | 10 | import ( 11 | "runtime" 12 | 13 | "golang.org/x/sys/unix" 14 | ) 15 | 16 | var fwmarkIoctl int 17 | 18 | func init() { 19 | switch runtime.GOOS { 20 | case "linux", "android": 21 | fwmarkIoctl = 36 /* unix.SO_MARK */ 22 | case "freebsd": 23 | fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ 24 | case "openbsd": 25 | fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ 26 | } 27 | } 28 | 29 | func (bind *StdNetBind) SetMark(mark uint32) error { 30 | var operr error 31 | if fwmarkIoctl == 0 { 32 | return nil 33 | } 34 | if bind.ipv4 != nil { 35 | fd, err := bind.ipv4.SyscallConn() 36 | if err != nil { 37 | return err 38 | } 39 | err = fd.Control(func(fd uintptr) { 40 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) 41 | }) 42 | if err == nil { 43 | err = operr 44 | } 45 | if err != nil { 46 | return err 47 | } 48 | } 49 | if bind.ipv6 != nil { 50 | fd, err := bind.ipv6.SyscallConn() 51 | if err != nil { 52 | return err 53 | } 54 | err = fd.Control(func(fd uintptr) { 55 | operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) 56 | }) 57 | if err == nil { 58 | err = operr 59 | } 60 | if err != nil { 61 | return err 62 | } 63 | } 64 | return nil 65 | } 66 | -------------------------------------------------------------------------------- /conn/winrio/rio_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package winrio 7 | 8 | import ( 9 | "log" 10 | "sync" 11 | "syscall" 12 | "unsafe" 13 | 14 | "golang.org/x/sys/windows" 15 | ) 16 | 17 | const ( 18 | MsgDontNotify = 1 19 | MsgDefer = 2 20 | MsgWaitAll = 4 21 | MsgCommitOnly = 8 22 | 23 | MaxCqSize = 0x8000000 24 | 25 | invalidBufferId = 0xFFFFFFFF 26 | invalidCq = 0 27 | invalidRq = 0 28 | corruptCq = 0xFFFFFFFF 29 | ) 30 | 31 | var extensionFunctionTable struct { 32 | cbSize uint32 33 | rioReceive uintptr 34 | rioReceiveEx uintptr 35 | rioSend uintptr 36 | rioSendEx uintptr 37 | rioCloseCompletionQueue uintptr 38 | rioCreateCompletionQueue uintptr 39 | rioCreateRequestQueue uintptr 40 | rioDequeueCompletion uintptr 41 | rioDeregisterBuffer uintptr 42 | rioNotify uintptr 43 | rioRegisterBuffer uintptr 44 | rioResizeCompletionQueue uintptr 45 | rioResizeRequestQueue uintptr 46 | } 47 | 48 | type Cq uintptr 49 | 50 | type Rq uintptr 51 | 52 | type BufferId uintptr 53 | 54 | type Buffer struct { 55 | Id BufferId 56 | Offset uint32 57 | Length uint32 58 | } 59 | 60 | type Result struct { 61 | Status int32 62 | BytesTransferred uint32 63 | SocketContext uint64 64 | RequestContext uint64 65 | } 66 | 67 | type notificationCompletionType uint32 68 | 69 | const ( 70 | eventCompletion notificationCompletionType = 1 71 | iocpCompletion notificationCompletionType = 2 72 | ) 73 | 74 | type eventNotificationCompletion struct { 75 | completionType notificationCompletionType 76 | event windows.Handle 77 | notifyReset uint32 78 | } 79 | 80 | type iocpNotificationCompletion struct { 81 | completionType notificationCompletionType 82 | iocp windows.Handle 83 | key uintptr 84 | overlapped *windows.Overlapped 85 | } 86 | 87 | var ( 88 | initialized sync.Once 89 | available bool 90 | ) 91 | 92 | func Initialize() bool { 93 | initialized.Do(func() { 94 | var ( 95 | err error 96 | socket windows.Handle 97 | cq Cq 98 | ) 99 | defer func() { 100 | if err == nil { 101 | return 102 | } 103 | if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 { 104 | return 105 | } 106 | log.Printf("Registered I/O is unavailable: %v", err) 107 | }() 108 | socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP) 109 | if err != nil { 110 | return 111 | } 112 | defer windows.CloseHandle(socket) 113 | WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}} 114 | const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024 115 | ob := uint32(0) 116 | err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER, 117 | (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)), 118 | (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)), 119 | &ob, nil, 0) 120 | if err != nil { 121 | return 122 | } 123 | 124 | // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes 125 | // failures in RIOCreateRequestQueue, so keep going to be certain this is supported. 126 | var iocp windows.Handle 127 | iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 128 | if err != nil { 129 | return 130 | } 131 | defer windows.CloseHandle(iocp) 132 | var overlapped windows.Overlapped 133 | cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped) 134 | if err != nil { 135 | return 136 | } 137 | defer CloseCompletionQueue(cq) 138 | _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0) 139 | if err != nil { 140 | return 141 | } 142 | available = true 143 | }) 144 | return available 145 | } 146 | 147 | func Socket(af, typ, proto int32) (windows.Handle, error) { 148 | return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO) 149 | } 150 | 151 | func CloseCompletionQueue(cq Cq) { 152 | _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0) 153 | } 154 | 155 | func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) { 156 | notificationCompletion := &eventNotificationCompletion{ 157 | completionType: eventCompletion, 158 | event: event, 159 | } 160 | if notifyReset { 161 | notificationCompletion.notifyReset = 1 162 | } 163 | ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) 164 | if ret == invalidCq { 165 | return 0, err 166 | } 167 | return Cq(ret), nil 168 | } 169 | 170 | func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) { 171 | notificationCompletion := &iocpNotificationCompletion{ 172 | completionType: iocpCompletion, 173 | iocp: iocp, 174 | key: key, 175 | overlapped: overlapped, 176 | } 177 | ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) 178 | if ret == invalidCq { 179 | return 0, err 180 | } 181 | return Cq(ret), nil 182 | } 183 | 184 | func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) { 185 | ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0) 186 | if ret == invalidCq { 187 | return 0, err 188 | } 189 | return Cq(ret), nil 190 | } 191 | 192 | func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) { 193 | ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0) 194 | if ret == invalidRq { 195 | return 0, err 196 | } 197 | return Rq(ret), nil 198 | } 199 | 200 | func DequeueCompletion(cq Cq, results []Result) uint32 { 201 | var array uintptr 202 | if len(results) > 0 { 203 | array = uintptr(unsafe.Pointer(&results[0])) 204 | } 205 | ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results))) 206 | if ret == corruptCq { 207 | panic("cq is corrupt") 208 | } 209 | return uint32(ret) 210 | } 211 | 212 | func DeregisterBuffer(id BufferId) { 213 | _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0) 214 | } 215 | 216 | func RegisterBuffer(buffer []byte) (BufferId, error) { 217 | var buf unsafe.Pointer 218 | if len(buffer) > 0 { 219 | buf = unsafe.Pointer(&buffer[0]) 220 | } 221 | return RegisterPointer(buf, uint32(len(buffer))) 222 | } 223 | 224 | func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) { 225 | ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0) 226 | if ret == invalidBufferId { 227 | return 0, err 228 | } 229 | return BufferId(ret), nil 230 | } 231 | 232 | func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { 233 | ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) 234 | if ret == 0 { 235 | return err 236 | } 237 | return nil 238 | } 239 | 240 | func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { 241 | ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) 242 | if ret == 0 { 243 | return err 244 | } 245 | return nil 246 | } 247 | 248 | func Notify(cq Cq) error { 249 | ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0) 250 | if ret != 0 { 251 | return windows.Errno(ret) 252 | } 253 | return nil 254 | } 255 | -------------------------------------------------------------------------------- /device/alignment_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "reflect" 10 | "testing" 11 | "unsafe" 12 | ) 13 | 14 | func checkAlignment(t *testing.T, name string, offset uintptr) { 15 | t.Helper() 16 | if offset%8 != 0 { 17 | t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) 18 | } 19 | } 20 | 21 | // TestPeerAlignment checks that atomically-accessed fields are 22 | // aligned to 64-bit boundaries, as required by the atomic package. 23 | // 24 | // Unfortunately, violating this rule on 32-bit platforms results in a 25 | // hard segfault at runtime. 26 | func TestPeerAlignment(t *testing.T) { 27 | var p Peer 28 | 29 | typ := reflect.TypeOf(&p).Elem() 30 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 31 | for i := 0; i < typ.NumField(); i++ { 32 | field := typ.Field(i) 33 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 34 | field.Name, 35 | field.Offset, 36 | field.Type.Size(), 37 | field.Type.Align(), 38 | ) 39 | } 40 | 41 | checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats)) 42 | checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) 43 | } 44 | 45 | // TestDeviceAlignment checks that atomically-accessed fields are 46 | // aligned to 64-bit boundaries, as required by the atomic package. 47 | // 48 | // Unfortunately, violating this rule on 32-bit platforms results in a 49 | // hard segfault at runtime. 50 | func TestDeviceAlignment(t *testing.T) { 51 | var d Device 52 | 53 | typ := reflect.TypeOf(&d).Elem() 54 | t.Logf("Device type size: %d, with fields:", typ.Size()) 55 | for i := 0; i < typ.NumField(); i++ { 56 | field := typ.Field(i) 57 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 58 | field.Name, 59 | field.Offset, 60 | field.Type.Size(), 61 | field.Type.Align(), 62 | ) 63 | } 64 | checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil)) 65 | } 66 | -------------------------------------------------------------------------------- /device/allowedips.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "container/list" 10 | "encoding/binary" 11 | "errors" 12 | "math/bits" 13 | "net" 14 | "net/netip" 15 | "sync" 16 | "unsafe" 17 | ) 18 | 19 | type parentIndirection struct { 20 | parentBit **trieEntry 21 | parentBitType uint8 22 | } 23 | 24 | type trieEntry struct { 25 | peer *Peer 26 | child [2]*trieEntry 27 | parent parentIndirection 28 | cidr uint8 29 | bitAtByte uint8 30 | bitAtShift uint8 31 | bits []byte 32 | perPeerElem *list.Element 33 | } 34 | 35 | func commonBits(ip1, ip2 []byte) uint8 { 36 | size := len(ip1) 37 | if size == net.IPv4len { 38 | a := binary.BigEndian.Uint32(ip1) 39 | b := binary.BigEndian.Uint32(ip2) 40 | x := a ^ b 41 | return uint8(bits.LeadingZeros32(x)) 42 | } else if size == net.IPv6len { 43 | a := binary.BigEndian.Uint64(ip1) 44 | b := binary.BigEndian.Uint64(ip2) 45 | x := a ^ b 46 | if x != 0 { 47 | return uint8(bits.LeadingZeros64(x)) 48 | } 49 | a = binary.BigEndian.Uint64(ip1[8:]) 50 | b = binary.BigEndian.Uint64(ip2[8:]) 51 | x = a ^ b 52 | return 64 + uint8(bits.LeadingZeros64(x)) 53 | } else { 54 | panic("Wrong size bit string") 55 | } 56 | } 57 | 58 | func (node *trieEntry) addToPeerEntries() { 59 | node.perPeerElem = node.peer.trieEntries.PushBack(node) 60 | } 61 | 62 | func (node *trieEntry) removeFromPeerEntries() { 63 | if node.perPeerElem != nil { 64 | node.peer.trieEntries.Remove(node.perPeerElem) 65 | node.perPeerElem = nil 66 | } 67 | } 68 | 69 | func (node *trieEntry) choose(ip []byte) byte { 70 | return (ip[node.bitAtByte] >> node.bitAtShift) & 1 71 | } 72 | 73 | func (node *trieEntry) maskSelf() { 74 | mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) 75 | for i := 0; i < len(mask); i++ { 76 | node.bits[i] &= mask[i] 77 | } 78 | } 79 | 80 | func (node *trieEntry) zeroizePointers() { 81 | // Make the garbage collector's life slightly easier 82 | node.peer = nil 83 | node.child[0] = nil 84 | node.child[1] = nil 85 | node.parent.parentBit = nil 86 | } 87 | 88 | func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { 89 | for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { 90 | parent = node 91 | if parent.cidr == cidr { 92 | exact = true 93 | return 94 | } 95 | bit := node.choose(ip) 96 | node = node.child[bit] 97 | } 98 | return 99 | } 100 | 101 | func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { 102 | if *trie.parentBit == nil { 103 | node := &trieEntry{ 104 | peer: peer, 105 | parent: trie, 106 | bits: ip, 107 | cidr: cidr, 108 | bitAtByte: cidr / 8, 109 | bitAtShift: 7 - (cidr % 8), 110 | } 111 | node.maskSelf() 112 | node.addToPeerEntries() 113 | *trie.parentBit = node 114 | return 115 | } 116 | node, exact := (*trie.parentBit).nodePlacement(ip, cidr) 117 | if exact { 118 | node.removeFromPeerEntries() 119 | node.peer = peer 120 | node.addToPeerEntries() 121 | return 122 | } 123 | 124 | newNode := &trieEntry{ 125 | peer: peer, 126 | bits: ip, 127 | cidr: cidr, 128 | bitAtByte: cidr / 8, 129 | bitAtShift: 7 - (cidr % 8), 130 | } 131 | newNode.maskSelf() 132 | newNode.addToPeerEntries() 133 | 134 | var down *trieEntry 135 | if node == nil { 136 | down = *trie.parentBit 137 | } else { 138 | bit := node.choose(ip) 139 | down = node.child[bit] 140 | if down == nil { 141 | newNode.parent = parentIndirection{&node.child[bit], bit} 142 | node.child[bit] = newNode 143 | return 144 | } 145 | } 146 | common := commonBits(down.bits, ip) 147 | if common < cidr { 148 | cidr = common 149 | } 150 | parent := node 151 | 152 | if newNode.cidr == cidr { 153 | bit := newNode.choose(down.bits) 154 | down.parent = parentIndirection{&newNode.child[bit], bit} 155 | newNode.child[bit] = down 156 | if parent == nil { 157 | newNode.parent = trie 158 | *trie.parentBit = newNode 159 | } else { 160 | bit := parent.choose(newNode.bits) 161 | newNode.parent = parentIndirection{&parent.child[bit], bit} 162 | parent.child[bit] = newNode 163 | } 164 | return 165 | } 166 | 167 | node = &trieEntry{ 168 | bits: append([]byte{}, newNode.bits...), 169 | cidr: cidr, 170 | bitAtByte: cidr / 8, 171 | bitAtShift: 7 - (cidr % 8), 172 | } 173 | node.maskSelf() 174 | 175 | bit := node.choose(down.bits) 176 | down.parent = parentIndirection{&node.child[bit], bit} 177 | node.child[bit] = down 178 | bit = node.choose(newNode.bits) 179 | newNode.parent = parentIndirection{&node.child[bit], bit} 180 | node.child[bit] = newNode 181 | if parent == nil { 182 | node.parent = trie 183 | *trie.parentBit = node 184 | } else { 185 | bit := parent.choose(node.bits) 186 | node.parent = parentIndirection{&parent.child[bit], bit} 187 | parent.child[bit] = node 188 | } 189 | } 190 | 191 | func (node *trieEntry) lookup(ip []byte) *Peer { 192 | var found *Peer 193 | size := uint8(len(ip)) 194 | for node != nil && commonBits(node.bits, ip) >= node.cidr { 195 | if node.peer != nil { 196 | found = node.peer 197 | } 198 | if node.bitAtByte == size { 199 | break 200 | } 201 | bit := node.choose(ip) 202 | node = node.child[bit] 203 | } 204 | return found 205 | } 206 | 207 | type AllowedIPs struct { 208 | IPv4 *trieEntry 209 | IPv6 *trieEntry 210 | mutex sync.RWMutex 211 | } 212 | 213 | func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { 214 | table.mutex.RLock() 215 | defer table.mutex.RUnlock() 216 | 217 | for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { 218 | node := elem.Value.(*trieEntry) 219 | a, _ := netip.AddrFromSlice(node.bits) 220 | if !cb(netip.PrefixFrom(a, int(node.cidr))) { 221 | return 222 | } 223 | } 224 | } 225 | 226 | func (table *AllowedIPs) RemoveByPeer(peer *Peer) { 227 | table.mutex.Lock() 228 | defer table.mutex.Unlock() 229 | 230 | var next *list.Element 231 | for elem := peer.trieEntries.Front(); elem != nil; elem = next { 232 | next = elem.Next() 233 | node := elem.Value.(*trieEntry) 234 | 235 | node.removeFromPeerEntries() 236 | node.peer = nil 237 | if node.child[0] != nil && node.child[1] != nil { 238 | continue 239 | } 240 | bit := 0 241 | if node.child[0] == nil { 242 | bit = 1 243 | } 244 | child := node.child[bit] 245 | if child != nil { 246 | child.parent = node.parent 247 | } 248 | *node.parent.parentBit = child 249 | if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { 250 | node.zeroizePointers() 251 | continue 252 | } 253 | parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) 254 | if parent.peer != nil { 255 | node.zeroizePointers() 256 | continue 257 | } 258 | child = parent.child[node.parent.parentBitType^1] 259 | if child != nil { 260 | child.parent = parent.parent 261 | } 262 | *parent.parent.parentBit = child 263 | node.zeroizePointers() 264 | parent.zeroizePointers() 265 | } 266 | } 267 | 268 | func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { 269 | table.mutex.Lock() 270 | defer table.mutex.Unlock() 271 | 272 | if prefix.Addr().Is6() { 273 | ip := prefix.Addr().As16() 274 | parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) 275 | } else if prefix.Addr().Is4() { 276 | ip := prefix.Addr().As4() 277 | parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) 278 | } else { 279 | panic(errors.New("inserting unknown address type")) 280 | } 281 | } 282 | 283 | func (table *AllowedIPs) Lookup(ip []byte) *Peer { 284 | table.mutex.RLock() 285 | defer table.mutex.RUnlock() 286 | switch len(ip) { 287 | case net.IPv6len: 288 | return table.IPv6.lookup(ip) 289 | case net.IPv4len: 290 | return table.IPv4.lookup(ip) 291 | default: 292 | panic(errors.New("looking up unknown address type")) 293 | } 294 | } 295 | -------------------------------------------------------------------------------- /device/allowedips_rand_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net" 11 | "net/netip" 12 | "sort" 13 | "testing" 14 | ) 15 | 16 | const ( 17 | NumberOfPeers = 100 18 | NumberOfPeerRemovals = 4 19 | NumberOfAddresses = 250 20 | NumberOfTests = 10000 21 | ) 22 | 23 | type SlowNode struct { 24 | peer *Peer 25 | cidr uint8 26 | bits []byte 27 | } 28 | 29 | type SlowRouter []*SlowNode 30 | 31 | func (r SlowRouter) Len() int { 32 | return len(r) 33 | } 34 | 35 | func (r SlowRouter) Less(i, j int) bool { 36 | return r[i].cidr > r[j].cidr 37 | } 38 | 39 | func (r SlowRouter) Swap(i, j int) { 40 | r[i], r[j] = r[j], r[i] 41 | } 42 | 43 | func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { 44 | for _, t := range r { 45 | if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { 46 | t.peer = peer 47 | t.bits = addr 48 | return r 49 | } 50 | } 51 | r = append(r, &SlowNode{ 52 | cidr: cidr, 53 | bits: addr, 54 | peer: peer, 55 | }) 56 | sort.Sort(r) 57 | return r 58 | } 59 | 60 | func (r SlowRouter) Lookup(addr []byte) *Peer { 61 | for _, t := range r { 62 | common := commonBits(t.bits, addr) 63 | if common >= t.cidr { 64 | return t.peer 65 | } 66 | } 67 | return nil 68 | } 69 | 70 | func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { 71 | n := 0 72 | for _, x := range r { 73 | if x.peer != peer { 74 | r[n] = x 75 | n++ 76 | } 77 | } 78 | return r[:n] 79 | } 80 | 81 | func TestTrieRandom(t *testing.T) { 82 | var slow4, slow6 SlowRouter 83 | var peers []*Peer 84 | var allowedIPs AllowedIPs 85 | 86 | rand.Seed(1) 87 | 88 | for n := 0; n < NumberOfPeers; n++ { 89 | peers = append(peers, &Peer{}) 90 | } 91 | 92 | for n := 0; n < NumberOfAddresses; n++ { 93 | var addr4 [4]byte 94 | rand.Read(addr4[:]) 95 | cidr := uint8(rand.Intn(32) + 1) 96 | index := rand.Intn(NumberOfPeers) 97 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) 98 | slow4 = slow4.Insert(addr4[:], cidr, peers[index]) 99 | 100 | var addr6 [16]byte 101 | rand.Read(addr6[:]) 102 | cidr = uint8(rand.Intn(128) + 1) 103 | index = rand.Intn(NumberOfPeers) 104 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) 105 | slow6 = slow6.Insert(addr6[:], cidr, peers[index]) 106 | } 107 | 108 | var p int 109 | for p = 0; ; p++ { 110 | for n := 0; n < NumberOfTests; n++ { 111 | var addr4 [4]byte 112 | rand.Read(addr4[:]) 113 | peer1 := slow4.Lookup(addr4[:]) 114 | peer2 := allowedIPs.Lookup(addr4[:]) 115 | if peer1 != peer2 { 116 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) 117 | } 118 | 119 | var addr6 [16]byte 120 | rand.Read(addr6[:]) 121 | peer1 = slow6.Lookup(addr6[:]) 122 | peer2 = allowedIPs.Lookup(addr6[:]) 123 | if peer1 != peer2 { 124 | t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) 125 | } 126 | } 127 | if p >= len(peers) || p >= NumberOfPeerRemovals { 128 | break 129 | } 130 | allowedIPs.RemoveByPeer(peers[p]) 131 | slow4 = slow4.RemoveByPeer(peers[p]) 132 | slow6 = slow6.RemoveByPeer(peers[p]) 133 | } 134 | for ; p < len(peers); p++ { 135 | allowedIPs.RemoveByPeer(peers[p]) 136 | } 137 | 138 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 139 | t.Error("Failed to remove all nodes from trie by peer") 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /device/allowedips_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net" 11 | "net/netip" 12 | "testing" 13 | ) 14 | 15 | type testPairCommonBits struct { 16 | s1 []byte 17 | s2 []byte 18 | match uint8 19 | } 20 | 21 | func TestCommonBits(t *testing.T) { 22 | tests := []testPairCommonBits{ 23 | {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, 24 | {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, 25 | {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, 26 | {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, 27 | {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, 28 | } 29 | 30 | for _, p := range tests { 31 | v := commonBits(p.s1, p.s2) 32 | if v != p.match { 33 | t.Error( 34 | "For slice", p.s1, p.s2, 35 | "expected match", p.match, 36 | ",but got", v, 37 | ) 38 | } 39 | } 40 | } 41 | 42 | func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { 43 | var trie *trieEntry 44 | var peers []*Peer 45 | root := parentIndirection{&trie, 2} 46 | 47 | rand.Seed(1) 48 | 49 | const AddressLength = 4 50 | 51 | for n := 0; n < peerNumber; n++ { 52 | peers = append(peers, &Peer{}) 53 | } 54 | 55 | for n := 0; n < addressNumber; n++ { 56 | var addr [AddressLength]byte 57 | rand.Read(addr[:]) 58 | cidr := uint8(rand.Uint32() % (AddressLength * 8)) 59 | index := rand.Int() % peerNumber 60 | root.insert(addr[:], cidr, peers[index]) 61 | } 62 | 63 | for n := 0; n < b.N; n++ { 64 | var addr [AddressLength]byte 65 | rand.Read(addr[:]) 66 | trie.lookup(addr[:]) 67 | } 68 | } 69 | 70 | func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { 71 | benchmarkTrie(100, 1000, net.IPv4len, b) 72 | } 73 | 74 | func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { 75 | benchmarkTrie(10, 10, net.IPv4len, b) 76 | } 77 | 78 | func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { 79 | benchmarkTrie(100, 1000, net.IPv6len, b) 80 | } 81 | 82 | func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { 83 | benchmarkTrie(10, 10, net.IPv6len, b) 84 | } 85 | 86 | /* Test ported from kernel implementation: 87 | * selftest/allowedips.h 88 | */ 89 | func TestTrieIPv4(t *testing.T) { 90 | a := &Peer{} 91 | b := &Peer{} 92 | c := &Peer{} 93 | d := &Peer{} 94 | e := &Peer{} 95 | g := &Peer{} 96 | h := &Peer{} 97 | 98 | var allowedIPs AllowedIPs 99 | 100 | insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { 101 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) 102 | } 103 | 104 | assertEQ := func(peer *Peer, a, b, c, d byte) { 105 | p := allowedIPs.Lookup([]byte{a, b, c, d}) 106 | if p != peer { 107 | t.Error("Assert EQ failed") 108 | } 109 | } 110 | 111 | assertNEQ := func(peer *Peer, a, b, c, d byte) { 112 | p := allowedIPs.Lookup([]byte{a, b, c, d}) 113 | if p == peer { 114 | t.Error("Assert NEQ failed") 115 | } 116 | } 117 | 118 | insert(a, 192, 168, 4, 0, 24) 119 | insert(b, 192, 168, 4, 4, 32) 120 | insert(c, 192, 168, 0, 0, 16) 121 | insert(d, 192, 95, 5, 64, 27) 122 | insert(c, 192, 95, 5, 65, 27) 123 | insert(e, 0, 0, 0, 0, 0) 124 | insert(g, 64, 15, 112, 0, 20) 125 | insert(h, 64, 15, 123, 211, 25) 126 | insert(a, 10, 0, 0, 0, 25) 127 | insert(b, 10, 0, 0, 128, 25) 128 | insert(a, 10, 1, 0, 0, 30) 129 | insert(b, 10, 1, 0, 4, 30) 130 | insert(c, 10, 1, 0, 8, 29) 131 | insert(d, 10, 1, 0, 16, 29) 132 | 133 | assertEQ(a, 192, 168, 4, 20) 134 | assertEQ(a, 192, 168, 4, 0) 135 | assertEQ(b, 192, 168, 4, 4) 136 | assertEQ(c, 192, 168, 200, 182) 137 | assertEQ(c, 192, 95, 5, 68) 138 | assertEQ(e, 192, 95, 5, 96) 139 | assertEQ(g, 64, 15, 116, 26) 140 | assertEQ(g, 64, 15, 127, 3) 141 | 142 | insert(a, 1, 0, 0, 0, 32) 143 | insert(a, 64, 0, 0, 0, 32) 144 | insert(a, 128, 0, 0, 0, 32) 145 | insert(a, 192, 0, 0, 0, 32) 146 | insert(a, 255, 0, 0, 0, 32) 147 | 148 | assertEQ(a, 1, 0, 0, 0) 149 | assertEQ(a, 64, 0, 0, 0) 150 | assertEQ(a, 128, 0, 0, 0) 151 | assertEQ(a, 192, 0, 0, 0) 152 | assertEQ(a, 255, 0, 0, 0) 153 | 154 | allowedIPs.RemoveByPeer(a) 155 | 156 | assertNEQ(a, 1, 0, 0, 0) 157 | assertNEQ(a, 64, 0, 0, 0) 158 | assertNEQ(a, 128, 0, 0, 0) 159 | assertNEQ(a, 192, 0, 0, 0) 160 | assertNEQ(a, 255, 0, 0, 0) 161 | 162 | allowedIPs.RemoveByPeer(a) 163 | allowedIPs.RemoveByPeer(b) 164 | allowedIPs.RemoveByPeer(c) 165 | allowedIPs.RemoveByPeer(d) 166 | allowedIPs.RemoveByPeer(e) 167 | allowedIPs.RemoveByPeer(g) 168 | allowedIPs.RemoveByPeer(h) 169 | if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { 170 | t.Error("Expected removing all the peers to empty trie, but it did not") 171 | } 172 | 173 | insert(a, 192, 168, 0, 0, 16) 174 | insert(a, 192, 168, 0, 0, 24) 175 | 176 | allowedIPs.RemoveByPeer(a) 177 | 178 | assertNEQ(a, 192, 168, 0, 1) 179 | } 180 | 181 | /* Test ported from kernel implementation: 182 | * selftest/allowedips.h 183 | */ 184 | func TestTrieIPv6(t *testing.T) { 185 | a := &Peer{} 186 | b := &Peer{} 187 | c := &Peer{} 188 | d := &Peer{} 189 | e := &Peer{} 190 | f := &Peer{} 191 | g := &Peer{} 192 | h := &Peer{} 193 | 194 | var allowedIPs AllowedIPs 195 | 196 | expand := func(a uint32) []byte { 197 | var out [4]byte 198 | out[0] = byte(a >> 24 & 0xff) 199 | out[1] = byte(a >> 16 & 0xff) 200 | out[2] = byte(a >> 8 & 0xff) 201 | out[3] = byte(a & 0xff) 202 | return out[:] 203 | } 204 | 205 | insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { 206 | var addr []byte 207 | addr = append(addr, expand(a)...) 208 | addr = append(addr, expand(b)...) 209 | addr = append(addr, expand(c)...) 210 | addr = append(addr, expand(d)...) 211 | allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) 212 | } 213 | 214 | assertEQ := func(peer *Peer, a, b, c, d uint32) { 215 | var addr []byte 216 | addr = append(addr, expand(a)...) 217 | addr = append(addr, expand(b)...) 218 | addr = append(addr, expand(c)...) 219 | addr = append(addr, expand(d)...) 220 | p := allowedIPs.Lookup(addr) 221 | if p != peer { 222 | t.Error("Assert EQ failed") 223 | } 224 | } 225 | 226 | insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) 227 | insert(c, 0x26075300, 0x60006b00, 0, 0, 64) 228 | insert(e, 0, 0, 0, 0, 0) 229 | insert(f, 0, 0, 0, 0, 0) 230 | insert(g, 0x24046800, 0, 0, 0, 32) 231 | insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) 232 | insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) 233 | insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) 234 | insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) 235 | 236 | assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) 237 | assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) 238 | assertEQ(f, 0x26075300, 0x60006b01, 0, 0) 239 | assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) 240 | assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) 241 | assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) 242 | assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) 243 | assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) 244 | assertEQ(h, 0x24046800, 0x40040800, 0, 0) 245 | assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) 246 | assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) 247 | } 248 | -------------------------------------------------------------------------------- /device/bind_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "errors" 10 | 11 | "golang.zx2c4.com/wireguard/conn" 12 | ) 13 | 14 | type DummyDatagram struct { 15 | msg []byte 16 | endpoint conn.Endpoint 17 | } 18 | 19 | type DummyBind struct { 20 | in6 chan DummyDatagram 21 | in4 chan DummyDatagram 22 | closed bool 23 | } 24 | 25 | func (b *DummyBind) SetMark(v uint32) error { 26 | return nil 27 | } 28 | 29 | func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) { 30 | datagram, ok := <-b.in6 31 | if !ok { 32 | return 0, nil, errors.New("closed") 33 | } 34 | copy(buff, datagram.msg) 35 | return len(datagram.msg), datagram.endpoint, nil 36 | } 37 | 38 | func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) { 39 | datagram, ok := <-b.in4 40 | if !ok { 41 | return 0, nil, errors.New("closed") 42 | } 43 | copy(buff, datagram.msg) 44 | return len(datagram.msg), datagram.endpoint, nil 45 | } 46 | 47 | func (b *DummyBind) Close() error { 48 | close(b.in6) 49 | close(b.in4) 50 | b.closed = true 51 | return nil 52 | } 53 | 54 | func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error { 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /device/channels.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "runtime" 10 | "sync" 11 | ) 12 | 13 | // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. 14 | // An outboundQueue is ref-counted using its wg field. 15 | // An outboundQueue created with newOutboundQueue has one reference. 16 | // Every additional writer must call wg.Add(1). 17 | // Every completed writer must call wg.Done(). 18 | // When no further writers will be added, 19 | // call wg.Done to remove the initial reference. 20 | // When the refcount hits 0, the queue's channel is closed. 21 | type outboundQueue struct { 22 | c chan *QueueOutboundElement 23 | wg sync.WaitGroup 24 | } 25 | 26 | func newOutboundQueue() *outboundQueue { 27 | q := &outboundQueue{ 28 | c: make(chan *QueueOutboundElement, QueueOutboundSize), 29 | } 30 | q.wg.Add(1) 31 | go func() { 32 | q.wg.Wait() 33 | close(q.c) 34 | }() 35 | return q 36 | } 37 | 38 | // A inboundQueue is similar to an outboundQueue; see those docs. 39 | type inboundQueue struct { 40 | c chan *QueueInboundElement 41 | wg sync.WaitGroup 42 | } 43 | 44 | func newInboundQueue() *inboundQueue { 45 | q := &inboundQueue{ 46 | c: make(chan *QueueInboundElement, QueueInboundSize), 47 | } 48 | q.wg.Add(1) 49 | go func() { 50 | q.wg.Wait() 51 | close(q.c) 52 | }() 53 | return q 54 | } 55 | 56 | // A handshakeQueue is similar to an outboundQueue; see those docs. 57 | type handshakeQueue struct { 58 | c chan QueueHandshakeElement 59 | wg sync.WaitGroup 60 | } 61 | 62 | func newHandshakeQueue() *handshakeQueue { 63 | q := &handshakeQueue{ 64 | c: make(chan QueueHandshakeElement, QueueHandshakeSize), 65 | } 66 | q.wg.Add(1) 67 | go func() { 68 | q.wg.Wait() 69 | close(q.c) 70 | }() 71 | return q 72 | } 73 | 74 | type autodrainingInboundQueue struct { 75 | c chan *QueueInboundElement 76 | } 77 | 78 | // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. 79 | // It is useful in cases in which is it hard to manage the lifetime of the channel. 80 | // The returned channel must not be closed. Senders should signal shutdown using 81 | // some other means, such as sending a sentinel nil values. 82 | func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { 83 | q := &autodrainingInboundQueue{ 84 | c: make(chan *QueueInboundElement, QueueInboundSize), 85 | } 86 | runtime.SetFinalizer(q, device.flushInboundQueue) 87 | return q 88 | } 89 | 90 | func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { 91 | for { 92 | select { 93 | case elem := <-q.c: 94 | elem.Lock() 95 | device.PutMessageBuffer(elem.buffer) 96 | device.PutInboundElement(elem) 97 | default: 98 | return 99 | } 100 | } 101 | } 102 | 103 | type autodrainingOutboundQueue struct { 104 | c chan *QueueOutboundElement 105 | } 106 | 107 | // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. 108 | // It is useful in cases in which is it hard to manage the lifetime of the channel. 109 | // The returned channel must not be closed. Senders should signal shutdown using 110 | // some other means, such as sending a sentinel nil values. 111 | // All sends to the channel must be best-effort, because there may be no receivers. 112 | func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { 113 | q := &autodrainingOutboundQueue{ 114 | c: make(chan *QueueOutboundElement, QueueOutboundSize), 115 | } 116 | runtime.SetFinalizer(q, device.flushOutboundQueue) 117 | return q 118 | } 119 | 120 | func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { 121 | for { 122 | select { 123 | case elem := <-q.c: 124 | elem.Lock() 125 | device.PutMessageBuffer(elem.buffer) 126 | device.PutOutboundElement(elem) 127 | default: 128 | return 129 | } 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /device/constants.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "time" 10 | ) 11 | 12 | /* Specification constants */ 13 | 14 | const ( 15 | RekeyAfterMessages = (1 << 60) 16 | RejectAfterMessages = (1 << 64) - (1 << 13) - 1 17 | RekeyAfterTime = time.Second * 120 18 | RekeyAttemptTime = time.Second * 90 19 | RekeyTimeout = time.Second * 5 20 | MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */ 21 | RekeyTimeoutJitterMaxMs = 334 22 | RejectAfterTime = time.Second * 180 23 | KeepaliveTimeout = time.Second * 10 24 | CookieRefreshTime = time.Second * 120 25 | HandshakeInitationRate = time.Second / 50 26 | PaddingMultiple = 16 27 | ) 28 | 29 | const ( 30 | MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) 31 | MaxMessageSize = MaxSegmentSize // maximum size of transport message 32 | MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content 33 | ) 34 | 35 | /* Implementation constants */ 36 | 37 | const ( 38 | UnderLoadAfterTime = time.Second // how long does the device remain under load after detected 39 | MaxPeers = 1 << 16 // maximum number of configured peers 40 | ) 41 | -------------------------------------------------------------------------------- /device/cookie.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "sync" 12 | "time" 13 | 14 | "golang.org/x/crypto/blake2s" 15 | "golang.org/x/crypto/chacha20poly1305" 16 | ) 17 | 18 | type CookieChecker struct { 19 | sync.RWMutex 20 | mac1 struct { 21 | key [blake2s.Size]byte 22 | } 23 | mac2 struct { 24 | secret [blake2s.Size]byte 25 | secretSet time.Time 26 | encryptionKey [chacha20poly1305.KeySize]byte 27 | } 28 | } 29 | 30 | type CookieGenerator struct { 31 | sync.RWMutex 32 | mac1 struct { 33 | key [blake2s.Size]byte 34 | } 35 | mac2 struct { 36 | cookie [blake2s.Size128]byte 37 | cookieSet time.Time 38 | hasLastMAC1 bool 39 | lastMAC1 [blake2s.Size128]byte 40 | encryptionKey [chacha20poly1305.KeySize]byte 41 | } 42 | } 43 | 44 | func (st *CookieChecker) Init(pk NoisePublicKey) { 45 | st.Lock() 46 | defer st.Unlock() 47 | 48 | // mac1 state 49 | 50 | func() { 51 | hash, _ := blake2s.New256(nil) 52 | hash.Write([]byte(WGLabelMAC1)) 53 | hash.Write(pk[:]) 54 | hash.Sum(st.mac1.key[:0]) 55 | }() 56 | 57 | // mac2 state 58 | 59 | func() { 60 | hash, _ := blake2s.New256(nil) 61 | hash.Write([]byte(WGLabelCookie)) 62 | hash.Write(pk[:]) 63 | hash.Sum(st.mac2.encryptionKey[:0]) 64 | }() 65 | 66 | st.mac2.secretSet = time.Time{} 67 | } 68 | 69 | func (st *CookieChecker) CheckMAC1(msg []byte) bool { 70 | st.RLock() 71 | defer st.RUnlock() 72 | 73 | size := len(msg) 74 | smac2 := size - blake2s.Size128 75 | smac1 := smac2 - blake2s.Size128 76 | 77 | var mac1 [blake2s.Size128]byte 78 | 79 | mac, _ := blake2s.New128(st.mac1.key[:]) 80 | mac.Write(msg[:smac1]) 81 | mac.Sum(mac1[:0]) 82 | 83 | return hmac.Equal(mac1[:], msg[smac1:smac2]) 84 | } 85 | 86 | func (st *CookieChecker) CheckMAC2(msg, src []byte) bool { 87 | st.RLock() 88 | defer st.RUnlock() 89 | 90 | if time.Since(st.mac2.secretSet) > CookieRefreshTime { 91 | return false 92 | } 93 | 94 | // derive cookie key 95 | 96 | var cookie [blake2s.Size128]byte 97 | func() { 98 | mac, _ := blake2s.New128(st.mac2.secret[:]) 99 | mac.Write(src) 100 | mac.Sum(cookie[:0]) 101 | }() 102 | 103 | // calculate mac of packet (including mac1) 104 | 105 | smac2 := len(msg) - blake2s.Size128 106 | 107 | var mac2 [blake2s.Size128]byte 108 | func() { 109 | mac, _ := blake2s.New128(cookie[:]) 110 | mac.Write(msg[:smac2]) 111 | mac.Sum(mac2[:0]) 112 | }() 113 | 114 | return hmac.Equal(mac2[:], msg[smac2:]) 115 | } 116 | 117 | func (st *CookieChecker) CreateReply( 118 | msg []byte, 119 | recv uint32, 120 | src []byte, 121 | ) (*MessageCookieReply, error) { 122 | st.RLock() 123 | 124 | // refresh cookie secret 125 | 126 | if time.Since(st.mac2.secretSet) > CookieRefreshTime { 127 | st.RUnlock() 128 | st.Lock() 129 | _, err := rand.Read(st.mac2.secret[:]) 130 | if err != nil { 131 | st.Unlock() 132 | return nil, err 133 | } 134 | st.mac2.secretSet = time.Now() 135 | st.Unlock() 136 | st.RLock() 137 | } 138 | 139 | // derive cookie 140 | 141 | var cookie [blake2s.Size128]byte 142 | func() { 143 | mac, _ := blake2s.New128(st.mac2.secret[:]) 144 | mac.Write(src) 145 | mac.Sum(cookie[:0]) 146 | }() 147 | 148 | // encrypt cookie 149 | 150 | size := len(msg) 151 | 152 | smac2 := size - blake2s.Size128 153 | smac1 := smac2 - blake2s.Size128 154 | 155 | reply := new(MessageCookieReply) 156 | reply.Type = MessageCookieReplyType 157 | reply.Receiver = recv 158 | 159 | _, err := rand.Read(reply.Nonce[:]) 160 | if err != nil { 161 | st.RUnlock() 162 | return nil, err 163 | } 164 | 165 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) 166 | xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2]) 167 | 168 | st.RUnlock() 169 | 170 | return reply, nil 171 | } 172 | 173 | func (st *CookieGenerator) Init(pk NoisePublicKey) { 174 | st.Lock() 175 | defer st.Unlock() 176 | 177 | func() { 178 | hash, _ := blake2s.New256(nil) 179 | hash.Write([]byte(WGLabelMAC1)) 180 | hash.Write(pk[:]) 181 | hash.Sum(st.mac1.key[:0]) 182 | }() 183 | 184 | func() { 185 | hash, _ := blake2s.New256(nil) 186 | hash.Write([]byte(WGLabelCookie)) 187 | hash.Write(pk[:]) 188 | hash.Sum(st.mac2.encryptionKey[:0]) 189 | }() 190 | 191 | st.mac2.cookieSet = time.Time{} 192 | } 193 | 194 | func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { 195 | st.Lock() 196 | defer st.Unlock() 197 | 198 | if !st.mac2.hasLastMAC1 { 199 | return false 200 | } 201 | 202 | var cookie [blake2s.Size128]byte 203 | 204 | xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) 205 | _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) 206 | if err != nil { 207 | return false 208 | } 209 | 210 | st.mac2.cookieSet = time.Now() 211 | st.mac2.cookie = cookie 212 | return true 213 | } 214 | 215 | func (st *CookieGenerator) AddMacs(msg []byte) { 216 | size := len(msg) 217 | 218 | smac2 := size - blake2s.Size128 219 | smac1 := smac2 - blake2s.Size128 220 | 221 | mac1 := msg[smac1:smac2] 222 | mac2 := msg[smac2:] 223 | 224 | st.Lock() 225 | defer st.Unlock() 226 | 227 | // set mac1 228 | 229 | func() { 230 | mac, _ := blake2s.New128(st.mac1.key[:]) 231 | mac.Write(msg[:smac1]) 232 | mac.Sum(mac1[:0]) 233 | }() 234 | copy(st.mac2.lastMAC1[:], mac1) 235 | st.mac2.hasLastMAC1 = true 236 | 237 | // set mac2 238 | 239 | if time.Since(st.mac2.cookieSet) > CookieRefreshTime { 240 | return 241 | } 242 | 243 | func() { 244 | mac, _ := blake2s.New128(st.mac2.cookie[:]) 245 | mac.Write(msg[:smac2]) 246 | mac.Sum(mac2[:0]) 247 | }() 248 | } 249 | -------------------------------------------------------------------------------- /device/cookie_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "testing" 10 | ) 11 | 12 | func TestCookieMAC1(t *testing.T) { 13 | // setup generator / checker 14 | 15 | var ( 16 | generator CookieGenerator 17 | checker CookieChecker 18 | ) 19 | 20 | sk, err := newPrivateKey() 21 | if err != nil { 22 | t.Fatal(err) 23 | } 24 | pk := sk.publicKey() 25 | 26 | generator.Init(pk) 27 | checker.Init(pk) 28 | 29 | // check mac1 30 | 31 | src := []byte{192, 168, 13, 37, 10, 10, 10} 32 | 33 | checkMAC1 := func(msg []byte) { 34 | generator.AddMacs(msg) 35 | if !checker.CheckMAC1(msg) { 36 | t.Fatal("MAC1 generation/verification failed") 37 | } 38 | if checker.CheckMAC2(msg, src) { 39 | t.Fatal("MAC2 generation/verification failed") 40 | } 41 | } 42 | 43 | checkMAC1([]byte{ 44 | 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, 45 | 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, 46 | 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, 47 | 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, 48 | 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, 49 | 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, 50 | 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, 51 | 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, 52 | 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, 53 | 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, 54 | 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, 55 | 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, 56 | 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, 57 | 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, 58 | 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, 59 | 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, 60 | }) 61 | 62 | checkMAC1([]byte{ 63 | 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, 64 | 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, 65 | 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, 66 | 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, 67 | 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, 68 | 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, 69 | 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, 70 | 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, 71 | }) 72 | 73 | checkMAC1([]byte{ 74 | 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, 75 | 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, 76 | 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, 77 | 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, 78 | }) 79 | 80 | // exchange cookie reply 81 | 82 | func() { 83 | msg := []byte{ 84 | 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, 85 | 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, 86 | 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, 87 | 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, 88 | 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, 89 | 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, 90 | 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, 91 | 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, 92 | 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, 93 | 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, 94 | 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, 95 | 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, 96 | 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, 97 | 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, 98 | 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, 99 | 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, 100 | } 101 | generator.AddMacs(msg) 102 | reply, err := checker.CreateReply(msg, 1377, src) 103 | if err != nil { 104 | t.Fatal("Failed to create cookie reply:", err) 105 | } 106 | if !generator.ConsumeReply(reply) { 107 | t.Fatal("Failed to consume cookie reply") 108 | } 109 | }() 110 | 111 | // check mac2 112 | 113 | checkMAC2 := func(msg []byte) { 114 | generator.AddMacs(msg) 115 | 116 | if !checker.CheckMAC1(msg) { 117 | t.Fatal("MAC1 generation/verification failed") 118 | } 119 | if !checker.CheckMAC2(msg, src) { 120 | t.Fatal("MAC2 generation/verification failed") 121 | } 122 | 123 | msg[5] ^= 0x20 124 | 125 | if checker.CheckMAC1(msg) { 126 | t.Fatal("MAC1 generation/verification failed") 127 | } 128 | if checker.CheckMAC2(msg, src) { 129 | t.Fatal("MAC2 generation/verification failed") 130 | } 131 | 132 | msg[5] ^= 0x20 133 | 134 | srcBad1 := []byte{192, 168, 13, 37, 40, 1} 135 | if checker.CheckMAC2(msg, srcBad1) { 136 | t.Fatal("MAC2 generation/verification failed") 137 | } 138 | 139 | srcBad2 := []byte{192, 168, 13, 38, 40, 1} 140 | if checker.CheckMAC2(msg, srcBad2) { 141 | t.Fatal("MAC2 generation/verification failed") 142 | } 143 | } 144 | 145 | checkMAC2([]byte{ 146 | 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, 147 | 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, 148 | 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, 149 | 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, 150 | 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, 151 | 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, 152 | 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, 153 | 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, 154 | }) 155 | 156 | checkMAC2([]byte{ 157 | 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, 158 | 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, 159 | 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, 160 | 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, 161 | 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, 162 | 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, 163 | 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, 164 | 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, 165 | 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, 166 | 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, 167 | 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, 168 | 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, 169 | 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, 170 | 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, 171 | 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, 172 | 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, 173 | 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, 174 | 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, 175 | 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, 176 | 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, 177 | 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, 178 | 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, 179 | 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, 180 | 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, 181 | 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, 182 | 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, 183 | 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, 184 | 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, 185 | 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, 186 | 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, 187 | 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, 188 | 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, 189 | }) 190 | } 191 | -------------------------------------------------------------------------------- /device/devicestate_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT. 2 | 3 | package device 4 | 5 | import "strconv" 6 | 7 | const _deviceState_name = "DownUpClosed" 8 | 9 | var _deviceState_index = [...]uint8{0, 4, 6, 12} 10 | 11 | func (i deviceState) String() string { 12 | if i >= deviceState(len(_deviceState_index)-1) { 13 | return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")" 14 | } 15 | return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]] 16 | } 17 | -------------------------------------------------------------------------------- /device/endpoint_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "net/netip" 11 | ) 12 | 13 | type DummyEndpoint struct { 14 | src, dst netip.Addr 15 | } 16 | 17 | func CreateDummyEndpoint() (*DummyEndpoint, error) { 18 | var src, dst [16]byte 19 | if _, err := rand.Read(src[:]); err != nil { 20 | return nil, err 21 | } 22 | _, err := rand.Read(dst[:]) 23 | return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err 24 | } 25 | 26 | func (e *DummyEndpoint) ClearSrc() {} 27 | 28 | func (e *DummyEndpoint) SrcToString() string { 29 | return netip.AddrPortFrom(e.SrcIP(), 1000).String() 30 | } 31 | 32 | func (e *DummyEndpoint) DstToString() string { 33 | return netip.AddrPortFrom(e.DstIP(), 1000).String() 34 | } 35 | 36 | func (e *DummyEndpoint) DstToBytes() []byte { 37 | out := e.DstIP().AsSlice() 38 | out = append(out, byte(1000&0xff)) 39 | out = append(out, byte((1000>>8)&0xff)) 40 | return out 41 | } 42 | 43 | func (e *DummyEndpoint) DstIP() netip.Addr { 44 | return e.dst 45 | } 46 | 47 | func (e *DummyEndpoint) SrcIP() netip.Addr { 48 | return e.src 49 | } 50 | -------------------------------------------------------------------------------- /device/indextable.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/rand" 10 | "encoding/binary" 11 | "sync" 12 | ) 13 | 14 | type IndexTableEntry struct { 15 | peer *Peer 16 | handshake *Handshake 17 | keypair *Keypair 18 | } 19 | 20 | type IndexTable struct { 21 | sync.RWMutex 22 | table map[uint32]IndexTableEntry 23 | } 24 | 25 | func randUint32() (uint32, error) { 26 | var integer [4]byte 27 | _, err := rand.Read(integer[:]) 28 | // Arbitrary endianness; both are intrinsified by the Go compiler. 29 | return binary.LittleEndian.Uint32(integer[:]), err 30 | } 31 | 32 | func (table *IndexTable) Init() { 33 | table.Lock() 34 | defer table.Unlock() 35 | table.table = make(map[uint32]IndexTableEntry) 36 | } 37 | 38 | func (table *IndexTable) Delete(index uint32) { 39 | table.Lock() 40 | defer table.Unlock() 41 | delete(table.table, index) 42 | } 43 | 44 | func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { 45 | table.Lock() 46 | defer table.Unlock() 47 | entry, ok := table.table[index] 48 | if !ok { 49 | return 50 | } 51 | table.table[index] = IndexTableEntry{ 52 | peer: entry.peer, 53 | keypair: keypair, 54 | handshake: nil, 55 | } 56 | } 57 | 58 | func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) { 59 | for { 60 | // generate random index 61 | 62 | index, err := randUint32() 63 | if err != nil { 64 | return index, err 65 | } 66 | 67 | // check if index used 68 | 69 | table.RLock() 70 | _, ok := table.table[index] 71 | table.RUnlock() 72 | if ok { 73 | continue 74 | } 75 | 76 | // check again while locked 77 | 78 | table.Lock() 79 | _, found := table.table[index] 80 | if found { 81 | table.Unlock() 82 | continue 83 | } 84 | table.table[index] = IndexTableEntry{ 85 | peer: peer, 86 | handshake: handshake, 87 | keypair: nil, 88 | } 89 | table.Unlock() 90 | return index, nil 91 | } 92 | } 93 | 94 | func (table *IndexTable) Lookup(id uint32) IndexTableEntry { 95 | table.RLock() 96 | defer table.RUnlock() 97 | return table.table[id] 98 | } 99 | -------------------------------------------------------------------------------- /device/ip.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "net" 10 | ) 11 | 12 | const ( 13 | IPv4offsetTotalLength = 2 14 | IPv4offsetSrc = 12 15 | IPv4offsetDst = IPv4offsetSrc + net.IPv4len 16 | ) 17 | 18 | const ( 19 | IPv6offsetPayloadLength = 4 20 | IPv6offsetSrc = 8 21 | IPv6offsetDst = IPv6offsetSrc + net.IPv6len 22 | ) 23 | -------------------------------------------------------------------------------- /device/kdf_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "encoding/hex" 10 | "testing" 11 | 12 | "golang.org/x/crypto/blake2s" 13 | ) 14 | 15 | type KDFTest struct { 16 | key string 17 | input string 18 | t0 string 19 | t1 string 20 | t2 string 21 | } 22 | 23 | func assertEquals(t *testing.T, a, b string) { 24 | if a != b { 25 | t.Fatal("expected", a, "=", b) 26 | } 27 | } 28 | 29 | func TestKDF(t *testing.T) { 30 | tests := []KDFTest{ 31 | { 32 | key: "746573742d6b6579", 33 | input: "746573742d696e707574", 34 | t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", 35 | t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", 36 | t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", 37 | }, 38 | { 39 | key: "776972656775617264", 40 | input: "776972656775617264", 41 | t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", 42 | t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", 43 | t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", 44 | }, 45 | { 46 | key: "", 47 | input: "", 48 | t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", 49 | t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", 50 | t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", 51 | }, 52 | } 53 | 54 | var t0, t1, t2 [blake2s.Size]byte 55 | 56 | for _, test := range tests { 57 | key, _ := hex.DecodeString(test.key) 58 | input, _ := hex.DecodeString(test.input) 59 | KDF3(&t0, &t1, &t2, key, input) 60 | t0s := hex.EncodeToString(t0[:]) 61 | t1s := hex.EncodeToString(t1[:]) 62 | t2s := hex.EncodeToString(t2[:]) 63 | assertEquals(t, t0s, test.t0) 64 | assertEquals(t, t1s, test.t1) 65 | assertEquals(t, t2s, test.t2) 66 | } 67 | 68 | for _, test := range tests { 69 | key, _ := hex.DecodeString(test.key) 70 | input, _ := hex.DecodeString(test.input) 71 | KDF2(&t0, &t1, key, input) 72 | t0s := hex.EncodeToString(t0[:]) 73 | t1s := hex.EncodeToString(t1[:]) 74 | assertEquals(t, t0s, test.t0) 75 | assertEquals(t, t1s, test.t1) 76 | } 77 | 78 | for _, test := range tests { 79 | key, _ := hex.DecodeString(test.key) 80 | input, _ := hex.DecodeString(test.input) 81 | KDF1(&t0, key, input) 82 | t0s := hex.EncodeToString(t0[:]) 83 | assertEquals(t, t0s, test.t0) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /device/keypair.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/cipher" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | "unsafe" 14 | 15 | "golang.zx2c4.com/wireguard/replay" 16 | ) 17 | 18 | /* Due to limitations in Go and /x/crypto there is currently 19 | * no way to ensure that key material is securely ereased in memory. 20 | * 21 | * Since this may harm the forward secrecy property, 22 | * we plan to resolve this issue; whenever Go allows us to do so. 23 | */ 24 | 25 | type Keypair struct { 26 | sendNonce uint64 // accessed atomically 27 | send cipher.AEAD 28 | receive cipher.AEAD 29 | replayFilter replay.Filter 30 | isInitiator bool 31 | created time.Time 32 | localIndex uint32 33 | remoteIndex uint32 34 | } 35 | 36 | type Keypairs struct { 37 | sync.RWMutex 38 | current *Keypair 39 | previous *Keypair 40 | next *Keypair 41 | } 42 | 43 | func (kp *Keypairs) storeNext(next *Keypair) { 44 | atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next)) 45 | } 46 | 47 | func (kp *Keypairs) loadNext() *Keypair { 48 | return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)))) 49 | } 50 | 51 | func (kp *Keypairs) Current() *Keypair { 52 | kp.RLock() 53 | defer kp.RUnlock() 54 | return kp.current 55 | } 56 | 57 | func (device *Device) DeleteKeypair(key *Keypair) { 58 | if key != nil { 59 | device.indexTable.Delete(key.localIndex) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /device/logger.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "log" 10 | "os" 11 | ) 12 | 13 | // A Logger provides logging for a Device. 14 | // The functions are Printf-style functions. 15 | // They must be safe for concurrent use. 16 | // They do not require a trailing newline in the format. 17 | // If nil, that level of logging will be silent. 18 | type Logger struct { 19 | Verbosef func(format string, args ...any) 20 | Errorf func(format string, args ...any) 21 | } 22 | 23 | // Log levels for use with NewLogger. 24 | const ( 25 | LogLevelSilent = iota 26 | LogLevelError 27 | LogLevelVerbose 28 | ) 29 | 30 | // Function for use in Logger for discarding logged lines. 31 | func DiscardLogf(format string, args ...any) {} 32 | 33 | // NewLogger constructs a Logger that writes to stdout. 34 | // It logs at the specified log level and above. 35 | // It decorates log lines with the log level, date, time, and prepend. 36 | func NewLogger(level int, prepend string) *Logger { 37 | logger := &Logger{DiscardLogf, DiscardLogf} 38 | logf := func(prefix string) func(string, ...any) { 39 | return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf 40 | } 41 | if level >= LogLevelVerbose { 42 | logger.Verbosef = logf("DEBUG") 43 | } 44 | if level >= LogLevelError { 45 | logger.Errorf = logf("ERROR") 46 | } 47 | return logger 48 | } 49 | -------------------------------------------------------------------------------- /device/misc.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "sync/atomic" 10 | ) 11 | 12 | /* Atomic Boolean */ 13 | 14 | const ( 15 | AtomicFalse = int32(iota) 16 | AtomicTrue 17 | ) 18 | 19 | type AtomicBool struct { 20 | int32 21 | } 22 | 23 | func (a *AtomicBool) Get() bool { 24 | return atomic.LoadInt32(&a.int32) == AtomicTrue 25 | } 26 | 27 | func (a *AtomicBool) Swap(val bool) bool { 28 | flag := AtomicFalse 29 | if val { 30 | flag = AtomicTrue 31 | } 32 | return atomic.SwapInt32(&a.int32, flag) == AtomicTrue 33 | } 34 | 35 | func (a *AtomicBool) Set(val bool) { 36 | flag := AtomicFalse 37 | if val { 38 | flag = AtomicTrue 39 | } 40 | atomic.StoreInt32(&a.int32, flag) 41 | } 42 | -------------------------------------------------------------------------------- /device/mobilequirks.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | // DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created, 9 | // though it will try to deal with it, and race maybe, if called after. 10 | func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { 11 | device.net.brokenRoaming = true 12 | device.peers.RLock() 13 | for _, peer := range device.peers.keyMap { 14 | peer.Lock() 15 | peer.disableRoaming = peer.endpoint != nil 16 | peer.Unlock() 17 | } 18 | device.peers.RUnlock() 19 | } 20 | -------------------------------------------------------------------------------- /device/noise-helpers.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/hmac" 10 | "crypto/rand" 11 | "crypto/subtle" 12 | "hash" 13 | 14 | "golang.org/x/crypto/blake2s" 15 | "golang.org/x/crypto/curve25519" 16 | ) 17 | 18 | /* KDF related functions. 19 | * HMAC-based Key Derivation Function (HKDF) 20 | * https://tools.ietf.org/html/rfc5869 21 | */ 22 | 23 | func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { 24 | mac := hmac.New(func() hash.Hash { 25 | h, _ := blake2s.New256(nil) 26 | return h 27 | }, key) 28 | mac.Write(in0) 29 | mac.Sum(sum[:0]) 30 | } 31 | 32 | func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { 33 | mac := hmac.New(func() hash.Hash { 34 | h, _ := blake2s.New256(nil) 35 | return h 36 | }, key) 37 | mac.Write(in0) 38 | mac.Write(in1) 39 | mac.Sum(sum[:0]) 40 | } 41 | 42 | func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { 43 | HMAC1(t0, key, input) 44 | HMAC1(t0, t0[:], []byte{0x1}) 45 | } 46 | 47 | func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { 48 | var prk [blake2s.Size]byte 49 | HMAC1(&prk, key, input) 50 | HMAC1(t0, prk[:], []byte{0x1}) 51 | HMAC2(t1, prk[:], t0[:], []byte{0x2}) 52 | setZero(prk[:]) 53 | } 54 | 55 | func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { 56 | var prk [blake2s.Size]byte 57 | HMAC1(&prk, key, input) 58 | HMAC1(t0, prk[:], []byte{0x1}) 59 | HMAC2(t1, prk[:], t0[:], []byte{0x2}) 60 | HMAC2(t2, prk[:], t1[:], []byte{0x3}) 61 | setZero(prk[:]) 62 | } 63 | 64 | func isZero(val []byte) bool { 65 | acc := 1 66 | for _, b := range val { 67 | acc &= subtle.ConstantTimeByteEq(b, 0) 68 | } 69 | return acc == 1 70 | } 71 | 72 | /* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ 73 | func setZero(arr []byte) { 74 | for i := range arr { 75 | arr[i] = 0 76 | } 77 | } 78 | 79 | func (sk *NoisePrivateKey) clamp() { 80 | sk[0] &= 248 81 | sk[31] = (sk[31] & 127) | 64 82 | } 83 | 84 | func newPrivateKey() (sk NoisePrivateKey, err error) { 85 | _, err = rand.Read(sk[:]) 86 | sk.clamp() 87 | return 88 | } 89 | 90 | func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { 91 | apk := (*[NoisePublicKeySize]byte)(&pk) 92 | ask := (*[NoisePrivateKeySize]byte)(sk) 93 | curve25519.ScalarBaseMult(apk, ask) 94 | return 95 | } 96 | 97 | func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { 98 | apk := (*[NoisePublicKeySize]byte)(&pk) 99 | ask := (*[NoisePrivateKeySize]byte)(sk) 100 | curve25519.ScalarMult(&ss, ask, apk) 101 | return ss 102 | } 103 | -------------------------------------------------------------------------------- /device/noise-types.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "crypto/subtle" 10 | "encoding/hex" 11 | "errors" 12 | ) 13 | 14 | const ( 15 | NoisePublicKeySize = 32 16 | NoisePrivateKeySize = 32 17 | NoisePresharedKeySize = 32 18 | ) 19 | 20 | type ( 21 | NoisePublicKey [NoisePublicKeySize]byte 22 | NoisePrivateKey [NoisePrivateKeySize]byte 23 | NoisePresharedKey [NoisePresharedKeySize]byte 24 | NoiseNonce uint64 // padded to 12-bytes 25 | ) 26 | 27 | func loadExactHex(dst []byte, src string) error { 28 | slice, err := hex.DecodeString(src) 29 | if err != nil { 30 | return err 31 | } 32 | if len(slice) != len(dst) { 33 | return errors.New("hex string does not fit the slice") 34 | } 35 | copy(dst, slice) 36 | return nil 37 | } 38 | 39 | func (key NoisePrivateKey) IsZero() bool { 40 | var zero NoisePrivateKey 41 | return key.Equals(zero) 42 | } 43 | 44 | func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { 45 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 46 | } 47 | 48 | func (key *NoisePrivateKey) FromHex(src string) (err error) { 49 | err = loadExactHex(key[:], src) 50 | key.clamp() 51 | return 52 | } 53 | 54 | func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { 55 | err = loadExactHex(key[:], src) 56 | if key.IsZero() { 57 | return 58 | } 59 | key.clamp() 60 | return 61 | } 62 | 63 | func (key *NoisePublicKey) FromHex(src string) error { 64 | return loadExactHex(key[:], src) 65 | } 66 | 67 | func (key NoisePublicKey) IsZero() bool { 68 | var zero NoisePublicKey 69 | return key.Equals(zero) 70 | } 71 | 72 | func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { 73 | return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 74 | } 75 | 76 | func (key *NoisePresharedKey) FromHex(src string) error { 77 | return loadExactHex(key[:], src) 78 | } 79 | -------------------------------------------------------------------------------- /device/noise_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "bytes" 10 | "encoding/binary" 11 | "testing" 12 | 13 | "golang.zx2c4.com/wireguard/conn" 14 | "golang.zx2c4.com/wireguard/tun/tuntest" 15 | ) 16 | 17 | func TestCurveWrappers(t *testing.T) { 18 | sk1, err := newPrivateKey() 19 | assertNil(t, err) 20 | 21 | sk2, err := newPrivateKey() 22 | assertNil(t, err) 23 | 24 | pk1 := sk1.publicKey() 25 | pk2 := sk2.publicKey() 26 | 27 | ss1 := sk1.sharedSecret(pk2) 28 | ss2 := sk2.sharedSecret(pk1) 29 | 30 | if ss1 != ss2 { 31 | t.Fatal("Failed to compute shared secet") 32 | } 33 | } 34 | 35 | func randDevice(t *testing.T) *Device { 36 | sk, err := newPrivateKey() 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | tun := tuntest.NewChannelTUN() 41 | logger := NewLogger(LogLevelError, "") 42 | device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) 43 | device.SetPrivateKey(sk) 44 | return device 45 | } 46 | 47 | func assertNil(t *testing.T, err error) { 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | } 52 | 53 | func assertEqual(t *testing.T, a, b []byte) { 54 | if !bytes.Equal(a, b) { 55 | t.Fatal(a, "!=", b) 56 | } 57 | } 58 | 59 | func TestNoiseHandshake(t *testing.T) { 60 | dev1 := randDevice(t) 61 | dev2 := randDevice(t) 62 | 63 | defer dev1.Close() 64 | defer dev2.Close() 65 | 66 | peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) 67 | if err != nil { 68 | t.Fatal(err) 69 | } 70 | peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | peer1.Start() 75 | peer2.Start() 76 | 77 | assertEqual( 78 | t, 79 | peer1.handshake.precomputedStaticStatic[:], 80 | peer2.handshake.precomputedStaticStatic[:], 81 | ) 82 | 83 | /* simulate handshake */ 84 | 85 | // initiation message 86 | 87 | t.Log("exchange initiation message") 88 | 89 | msg1, err := dev1.CreateMessageInitiation(peer2) 90 | assertNil(t, err) 91 | 92 | packet := make([]byte, 0, 256) 93 | writer := bytes.NewBuffer(packet) 94 | err = binary.Write(writer, binary.LittleEndian, msg1) 95 | assertNil(t, err) 96 | peer := dev2.ConsumeMessageInitiation(msg1) 97 | if peer == nil { 98 | t.Fatal("handshake failed at initiation message") 99 | } 100 | 101 | assertEqual( 102 | t, 103 | peer1.handshake.chainKey[:], 104 | peer2.handshake.chainKey[:], 105 | ) 106 | 107 | assertEqual( 108 | t, 109 | peer1.handshake.hash[:], 110 | peer2.handshake.hash[:], 111 | ) 112 | 113 | // response message 114 | 115 | t.Log("exchange response message") 116 | 117 | msg2, err := dev2.CreateMessageResponse(peer1) 118 | assertNil(t, err) 119 | 120 | peer = dev1.ConsumeMessageResponse(msg2) 121 | if peer == nil { 122 | t.Fatal("handshake failed at response message") 123 | } 124 | 125 | assertEqual( 126 | t, 127 | peer1.handshake.chainKey[:], 128 | peer2.handshake.chainKey[:], 129 | ) 130 | 131 | assertEqual( 132 | t, 133 | peer1.handshake.hash[:], 134 | peer2.handshake.hash[:], 135 | ) 136 | 137 | // key pairs 138 | 139 | t.Log("deriving keys") 140 | 141 | err = peer1.BeginSymmetricSession() 142 | if err != nil { 143 | t.Fatal("failed to derive keypair for peer 1", err) 144 | } 145 | 146 | err = peer2.BeginSymmetricSession() 147 | if err != nil { 148 | t.Fatal("failed to derive keypair for peer 2", err) 149 | } 150 | 151 | key1 := peer1.keypairs.loadNext() 152 | key2 := peer2.keypairs.current 153 | 154 | // encrypting / decryption test 155 | 156 | t.Log("test key pairs") 157 | 158 | func() { 159 | testMsg := []byte("wireguard test message 1") 160 | var err error 161 | var out []byte 162 | var nonce [12]byte 163 | out = key1.send.Seal(out, nonce[:], testMsg, nil) 164 | out, err = key2.receive.Open(out[:0], nonce[:], out, nil) 165 | assertNil(t, err) 166 | assertEqual(t, out, testMsg) 167 | }() 168 | 169 | func() { 170 | testMsg := []byte("wireguard test message 2") 171 | var err error 172 | var out []byte 173 | var nonce [12]byte 174 | out = key2.send.Seal(out, nonce[:], testMsg, nil) 175 | out, err = key1.receive.Open(out[:0], nonce[:], out, nil) 176 | assertNil(t, err) 177 | assertEqual(t, out, testMsg) 178 | }() 179 | } 180 | -------------------------------------------------------------------------------- /device/peer.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "container/list" 10 | "errors" 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | 15 | "golang.zx2c4.com/wireguard/conn" 16 | ) 17 | 18 | type Peer struct { 19 | isRunning AtomicBool 20 | sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer 21 | keypairs Keypairs 22 | handshake Handshake 23 | device *Device 24 | endpoint conn.Endpoint 25 | stopping sync.WaitGroup // routines pending stop 26 | 27 | // These fields are accessed with atomic operations, which must be 28 | // 64-bit aligned even on 32-bit platforms. Go guarantees that an 29 | // allocated struct will be 64-bit aligned. So we place 30 | // atomically-accessed fields up front, so that they can share in 31 | // this alignment before smaller fields throw it off. 32 | stats struct { 33 | txBytes uint64 // bytes send to peer (endpoint) 34 | rxBytes uint64 // bytes received from peer 35 | lastHandshakeNano int64 // nano seconds since epoch 36 | } 37 | 38 | disableRoaming bool 39 | 40 | timers struct { 41 | retransmitHandshake *Timer 42 | sendKeepalive *Timer 43 | newHandshake *Timer 44 | zeroKeyMaterial *Timer 45 | persistentKeepalive *Timer 46 | handshakeAttempts uint32 47 | needAnotherKeepalive AtomicBool 48 | sentLastMinuteHandshake AtomicBool 49 | } 50 | 51 | state struct { 52 | sync.Mutex // protects against concurrent Start/Stop 53 | } 54 | 55 | queue struct { 56 | staged chan *QueueOutboundElement // staged packets before a handshake is available 57 | outbound *autodrainingOutboundQueue // sequential ordering of udp transmission 58 | inbound *autodrainingInboundQueue // sequential ordering of tun writing 59 | } 60 | 61 | cookieGenerator CookieGenerator 62 | trieEntries list.List 63 | persistentKeepaliveInterval uint32 // accessed atomically 64 | } 65 | 66 | func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { 67 | if device.isClosed() { 68 | return nil, errors.New("device closed") 69 | } 70 | 71 | // lock resources 72 | device.staticIdentity.RLock() 73 | defer device.staticIdentity.RUnlock() 74 | 75 | device.peers.Lock() 76 | defer device.peers.Unlock() 77 | 78 | // check if over limit 79 | if len(device.peers.keyMap) >= MaxPeers { 80 | return nil, errors.New("too many peers") 81 | } 82 | 83 | // create peer 84 | peer := new(Peer) 85 | peer.Lock() 86 | defer peer.Unlock() 87 | 88 | peer.cookieGenerator.Init(pk) 89 | peer.device = device 90 | peer.queue.outbound = newAutodrainingOutboundQueue(device) 91 | peer.queue.inbound = newAutodrainingInboundQueue(device) 92 | peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) 93 | 94 | // map public key 95 | _, ok := device.peers.keyMap[pk] 96 | if ok { 97 | return nil, errors.New("adding existing peer") 98 | } 99 | 100 | // pre-compute DH 101 | handshake := &peer.handshake 102 | handshake.mutex.Lock() 103 | handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) 104 | handshake.remoteStatic = pk 105 | handshake.mutex.Unlock() 106 | 107 | // reset endpoint 108 | peer.endpoint = nil 109 | 110 | // init timers 111 | peer.timersInit() 112 | 113 | // add 114 | device.peers.keyMap[pk] = peer 115 | 116 | return peer, nil 117 | } 118 | 119 | func (peer *Peer) SendBuffer(buffer []byte) error { 120 | peer.device.net.RLock() 121 | defer peer.device.net.RUnlock() 122 | 123 | if peer.device.isClosed() { 124 | return nil 125 | } 126 | 127 | peer.RLock() 128 | defer peer.RUnlock() 129 | 130 | if peer.endpoint == nil { 131 | return errors.New("no known endpoint for peer") 132 | } 133 | 134 | err := peer.device.net.bind.Send(buffer, peer.endpoint) 135 | if err == nil { 136 | atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) 137 | } 138 | return err 139 | } 140 | 141 | func (peer *Peer) String() string { 142 | // The awful goo that follows is identical to: 143 | // 144 | // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) 145 | // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43] 146 | // return fmt.Sprintf("peer(%s)", abbreviatedKey) 147 | // 148 | // except that it is considerably more efficient. 149 | src := peer.handshake.remoteStatic 150 | b64 := func(input byte) byte { 151 | return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3) 152 | } 153 | b := []byte("peer(____…____)") 154 | const first = len("peer(") 155 | const second = len("peer(____…") 156 | b[first+0] = b64((src[0] >> 2) & 63) 157 | b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63) 158 | b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63) 159 | b[first+3] = b64(src[2] & 63) 160 | b[second+0] = b64(src[29] & 63) 161 | b[second+1] = b64((src[30] >> 2) & 63) 162 | b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63) 163 | b[second+3] = b64((src[31] << 2) & 63) 164 | return string(b) 165 | } 166 | 167 | func (peer *Peer) Start() { 168 | // should never start a peer on a closed device 169 | if peer.device.isClosed() { 170 | return 171 | } 172 | 173 | // prevent simultaneous start/stop operations 174 | peer.state.Lock() 175 | defer peer.state.Unlock() 176 | 177 | if peer.isRunning.Get() { 178 | return 179 | } 180 | 181 | device := peer.device 182 | device.log.Verbosef("%v - Starting", peer) 183 | 184 | // reset routine state 185 | peer.stopping.Wait() 186 | peer.stopping.Add(2) 187 | 188 | peer.handshake.mutex.Lock() 189 | peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) 190 | peer.handshake.mutex.Unlock() 191 | 192 | peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes 193 | 194 | peer.timersStart() 195 | 196 | device.flushInboundQueue(peer.queue.inbound) 197 | device.flushOutboundQueue(peer.queue.outbound) 198 | go peer.RoutineSequentialSender() 199 | go peer.RoutineSequentialReceiver() 200 | 201 | peer.isRunning.Set(true) 202 | } 203 | 204 | func (peer *Peer) ZeroAndFlushAll() { 205 | device := peer.device 206 | 207 | // clear key pairs 208 | 209 | keypairs := &peer.keypairs 210 | keypairs.Lock() 211 | device.DeleteKeypair(keypairs.previous) 212 | device.DeleteKeypair(keypairs.current) 213 | device.DeleteKeypair(keypairs.loadNext()) 214 | keypairs.previous = nil 215 | keypairs.current = nil 216 | keypairs.storeNext(nil) 217 | keypairs.Unlock() 218 | 219 | // clear handshake state 220 | 221 | handshake := &peer.handshake 222 | handshake.mutex.Lock() 223 | device.indexTable.Delete(handshake.localIndex) 224 | handshake.Clear() 225 | handshake.mutex.Unlock() 226 | 227 | peer.FlushStagedPackets() 228 | } 229 | 230 | func (peer *Peer) ExpireCurrentKeypairs() { 231 | handshake := &peer.handshake 232 | handshake.mutex.Lock() 233 | peer.device.indexTable.Delete(handshake.localIndex) 234 | handshake.Clear() 235 | peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) 236 | handshake.mutex.Unlock() 237 | 238 | keypairs := &peer.keypairs 239 | keypairs.Lock() 240 | if keypairs.current != nil { 241 | atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages) 242 | } 243 | if keypairs.next != nil { 244 | next := keypairs.loadNext() 245 | atomic.StoreUint64(&next.sendNonce, RejectAfterMessages) 246 | } 247 | keypairs.Unlock() 248 | } 249 | 250 | func (peer *Peer) Stop() { 251 | peer.state.Lock() 252 | defer peer.state.Unlock() 253 | 254 | if !peer.isRunning.Swap(false) { 255 | return 256 | } 257 | 258 | peer.device.log.Verbosef("%v - Stopping", peer) 259 | 260 | peer.timersStop() 261 | // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. 262 | peer.queue.inbound.c <- nil 263 | peer.queue.outbound.c <- nil 264 | peer.stopping.Wait() 265 | peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us 266 | 267 | peer.ZeroAndFlushAll() 268 | } 269 | 270 | func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { 271 | if peer.disableRoaming { 272 | return 273 | } 274 | peer.Lock() 275 | peer.endpoint = endpoint 276 | peer.Unlock() 277 | } 278 | -------------------------------------------------------------------------------- /device/pools.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "sync" 10 | "sync/atomic" 11 | ) 12 | 13 | type WaitPool struct { 14 | pool sync.Pool 15 | cond sync.Cond 16 | lock sync.Mutex 17 | count uint32 18 | max uint32 19 | } 20 | 21 | func NewWaitPool(max uint32, new func() any) *WaitPool { 22 | p := &WaitPool{pool: sync.Pool{New: new}, max: max} 23 | p.cond = sync.Cond{L: &p.lock} 24 | return p 25 | } 26 | 27 | func (p *WaitPool) Get() any { 28 | if p.max != 0 { 29 | p.lock.Lock() 30 | for atomic.LoadUint32(&p.count) >= p.max { 31 | p.cond.Wait() 32 | } 33 | atomic.AddUint32(&p.count, 1) 34 | p.lock.Unlock() 35 | } 36 | return p.pool.Get() 37 | } 38 | 39 | func (p *WaitPool) Put(x any) { 40 | p.pool.Put(x) 41 | if p.max == 0 { 42 | return 43 | } 44 | atomic.AddUint32(&p.count, ^uint32(0)) 45 | p.cond.Signal() 46 | } 47 | 48 | func (device *Device) PopulatePools() { 49 | device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { 50 | return new([MaxMessageSize]byte) 51 | }) 52 | device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 53 | return new(QueueInboundElement) 54 | }) 55 | device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 56 | return new(QueueOutboundElement) 57 | }) 58 | } 59 | 60 | func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { 61 | return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) 62 | } 63 | 64 | func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { 65 | device.pool.messageBuffers.Put(msg) 66 | } 67 | 68 | func (device *Device) GetInboundElement() *QueueInboundElement { 69 | return device.pool.inboundElements.Get().(*QueueInboundElement) 70 | } 71 | 72 | func (device *Device) PutInboundElement(elem *QueueInboundElement) { 73 | elem.clearPointers() 74 | device.pool.inboundElements.Put(elem) 75 | } 76 | 77 | func (device *Device) GetOutboundElement() *QueueOutboundElement { 78 | return device.pool.outboundElements.Get().(*QueueOutboundElement) 79 | } 80 | 81 | func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { 82 | elem.clearPointers() 83 | device.pool.outboundElements.Put(elem) 84 | } 85 | -------------------------------------------------------------------------------- /device/pools_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "math/rand" 10 | "runtime" 11 | "sync" 12 | "sync/atomic" 13 | "testing" 14 | "time" 15 | ) 16 | 17 | func TestWaitPool(t *testing.T) { 18 | t.Skip("Currently disabled") 19 | var wg sync.WaitGroup 20 | trials := int32(100000) 21 | if raceEnabled { 22 | // This test can be very slow with -race. 23 | trials /= 10 24 | } 25 | workers := runtime.NumCPU() + 2 26 | if workers-4 <= 0 { 27 | t.Skip("Not enough cores") 28 | } 29 | p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) 30 | wg.Add(workers) 31 | max := uint32(0) 32 | updateMax := func() { 33 | count := atomic.LoadUint32(&p.count) 34 | if count > p.max { 35 | t.Errorf("count (%d) > max (%d)", count, p.max) 36 | } 37 | for { 38 | old := atomic.LoadUint32(&max) 39 | if count <= old { 40 | break 41 | } 42 | if atomic.CompareAndSwapUint32(&max, old, count) { 43 | break 44 | } 45 | } 46 | } 47 | for i := 0; i < workers; i++ { 48 | go func() { 49 | defer wg.Done() 50 | for atomic.AddInt32(&trials, -1) > 0 { 51 | updateMax() 52 | x := p.Get() 53 | updateMax() 54 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 55 | updateMax() 56 | p.Put(x) 57 | updateMax() 58 | } 59 | }() 60 | } 61 | wg.Wait() 62 | if max != p.max { 63 | t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) 64 | } 65 | } 66 | 67 | func BenchmarkWaitPool(b *testing.B) { 68 | var wg sync.WaitGroup 69 | trials := int32(b.N) 70 | workers := runtime.NumCPU() + 2 71 | if workers-4 <= 0 { 72 | b.Skip("Not enough cores") 73 | } 74 | p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) 75 | wg.Add(workers) 76 | b.ResetTimer() 77 | for i := 0; i < workers; i++ { 78 | go func() { 79 | defer wg.Done() 80 | for atomic.AddInt32(&trials, -1) > 0 { 81 | x := p.Get() 82 | time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) 83 | p.Put(x) 84 | } 85 | }() 86 | } 87 | wg.Wait() 88 | } 89 | -------------------------------------------------------------------------------- /device/queueconstants_android.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | /* Reduce memory consumption for Android */ 9 | 10 | const ( 11 | QueueStagedSize = 128 12 | QueueOutboundSize = 1024 13 | QueueInboundSize = 1024 14 | QueueHandshakeSize = 1024 15 | MaxSegmentSize = 2200 16 | PreallocatedBuffersPerPool = 4096 17 | ) 18 | -------------------------------------------------------------------------------- /device/queueconstants_default.go: -------------------------------------------------------------------------------- 1 | //go:build !android && !ios && !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const ( 11 | QueueStagedSize = 128 12 | QueueOutboundSize = 1024 13 | QueueInboundSize = 1024 14 | QueueHandshakeSize = 1024 15 | MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram 16 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth 17 | ) 18 | -------------------------------------------------------------------------------- /device/queueconstants_ios.go: -------------------------------------------------------------------------------- 1 | //go:build ios 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | // Fit within memory limits for iOS's Network Extension API, which has stricter requirements. 11 | // These are vars instead of consts, because heavier network extensions might want to reduce 12 | // them further. 13 | var ( 14 | QueueStagedSize = 128 15 | QueueOutboundSize = 1024 16 | QueueInboundSize = 1024 17 | QueueHandshakeSize = 1024 18 | PreallocatedBuffersPerPool uint32 = 1024 19 | ) 20 | 21 | const MaxSegmentSize = 1700 22 | -------------------------------------------------------------------------------- /device/queueconstants_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | const ( 9 | QueueStagedSize = 128 10 | QueueOutboundSize = 1024 11 | QueueInboundSize = 1024 12 | QueueHandshakeSize = 1024 13 | MaxSegmentSize = 2048 - 32 // largest possible UDP datagram 14 | PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth 15 | ) 16 | -------------------------------------------------------------------------------- /device/race_disabled_test.go: -------------------------------------------------------------------------------- 1 | //go:build !race 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const raceEnabled = false 11 | -------------------------------------------------------------------------------- /device/race_enabled_test.go: -------------------------------------------------------------------------------- 1 | //go:build race 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package device 9 | 10 | const raceEnabled = true 11 | -------------------------------------------------------------------------------- /device/sticky_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | 3 | package device 4 | 5 | import ( 6 | "golang.zx2c4.com/wireguard/conn" 7 | "golang.zx2c4.com/wireguard/rwcancel" 8 | ) 9 | 10 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 11 | return nil, nil 12 | } 13 | -------------------------------------------------------------------------------- /device/sticky_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | * 5 | * This implements userspace semantics of "sticky sockets", modeled after 6 | * WireGuard's kernelspace implementation. This is more or less a straight port 7 | * of the sticky-sockets.c example code: 8 | * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c 9 | * 10 | * Currently there is no way to achieve this within the net package: 11 | * See e.g. https://github.com/golang/go/issues/17930 12 | * So this code is remains platform dependent. 13 | */ 14 | 15 | package device 16 | 17 | import ( 18 | "sync" 19 | "unsafe" 20 | 21 | "golang.org/x/sys/unix" 22 | 23 | "golang.zx2c4.com/wireguard/conn" 24 | "golang.zx2c4.com/wireguard/rwcancel" 25 | ) 26 | 27 | func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 28 | if _, ok := bind.(*conn.LinuxSocketBind); !ok { 29 | return nil, nil 30 | } 31 | 32 | netlinkSock, err := createNetlinkRouteSocket() 33 | if err != nil { 34 | return nil, err 35 | } 36 | netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) 37 | if err != nil { 38 | unix.Close(netlinkSock) 39 | return nil, err 40 | } 41 | 42 | go device.routineRouteListener(bind, netlinkSock, netlinkCancel) 43 | 44 | return netlinkCancel, nil 45 | } 46 | 47 | func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { 48 | type peerEndpointPtr struct { 49 | peer *Peer 50 | endpoint *conn.Endpoint 51 | } 52 | var reqPeer map[uint32]peerEndpointPtr 53 | var reqPeerLock sync.Mutex 54 | 55 | defer netlinkCancel.Close() 56 | defer unix.Close(netlinkSock) 57 | 58 | for msg := make([]byte, 1<<16); ; { 59 | var err error 60 | var msgn int 61 | for { 62 | msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) 63 | if err == nil || !rwcancel.RetryAfterError(err) { 64 | break 65 | } 66 | if !netlinkCancel.ReadyRead() { 67 | return 68 | } 69 | } 70 | if err != nil { 71 | return 72 | } 73 | 74 | for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { 75 | 76 | hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) 77 | 78 | if uint(hdr.Len) > uint(len(remain)) { 79 | break 80 | } 81 | 82 | switch hdr.Type { 83 | case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: 84 | if hdr.Seq <= MaxPeers && hdr.Seq > 0 { 85 | if uint(len(remain)) < uint(hdr.Len) { 86 | break 87 | } 88 | if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { 89 | attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] 90 | for { 91 | if uint(len(attr)) < uint(unix.SizeofRtAttr) { 92 | break 93 | } 94 | attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) 95 | if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { 96 | break 97 | } 98 | if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { 99 | ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) 100 | reqPeerLock.Lock() 101 | if reqPeer == nil { 102 | reqPeerLock.Unlock() 103 | break 104 | } 105 | pePtr, ok := reqPeer[hdr.Seq] 106 | reqPeerLock.Unlock() 107 | if !ok { 108 | break 109 | } 110 | pePtr.peer.Lock() 111 | if &pePtr.peer.endpoint != pePtr.endpoint { 112 | pePtr.peer.Unlock() 113 | break 114 | } 115 | if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx { 116 | pePtr.peer.Unlock() 117 | break 118 | } 119 | pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc() 120 | pePtr.peer.Unlock() 121 | } 122 | attr = attr[attrhdr.Len:] 123 | } 124 | } 125 | break 126 | } 127 | reqPeerLock.Lock() 128 | reqPeer = make(map[uint32]peerEndpointPtr) 129 | reqPeerLock.Unlock() 130 | go func() { 131 | device.peers.RLock() 132 | i := uint32(1) 133 | for _, peer := range device.peers.keyMap { 134 | peer.RLock() 135 | if peer.endpoint == nil { 136 | peer.RUnlock() 137 | continue 138 | } 139 | nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint) 140 | if nativeEP == nil { 141 | peer.RUnlock() 142 | continue 143 | } 144 | if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { 145 | peer.RUnlock() 146 | break 147 | } 148 | nlmsg := struct { 149 | hdr unix.NlMsghdr 150 | msg unix.RtMsg 151 | dsthdr unix.RtAttr 152 | dst [4]byte 153 | srchdr unix.RtAttr 154 | src [4]byte 155 | markhdr unix.RtAttr 156 | mark uint32 157 | }{ 158 | unix.NlMsghdr{ 159 | Type: uint16(unix.RTM_GETROUTE), 160 | Flags: unix.NLM_F_REQUEST, 161 | Seq: i, 162 | }, 163 | unix.RtMsg{ 164 | Family: unix.AF_INET, 165 | Dst_len: 32, 166 | Src_len: 32, 167 | }, 168 | unix.RtAttr{ 169 | Len: 8, 170 | Type: unix.RTA_DST, 171 | }, 172 | nativeEP.Dst4().Addr, 173 | unix.RtAttr{ 174 | Len: 8, 175 | Type: unix.RTA_SRC, 176 | }, 177 | nativeEP.Src4().Src, 178 | unix.RtAttr{ 179 | Len: 8, 180 | Type: unix.RTA_MARK, 181 | }, 182 | device.net.fwmark, 183 | } 184 | nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) 185 | reqPeerLock.Lock() 186 | reqPeer[i] = peerEndpointPtr{ 187 | peer: peer, 188 | endpoint: &peer.endpoint, 189 | } 190 | reqPeerLock.Unlock() 191 | peer.RUnlock() 192 | i++ 193 | _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) 194 | if err != nil { 195 | break 196 | } 197 | } 198 | device.peers.RUnlock() 199 | }() 200 | } 201 | remain = remain[hdr.Len:] 202 | } 203 | } 204 | } 205 | 206 | func createNetlinkRouteSocket() (int, error) { 207 | sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) 208 | if err != nil { 209 | return -1, err 210 | } 211 | saddr := &unix.SockaddrNetlink{ 212 | Family: unix.AF_NETLINK, 213 | Groups: unix.RTMGRP_IPV4_ROUTE, 214 | } 215 | err = unix.Bind(sock, saddr) 216 | if err != nil { 217 | unix.Close(sock) 218 | return -1, err 219 | } 220 | return sock, nil 221 | } 222 | -------------------------------------------------------------------------------- /device/timers.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | * 5 | * This is based heavily on timers.c from the kernel implementation. 6 | */ 7 | 8 | package device 9 | 10 | import ( 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | _ "unsafe" 15 | ) 16 | 17 | //go:linkname fastrandn runtime.fastrandn 18 | func fastrandn(n uint32) uint32 19 | 20 | // A Timer manages time-based aspects of the WireGuard protocol. 21 | // Timer roughly copies the interface of the Linux kernel's struct timer_list. 22 | type Timer struct { 23 | *time.Timer 24 | modifyingLock sync.RWMutex 25 | runningLock sync.Mutex 26 | isPending bool 27 | } 28 | 29 | func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { 30 | timer := &Timer{} 31 | timer.Timer = time.AfterFunc(time.Hour, func() { 32 | timer.runningLock.Lock() 33 | defer timer.runningLock.Unlock() 34 | 35 | timer.modifyingLock.Lock() 36 | if !timer.isPending { 37 | timer.modifyingLock.Unlock() 38 | return 39 | } 40 | timer.isPending = false 41 | timer.modifyingLock.Unlock() 42 | 43 | expirationFunction(peer) 44 | }) 45 | timer.Stop() 46 | return timer 47 | } 48 | 49 | func (timer *Timer) Mod(d time.Duration) { 50 | timer.modifyingLock.Lock() 51 | timer.isPending = true 52 | timer.Reset(d) 53 | timer.modifyingLock.Unlock() 54 | } 55 | 56 | func (timer *Timer) Del() { 57 | timer.modifyingLock.Lock() 58 | timer.isPending = false 59 | timer.Stop() 60 | timer.modifyingLock.Unlock() 61 | } 62 | 63 | func (timer *Timer) DelSync() { 64 | timer.Del() 65 | timer.runningLock.Lock() 66 | timer.Del() 67 | timer.runningLock.Unlock() 68 | } 69 | 70 | func (timer *Timer) IsPending() bool { 71 | timer.modifyingLock.RLock() 72 | defer timer.modifyingLock.RUnlock() 73 | return timer.isPending 74 | } 75 | 76 | func (peer *Peer) timersActive() bool { 77 | return peer.isRunning.Get() && peer.device != nil && peer.device.isUp() 78 | } 79 | 80 | func expiredRetransmitHandshake(peer *Peer) { 81 | if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { 82 | peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) 83 | 84 | if peer.timersActive() { 85 | peer.timers.sendKeepalive.Del() 86 | } 87 | 88 | /* We drop all packets without a keypair and don't try again, 89 | * if we try unsuccessfully for too long to make a handshake. 90 | */ 91 | peer.FlushStagedPackets() 92 | 93 | /* We set a timer for destroying any residue that might be left 94 | * of a partial exchange. 95 | */ 96 | if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() { 97 | peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) 98 | } 99 | } else { 100 | atomic.AddUint32(&peer.timers.handshakeAttempts, 1) 101 | peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) 102 | 103 | /* We clear the endpoint address src address, in case this is the cause of trouble. */ 104 | peer.Lock() 105 | if peer.endpoint != nil { 106 | peer.endpoint.ClearSrc() 107 | } 108 | peer.Unlock() 109 | 110 | peer.SendHandshakeInitiation(true) 111 | } 112 | } 113 | 114 | func expiredSendKeepalive(peer *Peer) { 115 | peer.SendKeepalive() 116 | if peer.timers.needAnotherKeepalive.Get() { 117 | peer.timers.needAnotherKeepalive.Set(false) 118 | if peer.timersActive() { 119 | peer.timers.sendKeepalive.Mod(KeepaliveTimeout) 120 | } 121 | } 122 | } 123 | 124 | func expiredNewHandshake(peer *Peer) { 125 | peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) 126 | /* We clear the endpoint address src address, in case this is the cause of trouble. */ 127 | peer.Lock() 128 | if peer.endpoint != nil { 129 | peer.endpoint.ClearSrc() 130 | } 131 | peer.Unlock() 132 | peer.SendHandshakeInitiation(false) 133 | } 134 | 135 | func expiredZeroKeyMaterial(peer *Peer) { 136 | peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds())) 137 | peer.ZeroAndFlushAll() 138 | } 139 | 140 | func expiredPersistentKeepalive(peer *Peer) { 141 | if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { 142 | peer.SendKeepalive() 143 | } 144 | } 145 | 146 | /* Should be called after an authenticated data packet is sent. */ 147 | func (peer *Peer) timersDataSent() { 148 | if peer.timersActive() && !peer.timers.newHandshake.IsPending() { 149 | peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) 150 | } 151 | } 152 | 153 | /* Should be called after an authenticated data packet is received. */ 154 | func (peer *Peer) timersDataReceived() { 155 | if peer.timersActive() { 156 | if !peer.timers.sendKeepalive.IsPending() { 157 | peer.timers.sendKeepalive.Mod(KeepaliveTimeout) 158 | } else { 159 | peer.timers.needAnotherKeepalive.Set(true) 160 | } 161 | } 162 | } 163 | 164 | /* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ 165 | func (peer *Peer) timersAnyAuthenticatedPacketSent() { 166 | if peer.timersActive() { 167 | peer.timers.sendKeepalive.Del() 168 | } 169 | } 170 | 171 | /* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ 172 | func (peer *Peer) timersAnyAuthenticatedPacketReceived() { 173 | if peer.timersActive() { 174 | peer.timers.newHandshake.Del() 175 | } 176 | } 177 | 178 | /* Should be called after a handshake initiation message is sent. */ 179 | func (peer *Peer) timersHandshakeInitiated() { 180 | if peer.timersActive() { 181 | peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) 182 | } 183 | } 184 | 185 | /* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ 186 | func (peer *Peer) timersHandshakeComplete() { 187 | if peer.timersActive() { 188 | peer.timers.retransmitHandshake.Del() 189 | } 190 | atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) 191 | peer.timers.sentLastMinuteHandshake.Set(false) 192 | atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) 193 | } 194 | 195 | /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ 196 | func (peer *Peer) timersSessionDerived() { 197 | if peer.timersActive() { 198 | peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) 199 | } 200 | } 201 | 202 | /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ 203 | func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { 204 | keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval) 205 | if keepalive > 0 && peer.timersActive() { 206 | peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) 207 | } 208 | } 209 | 210 | func (peer *Peer) timersInit() { 211 | peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake) 212 | peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive) 213 | peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) 214 | peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) 215 | peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) 216 | } 217 | 218 | func (peer *Peer) timersStart() { 219 | atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) 220 | peer.timers.sentLastMinuteHandshake.Set(false) 221 | peer.timers.needAnotherKeepalive.Set(false) 222 | } 223 | 224 | func (peer *Peer) timersStop() { 225 | peer.timers.retransmitHandshake.DelSync() 226 | peer.timers.sendKeepalive.DelSync() 227 | peer.timers.newHandshake.DelSync() 228 | peer.timers.zeroKeyMaterial.DelSync() 229 | peer.timers.persistentKeepalive.DelSync() 230 | } 231 | -------------------------------------------------------------------------------- /device/tun.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package device 7 | 8 | import ( 9 | "fmt" 10 | "sync/atomic" 11 | 12 | "golang.zx2c4.com/wireguard/tun" 13 | ) 14 | 15 | const DefaultMTU = 1420 16 | 17 | func (device *Device) RoutineTUNEventReader() { 18 | device.log.Verbosef("Routine: event worker - started") 19 | 20 | for event := range device.tun.device.Events() { 21 | if event&tun.EventMTUUpdate != 0 { 22 | mtu, err := device.tun.device.MTU() 23 | if err != nil { 24 | device.log.Errorf("Failed to load updated MTU of device: %v", err) 25 | continue 26 | } 27 | if mtu < 0 { 28 | device.log.Errorf("MTU not updated to negative value: %v", mtu) 29 | continue 30 | } 31 | var tooLarge string 32 | if mtu > MaxContentSize { 33 | tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) 34 | mtu = MaxContentSize 35 | } 36 | old := atomic.SwapInt32(&device.tun.mtu, int32(mtu)) 37 | if int(old) != mtu { 38 | device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) 39 | } 40 | } 41 | 42 | if event&tun.EventUp != 0 { 43 | device.log.Verbosef("Interface up requested") 44 | device.Up() 45 | } 46 | 47 | if event&tun.EventDown != 0 { 48 | device.log.Verbosef("Interface down requested") 49 | device.Down() 50 | } 51 | } 52 | 53 | device.log.Verbosef("Routine: event worker - stopped") 54 | } 55 | -------------------------------------------------------------------------------- /format_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | package main 6 | 7 | import ( 8 | "bytes" 9 | "go/format" 10 | "io/fs" 11 | "os" 12 | "path/filepath" 13 | "runtime" 14 | "sync" 15 | "testing" 16 | ) 17 | 18 | func TestFormatting(t *testing.T) { 19 | var wg sync.WaitGroup 20 | filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { 21 | if err != nil { 22 | t.Errorf("unable to walk %s: %v", path, err) 23 | return nil 24 | } 25 | if d.IsDir() || filepath.Ext(path) != ".go" { 26 | return nil 27 | } 28 | wg.Add(1) 29 | go func(path string) { 30 | defer wg.Done() 31 | src, err := os.ReadFile(path) 32 | if err != nil { 33 | t.Errorf("unable to read %s: %v", path, err) 34 | return 35 | } 36 | if runtime.GOOS == "windows" { 37 | src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'}) 38 | } 39 | formatted, err := format.Source(src) 40 | if err != nil { 41 | t.Errorf("unable to format %s: %v", path, err) 42 | return 43 | } 44 | if !bytes.Equal(src, formatted) { 45 | t.Errorf("unformatted code: %s", path) 46 | } 47 | }(path) 48 | return nil 49 | }) 50 | wg.Wait() 51 | } 52 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module golang.zx2c4.com/wireguard 2 | 3 | go 1.18 4 | 5 | require ( 6 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd 7 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f 8 | golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 9 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 10 | ) 11 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38= 2 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 3 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= 4 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= 5 | golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 h1:A9i04dxx7Cribqbs8jf3FQLogkL/CV2YN7hj9KWJCkc= 6 | golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 7 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= 8 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= 9 | -------------------------------------------------------------------------------- /ipc/namedpipe/file.go: -------------------------------------------------------------------------------- 1 | // Copyright 2021 The Go Authors. All rights reserved. 2 | // Copyright 2015 Microsoft 3 | // Use of this source code is governed by a BSD-style 4 | // license that can be found in the LICENSE file. 5 | 6 | //go:build windows 7 | // +build windows 8 | 9 | package namedpipe 10 | 11 | import ( 12 | "io" 13 | "os" 14 | "runtime" 15 | "sync" 16 | "sync/atomic" 17 | "time" 18 | "unsafe" 19 | 20 | "golang.org/x/sys/windows" 21 | ) 22 | 23 | type timeoutChan chan struct{} 24 | 25 | var ( 26 | ioInitOnce sync.Once 27 | ioCompletionPort windows.Handle 28 | ) 29 | 30 | // ioResult contains the result of an asynchronous IO operation 31 | type ioResult struct { 32 | bytes uint32 33 | err error 34 | } 35 | 36 | // ioOperation represents an outstanding asynchronous Win32 IO 37 | type ioOperation struct { 38 | o windows.Overlapped 39 | ch chan ioResult 40 | } 41 | 42 | func initIo() { 43 | h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) 44 | if err != nil { 45 | panic(err) 46 | } 47 | ioCompletionPort = h 48 | go ioCompletionProcessor(h) 49 | } 50 | 51 | // file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. 52 | // It takes ownership of this handle and will close it if it is garbage collected. 53 | type file struct { 54 | handle windows.Handle 55 | wg sync.WaitGroup 56 | wgLock sync.RWMutex 57 | closing uint32 // used as atomic boolean 58 | socket bool 59 | readDeadline deadlineHandler 60 | writeDeadline deadlineHandler 61 | } 62 | 63 | type deadlineHandler struct { 64 | setLock sync.Mutex 65 | channel timeoutChan 66 | channelLock sync.RWMutex 67 | timer *time.Timer 68 | timedout uint32 // used as atomic boolean 69 | } 70 | 71 | // makeFile makes a new file from an existing file handle 72 | func makeFile(h windows.Handle) (*file, error) { 73 | f := &file{handle: h} 74 | ioInitOnce.Do(initIo) 75 | _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) 76 | if err != nil { 77 | return nil, err 78 | } 79 | err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) 80 | if err != nil { 81 | return nil, err 82 | } 83 | f.readDeadline.channel = make(timeoutChan) 84 | f.writeDeadline.channel = make(timeoutChan) 85 | return f, nil 86 | } 87 | 88 | // closeHandle closes the resources associated with a Win32 handle 89 | func (f *file) closeHandle() { 90 | f.wgLock.Lock() 91 | // Atomically set that we are closing, releasing the resources only once. 92 | if atomic.SwapUint32(&f.closing, 1) == 0 { 93 | f.wgLock.Unlock() 94 | // cancel all IO and wait for it to complete 95 | windows.CancelIoEx(f.handle, nil) 96 | f.wg.Wait() 97 | // at this point, no new IO can start 98 | windows.Close(f.handle) 99 | f.handle = 0 100 | } else { 101 | f.wgLock.Unlock() 102 | } 103 | } 104 | 105 | // Close closes a file. 106 | func (f *file) Close() error { 107 | f.closeHandle() 108 | return nil 109 | } 110 | 111 | // prepareIo prepares for a new IO operation. 112 | // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. 113 | func (f *file) prepareIo() (*ioOperation, error) { 114 | f.wgLock.RLock() 115 | if atomic.LoadUint32(&f.closing) == 1 { 116 | f.wgLock.RUnlock() 117 | return nil, os.ErrClosed 118 | } 119 | f.wg.Add(1) 120 | f.wgLock.RUnlock() 121 | c := &ioOperation{} 122 | c.ch = make(chan ioResult) 123 | return c, nil 124 | } 125 | 126 | // ioCompletionProcessor processes completed async IOs forever 127 | func ioCompletionProcessor(h windows.Handle) { 128 | for { 129 | var bytes uint32 130 | var key uintptr 131 | var op *ioOperation 132 | err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) 133 | if op == nil { 134 | panic(err) 135 | } 136 | op.ch <- ioResult{bytes, err} 137 | } 138 | } 139 | 140 | // asyncIo processes the return value from ReadFile or WriteFile, blocking until 141 | // the operation has actually completed. 142 | func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { 143 | if err != windows.ERROR_IO_PENDING { 144 | return int(bytes), err 145 | } 146 | 147 | if atomic.LoadUint32(&f.closing) == 1 { 148 | windows.CancelIoEx(f.handle, &c.o) 149 | } 150 | 151 | var timeout timeoutChan 152 | if d != nil { 153 | d.channelLock.Lock() 154 | timeout = d.channel 155 | d.channelLock.Unlock() 156 | } 157 | 158 | var r ioResult 159 | select { 160 | case r = <-c.ch: 161 | err = r.err 162 | if err == windows.ERROR_OPERATION_ABORTED { 163 | if atomic.LoadUint32(&f.closing) == 1 { 164 | err = os.ErrClosed 165 | } 166 | } else if err != nil && f.socket { 167 | // err is from Win32. Query the overlapped structure to get the winsock error. 168 | var bytes, flags uint32 169 | err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) 170 | } 171 | case <-timeout: 172 | windows.CancelIoEx(f.handle, &c.o) 173 | r = <-c.ch 174 | err = r.err 175 | if err == windows.ERROR_OPERATION_ABORTED { 176 | err = os.ErrDeadlineExceeded 177 | } 178 | } 179 | 180 | // runtime.KeepAlive is needed, as c is passed via native 181 | // code to ioCompletionProcessor, c must remain alive 182 | // until the channel read is complete. 183 | runtime.KeepAlive(c) 184 | return int(r.bytes), err 185 | } 186 | 187 | // Read reads from a file handle. 188 | func (f *file) Read(b []byte) (int, error) { 189 | c, err := f.prepareIo() 190 | if err != nil { 191 | return 0, err 192 | } 193 | defer f.wg.Done() 194 | 195 | if atomic.LoadUint32(&f.readDeadline.timedout) == 1 { 196 | return 0, os.ErrDeadlineExceeded 197 | } 198 | 199 | var bytes uint32 200 | err = windows.ReadFile(f.handle, b, &bytes, &c.o) 201 | n, err := f.asyncIo(c, &f.readDeadline, bytes, err) 202 | runtime.KeepAlive(b) 203 | 204 | // Handle EOF conditions. 205 | if err == nil && n == 0 && len(b) != 0 { 206 | return 0, io.EOF 207 | } else if err == windows.ERROR_BROKEN_PIPE { 208 | return 0, io.EOF 209 | } else { 210 | return n, err 211 | } 212 | } 213 | 214 | // Write writes to a file handle. 215 | func (f *file) Write(b []byte) (int, error) { 216 | c, err := f.prepareIo() 217 | if err != nil { 218 | return 0, err 219 | } 220 | defer f.wg.Done() 221 | 222 | if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 { 223 | return 0, os.ErrDeadlineExceeded 224 | } 225 | 226 | var bytes uint32 227 | err = windows.WriteFile(f.handle, b, &bytes, &c.o) 228 | n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) 229 | runtime.KeepAlive(b) 230 | return n, err 231 | } 232 | 233 | func (f *file) SetReadDeadline(deadline time.Time) error { 234 | return f.readDeadline.set(deadline) 235 | } 236 | 237 | func (f *file) SetWriteDeadline(deadline time.Time) error { 238 | return f.writeDeadline.set(deadline) 239 | } 240 | 241 | func (f *file) Flush() error { 242 | return windows.FlushFileBuffers(f.handle) 243 | } 244 | 245 | func (f *file) Fd() uintptr { 246 | return uintptr(f.handle) 247 | } 248 | 249 | func (d *deadlineHandler) set(deadline time.Time) error { 250 | d.setLock.Lock() 251 | defer d.setLock.Unlock() 252 | 253 | if d.timer != nil { 254 | if !d.timer.Stop() { 255 | <-d.channel 256 | } 257 | d.timer = nil 258 | } 259 | atomic.StoreUint32(&d.timedout, 0) 260 | 261 | select { 262 | case <-d.channel: 263 | d.channelLock.Lock() 264 | d.channel = make(chan struct{}) 265 | d.channelLock.Unlock() 266 | default: 267 | } 268 | 269 | if deadline.IsZero() { 270 | return nil 271 | } 272 | 273 | timeoutIO := func() { 274 | atomic.StoreUint32(&d.timedout, 1) 275 | close(d.channel) 276 | } 277 | 278 | now := time.Now() 279 | duration := deadline.Sub(now) 280 | if deadline.After(now) { 281 | // Deadline is in the future, set a timer to wait 282 | d.timer = time.AfterFunc(duration, timeoutIO) 283 | } else { 284 | // Deadline is in the past. Cancel all pending IO now. 285 | timeoutIO() 286 | } 287 | return nil 288 | } 289 | -------------------------------------------------------------------------------- /ipc/uapi_bsd.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || freebsd || openbsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package ipc 9 | 10 | import ( 11 | "errors" 12 | "net" 13 | "os" 14 | "unsafe" 15 | 16 | "golang.org/x/sys/unix" 17 | ) 18 | 19 | type UAPIListener struct { 20 | listener net.Listener // unix socket listener 21 | connNew chan net.Conn 22 | connErr chan error 23 | kqueueFd int 24 | keventFd int 25 | } 26 | 27 | func (l *UAPIListener) Accept() (net.Conn, error) { 28 | for { 29 | select { 30 | case conn := <-l.connNew: 31 | return conn, nil 32 | 33 | case err := <-l.connErr: 34 | return nil, err 35 | } 36 | } 37 | } 38 | 39 | func (l *UAPIListener) Close() error { 40 | err1 := unix.Close(l.kqueueFd) 41 | err2 := unix.Close(l.keventFd) 42 | err3 := l.listener.Close() 43 | if err1 != nil { 44 | return err1 45 | } 46 | if err2 != nil { 47 | return err2 48 | } 49 | return err3 50 | } 51 | 52 | func (l *UAPIListener) Addr() net.Addr { 53 | return l.listener.Addr() 54 | } 55 | 56 | func UAPIListen(name string, file *os.File) (net.Listener, error) { 57 | // wrap file in listener 58 | 59 | listener, err := net.FileListener(file) 60 | if err != nil { 61 | return nil, err 62 | } 63 | 64 | uapi := &UAPIListener{ 65 | listener: listener, 66 | connNew: make(chan net.Conn, 1), 67 | connErr: make(chan error, 1), 68 | } 69 | 70 | if unixListener, ok := listener.(*net.UnixListener); ok { 71 | unixListener.SetUnlinkOnClose(true) 72 | } 73 | 74 | socketPath := sockPath(name) 75 | 76 | // watch for deletion of socket 77 | 78 | uapi.kqueueFd, err = unix.Kqueue() 79 | if err != nil { 80 | return nil, err 81 | } 82 | uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0) 83 | if err != nil { 84 | unix.Close(uapi.kqueueFd) 85 | return nil, err 86 | } 87 | 88 | go func(l *UAPIListener) { 89 | event := unix.Kevent_t{ 90 | Filter: unix.EVFILT_VNODE, 91 | Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT, 92 | Fflags: unix.NOTE_WRITE, 93 | } 94 | // Allow this assignment to work with both the 32-bit and 64-bit version 95 | // of the above struct. If you know another way, please submit a patch. 96 | *(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd) 97 | events := make([]unix.Kevent_t, 1) 98 | n := 1 99 | var kerr error 100 | for { 101 | // start with lstat to avoid race condition 102 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) { 103 | l.connErr <- err 104 | return 105 | } 106 | if (kerr != nil || n != 1) && kerr != unix.EINTR { 107 | if kerr != nil { 108 | l.connErr <- kerr 109 | } else { 110 | l.connErr <- errors.New("kqueue returned empty") 111 | } 112 | return 113 | } 114 | n, kerr = unix.Kevent(uapi.kqueueFd, []unix.Kevent_t{event}, events, nil) 115 | } 116 | }(uapi) 117 | 118 | // watch for new connections 119 | 120 | go func(l *UAPIListener) { 121 | for { 122 | conn, err := l.listener.Accept() 123 | if err != nil { 124 | l.connErr <- err 125 | break 126 | } 127 | l.connNew <- conn 128 | } 129 | }(uapi) 130 | 131 | return uapi, nil 132 | } 133 | -------------------------------------------------------------------------------- /ipc/uapi_js.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | // Made up sentinel error codes for the js/wasm platform. 9 | const ( 10 | IpcErrorIO = 1 11 | IpcErrorInvalid = 2 12 | IpcErrorPortInUse = 3 13 | IpcErrorUnknown = 4 14 | IpcErrorProtocol = 5 15 | ) 16 | -------------------------------------------------------------------------------- /ipc/uapi_linux.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | import ( 9 | "net" 10 | "os" 11 | 12 | "golang.org/x/sys/unix" 13 | "golang.zx2c4.com/wireguard/rwcancel" 14 | ) 15 | 16 | type UAPIListener struct { 17 | listener net.Listener // unix socket listener 18 | connNew chan net.Conn 19 | connErr chan error 20 | inotifyFd int 21 | inotifyRWCancel *rwcancel.RWCancel 22 | } 23 | 24 | func (l *UAPIListener) Accept() (net.Conn, error) { 25 | for { 26 | select { 27 | case conn := <-l.connNew: 28 | return conn, nil 29 | 30 | case err := <-l.connErr: 31 | return nil, err 32 | } 33 | } 34 | } 35 | 36 | func (l *UAPIListener) Close() error { 37 | err1 := unix.Close(l.inotifyFd) 38 | err2 := l.inotifyRWCancel.Cancel() 39 | err3 := l.listener.Close() 40 | if err1 != nil { 41 | return err1 42 | } 43 | if err2 != nil { 44 | return err2 45 | } 46 | return err3 47 | } 48 | 49 | func (l *UAPIListener) Addr() net.Addr { 50 | return l.listener.Addr() 51 | } 52 | 53 | func UAPIListen(name string, file *os.File) (net.Listener, error) { 54 | // wrap file in listener 55 | 56 | listener, err := net.FileListener(file) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | if unixListener, ok := listener.(*net.UnixListener); ok { 62 | unixListener.SetUnlinkOnClose(true) 63 | } 64 | 65 | uapi := &UAPIListener{ 66 | listener: listener, 67 | connNew: make(chan net.Conn, 1), 68 | connErr: make(chan error, 1), 69 | } 70 | 71 | // watch for deletion of socket 72 | 73 | socketPath := sockPath(name) 74 | 75 | uapi.inotifyFd, err = unix.InotifyInit() 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | _, err = unix.InotifyAddWatch( 81 | uapi.inotifyFd, 82 | socketPath, 83 | unix.IN_ATTRIB| 84 | unix.IN_DELETE| 85 | unix.IN_DELETE_SELF, 86 | ) 87 | 88 | if err != nil { 89 | return nil, err 90 | } 91 | 92 | uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd) 93 | if err != nil { 94 | unix.Close(uapi.inotifyFd) 95 | return nil, err 96 | } 97 | 98 | go func(l *UAPIListener) { 99 | var buff [0]byte 100 | for { 101 | defer uapi.inotifyRWCancel.Close() 102 | // start with lstat to avoid race condition 103 | if _, err := os.Lstat(socketPath); os.IsNotExist(err) { 104 | l.connErr <- err 105 | return 106 | } 107 | _, err := uapi.inotifyRWCancel.Read(buff[:]) 108 | if err != nil { 109 | l.connErr <- err 110 | return 111 | } 112 | } 113 | }(uapi) 114 | 115 | // watch for new connections 116 | 117 | go func(l *UAPIListener) { 118 | for { 119 | conn, err := l.listener.Accept() 120 | if err != nil { 121 | l.connErr <- err 122 | break 123 | } 124 | l.connNew <- conn 125 | } 126 | }(uapi) 127 | 128 | return uapi, nil 129 | } 130 | -------------------------------------------------------------------------------- /ipc/uapi_unix.go: -------------------------------------------------------------------------------- 1 | //go:build linux || darwin || freebsd || openbsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package ipc 9 | 10 | import ( 11 | "errors" 12 | "fmt" 13 | "net" 14 | "os" 15 | 16 | "golang.org/x/sys/unix" 17 | ) 18 | 19 | const ( 20 | IpcErrorIO = -int64(unix.EIO) 21 | IpcErrorProtocol = -int64(unix.EPROTO) 22 | IpcErrorInvalid = -int64(unix.EINVAL) 23 | IpcErrorPortInUse = -int64(unix.EADDRINUSE) 24 | IpcErrorUnknown = -55 // ENOANO 25 | ) 26 | 27 | // socketDirectory is variable because it is modified by a linker 28 | // flag in wireguard-android. 29 | var socketDirectory = "/var/run/wireguard" 30 | 31 | func sockPath(iface string) string { 32 | return fmt.Sprintf("%s/%s.sock", socketDirectory, iface) 33 | } 34 | 35 | func UAPIOpen(name string) (*os.File, error) { 36 | if err := os.MkdirAll(socketDirectory, 0o755); err != nil { 37 | return nil, err 38 | } 39 | 40 | socketPath := sockPath(name) 41 | addr, err := net.ResolveUnixAddr("unix", socketPath) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | oldUmask := unix.Umask(0o077) 47 | defer unix.Umask(oldUmask) 48 | 49 | listener, err := net.ListenUnix("unix", addr) 50 | if err == nil { 51 | return listener.File() 52 | } 53 | 54 | // Test socket, if not in use cleanup and try again. 55 | if _, err := net.Dial("unix", socketPath); err == nil { 56 | return nil, errors.New("unix socket in use") 57 | } 58 | if err := os.Remove(socketPath); err != nil { 59 | return nil, err 60 | } 61 | listener, err = net.ListenUnix("unix", addr) 62 | if err != nil { 63 | return nil, err 64 | } 65 | return listener.File() 66 | } 67 | -------------------------------------------------------------------------------- /ipc/uapi_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ipc 7 | 8 | import ( 9 | "net" 10 | 11 | "golang.org/x/sys/windows" 12 | "golang.zx2c4.com/wireguard/ipc/namedpipe" 13 | ) 14 | 15 | // TODO: replace these with actual standard windows error numbers from the win package 16 | const ( 17 | IpcErrorIO = -int64(5) 18 | IpcErrorProtocol = -int64(71) 19 | IpcErrorInvalid = -int64(22) 20 | IpcErrorPortInUse = -int64(98) 21 | IpcErrorUnknown = -int64(55) 22 | ) 23 | 24 | type UAPIListener struct { 25 | listener net.Listener // unix socket listener 26 | connNew chan net.Conn 27 | connErr chan error 28 | kqueueFd int 29 | keventFd int 30 | } 31 | 32 | func (l *UAPIListener) Accept() (net.Conn, error) { 33 | for { 34 | select { 35 | case conn := <-l.connNew: 36 | return conn, nil 37 | 38 | case err := <-l.connErr: 39 | return nil, err 40 | } 41 | } 42 | } 43 | 44 | func (l *UAPIListener) Close() error { 45 | return l.listener.Close() 46 | } 47 | 48 | func (l *UAPIListener) Addr() net.Addr { 49 | return l.listener.Addr() 50 | } 51 | 52 | var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR 53 | 54 | func init() { 55 | var err error 56 | UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)") 57 | if err != nil { 58 | panic(err) 59 | } 60 | } 61 | 62 | func UAPIListen(name string) (net.Listener, error) { 63 | listener, err := (&namedpipe.ListenConfig{ 64 | SecurityDescriptor: UAPISecurityDescriptor, 65 | }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | uapi := &UAPIListener{ 71 | listener: listener, 72 | connNew: make(chan net.Conn, 1), 73 | connErr: make(chan error, 1), 74 | } 75 | 76 | go func(l *UAPIListener) { 77 | for { 78 | conn, err := l.listener.Accept() 79 | if err != nil { 80 | l.connErr <- err 81 | break 82 | } 83 | l.connNew <- conn 84 | } 85 | }(uapi) 86 | 87 | return uapi, nil 88 | } 89 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package main 9 | 10 | import ( 11 | "fmt" 12 | "os" 13 | "os/signal" 14 | "runtime" 15 | "strconv" 16 | "syscall" 17 | 18 | "golang.zx2c4.com/wireguard/conn" 19 | "golang.zx2c4.com/wireguard/device" 20 | "golang.zx2c4.com/wireguard/ipc" 21 | "golang.zx2c4.com/wireguard/tun" 22 | ) 23 | 24 | const ( 25 | ExitSetupSuccess = 0 26 | ExitSetupFailed = 1 27 | ) 28 | 29 | const ( 30 | ENV_WG_TUN_FD = "WG_TUN_FD" 31 | ENV_WG_UAPI_FD = "WG_UAPI_FD" 32 | ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" 33 | ) 34 | 35 | func printUsage() { 36 | fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) 37 | } 38 | 39 | func warning() { 40 | switch runtime.GOOS { 41 | case "linux", "freebsd", "openbsd": 42 | if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { 43 | return 44 | } 45 | default: 46 | return 47 | } 48 | 49 | fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐") 50 | fmt.Fprintln(os.Stderr, "│ │") 51 | fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │") 52 | fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │") 53 | fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │") 54 | fmt.Fprintln(os.Stderr, "│ please visit: │") 55 | fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │") 56 | fmt.Fprintln(os.Stderr, "│ │") 57 | fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘") 58 | } 59 | 60 | func main() { 61 | if len(os.Args) == 2 && os.Args[1] == "--version" { 62 | fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld .\n", Version, runtime.GOOS, runtime.GOARCH) 63 | return 64 | } 65 | 66 | warning() 67 | 68 | var foreground bool 69 | var interfaceName string 70 | if len(os.Args) < 2 || len(os.Args) > 3 { 71 | printUsage() 72 | return 73 | } 74 | 75 | switch os.Args[1] { 76 | 77 | case "-f", "--foreground": 78 | foreground = true 79 | if len(os.Args) != 3 { 80 | printUsage() 81 | return 82 | } 83 | interfaceName = os.Args[2] 84 | 85 | default: 86 | foreground = false 87 | if len(os.Args) != 2 { 88 | printUsage() 89 | return 90 | } 91 | interfaceName = os.Args[1] 92 | } 93 | 94 | if !foreground { 95 | foreground = os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" 96 | } 97 | 98 | // get log level (default: info) 99 | 100 | logLevel := func() int { 101 | switch os.Getenv("LOG_LEVEL") { 102 | case "verbose", "debug": 103 | return device.LogLevelVerbose 104 | case "error": 105 | return device.LogLevelError 106 | case "silent": 107 | return device.LogLevelSilent 108 | } 109 | return device.LogLevelError 110 | }() 111 | 112 | // open TUN device (or use supplied fd) 113 | 114 | tun, err := func() (tun.Device, error) { 115 | tunFdStr := os.Getenv(ENV_WG_TUN_FD) 116 | if tunFdStr == "" { 117 | return tun.CreateTUN(interfaceName, device.DefaultMTU) 118 | } 119 | 120 | // construct tun device from supplied fd 121 | 122 | fd, err := strconv.ParseUint(tunFdStr, 10, 32) 123 | if err != nil { 124 | return nil, err 125 | } 126 | 127 | err = syscall.SetNonblock(int(fd), true) 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | file := os.NewFile(uintptr(fd), "") 133 | return tun.CreateTUNFromFile(file, device.DefaultMTU) 134 | }() 135 | 136 | if err == nil { 137 | realInterfaceName, err2 := tun.Name() 138 | if err2 == nil { 139 | interfaceName = realInterfaceName 140 | } 141 | } 142 | 143 | logger := device.NewLogger( 144 | logLevel, 145 | fmt.Sprintf("(%s) ", interfaceName), 146 | ) 147 | 148 | logger.Verbosef("Starting wireguard-go version %s", Version) 149 | 150 | if err != nil { 151 | logger.Errorf("Failed to create TUN device: %v", err) 152 | os.Exit(ExitSetupFailed) 153 | } 154 | 155 | // open UAPI file (or use supplied fd) 156 | 157 | fileUAPI, err := func() (*os.File, error) { 158 | uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) 159 | if uapiFdStr == "" { 160 | return ipc.UAPIOpen(interfaceName) 161 | } 162 | 163 | // use supplied fd 164 | 165 | fd, err := strconv.ParseUint(uapiFdStr, 10, 32) 166 | if err != nil { 167 | return nil, err 168 | } 169 | 170 | return os.NewFile(uintptr(fd), ""), nil 171 | }() 172 | if err != nil { 173 | logger.Errorf("UAPI listen error: %v", err) 174 | os.Exit(ExitSetupFailed) 175 | return 176 | } 177 | // daemonize the process 178 | 179 | if !foreground { 180 | env := os.Environ() 181 | env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) 182 | env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) 183 | env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND)) 184 | files := [3]*os.File{} 185 | if os.Getenv("LOG_LEVEL") != "" && logLevel != device.LogLevelSilent { 186 | files[0], _ = os.Open(os.DevNull) 187 | files[1] = os.Stdout 188 | files[2] = os.Stderr 189 | } else { 190 | files[0], _ = os.Open(os.DevNull) 191 | files[1], _ = os.Open(os.DevNull) 192 | files[2], _ = os.Open(os.DevNull) 193 | } 194 | attr := &os.ProcAttr{ 195 | Files: []*os.File{ 196 | files[0], // stdin 197 | files[1], // stdout 198 | files[2], // stderr 199 | tun.File(), 200 | fileUAPI, 201 | }, 202 | Dir: ".", 203 | Env: env, 204 | } 205 | 206 | path, err := os.Executable() 207 | if err != nil { 208 | logger.Errorf("Failed to determine executable: %v", err) 209 | os.Exit(ExitSetupFailed) 210 | } 211 | 212 | process, err := os.StartProcess( 213 | path, 214 | os.Args, 215 | attr, 216 | ) 217 | if err != nil { 218 | logger.Errorf("Failed to daemonize: %v", err) 219 | os.Exit(ExitSetupFailed) 220 | } 221 | process.Release() 222 | return 223 | } 224 | 225 | device := device.NewDevice(tun, conn.NewDefaultBind(), logger) 226 | 227 | logger.Verbosef("Device started") 228 | 229 | errs := make(chan error) 230 | term := make(chan os.Signal, 1) 231 | 232 | uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) 233 | if err != nil { 234 | logger.Errorf("Failed to listen on uapi socket: %v", err) 235 | os.Exit(ExitSetupFailed) 236 | } 237 | 238 | go func() { 239 | for { 240 | conn, err := uapi.Accept() 241 | if err != nil { 242 | errs <- err 243 | return 244 | } 245 | go device.IpcHandle(conn) 246 | } 247 | }() 248 | 249 | logger.Verbosef("UAPI listener started") 250 | 251 | // wait for program to terminate 252 | 253 | signal.Notify(term, syscall.SIGTERM) 254 | signal.Notify(term, os.Interrupt) 255 | 256 | select { 257 | case <-term: 258 | case <-errs: 259 | case <-device.Wait(): 260 | } 261 | 262 | // clean up 263 | 264 | uapi.Close() 265 | device.Close() 266 | 267 | logger.Verbosef("Shutting down") 268 | } 269 | -------------------------------------------------------------------------------- /main_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package main 7 | 8 | import ( 9 | "fmt" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | 14 | "golang.zx2c4.com/wireguard/conn" 15 | "golang.zx2c4.com/wireguard/device" 16 | "golang.zx2c4.com/wireguard/ipc" 17 | 18 | "golang.zx2c4.com/wireguard/tun" 19 | ) 20 | 21 | const ( 22 | ExitSetupSuccess = 0 23 | ExitSetupFailed = 1 24 | ) 25 | 26 | func main() { 27 | if len(os.Args) != 2 { 28 | os.Exit(ExitSetupFailed) 29 | } 30 | interfaceName := os.Args[1] 31 | 32 | fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is , which includes this code as a module.") 33 | 34 | logger := device.NewLogger( 35 | device.LogLevelVerbose, 36 | fmt.Sprintf("(%s) ", interfaceName), 37 | ) 38 | logger.Verbosef("Starting wireguard-go version %s", Version) 39 | 40 | tun, err := tun.CreateTUN(interfaceName, 0) 41 | if err == nil { 42 | realInterfaceName, err2 := tun.Name() 43 | if err2 == nil { 44 | interfaceName = realInterfaceName 45 | } 46 | } else { 47 | logger.Errorf("Failed to create TUN device: %v", err) 48 | os.Exit(ExitSetupFailed) 49 | } 50 | 51 | device := device.NewDevice(tun, conn.NewDefaultBind(), logger) 52 | err = device.Up() 53 | if err != nil { 54 | logger.Errorf("Failed to bring up device: %v", err) 55 | os.Exit(ExitSetupFailed) 56 | } 57 | logger.Verbosef("Device started") 58 | 59 | uapi, err := ipc.UAPIListen(interfaceName) 60 | if err != nil { 61 | logger.Errorf("Failed to listen on uapi socket: %v", err) 62 | os.Exit(ExitSetupFailed) 63 | } 64 | 65 | errs := make(chan error) 66 | term := make(chan os.Signal, 1) 67 | 68 | go func() { 69 | for { 70 | conn, err := uapi.Accept() 71 | if err != nil { 72 | errs <- err 73 | return 74 | } 75 | go device.IpcHandle(conn) 76 | } 77 | }() 78 | logger.Verbosef("UAPI listener started") 79 | 80 | // wait for program to terminate 81 | 82 | signal.Notify(term, os.Interrupt) 83 | signal.Notify(term, os.Kill) 84 | signal.Notify(term, syscall.SIGTERM) 85 | 86 | select { 87 | case <-term: 88 | case <-errs: 89 | case <-device.Wait(): 90 | } 91 | 92 | // clean up 93 | 94 | uapi.Close() 95 | device.Close() 96 | 97 | logger.Verbosef("Shutting down") 98 | } 99 | -------------------------------------------------------------------------------- /ratelimiter/ratelimiter.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ratelimiter 7 | 8 | import ( 9 | "net/netip" 10 | "sync" 11 | "time" 12 | ) 13 | 14 | const ( 15 | packetsPerSecond = 20 16 | packetsBurstable = 5 17 | garbageCollectTime = time.Second 18 | packetCost = 1000000000 / packetsPerSecond 19 | maxTokens = packetCost * packetsBurstable 20 | ) 21 | 22 | type RatelimiterEntry struct { 23 | mu sync.Mutex 24 | lastTime time.Time 25 | tokens int64 26 | } 27 | 28 | type Ratelimiter struct { 29 | mu sync.RWMutex 30 | timeNow func() time.Time 31 | 32 | stopReset chan struct{} // send to reset, close to stop 33 | table map[netip.Addr]*RatelimiterEntry 34 | } 35 | 36 | func (rate *Ratelimiter) Close() { 37 | rate.mu.Lock() 38 | defer rate.mu.Unlock() 39 | 40 | if rate.stopReset != nil { 41 | close(rate.stopReset) 42 | } 43 | } 44 | 45 | func (rate *Ratelimiter) Init() { 46 | rate.mu.Lock() 47 | defer rate.mu.Unlock() 48 | 49 | if rate.timeNow == nil { 50 | rate.timeNow = time.Now 51 | } 52 | 53 | // stop any ongoing garbage collection routine 54 | if rate.stopReset != nil { 55 | close(rate.stopReset) 56 | } 57 | 58 | rate.stopReset = make(chan struct{}) 59 | rate.table = make(map[netip.Addr]*RatelimiterEntry) 60 | 61 | stopReset := rate.stopReset // store in case Init is called again. 62 | 63 | // Start garbage collection routine. 64 | go func() { 65 | ticker := time.NewTicker(time.Second) 66 | ticker.Stop() 67 | for { 68 | select { 69 | case _, ok := <-stopReset: 70 | ticker.Stop() 71 | if !ok { 72 | return 73 | } 74 | ticker = time.NewTicker(time.Second) 75 | case <-ticker.C: 76 | if rate.cleanup() { 77 | ticker.Stop() 78 | } 79 | } 80 | } 81 | }() 82 | } 83 | 84 | func (rate *Ratelimiter) cleanup() (empty bool) { 85 | rate.mu.Lock() 86 | defer rate.mu.Unlock() 87 | 88 | for key, entry := range rate.table { 89 | entry.mu.Lock() 90 | if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 91 | delete(rate.table, key) 92 | } 93 | entry.mu.Unlock() 94 | } 95 | 96 | return len(rate.table) == 0 97 | } 98 | 99 | func (rate *Ratelimiter) Allow(ip netip.Addr) bool { 100 | var entry *RatelimiterEntry 101 | // lookup entry 102 | rate.mu.RLock() 103 | entry = rate.table[ip] 104 | rate.mu.RUnlock() 105 | 106 | // make new entry if not found 107 | if entry == nil { 108 | entry = new(RatelimiterEntry) 109 | entry.tokens = maxTokens - packetCost 110 | entry.lastTime = rate.timeNow() 111 | rate.mu.Lock() 112 | rate.table[ip] = entry 113 | if len(rate.table) == 1 { 114 | rate.stopReset <- struct{}{} 115 | } 116 | rate.mu.Unlock() 117 | return true 118 | } 119 | 120 | // add tokens to entry 121 | entry.mu.Lock() 122 | now := rate.timeNow() 123 | entry.tokens += now.Sub(entry.lastTime).Nanoseconds() 124 | entry.lastTime = now 125 | if entry.tokens > maxTokens { 126 | entry.tokens = maxTokens 127 | } 128 | 129 | // subtract cost of packet 130 | if entry.tokens > packetCost { 131 | entry.tokens -= packetCost 132 | entry.mu.Unlock() 133 | return true 134 | } 135 | entry.mu.Unlock() 136 | return false 137 | } 138 | -------------------------------------------------------------------------------- /ratelimiter/ratelimiter_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package ratelimiter 7 | 8 | import ( 9 | "net/netip" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | type result struct { 15 | allowed bool 16 | text string 17 | wait time.Duration 18 | } 19 | 20 | func TestRatelimiter(t *testing.T) { 21 | var rate Ratelimiter 22 | var expectedResults []result 23 | 24 | nano := func(nano int64) time.Duration { 25 | return time.Nanosecond * time.Duration(nano) 26 | } 27 | 28 | add := func(res result) { 29 | expectedResults = append( 30 | expectedResults, 31 | res, 32 | ) 33 | } 34 | 35 | for i := 0; i < packetsBurstable; i++ { 36 | add(result{ 37 | allowed: true, 38 | text: "initial burst", 39 | }) 40 | } 41 | 42 | add(result{ 43 | allowed: false, 44 | text: "after burst", 45 | }) 46 | 47 | add(result{ 48 | allowed: true, 49 | wait: nano(time.Second.Nanoseconds() / packetsPerSecond), 50 | text: "filling tokens for single packet", 51 | }) 52 | 53 | add(result{ 54 | allowed: false, 55 | text: "not having refilled enough", 56 | }) 57 | 58 | add(result{ 59 | allowed: true, 60 | wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)), 61 | text: "filling tokens for two packet burst", 62 | }) 63 | 64 | add(result{ 65 | allowed: true, 66 | text: "second packet in 2 packet burst", 67 | }) 68 | 69 | add(result{ 70 | allowed: false, 71 | text: "packet following 2 packet burst", 72 | }) 73 | 74 | ips := []netip.Addr{ 75 | netip.MustParseAddr("127.0.0.1"), 76 | netip.MustParseAddr("192.168.1.1"), 77 | netip.MustParseAddr("172.167.2.3"), 78 | netip.MustParseAddr("97.231.252.215"), 79 | netip.MustParseAddr("248.97.91.167"), 80 | netip.MustParseAddr("188.208.233.47"), 81 | netip.MustParseAddr("104.2.183.179"), 82 | netip.MustParseAddr("72.129.46.120"), 83 | netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), 84 | netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), 85 | netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), 86 | netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), 87 | netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), 88 | netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), 89 | } 90 | 91 | now := time.Now() 92 | rate.timeNow = func() time.Time { 93 | return now 94 | } 95 | defer func() { 96 | // Lock to avoid data race with cleanup goroutine from Init. 97 | rate.mu.Lock() 98 | defer rate.mu.Unlock() 99 | 100 | rate.timeNow = time.Now 101 | }() 102 | timeSleep := func(d time.Duration) { 103 | now = now.Add(d + 1) 104 | rate.cleanup() 105 | } 106 | 107 | rate.Init() 108 | defer rate.Close() 109 | 110 | for i, res := range expectedResults { 111 | timeSleep(res.wait) 112 | for _, ip := range ips { 113 | allowed := rate.Allow(ip) 114 | if allowed != res.allowed { 115 | t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed) 116 | } 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /replay/replay.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. 7 | package replay 8 | 9 | type block uint64 10 | 11 | const ( 12 | blockBitLog = 6 // 1<<6 == 64 bits 13 | blockBits = 1 << blockBitLog // must be power of 2 14 | ringBlocks = 1 << 7 // must be power of 2 15 | windowSize = (ringBlocks - 1) * blockBits 16 | blockMask = ringBlocks - 1 17 | bitMask = blockBits - 1 18 | ) 19 | 20 | // A Filter rejects replayed messages by checking if message counter value is 21 | // within a sliding window of previously received messages. 22 | // The zero value for Filter is an empty filter ready to use. 23 | // Filters are unsafe for concurrent use. 24 | type Filter struct { 25 | last uint64 26 | ring [ringBlocks]block 27 | } 28 | 29 | // Reset resets the filter to empty state. 30 | func (f *Filter) Reset() { 31 | f.last = 0 32 | f.ring[0] = 0 33 | } 34 | 35 | // ValidateCounter checks if the counter should be accepted. 36 | // Overlimit counters (>= limit) are always rejected. 37 | func (f *Filter) ValidateCounter(counter, limit uint64) bool { 38 | if counter >= limit { 39 | return false 40 | } 41 | indexBlock := counter >> blockBitLog 42 | if counter > f.last { // move window forward 43 | current := f.last >> blockBitLog 44 | diff := indexBlock - current 45 | if diff > ringBlocks { 46 | diff = ringBlocks // cap diff to clear the whole ring 47 | } 48 | for i := current + 1; i <= current+diff; i++ { 49 | f.ring[i&blockMask] = 0 50 | } 51 | f.last = counter 52 | } else if f.last-counter > windowSize { // behind current window 53 | return false 54 | } 55 | // check and set bit 56 | indexBlock &= blockMask 57 | indexBit := counter & bitMask 58 | old := f.ring[indexBlock] 59 | new := old | 1< 0; i-- { 91 | T(i, true) 92 | } 93 | 94 | t.Log("Bulk test 4") 95 | filter.Reset() 96 | testNumber = 0 97 | for i := uint64(windowSize + 2); i > 1; i-- { 98 | T(i, true) 99 | } 100 | T(0, false) 101 | 102 | t.Log("Bulk test 5") 103 | filter.Reset() 104 | testNumber = 0 105 | for i := uint64(windowSize); i > 0; i-- { 106 | T(i, true) 107 | } 108 | T(windowSize+1, true) 109 | T(0, false) 110 | 111 | t.Log("Bulk test 6") 112 | filter.Reset() 113 | testNumber = 0 114 | for i := uint64(windowSize); i > 0; i-- { 115 | T(i, true) 116 | } 117 | T(0, true) 118 | T(windowSize+1, true) 119 | } 120 | -------------------------------------------------------------------------------- /rwcancel/rwcancel.go: -------------------------------------------------------------------------------- 1 | //go:build !windows && !js 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | // Package rwcancel implements cancelable read/write operations on 9 | // a file descriptor. 10 | package rwcancel 11 | 12 | import ( 13 | "errors" 14 | "os" 15 | "syscall" 16 | 17 | "golang.org/x/sys/unix" 18 | ) 19 | 20 | type RWCancel struct { 21 | fd int 22 | closingReader *os.File 23 | closingWriter *os.File 24 | } 25 | 26 | func NewRWCancel(fd int) (*RWCancel, error) { 27 | err := unix.SetNonblock(fd, true) 28 | if err != nil { 29 | return nil, err 30 | } 31 | rwcancel := RWCancel{fd: fd} 32 | 33 | rwcancel.closingReader, rwcancel.closingWriter, err = os.Pipe() 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | return &rwcancel, nil 39 | } 40 | 41 | func RetryAfterError(err error) bool { 42 | return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) 43 | } 44 | 45 | func (rw *RWCancel) ReadyRead() bool { 46 | closeFd := int32(rw.closingReader.Fd()) 47 | 48 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}} 49 | var err error 50 | for { 51 | _, err = unix.Poll(pollFds, -1) 52 | if err == nil || !RetryAfterError(err) { 53 | break 54 | } 55 | } 56 | if err != nil { 57 | return false 58 | } 59 | if pollFds[1].Revents != 0 { 60 | return false 61 | } 62 | return pollFds[0].Revents != 0 63 | } 64 | 65 | func (rw *RWCancel) ReadyWrite() bool { 66 | closeFd := int32(rw.closingReader.Fd()) 67 | pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}} 68 | var err error 69 | for { 70 | _, err = unix.Poll(pollFds, -1) 71 | if err == nil || !RetryAfterError(err) { 72 | break 73 | } 74 | } 75 | if err != nil { 76 | return false 77 | } 78 | 79 | if pollFds[1].Revents != 0 { 80 | return false 81 | } 82 | return pollFds[0].Revents != 0 83 | } 84 | 85 | func (rw *RWCancel) Read(p []byte) (n int, err error) { 86 | for { 87 | n, err := unix.Read(rw.fd, p) 88 | if err == nil || !RetryAfterError(err) { 89 | return n, err 90 | } 91 | if !rw.ReadyRead() { 92 | return 0, os.ErrClosed 93 | } 94 | } 95 | } 96 | 97 | func (rw *RWCancel) Write(p []byte) (n int, err error) { 98 | for { 99 | n, err := unix.Write(rw.fd, p) 100 | if err == nil || !RetryAfterError(err) { 101 | return n, err 102 | } 103 | if !rw.ReadyWrite() { 104 | return 0, os.ErrClosed 105 | } 106 | } 107 | } 108 | 109 | func (rw *RWCancel) Cancel() (err error) { 110 | _, err = rw.closingWriter.Write([]byte{0}) 111 | return 112 | } 113 | 114 | func (rw *RWCancel) Close() { 115 | rw.closingReader.Close() 116 | rw.closingWriter.Close() 117 | } 118 | -------------------------------------------------------------------------------- /rwcancel/rwcancel_stub.go: -------------------------------------------------------------------------------- 1 | //go:build windows || js 2 | 3 | // SPDX-License-Identifier: MIT 4 | 5 | package rwcancel 6 | 7 | type RWCancel struct{} 8 | 9 | func (*RWCancel) Cancel() {} 10 | -------------------------------------------------------------------------------- /tai64n/tai64n.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tai64n 7 | 8 | import ( 9 | "bytes" 10 | "encoding/binary" 11 | "time" 12 | ) 13 | 14 | const ( 15 | TimestampSize = 12 16 | base = uint64(0x400000000000000a) 17 | whitenerMask = uint32(0x1000000 - 1) 18 | ) 19 | 20 | type Timestamp [TimestampSize]byte 21 | 22 | func stamp(t time.Time) Timestamp { 23 | var tai64n Timestamp 24 | secs := base + uint64(t.Unix()) 25 | nano := uint32(t.Nanosecond()) &^ whitenerMask 26 | binary.BigEndian.PutUint64(tai64n[:], secs) 27 | binary.BigEndian.PutUint32(tai64n[8:], nano) 28 | return tai64n 29 | } 30 | 31 | func Now() Timestamp { 32 | return stamp(time.Now()) 33 | } 34 | 35 | func (t1 Timestamp) After(t2 Timestamp) bool { 36 | return bytes.Compare(t1[:], t2[:]) > 0 37 | } 38 | 39 | func (t Timestamp) String() string { 40 | return time.Unix(int64(binary.BigEndian.Uint64(t[:8])-base), int64(binary.BigEndian.Uint32(t[8:12]))).String() 41 | } 42 | -------------------------------------------------------------------------------- /tai64n/tai64n_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tai64n 7 | 8 | import ( 9 | "testing" 10 | "time" 11 | ) 12 | 13 | // Test that timestamps are monotonic as required by Wireguard and that 14 | // nanosecond-level information is whitened to prevent side channel attacks. 15 | func TestMonotonic(t *testing.T) { 16 | startTime := time.Unix(0, 123456789) // a nontrivial bit pattern 17 | // Whitening should reduce timestamp granularity 18 | // to more than 10 but fewer than 20 milliseconds. 19 | tests := []struct { 20 | name string 21 | t1, t2 time.Time 22 | wantAfter bool 23 | }{ 24 | {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false}, 25 | {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false}, 26 | {"after_1_ms", startTime, startTime.Add(time.Millisecond), false}, 27 | {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false}, 28 | {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true}, 29 | } 30 | 31 | for _, tt := range tests { 32 | t.Run(tt.name, func(t *testing.T) { 33 | ts1, ts2 := stamp(tt.t1), stamp(tt.t2) 34 | got := ts2.After(ts1) 35 | if got != tt.wantAfter { 36 | t.Errorf("after = %v; want %v", got, tt.wantAfter) 37 | } 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /tun/alignment_windows_test.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "reflect" 10 | "testing" 11 | "unsafe" 12 | ) 13 | 14 | func checkAlignment(t *testing.T, name string, offset uintptr) { 15 | t.Helper() 16 | if offset%8 != 0 { 17 | t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) 18 | } 19 | } 20 | 21 | // TestRateJugglerAlignment checks that atomically-accessed fields are 22 | // aligned to 64-bit boundaries, as required by the atomic package. 23 | // 24 | // Unfortunately, violating this rule on 32-bit platforms results in a 25 | // hard segfault at runtime. 26 | func TestRateJugglerAlignment(t *testing.T) { 27 | var r rateJuggler 28 | 29 | typ := reflect.TypeOf(&r).Elem() 30 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 31 | for i := 0; i < typ.NumField(); i++ { 32 | field := typ.Field(i) 33 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 34 | field.Name, 35 | field.Offset, 36 | field.Type.Size(), 37 | field.Type.Align(), 38 | ) 39 | } 40 | 41 | checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current)) 42 | checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount)) 43 | checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime)) 44 | } 45 | 46 | // TestNativeTunAlignment checks that atomically-accessed fields are 47 | // aligned to 64-bit boundaries, as required by the atomic package. 48 | // 49 | // Unfortunately, violating this rule on 32-bit platforms results in a 50 | // hard segfault at runtime. 51 | func TestNativeTunAlignment(t *testing.T) { 52 | var tun NativeTun 53 | 54 | typ := reflect.TypeOf(&tun).Elem() 55 | t.Logf("Peer type size: %d, with fields:", typ.Size()) 56 | for i := 0; i < typ.NumField(); i++ { 57 | field := typ.Field(i) 58 | t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", 59 | field.Name, 60 | field.Offset, 61 | field.Type.Size(), 62 | field.Type.Align(), 63 | ) 64 | } 65 | 66 | checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate)) 67 | } 68 | -------------------------------------------------------------------------------- /tun/netstack/examples/http_client.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | /* SPDX-License-Identifier: MIT 5 | * 6 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 7 | */ 8 | 9 | package main 10 | 11 | import ( 12 | "io" 13 | "log" 14 | "net/http" 15 | "net/netip" 16 | 17 | "golang.zx2c4.com/wireguard/conn" 18 | "golang.zx2c4.com/wireguard/device" 19 | "golang.zx2c4.com/wireguard/tun/netstack" 20 | ) 21 | 22 | func main() { 23 | tun, tnet, err := netstack.CreateNetTUN( 24 | []netip.Addr{netip.MustParseAddr("192.168.4.29")}, 25 | []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 26 | 1420) 27 | if err != nil { 28 | log.Panic(err) 29 | } 30 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 31 | dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f 32 | public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b 33 | endpoint=163.172.161.0:12912 34 | allowed_ip=0.0.0.0/0 35 | `) 36 | err = dev.Up() 37 | if err != nil { 38 | log.Panic(err) 39 | } 40 | 41 | client := http.Client{ 42 | Transport: &http.Transport{ 43 | DialContext: tnet.DialContext, 44 | }, 45 | } 46 | resp, err := client.Get("https://www.zx2c4.com/ip") 47 | if err != nil { 48 | log.Panic(err) 49 | } 50 | body, err := io.ReadAll(resp.Body) 51 | if err != nil { 52 | log.Panic(err) 53 | } 54 | log.Println(string(body)) 55 | } 56 | -------------------------------------------------------------------------------- /tun/netstack/examples/http_server.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | /* SPDX-License-Identifier: MIT 5 | * 6 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 7 | */ 8 | 9 | package main 10 | 11 | import ( 12 | "io" 13 | "log" 14 | "net" 15 | "net/http" 16 | "net/netip" 17 | 18 | "golang.zx2c4.com/wireguard/conn" 19 | "golang.zx2c4.com/wireguard/device" 20 | "golang.zx2c4.com/wireguard/tun/netstack" 21 | ) 22 | 23 | func main() { 24 | tun, tnet, err := netstack.CreateNetTUN( 25 | []netip.Addr{netip.MustParseAddr("192.168.4.29")}, 26 | []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, 27 | 1420, 28 | ) 29 | if err != nil { 30 | log.Panic(err) 31 | } 32 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 33 | dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f 34 | public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b 35 | endpoint=163.172.161.0:12912 36 | allowed_ip=0.0.0.0/0 37 | persistent_keepalive_interval=25 38 | `) 39 | dev.Up() 40 | listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80}) 41 | if err != nil { 42 | log.Panicln(err) 43 | } 44 | http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { 45 | log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent()) 46 | io.WriteString(writer, "Hello from userspace TCP!") 47 | }) 48 | err = http.Serve(listener, nil) 49 | if err != nil { 50 | log.Panicln(err) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /tun/netstack/examples/ping_client.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | /* SPDX-License-Identifier: MIT 5 | * 6 | * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 7 | */ 8 | 9 | package main 10 | 11 | import ( 12 | "bytes" 13 | "log" 14 | "math/rand" 15 | "net/netip" 16 | "time" 17 | 18 | "golang.org/x/net/icmp" 19 | "golang.org/x/net/ipv4" 20 | 21 | "golang.zx2c4.com/wireguard/conn" 22 | "golang.zx2c4.com/wireguard/device" 23 | "golang.zx2c4.com/wireguard/tun/netstack" 24 | ) 25 | 26 | func main() { 27 | tun, tnet, err := netstack.CreateNetTUN( 28 | []netip.Addr{netip.MustParseAddr("192.168.4.29")}, 29 | []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 30 | 1420) 31 | if err != nil { 32 | log.Panic(err) 33 | } 34 | dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) 35 | dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f 36 | public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b 37 | endpoint=163.172.161.0:12912 38 | allowed_ip=0.0.0.0/0 39 | `) 40 | err = dev.Up() 41 | if err != nil { 42 | log.Panic(err) 43 | } 44 | 45 | socket, err := tnet.Dial("ping4", "zx2c4.com") 46 | if err != nil { 47 | log.Panic(err) 48 | } 49 | requestPing := icmp.Echo{ 50 | Seq: rand.Intn(1 << 16), 51 | Data: []byte("gopher burrow"), 52 | } 53 | icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) 54 | socket.SetReadDeadline(time.Now().Add(time.Second * 10)) 55 | start := time.Now() 56 | _, err = socket.Write(icmpBytes) 57 | if err != nil { 58 | log.Panic(err) 59 | } 60 | n, err := socket.Read(icmpBytes[:]) 61 | if err != nil { 62 | log.Panic(err) 63 | } 64 | replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) 65 | if err != nil { 66 | log.Panic(err) 67 | } 68 | replyPing, ok := replyPacket.Body.(*icmp.Echo) 69 | if !ok { 70 | log.Panicf("invalid reply type: %v", replyPacket) 71 | } 72 | if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { 73 | log.Panicf("invalid ping reply: %v", replyPing) 74 | } 75 | log.Printf("Ping latency: %v", time.Since(start)) 76 | } 77 | -------------------------------------------------------------------------------- /tun/netstack/go.mod: -------------------------------------------------------------------------------- 1 | module golang.zx2c4.com/wireguard/tun/netstack 2 | 3 | go 1.18 4 | 5 | require ( 6 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f 7 | golang.zx2c4.com/wireguard v0.0.0-20220316235147-5aff28b14c24 8 | gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6 9 | ) 10 | 11 | require ( 12 | github.com/google/btree v1.0.1 // indirect 13 | golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd // indirect 14 | golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86 // indirect 15 | golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect 16 | golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect 17 | ) 18 | -------------------------------------------------------------------------------- /tun/operateonfd.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || freebsd 2 | 3 | /* SPDX-License-Identifier: MIT 4 | * 5 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 6 | */ 7 | 8 | package tun 9 | 10 | import ( 11 | "fmt" 12 | ) 13 | 14 | func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) { 15 | sysconn, err := tun.tunFile.SyscallConn() 16 | if err != nil { 17 | tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error()) 18 | return 19 | } 20 | err = sysconn.Control(fn) 21 | if err != nil { 22 | tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error()) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /tun/tun.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "os" 10 | ) 11 | 12 | type Event int 13 | 14 | const ( 15 | EventUp = 1 << iota 16 | EventDown 17 | EventMTUUpdate 18 | ) 19 | 20 | type Device interface { 21 | File() *os.File // returns the file descriptor of the device 22 | Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) 23 | Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) 24 | Flush() error // flush all previous writes to the device 25 | MTU() (int, error) // returns the MTU of the device 26 | Name() (string, error) // fetches and returns the current name 27 | Events() chan Event // returns a constant channel of events related to the device 28 | Close() error // stops the device and closes the event channel 29 | } 30 | -------------------------------------------------------------------------------- /tun/tun_darwin.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "net" 12 | "os" 13 | "sync" 14 | "syscall" 15 | "time" 16 | "unsafe" 17 | 18 | "golang.org/x/net/ipv6" 19 | "golang.org/x/sys/unix" 20 | ) 21 | 22 | const utunControlName = "com.apple.net.utun_control" 23 | 24 | type NativeTun struct { 25 | name string 26 | tunFile *os.File 27 | events chan Event 28 | errors chan error 29 | routeSocket int 30 | closeOnce sync.Once 31 | } 32 | 33 | func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { 34 | for i := 0; i < 20; i++ { 35 | iface, err = net.InterfaceByIndex(index) 36 | if err != nil && errors.Is(err, syscall.ENOMEM) { 37 | time.Sleep(time.Duration(i) * time.Second / 3) 38 | continue 39 | } 40 | return iface, err 41 | } 42 | return nil, err 43 | } 44 | 45 | func (tun *NativeTun) routineRouteListener(tunIfindex int) { 46 | var ( 47 | statusUp bool 48 | statusMTU int 49 | ) 50 | 51 | defer close(tun.events) 52 | 53 | data := make([]byte, os.Getpagesize()) 54 | for { 55 | retry: 56 | n, err := unix.Read(tun.routeSocket, data) 57 | if err != nil { 58 | if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { 59 | goto retry 60 | } 61 | tun.errors <- err 62 | return 63 | } 64 | 65 | if n < 14 { 66 | continue 67 | } 68 | 69 | if data[3 /* type */] != unix.RTM_IFINFO { 70 | continue 71 | } 72 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */]))) 73 | if ifindex != tunIfindex { 74 | continue 75 | } 76 | 77 | iface, err := retryInterfaceByIndex(ifindex) 78 | if err != nil { 79 | tun.errors <- err 80 | return 81 | } 82 | 83 | // Up / Down event 84 | up := (iface.Flags & net.FlagUp) != 0 85 | if up != statusUp && up { 86 | tun.events <- EventUp 87 | } 88 | if up != statusUp && !up { 89 | tun.events <- EventDown 90 | } 91 | statusUp = up 92 | 93 | // MTU changes 94 | if iface.MTU != statusMTU { 95 | tun.events <- EventMTUUpdate 96 | } 97 | statusMTU = iface.MTU 98 | } 99 | } 100 | 101 | func CreateTUN(name string, mtu int) (Device, error) { 102 | ifIndex := -1 103 | if name != "utun" { 104 | _, err := fmt.Sscanf(name, "utun%d", &ifIndex) 105 | if err != nil || ifIndex < 0 { 106 | return nil, fmt.Errorf("Interface name must be utun[0-9]*") 107 | } 108 | } 109 | 110 | fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) 111 | if err != nil { 112 | return nil, err 113 | } 114 | 115 | ctlInfo := &unix.CtlInfo{} 116 | copy(ctlInfo.Name[:], []byte(utunControlName)) 117 | err = unix.IoctlCtlInfo(fd, ctlInfo) 118 | if err != nil { 119 | unix.Close(fd) 120 | return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err) 121 | } 122 | 123 | sc := &unix.SockaddrCtl{ 124 | ID: ctlInfo.Id, 125 | Unit: uint32(ifIndex) + 1, 126 | } 127 | 128 | err = unix.Connect(fd, sc) 129 | if err != nil { 130 | unix.Close(fd) 131 | return nil, err 132 | } 133 | 134 | err = unix.SetNonblock(fd, true) 135 | if err != nil { 136 | unix.Close(fd) 137 | return nil, err 138 | } 139 | tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu) 140 | 141 | if err == nil && name == "utun" { 142 | fname := os.Getenv("WG_TUN_NAME_FILE") 143 | if fname != "" { 144 | os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) 145 | } 146 | } 147 | 148 | return tun, err 149 | } 150 | 151 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 152 | tun := &NativeTun{ 153 | tunFile: file, 154 | events: make(chan Event, 10), 155 | errors: make(chan error, 5), 156 | } 157 | 158 | name, err := tun.Name() 159 | if err != nil { 160 | tun.tunFile.Close() 161 | return nil, err 162 | } 163 | 164 | tunIfindex, err := func() (int, error) { 165 | iface, err := net.InterfaceByName(name) 166 | if err != nil { 167 | return -1, err 168 | } 169 | return iface.Index, nil 170 | }() 171 | if err != nil { 172 | tun.tunFile.Close() 173 | return nil, err 174 | } 175 | 176 | tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) 177 | if err != nil { 178 | tun.tunFile.Close() 179 | return nil, err 180 | } 181 | 182 | go tun.routineRouteListener(tunIfindex) 183 | 184 | if mtu > 0 { 185 | err = tun.setMTU(mtu) 186 | if err != nil { 187 | tun.Close() 188 | return nil, err 189 | } 190 | } 191 | 192 | return tun, nil 193 | } 194 | 195 | func (tun *NativeTun) Name() (string, error) { 196 | var err error 197 | tun.operateOnFd(func(fd uintptr) { 198 | tun.name, err = unix.GetsockoptString( 199 | int(fd), 200 | 2, /* #define SYSPROTO_CONTROL 2 */ 201 | 2, /* #define UTUN_OPT_IFNAME 2 */ 202 | ) 203 | }) 204 | 205 | if err != nil { 206 | return "", fmt.Errorf("GetSockoptString: %w", err) 207 | } 208 | 209 | return tun.name, nil 210 | } 211 | 212 | func (tun *NativeTun) File() *os.File { 213 | return tun.tunFile 214 | } 215 | 216 | func (tun *NativeTun) Events() chan Event { 217 | return tun.events 218 | } 219 | 220 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 221 | select { 222 | case err := <-tun.errors: 223 | return 0, err 224 | default: 225 | buff := buff[offset-4:] 226 | n, err := tun.tunFile.Read(buff[:]) 227 | if n < 4 { 228 | return 0, err 229 | } 230 | return n - 4, err 231 | } 232 | } 233 | 234 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 235 | // reserve space for header 236 | 237 | buff = buff[offset-4:] 238 | 239 | // add packet information header 240 | 241 | buff[0] = 0x00 242 | buff[1] = 0x00 243 | buff[2] = 0x00 244 | 245 | if buff[4]>>4 == ipv6.Version { 246 | buff[3] = unix.AF_INET6 247 | } else { 248 | buff[3] = unix.AF_INET 249 | } 250 | 251 | // write 252 | 253 | return tun.tunFile.Write(buff) 254 | } 255 | 256 | func (tun *NativeTun) Flush() error { 257 | // TODO: can flushing be implemented by buffering and using sendmmsg? 258 | return nil 259 | } 260 | 261 | func (tun *NativeTun) Close() error { 262 | var err1, err2 error 263 | tun.closeOnce.Do(func() { 264 | err1 = tun.tunFile.Close() 265 | if tun.routeSocket != -1 { 266 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) 267 | err2 = unix.Close(tun.routeSocket) 268 | } else if tun.events != nil { 269 | close(tun.events) 270 | } 271 | }) 272 | if err1 != nil { 273 | return err1 274 | } 275 | return err2 276 | } 277 | 278 | func (tun *NativeTun) setMTU(n int) error { 279 | fd, err := socketCloexec( 280 | unix.AF_INET, 281 | unix.SOCK_DGRAM, 282 | 0, 283 | ) 284 | if err != nil { 285 | return err 286 | } 287 | 288 | defer unix.Close(fd) 289 | 290 | var ifr unix.IfreqMTU 291 | copy(ifr.Name[:], tun.name) 292 | ifr.MTU = int32(n) 293 | err = unix.IoctlSetIfreqMTU(fd, &ifr) 294 | if err != nil { 295 | return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err) 296 | } 297 | 298 | return nil 299 | } 300 | 301 | func (tun *NativeTun) MTU() (int, error) { 302 | fd, err := socketCloexec( 303 | unix.AF_INET, 304 | unix.SOCK_DGRAM, 305 | 0, 306 | ) 307 | if err != nil { 308 | return 0, err 309 | } 310 | 311 | defer unix.Close(fd) 312 | 313 | ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name) 314 | if err != nil { 315 | return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err) 316 | } 317 | 318 | return int(ifr.MTU), nil 319 | } 320 | 321 | func socketCloexec(family, sotype, proto int) (fd int, err error) { 322 | // See go/src/net/sys_cloexec.go for background. 323 | syscall.ForkLock.RLock() 324 | defer syscall.ForkLock.RUnlock() 325 | 326 | fd, err = unix.Socket(family, sotype, proto) 327 | if err == nil { 328 | unix.CloseOnExec(fd) 329 | } 330 | return 331 | } 332 | -------------------------------------------------------------------------------- /tun/tun_freebsd.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "io" 12 | "net" 13 | "os" 14 | "sync" 15 | "syscall" 16 | "unsafe" 17 | 18 | "golang.org/x/sys/unix" 19 | ) 20 | 21 | const ( 22 | _TUNSIFHEAD = 0x80047460 23 | _TUNSIFMODE = 0x8004745e 24 | _TUNGIFNAME = 0x4020745d 25 | _TUNSIFPID = 0x2000745f 26 | 27 | _SIOCGIFINFO_IN6 = 0xc048696c 28 | _SIOCSIFINFO_IN6 = 0xc048696d 29 | _ND6_IFF_AUTO_LINKLOCAL = 0x20 30 | _ND6_IFF_NO_DAD = 0x100 31 | ) 32 | 33 | // Iface requests with just the name 34 | type ifreqName struct { 35 | Name [unix.IFNAMSIZ]byte 36 | _ [16]byte 37 | } 38 | 39 | // Iface requests with a pointer 40 | type ifreqPtr struct { 41 | Name [unix.IFNAMSIZ]byte 42 | Data uintptr 43 | _ [16 - unsafe.Sizeof(uintptr(0))]byte 44 | } 45 | 46 | // Iface requests with MTU 47 | type ifreqMtu struct { 48 | Name [unix.IFNAMSIZ]byte 49 | MTU uint32 50 | _ [12]byte 51 | } 52 | 53 | // ND6 flag manipulation 54 | type nd6Req struct { 55 | Name [unix.IFNAMSIZ]byte 56 | Linkmtu uint32 57 | Maxmtu uint32 58 | Basereachable uint32 59 | Reachable uint32 60 | Retrans uint32 61 | Flags uint32 62 | Recalctm int 63 | Chlim uint8 64 | Initialized uint8 65 | Randomseed0 [8]byte 66 | Randomseed1 [8]byte 67 | Randomid [8]byte 68 | } 69 | 70 | type NativeTun struct { 71 | name string 72 | tunFile *os.File 73 | events chan Event 74 | errors chan error 75 | routeSocket int 76 | closeOnce sync.Once 77 | } 78 | 79 | func (tun *NativeTun) routineRouteListener(tunIfindex int) { 80 | var ( 81 | statusUp bool 82 | statusMTU int 83 | ) 84 | 85 | defer close(tun.events) 86 | 87 | data := make([]byte, os.Getpagesize()) 88 | for { 89 | retry: 90 | n, err := unix.Read(tun.routeSocket, data) 91 | if err != nil { 92 | if errors.Is(err, syscall.EINTR) { 93 | goto retry 94 | } 95 | tun.errors <- err 96 | return 97 | } 98 | 99 | if n < 14 { 100 | continue 101 | } 102 | 103 | if data[3 /* type */] != unix.RTM_IFINFO { 104 | continue 105 | } 106 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */]))) 107 | if ifindex != tunIfindex { 108 | continue 109 | } 110 | 111 | iface, err := net.InterfaceByIndex(ifindex) 112 | if err != nil { 113 | tun.errors <- err 114 | return 115 | } 116 | 117 | // Up / Down event 118 | up := (iface.Flags & net.FlagUp) != 0 119 | if up != statusUp && up { 120 | tun.events <- EventUp 121 | } 122 | if up != statusUp && !up { 123 | tun.events <- EventDown 124 | } 125 | statusUp = up 126 | 127 | // MTU changes 128 | if iface.MTU != statusMTU { 129 | tun.events <- EventMTUUpdate 130 | } 131 | statusMTU = iface.MTU 132 | } 133 | } 134 | 135 | func tunName(fd uintptr) (string, error) { 136 | var ifreq ifreqName 137 | _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq))) 138 | if err != 0 { 139 | return "", err 140 | } 141 | return unix.ByteSliceToString(ifreq.Name[:]), nil 142 | } 143 | 144 | // Destroy a named system interface 145 | func tunDestroy(name string) error { 146 | fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) 147 | if err != nil { 148 | return err 149 | } 150 | defer unix.Close(fd) 151 | 152 | var ifr [32]byte 153 | copy(ifr[:], name) 154 | _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0]))) 155 | if errno != 0 { 156 | return fmt.Errorf("failed to destroy interface %s: %w", name, errno) 157 | } 158 | 159 | return nil 160 | } 161 | 162 | func CreateTUN(name string, mtu int) (Device, error) { 163 | if len(name) > unix.IFNAMSIZ-1 { 164 | return nil, errors.New("interface name too long") 165 | } 166 | 167 | // See if interface already exists 168 | iface, _ := net.InterfaceByName(name) 169 | if iface != nil { 170 | return nil, fmt.Errorf("interface %s already exists", name) 171 | } 172 | 173 | tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR|unix.O_CLOEXEC, 0) 174 | if err != nil { 175 | return nil, err 176 | } 177 | 178 | tun := NativeTun{tunFile: tunFile} 179 | var assignedName string 180 | tun.operateOnFd(func(fd uintptr) { 181 | assignedName, err = tunName(fd) 182 | }) 183 | if err != nil { 184 | tunFile.Close() 185 | return nil, err 186 | } 187 | 188 | // Enable ifhead mode, otherwise tun will complain if it gets a non-AF_INET packet 189 | ifheadmode := 1 190 | var errno syscall.Errno 191 | tun.operateOnFd(func(fd uintptr) { 192 | _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode))) 193 | }) 194 | 195 | if errno != 0 { 196 | tunFile.Close() 197 | tunDestroy(assignedName) 198 | return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno) 199 | } 200 | 201 | // Get out of PTP mode. 202 | ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST 203 | tun.operateOnFd(func(fd uintptr) { 204 | _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags))) 205 | }) 206 | 207 | if errno != 0 { 208 | tunFile.Close() 209 | tunDestroy(assignedName) 210 | return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno) 211 | } 212 | 213 | // Disable link-local v6, not just because WireGuard doesn't do that anyway, but 214 | // also because there are serious races with attaching and detaching LLv6 addresses 215 | // in relation to interface lifetime within the FreeBSD kernel. 216 | confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) 217 | if err != nil { 218 | tunFile.Close() 219 | tunDestroy(assignedName) 220 | return nil, err 221 | } 222 | defer unix.Close(confd6) 223 | var ndireq nd6Req 224 | copy(ndireq.Name[:], assignedName) 225 | _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq))) 226 | if errno != 0 { 227 | tunFile.Close() 228 | tunDestroy(assignedName) 229 | return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno) 230 | } 231 | ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL 232 | ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD 233 | _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq))) 234 | if errno != 0 { 235 | tunFile.Close() 236 | tunDestroy(assignedName) 237 | return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno) 238 | } 239 | 240 | if name != "" { 241 | confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) 242 | if err != nil { 243 | tunFile.Close() 244 | tunDestroy(assignedName) 245 | return nil, err 246 | } 247 | defer unix.Close(confd) 248 | var newnp [unix.IFNAMSIZ]byte 249 | copy(newnp[:], name) 250 | var ifr ifreqPtr 251 | copy(ifr.Name[:], assignedName) 252 | ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) 253 | _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr))) 254 | if errno != 0 { 255 | tunFile.Close() 256 | tunDestroy(assignedName) 257 | return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno) 258 | } 259 | } 260 | 261 | return CreateTUNFromFile(tunFile, mtu) 262 | } 263 | 264 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 265 | tun := &NativeTun{ 266 | tunFile: file, 267 | events: make(chan Event, 10), 268 | errors: make(chan error, 1), 269 | } 270 | 271 | var errno syscall.Errno 272 | tun.operateOnFd(func(fd uintptr) { 273 | _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0)) 274 | }) 275 | if errno != 0 { 276 | tun.tunFile.Close() 277 | return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno) 278 | } 279 | 280 | name, err := tun.Name() 281 | if err != nil { 282 | tun.tunFile.Close() 283 | return nil, err 284 | } 285 | 286 | tunIfindex, err := func() (int, error) { 287 | iface, err := net.InterfaceByName(name) 288 | if err != nil { 289 | return -1, err 290 | } 291 | return iface.Index, nil 292 | }() 293 | if err != nil { 294 | tun.tunFile.Close() 295 | return nil, err 296 | } 297 | 298 | tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC) 299 | if err != nil { 300 | tun.tunFile.Close() 301 | return nil, err 302 | } 303 | 304 | go tun.routineRouteListener(tunIfindex) 305 | 306 | err = tun.setMTU(mtu) 307 | if err != nil { 308 | tun.Close() 309 | return nil, err 310 | } 311 | 312 | return tun, nil 313 | } 314 | 315 | func (tun *NativeTun) Name() (string, error) { 316 | var name string 317 | var err error 318 | tun.operateOnFd(func(fd uintptr) { 319 | name, err = tunName(fd) 320 | }) 321 | if err != nil { 322 | return "", err 323 | } 324 | tun.name = name 325 | return name, nil 326 | } 327 | 328 | func (tun *NativeTun) File() *os.File { 329 | return tun.tunFile 330 | } 331 | 332 | func (tun *NativeTun) Events() chan Event { 333 | return tun.events 334 | } 335 | 336 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 337 | select { 338 | case err := <-tun.errors: 339 | return 0, err 340 | default: 341 | buff := buff[offset-4:] 342 | n, err := tun.tunFile.Read(buff[:]) 343 | if n < 4 { 344 | return 0, err 345 | } 346 | return n - 4, err 347 | } 348 | } 349 | 350 | func (tun *NativeTun) Write(buf []byte, offset int) (int, error) { 351 | if offset < 4 { 352 | return 0, io.ErrShortBuffer 353 | } 354 | buf = buf[offset-4:] 355 | if len(buf) < 5 { 356 | return 0, io.ErrShortBuffer 357 | } 358 | buf[0] = 0x00 359 | buf[1] = 0x00 360 | buf[2] = 0x00 361 | switch buf[4] >> 4 { 362 | case 4: 363 | buf[3] = unix.AF_INET 364 | case 6: 365 | buf[3] = unix.AF_INET6 366 | default: 367 | return 0, unix.EAFNOSUPPORT 368 | } 369 | return tun.tunFile.Write(buf) 370 | } 371 | 372 | func (tun *NativeTun) Flush() error { 373 | // TODO: can flushing be implemented by buffering and using sendmmsg? 374 | return nil 375 | } 376 | 377 | func (tun *NativeTun) Close() error { 378 | var err1, err2, err3 error 379 | tun.closeOnce.Do(func() { 380 | err1 = tun.tunFile.Close() 381 | err2 = tunDestroy(tun.name) 382 | if tun.routeSocket != -1 { 383 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) 384 | err3 = unix.Close(tun.routeSocket) 385 | tun.routeSocket = -1 386 | } else if tun.events != nil { 387 | close(tun.events) 388 | } 389 | }) 390 | if err1 != nil { 391 | return err1 392 | } 393 | if err2 != nil { 394 | return err2 395 | } 396 | return err3 397 | } 398 | 399 | func (tun *NativeTun) setMTU(n int) error { 400 | fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) 401 | if err != nil { 402 | return err 403 | } 404 | defer unix.Close(fd) 405 | 406 | var ifr ifreqMtu 407 | copy(ifr.Name[:], tun.name) 408 | ifr.MTU = uint32(n) 409 | _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr))) 410 | if errno != 0 { 411 | return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno) 412 | } 413 | return nil 414 | } 415 | 416 | func (tun *NativeTun) MTU() (int, error) { 417 | fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) 418 | if err != nil { 419 | return 0, err 420 | } 421 | defer unix.Close(fd) 422 | 423 | var ifr ifreqMtu 424 | copy(ifr.Name[:], tun.name) 425 | _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr))) 426 | if errno != 0 { 427 | return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno) 428 | } 429 | return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil 430 | } 431 | -------------------------------------------------------------------------------- /tun/tun_openbsd.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "net" 12 | "os" 13 | "sync" 14 | "syscall" 15 | "unsafe" 16 | 17 | "golang.org/x/net/ipv6" 18 | "golang.org/x/sys/unix" 19 | ) 20 | 21 | // Structure for iface mtu get/set ioctls 22 | type ifreq_mtu struct { 23 | Name [unix.IFNAMSIZ]byte 24 | MTU uint32 25 | Pad0 [12]byte 26 | } 27 | 28 | const _TUNSIFMODE = 0x8004745d 29 | 30 | type NativeTun struct { 31 | name string 32 | tunFile *os.File 33 | events chan Event 34 | errors chan error 35 | routeSocket int 36 | closeOnce sync.Once 37 | } 38 | 39 | func (tun *NativeTun) routineRouteListener(tunIfindex int) { 40 | var ( 41 | statusUp bool 42 | statusMTU int 43 | ) 44 | 45 | defer close(tun.events) 46 | 47 | check := func() bool { 48 | iface, err := net.InterfaceByIndex(tunIfindex) 49 | if err != nil { 50 | tun.errors <- err 51 | return true 52 | } 53 | 54 | // Up / Down event 55 | up := (iface.Flags & net.FlagUp) != 0 56 | if up != statusUp && up { 57 | tun.events <- EventUp 58 | } 59 | if up != statusUp && !up { 60 | tun.events <- EventDown 61 | } 62 | statusUp = up 63 | 64 | // MTU changes 65 | if iface.MTU != statusMTU { 66 | tun.events <- EventMTUUpdate 67 | } 68 | statusMTU = iface.MTU 69 | return false 70 | } 71 | 72 | if check() { 73 | return 74 | } 75 | 76 | data := make([]byte, os.Getpagesize()) 77 | for { 78 | n, err := unix.Read(tun.routeSocket, data) 79 | if err != nil { 80 | if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { 81 | continue 82 | } 83 | tun.errors <- err 84 | return 85 | } 86 | 87 | if n < 8 { 88 | continue 89 | } 90 | 91 | if data[3 /* type */] != unix.RTM_IFINFO { 92 | continue 93 | } 94 | ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */]))) 95 | if ifindex != tunIfindex { 96 | continue 97 | } 98 | if check() { 99 | return 100 | } 101 | } 102 | } 103 | 104 | func CreateTUN(name string, mtu int) (Device, error) { 105 | ifIndex := -1 106 | if name != "tun" { 107 | _, err := fmt.Sscanf(name, "tun%d", &ifIndex) 108 | if err != nil || ifIndex < 0 { 109 | return nil, fmt.Errorf("Interface name must be tun[0-9]*") 110 | } 111 | } 112 | 113 | var tunfile *os.File 114 | var err error 115 | 116 | if ifIndex != -1 { 117 | tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) 118 | } else { 119 | for ifIndex = 0; ifIndex < 256; ifIndex++ { 120 | tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) 121 | if err == nil || !errors.Is(err, syscall.EBUSY) { 122 | break 123 | } 124 | } 125 | } 126 | 127 | if err != nil { 128 | return nil, err 129 | } 130 | 131 | tun, err := CreateTUNFromFile(tunfile, mtu) 132 | 133 | if err == nil && name == "tun" { 134 | fname := os.Getenv("WG_TUN_NAME_FILE") 135 | if fname != "" { 136 | os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) 137 | } 138 | } 139 | 140 | return tun, err 141 | } 142 | 143 | func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { 144 | tun := &NativeTun{ 145 | tunFile: file, 146 | events: make(chan Event, 10), 147 | errors: make(chan error, 1), 148 | } 149 | 150 | name, err := tun.Name() 151 | if err != nil { 152 | tun.tunFile.Close() 153 | return nil, err 154 | } 155 | 156 | tunIfindex, err := func() (int, error) { 157 | iface, err := net.InterfaceByName(name) 158 | if err != nil { 159 | return -1, err 160 | } 161 | return iface.Index, nil 162 | }() 163 | if err != nil { 164 | tun.tunFile.Close() 165 | return nil, err 166 | } 167 | 168 | tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC) 169 | if err != nil { 170 | tun.tunFile.Close() 171 | return nil, err 172 | } 173 | 174 | go tun.routineRouteListener(tunIfindex) 175 | 176 | currentMTU, err := tun.MTU() 177 | if err != nil || currentMTU != mtu { 178 | err = tun.setMTU(mtu) 179 | if err != nil { 180 | tun.Close() 181 | return nil, err 182 | } 183 | } 184 | 185 | return tun, nil 186 | } 187 | 188 | func (tun *NativeTun) Name() (string, error) { 189 | gostat, err := tun.tunFile.Stat() 190 | if err != nil { 191 | tun.name = "" 192 | return "", err 193 | } 194 | stat := gostat.Sys().(*syscall.Stat_t) 195 | tun.name = fmt.Sprintf("tun%d", stat.Rdev%256) 196 | return tun.name, nil 197 | } 198 | 199 | func (tun *NativeTun) File() *os.File { 200 | return tun.tunFile 201 | } 202 | 203 | func (tun *NativeTun) Events() chan Event { 204 | return tun.events 205 | } 206 | 207 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 208 | select { 209 | case err := <-tun.errors: 210 | return 0, err 211 | default: 212 | buff := buff[offset-4:] 213 | n, err := tun.tunFile.Read(buff[:]) 214 | if n < 4 { 215 | return 0, err 216 | } 217 | return n - 4, err 218 | } 219 | } 220 | 221 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 222 | // reserve space for header 223 | 224 | buff = buff[offset-4:] 225 | 226 | // add packet information header 227 | 228 | buff[0] = 0x00 229 | buff[1] = 0x00 230 | buff[2] = 0x00 231 | 232 | if buff[4]>>4 == ipv6.Version { 233 | buff[3] = unix.AF_INET6 234 | } else { 235 | buff[3] = unix.AF_INET 236 | } 237 | 238 | // write 239 | 240 | return tun.tunFile.Write(buff) 241 | } 242 | 243 | func (tun *NativeTun) Flush() error { 244 | // TODO: can flushing be implemented by buffering and using sendmmsg? 245 | return nil 246 | } 247 | 248 | func (tun *NativeTun) Close() error { 249 | var err1, err2 error 250 | tun.closeOnce.Do(func() { 251 | err1 = tun.tunFile.Close() 252 | if tun.routeSocket != -1 { 253 | unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) 254 | err2 = unix.Close(tun.routeSocket) 255 | tun.routeSocket = -1 256 | } else if tun.events != nil { 257 | close(tun.events) 258 | } 259 | }) 260 | if err1 != nil { 261 | return err1 262 | } 263 | return err2 264 | } 265 | 266 | func (tun *NativeTun) setMTU(n int) error { 267 | // open datagram socket 268 | 269 | var fd int 270 | 271 | fd, err := unix.Socket( 272 | unix.AF_INET, 273 | unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 274 | 0, 275 | ) 276 | if err != nil { 277 | return err 278 | } 279 | 280 | defer unix.Close(fd) 281 | 282 | // do ioctl call 283 | 284 | var ifr ifreq_mtu 285 | copy(ifr.Name[:], tun.name) 286 | ifr.MTU = uint32(n) 287 | 288 | _, _, errno := unix.Syscall( 289 | unix.SYS_IOCTL, 290 | uintptr(fd), 291 | uintptr(unix.SIOCSIFMTU), 292 | uintptr(unsafe.Pointer(&ifr)), 293 | ) 294 | 295 | if errno != 0 { 296 | return fmt.Errorf("failed to set MTU on %s", tun.name) 297 | } 298 | 299 | return nil 300 | } 301 | 302 | func (tun *NativeTun) MTU() (int, error) { 303 | // open datagram socket 304 | 305 | fd, err := unix.Socket( 306 | unix.AF_INET, 307 | unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 308 | 0, 309 | ) 310 | if err != nil { 311 | return 0, err 312 | } 313 | 314 | defer unix.Close(fd) 315 | 316 | // do ioctl call 317 | var ifr ifreq_mtu 318 | copy(ifr.Name[:], tun.name) 319 | 320 | _, _, errno := unix.Syscall( 321 | unix.SYS_IOCTL, 322 | uintptr(fd), 323 | uintptr(unix.SIOCGIFMTU), 324 | uintptr(unsafe.Pointer(&ifr)), 325 | ) 326 | if errno != 0 { 327 | return 0, fmt.Errorf("failed to get MTU on %s", tun.name) 328 | } 329 | 330 | return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil 331 | } 332 | -------------------------------------------------------------------------------- /tun/tun_windows.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tun 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | "os" 12 | "sync" 13 | "sync/atomic" 14 | "time" 15 | _ "unsafe" 16 | 17 | "golang.org/x/sys/windows" 18 | 19 | "golang.zx2c4.com/wintun" 20 | ) 21 | 22 | const ( 23 | rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) 24 | spinloopRateThreshold = 800000000 / 8 // 800mbps 25 | spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s 26 | ) 27 | 28 | type rateJuggler struct { 29 | current uint64 30 | nextByteCount uint64 31 | nextStartTime int64 32 | changing int32 33 | } 34 | 35 | type NativeTun struct { 36 | wt *wintun.Adapter 37 | name string 38 | handle windows.Handle 39 | rate rateJuggler 40 | session wintun.Session 41 | readWait windows.Handle 42 | events chan Event 43 | running sync.WaitGroup 44 | closeOnce sync.Once 45 | close int32 46 | forcedMTU int 47 | } 48 | 49 | var ( 50 | WintunTunnelType = "WireGuard" 51 | WintunStaticRequestedGUID *windows.GUID 52 | ) 53 | 54 | //go:linkname procyield runtime.procyield 55 | func procyield(cycles uint32) 56 | 57 | //go:linkname nanotime runtime.nanotime 58 | func nanotime() int64 59 | 60 | // 61 | // CreateTUN creates a Wintun interface with the given name. Should a Wintun 62 | // interface with the same name exist, it is reused. 63 | // 64 | func CreateTUN(ifname string, mtu int) (Device, error) { 65 | return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) 66 | } 67 | 68 | // 69 | // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and 70 | // a requested GUID. Should a Wintun interface with the same name exist, it is reused. 71 | // 72 | func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { 73 | wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) 74 | if err != nil { 75 | return nil, fmt.Errorf("Error creating interface: %w", err) 76 | } 77 | 78 | forcedMTU := 1420 79 | if mtu > 0 { 80 | forcedMTU = mtu 81 | } 82 | 83 | tun := &NativeTun{ 84 | wt: wt, 85 | name: ifname, 86 | handle: windows.InvalidHandle, 87 | events: make(chan Event, 10), 88 | forcedMTU: forcedMTU, 89 | } 90 | 91 | tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB 92 | if err != nil { 93 | tun.wt.Close() 94 | close(tun.events) 95 | return nil, fmt.Errorf("Error starting session: %w", err) 96 | } 97 | tun.readWait = tun.session.ReadWaitEvent() 98 | return tun, nil 99 | } 100 | 101 | func (tun *NativeTun) Name() (string, error) { 102 | return tun.name, nil 103 | } 104 | 105 | func (tun *NativeTun) File() *os.File { 106 | return nil 107 | } 108 | 109 | func (tun *NativeTun) Events() chan Event { 110 | return tun.events 111 | } 112 | 113 | func (tun *NativeTun) Close() error { 114 | var err error 115 | tun.closeOnce.Do(func() { 116 | atomic.StoreInt32(&tun.close, 1) 117 | windows.SetEvent(tun.readWait) 118 | tun.running.Wait() 119 | tun.session.End() 120 | if tun.wt != nil { 121 | tun.wt.Close() 122 | } 123 | close(tun.events) 124 | }) 125 | return err 126 | } 127 | 128 | func (tun *NativeTun) MTU() (int, error) { 129 | return tun.forcedMTU, nil 130 | } 131 | 132 | // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. 133 | func (tun *NativeTun) ForceMTU(mtu int) { 134 | update := tun.forcedMTU != mtu 135 | tun.forcedMTU = mtu 136 | if update { 137 | tun.events <- EventMTUUpdate 138 | } 139 | } 140 | 141 | // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. 142 | 143 | func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { 144 | tun.running.Add(1) 145 | defer tun.running.Done() 146 | retry: 147 | if atomic.LoadInt32(&tun.close) == 1 { 148 | return 0, os.ErrClosed 149 | } 150 | start := nanotime() 151 | shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 152 | for { 153 | if atomic.LoadInt32(&tun.close) == 1 { 154 | return 0, os.ErrClosed 155 | } 156 | packet, err := tun.session.ReceivePacket() 157 | switch err { 158 | case nil: 159 | packetSize := len(packet) 160 | copy(buff[offset:], packet) 161 | tun.session.ReleaseReceivePacket(packet) 162 | tun.rate.update(uint64(packetSize)) 163 | return packetSize, nil 164 | case windows.ERROR_NO_MORE_ITEMS: 165 | if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { 166 | windows.WaitForSingleObject(tun.readWait, windows.INFINITE) 167 | goto retry 168 | } 169 | procyield(1) 170 | continue 171 | case windows.ERROR_HANDLE_EOF: 172 | return 0, os.ErrClosed 173 | case windows.ERROR_INVALID_DATA: 174 | return 0, errors.New("Send ring corrupt") 175 | } 176 | return 0, fmt.Errorf("Read failed: %w", err) 177 | } 178 | } 179 | 180 | func (tun *NativeTun) Flush() error { 181 | return nil 182 | } 183 | 184 | func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { 185 | tun.running.Add(1) 186 | defer tun.running.Done() 187 | if atomic.LoadInt32(&tun.close) == 1 { 188 | return 0, os.ErrClosed 189 | } 190 | 191 | packetSize := len(buff) - offset 192 | tun.rate.update(uint64(packetSize)) 193 | 194 | packet, err := tun.session.AllocateSendPacket(packetSize) 195 | if err == nil { 196 | copy(packet, buff[offset:]) 197 | tun.session.SendPacket(packet) 198 | return packetSize, nil 199 | } 200 | switch err { 201 | case windows.ERROR_HANDLE_EOF: 202 | return 0, os.ErrClosed 203 | case windows.ERROR_BUFFER_OVERFLOW: 204 | return 0, nil // Dropping when ring is full. 205 | } 206 | return 0, fmt.Errorf("Write failed: %w", err) 207 | } 208 | 209 | // LUID returns Windows interface instance ID. 210 | func (tun *NativeTun) LUID() uint64 { 211 | tun.running.Add(1) 212 | defer tun.running.Done() 213 | if atomic.LoadInt32(&tun.close) == 1 { 214 | return 0 215 | } 216 | return tun.wt.LUID() 217 | } 218 | 219 | // RunningVersion returns the running version of the Wintun driver. 220 | func (tun *NativeTun) RunningVersion() (version uint32, err error) { 221 | return wintun.RunningVersion() 222 | } 223 | 224 | func (rate *rateJuggler) update(packetLen uint64) { 225 | now := nanotime() 226 | total := atomic.AddUint64(&rate.nextByteCount, packetLen) 227 | period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) 228 | if period >= rateMeasurementGranularity { 229 | if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { 230 | return 231 | } 232 | atomic.StoreInt64(&rate.nextStartTime, now) 233 | atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) 234 | atomic.StoreUint64(&rate.nextByteCount, 0) 235 | atomic.StoreInt32(&rate.changing, 0) 236 | } 237 | } 238 | -------------------------------------------------------------------------------- /tun/tuntest/tuntest.go: -------------------------------------------------------------------------------- 1 | /* SPDX-License-Identifier: MIT 2 | * 3 | * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 | */ 5 | 6 | package tuntest 7 | 8 | import ( 9 | "encoding/binary" 10 | "io" 11 | "net/netip" 12 | "os" 13 | 14 | "golang.zx2c4.com/wireguard/tun" 15 | ) 16 | 17 | func Ping(dst, src netip.Addr) []byte { 18 | localPort := uint16(1337) 19 | seq := uint16(0) 20 | 21 | payload := make([]byte, 4) 22 | binary.BigEndian.PutUint16(payload[0:], localPort) 23 | binary.BigEndian.PutUint16(payload[2:], seq) 24 | 25 | return genICMPv4(payload, dst, src) 26 | } 27 | 28 | // Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. 29 | func checksum(buf []byte, initial uint16) uint16 { 30 | v := uint32(initial) 31 | for i := 0; i < len(buf)-1; i += 2 { 32 | v += uint32(binary.BigEndian.Uint16(buf[i:])) 33 | } 34 | if len(buf)%2 == 1 { 35 | v += uint32(buf[len(buf)-1]) << 8 36 | } 37 | for v > 0xffff { 38 | v = (v >> 16) + (v & 0xffff) 39 | } 40 | return ^uint16(v) 41 | } 42 | 43 | func genICMPv4(payload []byte, dst, src netip.Addr) []byte { 44 | const ( 45 | icmpv4ProtocolNumber = 1 46 | icmpv4Echo = 8 47 | icmpv4ChecksumOffset = 2 48 | icmpv4Size = 8 49 | ipv4Size = 20 50 | ipv4TotalLenOffset = 2 51 | ipv4ChecksumOffset = 10 52 | ttl = 65 53 | headerSize = ipv4Size + icmpv4Size 54 | ) 55 | 56 | pkt := make([]byte, headerSize+len(payload)) 57 | 58 | ip := pkt[0:ipv4Size] 59 | icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size] 60 | 61 | // https://tools.ietf.org/html/rfc792 62 | icmpv4[0] = icmpv4Echo // type 63 | icmpv4[1] = 0 // code 64 | chksum := ^checksum(icmpv4, checksum(payload, 0)) 65 | binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) 66 | 67 | // https://tools.ietf.org/html/rfc760 section 3.1 68 | length := uint16(len(pkt)) 69 | ip[0] = (4 << 4) | (ipv4Size / 4) 70 | binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) 71 | ip[8] = ttl 72 | ip[9] = icmpv4ProtocolNumber 73 | copy(ip[12:], src.AsSlice()) 74 | copy(ip[16:], dst.AsSlice()) 75 | chksum = ^checksum(ip[:], 0) 76 | binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) 77 | 78 | copy(pkt[headerSize:], payload) 79 | return pkt 80 | } 81 | 82 | type ChannelTUN struct { 83 | Inbound chan []byte // incoming packets, closed on TUN close 84 | Outbound chan []byte // outbound packets, blocks forever on TUN close 85 | 86 | closed chan struct{} 87 | events chan tun.Event 88 | tun chTun 89 | } 90 | 91 | func NewChannelTUN() *ChannelTUN { 92 | c := &ChannelTUN{ 93 | Inbound: make(chan []byte), 94 | Outbound: make(chan []byte), 95 | closed: make(chan struct{}), 96 | events: make(chan tun.Event, 1), 97 | } 98 | c.tun.c = c 99 | c.events <- tun.EventUp 100 | return c 101 | } 102 | 103 | func (c *ChannelTUN) TUN() tun.Device { 104 | return &c.tun 105 | } 106 | 107 | type chTun struct { 108 | c *ChannelTUN 109 | } 110 | 111 | func (t *chTun) File() *os.File { return nil } 112 | 113 | func (t *chTun) Read(data []byte, offset int) (int, error) { 114 | select { 115 | case <-t.c.closed: 116 | return 0, os.ErrClosed 117 | case msg := <-t.c.Outbound: 118 | return copy(data[offset:], msg), nil 119 | } 120 | } 121 | 122 | // Write is called by the wireguard device to deliver a packet for routing. 123 | func (t *chTun) Write(data []byte, offset int) (int, error) { 124 | if offset == -1 { 125 | close(t.c.closed) 126 | close(t.c.events) 127 | return 0, io.EOF 128 | } 129 | msg := make([]byte, len(data)-offset) 130 | copy(msg, data[offset:]) 131 | select { 132 | case <-t.c.closed: 133 | return 0, os.ErrClosed 134 | case t.c.Inbound <- msg: 135 | return len(data) - offset, nil 136 | } 137 | } 138 | 139 | const DefaultMTU = 1420 140 | 141 | func (t *chTun) Flush() error { return nil } 142 | func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } 143 | func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } 144 | func (t *chTun) Events() chan tun.Event { return t.c.events } 145 | func (t *chTun) Close() error { 146 | t.Write(nil, -1) 147 | return nil 148 | } 149 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | const Version = "0.0.20220316" 4 | --------------------------------------------------------------------------------