├── .gitignore ├── go.mod ├── client ├── util │ └── util.go ├── auth │ └── auth.go ├── cmd │ └── cmd.go ├── client.go └── network │ └── network.go ├── LICENSE ├── antireplay ├── window_test.go └── window.go ├── server ├── auth │ └── auth.go └── server.go ├── README.md └── go.sum /.gitignore: -------------------------------------------------------------------------------- 1 | testing/ 2 | client/client 3 | server/server 4 | .idea/ -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/malcolmseyd/natpunch-go 2 | 3 | go 1.15 4 | 5 | require ( 6 | github.com/flynn/noise v1.0.0 7 | github.com/google/gopacket v1.1.18 8 | github.com/ogier/pflag v0.0.1 9 | github.com/vishvananda/netlink v1.1.0 10 | golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 11 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 12 | ) 13 | -------------------------------------------------------------------------------- /client/util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "encoding/base64" 5 | "log" 6 | 7 | "github.com/malcolmseyd/natpunch-go/client/network" 8 | ) 9 | 10 | // MakePeerSlice constructs a slice of Peers, each with a Pubkey 11 | func MakePeerSlice(peerKeys []string) []network.Peer { 12 | keys := make([]network.Peer, len(peerKeys)) 13 | for i, key := range peerKeys { 14 | keyBytes, err := base64.StdEncoding.DecodeString(key) 15 | if err != nil { 16 | log.Fatalln("Error decoding key "+key+":", err) 17 | } 18 | 19 | keyArr := [32]byte{} 20 | copy(keyArr[:], keyBytes) 21 | 22 | peer := network.Peer{ 23 | Pubkey: network.Key(keyArr), 24 | Resolved: false, 25 | } 26 | keys[i] = peer 27 | } 28 | return keys 29 | } 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Malcolm Seyd 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /antireplay/window_test.go: -------------------------------------------------------------------------------- 1 | package antireplay 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestWindow(t *testing.T) { 8 | w := Window{} 9 | w.testCheck(t, 0, true) 10 | w.testCheck(t, 0, false) 11 | w.testCheck(t, 1, true) 12 | w.testCheck(t, 1, false) 13 | w.testCheck(t, 0, false) 14 | w.testCheck(t, 3, true) 15 | w.testCheck(t, 2, true) 16 | w.testCheck(t, 2, false) 17 | w.testCheck(t, 3, false) 18 | w.testCheck(t, 30, true) 19 | w.testCheck(t, 29, true) 20 | w.testCheck(t, 28, true) 21 | w.testCheck(t, 30, false) 22 | w.testCheck(t, 28, false) 23 | w.testCheck(t, WindowSize, true) 24 | w.testCheck(t, WindowSize, false) 25 | w.testCheck(t, WindowSize+1, true) 26 | 27 | w.Reset() 28 | w.testCheck(t, 0, true) 29 | w.testCheck(t, 1, true) 30 | w.testCheck(t, WindowSize, true) 31 | 32 | w.Reset() 33 | w.testCheck(t, WindowSize+1, true) 34 | w.testCheck(t, 0, false) 35 | w.testCheck(t, 1, true) 36 | w.testCheck(t, WindowSize+3, true) 37 | w.testCheck(t, 1, false) 38 | w.testCheck(t, 2, false) 39 | w.testCheck(t, WindowSize*3, true) 40 | w.testCheck(t, WindowSize*2-1, false) 41 | w.testCheck(t, WindowSize*2, true) 42 | w.testCheck(t, WindowSize*3, false) 43 | } 44 | 45 | func (w *Window) testCheck(t *testing.T, index uint64, expected bool) { 46 | result := w.Check(index) 47 | t.Log(index, "->", result) 48 | if result != expected { 49 | t.FailNow() 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /client/auth/auth.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "crypto/rand" 5 | 6 | "github.com/flynn/noise" 7 | "github.com/malcolmseyd/natpunch-go/antireplay" 8 | "golang.org/x/crypto/curve25519" 9 | ) 10 | 11 | var noiseConfig = noise.Config{ 12 | CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s), 13 | Random: rand.Reader, 14 | Pattern: noise.HandshakeIK, 15 | Initiator: true, 16 | Prologue: []byte("natpunch-go is the best :)"), 17 | } 18 | 19 | // CipherState is an alternate implementation of noise.CipherState 20 | // that allows manual control over the nonce 21 | type CipherState struct { 22 | c noise.Cipher 23 | n uint64 24 | w antireplay.Window 25 | } 26 | 27 | // NewCipherState initializes a new CipherState 28 | func NewCipherState(c noise.Cipher) *CipherState { 29 | return &CipherState{c: c} 30 | } 31 | 32 | // Encrypt is the same as noise.HandshakeState 33 | func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte { 34 | out = s.c.Encrypt(out, s.n, ad, plaintext) 35 | s.n++ 36 | return out 37 | } 38 | 39 | // Decrypt is the same as noise.HandshakeState 40 | func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) { 41 | out, err := s.c.Decrypt(out, s.n, ad, ciphertext) 42 | s.n++ 43 | return out, err 44 | } 45 | 46 | // Nonce returns the nonce value inside CipherState 47 | func (s *CipherState) Nonce() uint64 { 48 | return s.n 49 | } 50 | 51 | // SetNonce sets the nonce value inside CipherState 52 | func (s *CipherState) SetNonce(n uint64) { 53 | s.n = n 54 | } 55 | 56 | // CheckNonce returns true if the nonce is valid, and false if the nonce is 57 | // reused or outside of the sliding window 58 | func (s *CipherState) CheckNonce(n uint64) bool { 59 | return s.w.Check(n) 60 | } 61 | 62 | // NewConfig initializes a new noise.Config with the provided data 63 | func NewConfig(privkey, theirPubkey [32]byte) (config noise.Config, err error) { 64 | config = noiseConfig 65 | config.StaticKeypair = noise.DHKey{ 66 | Private: privkey[:], 67 | } 68 | config.StaticKeypair.Public, err = curve25519.X25519(config.StaticKeypair.Private, curve25519.Basepoint) 69 | if err != nil { 70 | return config, err 71 | } 72 | config.PeerStatic = theirPubkey[:] 73 | return 74 | } 75 | -------------------------------------------------------------------------------- /server/auth/auth.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "crypto/rand" 5 | 6 | "github.com/flynn/noise" 7 | "github.com/malcolmseyd/natpunch-go/antireplay" 8 | "golang.org/x/crypto/curve25519" 9 | ) 10 | 11 | var noiseConfig = noise.Config{ 12 | CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s), 13 | Random: rand.Reader, 14 | Pattern: noise.HandshakeIK, 15 | Initiator: false, 16 | Prologue: []byte("natpunch-go is the best :)"), 17 | } 18 | 19 | // CipherState is an alternate implementation of noise.CipherState 20 | // that allows manual control over the nonce 21 | type CipherState struct { 22 | c noise.Cipher 23 | n uint64 24 | w antireplay.Window 25 | } 26 | 27 | // NewCipherState initializes a new CipherState 28 | func NewCipherState(c noise.Cipher) *CipherState { 29 | return &CipherState{c: c} 30 | } 31 | 32 | // Encrypt is the same as noise.HandshakeState 33 | func (s *CipherState) Encrypt(out, ad, plaintext []byte) []byte { 34 | out = s.c.Encrypt(out, s.n, ad, plaintext) 35 | s.n++ 36 | return out 37 | } 38 | 39 | // Decrypt is the same as noise.HandshakeState 40 | func (s *CipherState) Decrypt(out, ad, ciphertext []byte) ([]byte, error) { 41 | out, err := s.c.Decrypt(out, s.n, ad, ciphertext) 42 | s.n++ 43 | return out, err 44 | } 45 | 46 | // Nonce returns the nonce value inside CipherState 47 | func (s *CipherState) Nonce() uint64 { 48 | return s.n 49 | } 50 | 51 | // SetNonce sets the nonce value inside CipherState 52 | func (s *CipherState) SetNonce(n uint64) { 53 | s.n = n 54 | } 55 | 56 | // CheckNonce returns true if the nonce is valid, and false if the nonce is 57 | // reused or outside of the sliding window 58 | func (s *CipherState) CheckNonce(n uint64) bool { 59 | return s.w.Check(n) 60 | } 61 | 62 | // NewConfig initializes a new noise.Config with the provided data 63 | func NewConfig(privkey, theirPubkey [32]byte) (config noise.Config, err error) { 64 | config = noiseConfig 65 | config.StaticKeypair = noise.DHKey{ 66 | Private: privkey[:], 67 | } 68 | config.StaticKeypair.Public, err = curve25519.X25519(config.StaticKeypair.Private, curve25519.Basepoint) 69 | if err != nil { 70 | return config, err 71 | } 72 | config.PeerStatic = theirPubkey[:] 73 | return 74 | } 75 | -------------------------------------------------------------------------------- /client/cmd/cmd.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "encoding/base64" 5 | "log" 6 | "os/exec" 7 | "strconv" 8 | "strings" 9 | 10 | "github.com/malcolmseyd/natpunch-go/client/network" 11 | ) 12 | 13 | const persistentKeepalive = "25" 14 | 15 | // RunCmd runs a command and returns the output, returning any errors 16 | func RunCmd(command string, args ...string) (string, error) { 17 | outBytes, err := exec.Command(command, args...).Output() 18 | if err != nil { 19 | return "", err 20 | } 21 | return string(outBytes), nil 22 | } 23 | 24 | // GetClientPort gets the client's listening port for Wireguard 25 | func GetClientPort(iface string) uint16 { 26 | output, err := RunCmd("wg", "show", iface, "listen-port") 27 | if err != nil { 28 | log.Fatalln("Error getting listen port:", err) 29 | } 30 | // guaranteed castable to uint16, as ports are only 2 bytes and positive 31 | port, err := strconv.ParseUint(strings.TrimSpace(output), 10, 16) 32 | if err != nil { 33 | log.Fatalln("Error parsing listen port:", err) 34 | } 35 | return uint16(port) 36 | } 37 | 38 | // GetPeers returns a list of peers on the Wireguard interface 39 | func GetPeers(iface string) []string { 40 | output, err := RunCmd("wg", "show", iface, "peers") 41 | if err != nil { 42 | log.Fatalln("Error getting peers:", err) 43 | } 44 | return strings.Split(strings.TrimSpace(output), "\n") 45 | } 46 | 47 | // GetClientPubkey returns the publib key on the Wireguard interface 48 | func GetClientPubkey(iface string) network.Key { 49 | var keyArr [32]byte 50 | output, err := RunCmd("wg", "show", iface, "public-key") 51 | if err != nil { 52 | log.Fatalln("Error getting client pubkey:", err) 53 | } 54 | keyBytes, err := base64.StdEncoding.DecodeString(strings.TrimSpace(output)) 55 | if err != nil { 56 | log.Fatalln("Error parsing client pubkey:", err) 57 | } 58 | copy(keyArr[:], keyBytes) 59 | return network.Key(keyArr) 60 | } 61 | 62 | // GetClientPrivkey returns the private key on the Wireguard interface 63 | func GetClientPrivkey(iface string) network.Key { 64 | var keyArr [32]byte 65 | output, err := RunCmd("wg", "show", iface, "private-key") 66 | if err != nil { 67 | log.Fatalln("Error getting client privkey:", err) 68 | } 69 | keyBytes, err := base64.StdEncoding.DecodeString(strings.TrimSpace(output)) 70 | if err != nil { 71 | log.Fatalln("Error parsing client privkey:", err) 72 | } 73 | copy(keyArr[:], keyBytes) 74 | return network.Key(keyArr) 75 | } 76 | 77 | // SetPeer updates a peer's endpoint and keepalive with `wg`. keepalive is in seconds 78 | func SetPeer(peer *network.Peer, keepalive int, iface string) { 79 | keyString := base64.StdEncoding.EncodeToString(peer.Pubkey[:]) 80 | RunCmd("wg", 81 | "set", iface, 82 | "peer", keyString, 83 | "persistent-keepalive", strconv.Itoa(keepalive), 84 | "endpoint", peer.IP.String()+":"+strconv.FormatUint(uint64(peer.Port), 10), 85 | ) 86 | } 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # natpunch-go 2 | 3 | This is a [NAT hole punching](https://en.wikipedia.org/wiki/UDP_hole_punching) tool designed for creating Wireguard mesh networks. It was inspired by [Tailscale](https://www.tailscale.com/) and informed by [this example](https://git.zx2c4.com/wireguard-tools/tree/contrib/nat-hole-punching/). 4 | 5 | This tools allows you to connect to other Wireguard peers from behind a NAT using a server for ip and port discovery. The client is Linux only. 6 | 7 | ## Usage 8 | 9 | The client cycles through each peer on the interface until they are all resolved. Requires root to run due to raw socket usage. 10 | ``` 11 | Usage: ./client [OPTION]... WIREGUARD_INTERFACE SERVER_HOSTNAME:PORT SERVER_PUBKEY 12 | Flags: 13 | -c, --continuous=false: continuously resolve peers after they've already been resolved 14 | -d, --delay=2: time to wait between retries (in seconds) 15 | Example: 16 | ./client wg0 demo.wireguard.com:12345 1rwvlEQkF6vL4jA1gRzlTM7I3tuZHtdq8qkLMwBs8Uw= 17 | ``` 18 | 19 | The server associates each pubkey to an ip and a port. Doesn't require root to run. 20 | ``` 21 | Usage: ./server PORT [PRIVATE_KEY] 22 | ``` 23 | 24 | ## Why 25 | 26 | I want to have a VPN so that I can access all of my devices even when I'm out of the house. Unfortunately, using a traditional client-server model creates additional latency. Peer-to-peer connections are ideal for each client, however many of the devices are behind a NAT. The motivation for this tool was to allow p2p Wireguard connections through a NAT. 27 | 28 | ## How 29 | 30 | UDP NAT hole punching allows us to open a connection when both clients are behind a NAT. Modern NATs may employ source port randomization, which means that clients cannot predict which port to connect to in order to punch through that NAT. We need a way to discover the port that the client is using. 31 | 32 | We use a publicly facing server in order to determine the ip address and port of each client. Each client can connect using the same port as Wireguard (by spoofing the source port using raw sockets.) and its ip and port will be recorderd. It can also request the ip and port of another client using their public key. This breaks source port randomization and allows NAT punching on [every NAT type except symmetric](https://en.wikipedia.org/wiki/Network_address_translation#Methods_of_translation), as symmetric NATs may use a different external ip and port for each connection. 33 | 34 | Once each client gets the ip and port, they simply set the peer's endpoint to the ip and port it learned about and set a persistent keepalive to start the packet flow. With the keepalive, the peers will keep trying to contact each other which will create the hole on both sides and maintain the connection. 35 | 36 | ## Why Go? 37 | 38 | Go has great support for [raw sockets](https://pkg.go.dev/golang.org/x/net/ipv4?tab=doc), [packet filtering](https://pkg.go.dev/golang.org/x/net/bpf?tab=doc), and [packet construction/deconstruction](https://pkg.go.dev/github.com/google/gopacket?tab=doc). I plan on rewriting this in Rust one day but Go's library support is too good to pass up. 39 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ= 2 | github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= 3 | github.com/google/gopacket v1.1.18 h1:lum7VRA9kdlvBi7/v2p7/zcbkduHaCH/SVVyurs7OpY= 4 | github.com/google/gopacket v1.1.18/go.mod h1:UdDNZ1OO62aGYVnPhxT1U6aI7ukYtA/kB8vaU0diBUM= 5 | github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= 6 | github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= 7 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 8 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 9 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 10 | github.com/ogier/pflag v0.0.1 h1:RW6JSWSu/RkSatfcLtogGfFgpim5p7ARQ10ECk5O750= 11 | github.com/ogier/pflag v0.0.1/go.mod h1:zkFki7tvTa0tafRvTBIZTvzYyAu6kQhPZFnshFFPE+g= 12 | github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= 13 | github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= 14 | github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k= 15 | github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= 16 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 17 | golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= 18 | golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= 19 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 20 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= 21 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 22 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 23 | golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 24 | golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 25 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= 26 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 27 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 28 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 29 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 30 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 31 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 32 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 33 | -------------------------------------------------------------------------------- /antireplay/window.go: -------------------------------------------------------------------------------- 1 | package antireplay 2 | 3 | // thank you again to Wireguard-Go for helping me understand this 4 | // most credit to https://git.zx2c4.com/wireguard-go/tree/replay/replay.go 5 | 6 | // We use uintptr as blocks because pointers' size are optimized for the 7 | // local CPU architecture. 8 | const ( 9 | // a word filled with 1's 10 | blockMask = ^uintptr(0) 11 | // each word is 2**blockSizeLog bytes long 12 | // 1 if > 8 bit 1 if > 16 bit 1 if > 32 bit 13 | blockSizeLog = blockMask>>8&1 + blockMask>>16&1 + blockMask>>32&1 14 | // size of word in bytes 15 | blockSize = 1 << blockSizeLog 16 | ) 17 | 18 | const ( 19 | // total number of bits in the array 20 | // must be power of 2 21 | blocksTotalBits = 1024 22 | // bits in a block 23 | blockBits = blockSize * 8 24 | // log of bits in a block 25 | blockBitsLog = blockSizeLog + 3 26 | // WindowSize is the size of the range in which indicies are stored 27 | // W = M-1*blockSize 28 | // uint64 to avoid casting in comparisons 29 | WindowSize = uint64(blocksTotalBits - blockBits) 30 | 31 | numBlocks = blocksTotalBits / blockSize 32 | ) 33 | 34 | // Window is a sliding window that records which sequence numbers have been seen. 35 | // It implements the anti-replay algorithm described in RFC 6479 36 | type Window struct { 37 | highest uint64 38 | blocks [numBlocks]uintptr 39 | } 40 | 41 | // Reset resets the window to its initial state 42 | func (w *Window) Reset() { 43 | w.highest = 0 44 | // this is fine because higher blocks are cleared during Check() 45 | w.blocks[0] = 0 46 | } 47 | 48 | // Check records seeing index and returns true if the index is within the 49 | // window and has not been seen before. If it returns false, the index is 50 | // considered invalid. 51 | func (w *Window) Check(index uint64) bool { 52 | // check if too old 53 | if index+WindowSize < w.highest { 54 | return false 55 | } 56 | 57 | // bits outside the block size represent which block the index is in 58 | indexBlock := index >> blockBitsLog 59 | 60 | // move window if new index is higher 61 | if index > w.highest { 62 | currTopBlock := w.highest >> blockBitsLog 63 | // how many blocks ahead is indexBlock? 64 | // cap it at a full circle around the array, at that point we clear the 65 | // whole thing 66 | newBlocks := min(indexBlock-currTopBlock, numBlocks) 67 | // clear each new block 68 | for i := uint64(1); i <= newBlocks; i++ { 69 | // mod index so it wraps around 70 | w.blocks[(currTopBlock+i)%numBlocks] = 0 71 | } 72 | w.highest = index 73 | } 74 | 75 | // we didn't mod until now because we needed to know the difference between 76 | // a lower index and wrapped higher index 77 | // we need to keep the index inside the array now 78 | indexBlock %= numBlocks 79 | 80 | // bits inside the block represent where in the block the bit is 81 | // mask it with the block size 82 | indexBit := index & uint64(blockBits-1) 83 | 84 | // finally check the index 85 | 86 | // save existing block to see if it changes 87 | oldBlock := w.blocks[indexBlock] 88 | // create updated block 89 | newBlock := oldBlock | (1 << indexBit) 90 | // set block to new value 91 | w.blocks[indexBlock] = newBlock 92 | 93 | // if the bit wasn't already 1, the values should be different and this should return true 94 | return oldBlock != newBlock 95 | } 96 | 97 | func min(a, b uint64) uint64 { 98 | if a < b { 99 | return a 100 | } 101 | return b 102 | } 103 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/hex" 6 | "fmt" 7 | "log" 8 | "net" 9 | "os" 10 | "strconv" 11 | "strings" 12 | "time" 13 | 14 | "github.com/ogier/pflag" 15 | 16 | "github.com/malcolmseyd/natpunch-go/client/auth" 17 | "github.com/malcolmseyd/natpunch-go/client/cmd" 18 | "github.com/malcolmseyd/natpunch-go/client/network" 19 | "github.com/malcolmseyd/natpunch-go/client/util" 20 | ) 21 | 22 | const persistentKeepalive = 25 23 | 24 | func main() { 25 | pflag.Usage = printUsage 26 | 27 | continuous := pflag.BoolP("continuous", "c", false, "continuously resolve peers after they've already been resolved") 28 | delay := pflag.Float32P("delay", "d", 2.0, "time to wait between retries (in seconds)") 29 | 30 | pflag.Parse() 31 | args := pflag.Args() 32 | 33 | if len(args) < 3 { 34 | printUsage() 35 | os.Exit(1) 36 | } 37 | 38 | if os.Getuid() != 0 { 39 | fmt.Fprintln(os.Stderr, "Must be root!") 40 | os.Exit(1) 41 | } 42 | 43 | ifaceName := args[0] 44 | 45 | serverSplit := strings.Split(args[1], ":") 46 | serverHostname := serverSplit[0] 47 | if len(serverSplit) < 2 { 48 | fmt.Fprintln(os.Stderr, "Please include a port like this:", serverHostname+":PORT") 49 | os.Exit(1) 50 | } 51 | 52 | serverAddr := network.HostToAddr(serverHostname) 53 | 54 | serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16) 55 | if err != nil { 56 | log.Fatalln("Error parsing server port:", err) 57 | } 58 | 59 | serverKey, err := base64.StdEncoding.DecodeString(args[2]) 60 | if err != nil || len(serverKey) != 32 { 61 | log.Fatalln("Server key has improper formatting") 62 | } 63 | var serverKeyArr network.Key 64 | copy(serverKeyArr[:], serverKey) 65 | 66 | server := network.Server{ 67 | Hostname: serverHostname, 68 | Addr: serverAddr, 69 | Port: uint16(serverPort), 70 | Pubkey: serverKeyArr, 71 | } 72 | 73 | run(ifaceName, server, *continuous, *delay) 74 | } 75 | 76 | func run(ifaceName string, server network.Server, continuous bool, delay float32) { 77 | // get the source ip that we'll send the packet from 78 | clientIP := network.GetClientIP(server.Addr.IP) 79 | 80 | cmd.RunCmd("wg-quick", "up", ifaceName) 81 | 82 | // get info about the Wireguard config 83 | clientPort := cmd.GetClientPort(ifaceName) 84 | clientPubkey := cmd.GetClientPubkey(ifaceName) 85 | clientPrivkey := cmd.GetClientPrivkey(ifaceName) 86 | 87 | client := network.Peer{ 88 | IP: clientIP, 89 | Port: clientPort, 90 | Pubkey: clientPubkey, 91 | } 92 | 93 | peerKeysStr := cmd.GetPeers(ifaceName) 94 | var peers []network.Peer = util.MakePeerSlice(peerKeysStr) 95 | 96 | // we're using raw sockets to spoof the source port, 97 | // which is already being used by Wireguard 98 | rawConn := network.SetupRawConn(&server, &client) 99 | defer rawConn.Close() 100 | 101 | // payload consists of client key + peer key 102 | payload := make([]byte, 64) 103 | copy(payload[0:32], clientPubkey[:]) 104 | 105 | totalPeers := len(peers) 106 | resolvedPeers := 0 107 | 108 | fmt.Println("Resolving", totalPeers, "peers") 109 | 110 | var sendCipher, recvCipher *auth.CipherState 111 | var index uint32 112 | var err error 113 | 114 | // we keep requesting if the server doesn't have one of our peers. 115 | // this keeps running until all connections are established. 116 | tryAgain := true 117 | for tryAgain { 118 | tryAgain = false 119 | for i, peer := range peers { 120 | if peer.Resolved && !continuous { 121 | continue 122 | } 123 | 124 | // Noise handshake w/ key rotation 125 | if time.Since(server.LastHandshake) > network.RekeyDuration { 126 | sendCipher, recvCipher, index, err = network.Handshake(rawConn, clientPrivkey, &server, &client) 127 | if err != nil { 128 | if err, ok := err.(net.Error); ok && err.Timeout() { 129 | fmt.Println("Connection to", server.Hostname, "timed out.") 130 | tryAgain = true 131 | break 132 | } 133 | fmt.Fprintln(os.Stderr, "Key rotation failed:", err) 134 | tryAgain = true 135 | break 136 | } 137 | } 138 | fmt.Printf("(%d/%d) %s: ", resolvedPeers, totalPeers, base64.RawStdEncoding.EncodeToString(peer.Pubkey[:])[:16]) 139 | copy(payload[32:64], peer.Pubkey[:]) 140 | 141 | err := network.SendDataPacket(sendCipher, index, payload, rawConn, &server, &client) 142 | if err != nil { 143 | log.Println("\nError sending packet:", err) 144 | continue 145 | } 146 | 147 | // throw away udp header, we have no use for it right now 148 | body, _, packetType, n, err := network.RecvDataPacket(recvCipher, rawConn, &server, &client) 149 | if err != nil { 150 | if err, ok := err.(net.Error); ok && err.Timeout() { 151 | fmt.Println("\nConnection to", server.Hostname, "timed out.") 152 | tryAgain = true 153 | continue 154 | } 155 | fmt.Println("\nError receiving packet:", err) 156 | continue 157 | } 158 | if packetType != network.PacketData { 159 | fmt.Println("\nExpected data packet, got", packetType) 160 | } 161 | 162 | if len(body) == 0 { 163 | fmt.Println("not found") 164 | tryAgain = true 165 | continue 166 | } else if len(body) != 4+2 { 167 | // expected packet size, 4 bytes for ip, 2 for port 168 | log.Println("\nError: invalid response of length", len(body)) 169 | // For debugging 170 | fmt.Println(hex.Dump(body[:n])) 171 | tryAgain = true 172 | continue 173 | } 174 | 175 | peer.IP, peer.Port = network.ParseResponse(body) 176 | if peer.IP == nil { 177 | log.Println("Error parsing packet: not a valid UDP packet") 178 | } 179 | if !peer.Resolved { 180 | peer.Resolved = true 181 | resolvedPeers++ 182 | } 183 | 184 | fmt.Println(peer.IP.String() + ":" + strconv.FormatUint(uint64(peer.Port), 10)) 185 | cmd.SetPeer(&peer, persistentKeepalive, ifaceName) 186 | 187 | peers[i] = peer 188 | 189 | if continuous { 190 | // always try again if continuous 191 | tryAgain = true 192 | } 193 | } 194 | if tryAgain { 195 | time.Sleep(time.Second * time.Duration(delay)) 196 | } 197 | } 198 | fmt.Print("Resolved ", resolvedPeers, " peer") 199 | if totalPeers != 1 { 200 | fmt.Print("s") 201 | } 202 | fmt.Print("\n") 203 | } 204 | 205 | func printUsage() { 206 | fmt.Fprintf(os.Stderr, 207 | "Usage: %s [OPTION]... WIREGUARD_INTERFACE SERVER_HOSTNAME:PORT SERVER_PUBKEY\n"+ 208 | "Flags:\n", os.Args[0], 209 | ) 210 | pflag.PrintDefaults() 211 | fmt.Fprintf(os.Stderr, 212 | "Example:\n"+ 213 | " %s wg0 demo.wireguard.com:12345 1rwvlEQkF6vL4jA1gRzlTM7I3tuZHtdq8qkLMwBs8Uw=\n", 214 | os.Args[0], 215 | ) 216 | } 217 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "encoding/base64" 7 | "encoding/binary" 8 | "encoding/hex" 9 | "errors" 10 | "fmt" 11 | "log" 12 | "net" 13 | "os" 14 | "time" 15 | 16 | "github.com/flynn/noise" 17 | "github.com/malcolmseyd/natpunch-go/server/auth" 18 | "golang.org/x/crypto/curve25519" 19 | ) 20 | 21 | const ( 22 | // PacketHandshakeInit identifies handhshake initiation packets 23 | PacketHandshakeInit byte = 1 24 | // PacketHandshakeResp identifies handhshake response packets 25 | PacketHandshakeResp byte = 2 26 | // PacketData identifies regular data packets. 27 | PacketData byte = 3 28 | ) 29 | 30 | var ( 31 | // ErrPacketType is returned when an unexepcted packet type is enountered 32 | ErrPacketType = errors.New("server: incorrect packet type") 33 | // ErrPeerNotFound is returned when the requested peer is not found 34 | ErrPeerNotFound = errors.New("server: peer not found") 35 | // ErrPubkey is returned when the public key recieved does not match the one we expect 36 | ErrPubkey = errors.New("server: public key did not match expected one") 37 | // ErrOldTimestamp is returned when a handshake timestamp isn't newer than the previous one 38 | ErrOldTimestamp = errors.New("server: handshake timestamp isn't new") 39 | // ErrNoTimestamp is returned when the handshake packet doesn't contain a timestamp 40 | ErrNoTimestamp = errors.New("server: handshake had no timestamp") 41 | // ErrNonce is returned when the nonce on a packet isn't valid 42 | ErrNonce = errors.New("client/network: invalid nonce") 43 | 44 | timeout = 5 * time.Second 45 | 46 | noiseConfig = noise.Config{ 47 | CipherSuite: noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s), 48 | Random: rand.Reader, 49 | Pattern: noise.HandshakeIK, 50 | Initiator: false, 51 | Prologue: []byte("natpunch-go is the best :)"), 52 | } 53 | ) 54 | 55 | // Key stores a Wireguard key 56 | type Key [32]byte 57 | 58 | // we use pointers on these maps so that two maps can link to one object 59 | 60 | // PeerMap stores the peers by key 61 | type PeerMap map[Key]*Peer 62 | 63 | // IndexMap stores the Peers by index 64 | type IndexMap map[uint32]*Peer 65 | 66 | // Peer represents a Wireguard peer. 67 | type Peer struct { 68 | ip net.IP 69 | port uint16 70 | pubkey Key 71 | 72 | index uint32 73 | send, recv *auth.CipherState 74 | // UnixNano cast to uint64 75 | lastHandshake uint64 76 | } 77 | 78 | type state struct { 79 | conn *net.UDPConn 80 | keyMap PeerMap 81 | indexMap IndexMap 82 | privKey Key 83 | } 84 | 85 | func main() { 86 | if len(os.Args) < 2 { 87 | fmt.Fprintln(os.Stderr, "Usage:", os.Args[0], "PORT [PRIVATE_KEY]") 88 | os.Exit(1) 89 | } 90 | 91 | s := state{} 92 | var err error 93 | 94 | port := os.Args[1] 95 | if len(os.Args) > 2 { 96 | priv, err := base64.StdEncoding.DecodeString(os.Args[2]) 97 | if err != nil || len(priv) != 32 { 98 | fmt.Fprintln(os.Stderr, "Error parsing public key") 99 | } 100 | copy(s.privKey[:], priv) 101 | } else { 102 | rand.Read(s.privKey[:]) 103 | s.privKey.clamp() 104 | } 105 | 106 | pubkey, _ := curve25519.X25519(s.privKey[:], curve25519.Basepoint) 107 | fmt.Println("Starting nat-punching server on port", port) 108 | fmt.Println("Public key:", base64.StdEncoding.EncodeToString(pubkey)) 109 | 110 | s.keyMap = make(PeerMap) 111 | s.indexMap = make(IndexMap) 112 | 113 | // the client can only handle IPv4 addresses right now. 114 | listenAddr, err := net.ResolveUDPAddr("udp4", ":"+port) 115 | if err != nil { 116 | log.Panicln("Error getting UDP address", err) 117 | } 118 | 119 | s.conn, err = net.ListenUDP("udp4", listenAddr) 120 | if err != nil { 121 | log.Panicln("Error getting UDP listen connection") 122 | } 123 | 124 | for { 125 | err := s.handleConnection() 126 | if err != nil { 127 | fmt.Println("Error handling the connection", err) 128 | } 129 | } 130 | } 131 | 132 | func (s *state) handleConnection() error { 133 | packet := make([]byte, 4096) 134 | 135 | n, clientAddr, err := s.conn.ReadFromUDP(packet) 136 | if err != nil { 137 | return err 138 | } 139 | packet = packet[:n] 140 | 141 | packetType := packet[0] 142 | packet = packet[1:] 143 | 144 | if packetType == PacketHandshakeInit { 145 | return s.handshake(packet, clientAddr, timeout) 146 | } else if packetType == PacketData { 147 | return s.dataPacket(packet, clientAddr, timeout) 148 | } else { 149 | fmt.Println("Unknown packet type:", packetType) 150 | fmt.Println(hex.Dump(packet)) 151 | } 152 | 153 | return nil 154 | } 155 | 156 | func (s *state) dataPacket(packet []byte, clientAddr *net.UDPAddr, timeout time.Duration) (err error) { 157 | index := binary.BigEndian.Uint32(packet[:4]) 158 | packet = packet[4:] 159 | 160 | client, ok := s.indexMap[index] 161 | if !ok { 162 | return 163 | } 164 | 165 | nonce := binary.BigEndian.Uint64(packet[:8]) 166 | packet = packet[8:] 167 | // println("recving nonce", nonce) 168 | 169 | client.recv.SetNonce(nonce) 170 | plaintext, err := client.recv.Decrypt(nil, nil, packet) 171 | if err != nil { 172 | return 173 | } 174 | if !client.recv.CheckNonce(nonce) { 175 | // no need to throw an error, just return 176 | return 177 | } 178 | 179 | clientPubKey := plaintext[:32] 180 | plaintext = plaintext[32:] 181 | 182 | if !bytes.Equal(clientPubKey, client.pubkey[:]) { 183 | err = ErrPubkey 184 | return 185 | } 186 | 187 | var targetPubKey Key 188 | copy(targetPubKey[:], plaintext[:32]) 189 | // for later use 190 | plaintext = plaintext[:6] 191 | 192 | client.ip = clientAddr.IP 193 | client.port = uint16(clientAddr.Port) 194 | 195 | targetPeer, peerExists := s.keyMap[targetPubKey] 196 | if peerExists { 197 | // client must be ipv4 so this will never return nil 198 | copy(plaintext[:4], targetPeer.ip.To4()) 199 | binary.BigEndian.PutUint16(plaintext[4:6], targetPeer.port) 200 | } else { 201 | // return nothing if peer not found 202 | plaintext = plaintext[:0] 203 | } 204 | 205 | nonceBytes := make([]byte, 8) 206 | binary.BigEndian.PutUint64(nonceBytes, client.send.Nonce()) 207 | 208 | header := append([]byte{PacketData}, nonceBytes...) 209 | // println("sent nonce:", client.send.Nonce()) 210 | // println("sending", len(plaintext), "bytes") 211 | response := client.send.Encrypt(header, nil, plaintext) 212 | 213 | _, err = s.conn.WriteToUDP(response, clientAddr) 214 | if err != nil { 215 | return 216 | } 217 | 218 | fmt.Print( 219 | base64.StdEncoding.EncodeToString(client.pubkey[:])[:16], 220 | " ==> ", 221 | base64.StdEncoding.EncodeToString(targetPubKey[:])[:16], 222 | ": ", 223 | ) 224 | 225 | if peerExists { 226 | fmt.Println("CONNECTED") 227 | } else { 228 | fmt.Println("NOT FOUND") 229 | } 230 | 231 | return 232 | } 233 | 234 | func (s *state) handshake(packet []byte, clientAddr *net.UDPAddr, timeout time.Duration) (err error) { 235 | config := noiseConfig 236 | config.StaticKeypair = noise.DHKey{ 237 | Private: s.privKey[:], 238 | } 239 | config.StaticKeypair.Public, err = curve25519.X25519(config.StaticKeypair.Private, curve25519.Basepoint) 240 | if err != nil { 241 | return 242 | } 243 | 244 | handshake, err := noise.NewHandshakeState(config) 245 | if err != nil { 246 | return 247 | } 248 | 249 | indexBytes := packet[:4] 250 | index := binary.BigEndian.Uint32(indexBytes) 251 | packet = packet[4:] 252 | 253 | timestampBytes, _, _, err := handshake.ReadMessage(nil, packet) 254 | if err != nil { 255 | return 256 | } 257 | if len(timestampBytes) == 0 { 258 | err = ErrNoTimestamp 259 | } 260 | timestamp := binary.BigEndian.Uint64(timestampBytes) 261 | 262 | var pubkey Key 263 | copy(pubkey[:], handshake.PeerStatic()) 264 | client, ok := s.keyMap[pubkey] 265 | if !ok { 266 | client = &Peer{ 267 | pubkey: pubkey, 268 | } 269 | s.keyMap[pubkey] = client 270 | } 271 | if timestamp <= client.lastHandshake { 272 | err = ErrOldTimestamp 273 | return 274 | } 275 | client.lastHandshake = timestamp 276 | // clear old entry 277 | s.indexMap[index] = nil 278 | client.ip = clientAddr.IP 279 | client.port = uint16(clientAddr.Port) 280 | // if index is aleady taken, set a new one 281 | for { 282 | _, ok = s.indexMap[index] 283 | if !ok { 284 | break 285 | } 286 | index++ 287 | } 288 | client.index = index 289 | binary.BigEndian.PutUint32(indexBytes, index) 290 | s.indexMap[index] = client 291 | 292 | header := append([]byte{PacketHandshakeResp}, indexBytes...) 293 | // recv and send are opposite order from client code 294 | packet, recv, send, err := handshake.WriteMessage(header, nil) 295 | if err != nil { 296 | return 297 | } 298 | client.send = auth.NewCipherState(send.Cipher()) 299 | client.recv = auth.NewCipherState(recv.Cipher()) 300 | 301 | _, err = s.conn.WriteTo(packet, clientAddr) 302 | 303 | return 304 | } 305 | 306 | func (k *Key) clamp() { 307 | k[0] &= 248 308 | k[31] = (k[31] & 127) | 64 309 | } 310 | -------------------------------------------------------------------------------- /client/network/network.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "encoding/binary" 7 | "errors" 8 | "fmt" 9 | "log" 10 | "net" 11 | "time" 12 | 13 | "github.com/flynn/noise" 14 | "github.com/google/gopacket" 15 | "github.com/google/gopacket/layers" 16 | "github.com/malcolmseyd/natpunch-go/client/auth" 17 | "github.com/vishvananda/netlink" 18 | "golang.org/x/net/bpf" 19 | "golang.org/x/net/ipv4" 20 | ) 21 | 22 | const ( 23 | udpProtocol = 17 24 | // EmptyUDPSize is the size of an empty UDP packet 25 | EmptyUDPSize = 28 26 | 27 | timeout = time.Second * 10 28 | 29 | // PacketHandshakeInit identifies handhshake initiation packets 30 | PacketHandshakeInit byte = 1 31 | // PacketHandshakeResp identifies handhshake response packets 32 | PacketHandshakeResp byte = 2 33 | // PacketData identifies regular data packets 34 | PacketData byte = 3 35 | ) 36 | 37 | var ( 38 | // ErrPacketType is returned when an unexepcted packet type is enountered 39 | ErrPacketType = errors.New("client/network: incorrect packet type") 40 | // ErrNonce is returned when the nonce on a packet isn't valid 41 | ErrNonce = errors.New("client/network: invalid nonce") 42 | 43 | // RekeyDuration is the time after which keys are invalid and a new handshake is required. 44 | RekeyDuration = 5 * time.Minute 45 | ) 46 | 47 | // EmptyUDPSize is the size of the IPv4 and UDP headers combined. 48 | 49 | // Key stores a 32 byte representation of a Wireguard key 50 | type Key [32]byte 51 | 52 | // Server stores data relating to the server 53 | type Server struct { 54 | Hostname string 55 | Addr *net.IPAddr 56 | Port uint16 57 | Pubkey Key 58 | 59 | LastHandshake time.Time 60 | } 61 | 62 | // Peer stores data about a peer's key and endpoint, whether it's another peer or the client 63 | // While Resolved == false, we consider IP and Port to be uninitialized 64 | // I could have done a nested struct with Endpoint containing IP and Port but that's 65 | // unnecessary right now. 66 | type Peer struct { 67 | Resolved bool 68 | IP net.IP 69 | Port uint16 70 | Pubkey Key 71 | } 72 | 73 | // GetClientIP gets source ip address that will be used when sending data to dstIP 74 | func GetClientIP(dstIP net.IP) net.IP { 75 | // i wanted to use gopacket/routing but it breaks when the vpn iface is already up 76 | routes, err := netlink.RouteGet(dstIP) 77 | if err != nil { 78 | log.Fatalln("Error getting route:", err) 79 | } 80 | // pick the first one cuz why not 81 | return routes[0].Src 82 | } 83 | 84 | // HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr 85 | func HostToAddr(hostStr string) *net.IPAddr { 86 | remoteAddrs, err := net.LookupHost(hostStr) 87 | if err != nil { 88 | log.Fatalln("Error parsing remote address:", err) 89 | } 90 | 91 | for _, addrStr := range remoteAddrs { 92 | if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil { 93 | return remoteAddr 94 | } 95 | } 96 | return nil 97 | } 98 | 99 | // SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering 100 | func SetupRawConn(server *Server, client *Peer) *ipv4.RawConn { 101 | packetConn, err := net.ListenPacket("ip4:udp", client.IP.String()) 102 | if err != nil { 103 | log.Fatalln("Error creating packetConn:", err) 104 | } 105 | 106 | rawConn, err := ipv4.NewRawConn(packetConn) 107 | if err != nil { 108 | log.Fatalln("Error creating rawConn:", err) 109 | } 110 | 111 | ApplyBPF(rawConn, server, client) 112 | 113 | return rawConn 114 | } 115 | 116 | // ApplyBPF constructs a BPF program and applies it to the RawConn 117 | func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *Peer) { 118 | const ipv4HeaderLen = 20 119 | 120 | const srcIPOffset = 12 121 | const srcPortOffset = ipv4HeaderLen + 0 122 | const dstPortOffset = ipv4HeaderLen + 2 123 | 124 | ipArr := []byte(server.Addr.IP.To4()) 125 | ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3]) 126 | 127 | // we don't need to filter packet type because the rawconn is ipv4-udp only 128 | // Skip values represent the number of instructions to skip if true or false 129 | // We can skip to the end if we get a !=, otherwise keep going 130 | bpfRaw, err := bpf.Assemble([]bpf.Instruction{ 131 | bpf.LoadAbsolute{Off: srcIPOffset, Size: 4}, //src ip is server 132 | bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0}, 133 | 134 | bpf.LoadAbsolute{Off: srcPortOffset, Size: 2}, //src port is server 135 | bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0}, 136 | 137 | bpf.LoadAbsolute{Off: dstPortOffset, Size: 2}, //dst port is client 138 | bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0}, 139 | 140 | bpf.RetConstant{Val: 1<<(8*4) - 1}, // max number that fits this value (entire packet) 141 | bpf.RetConstant{Val: 0}, 142 | }) 143 | 144 | err = rawConn.SetBPF(bpfRaw) 145 | if err != nil { 146 | log.Fatalln("Error setting BPF:", err) 147 | } 148 | } 149 | 150 | // MakePacket constructs a request packet to send to the server 151 | func MakePacket(payload []byte, server *Server, client *Peer) []byte { 152 | buf := gopacket.NewSerializeBuffer() 153 | 154 | // this does the hard stuff for us 155 | opts := gopacket.SerializeOptions{ 156 | FixLengths: true, 157 | ComputeChecksums: true, 158 | } 159 | 160 | ipHeader := layers.IPv4{ 161 | SrcIP: client.IP, 162 | DstIP: server.Addr.IP, 163 | Version: 4, 164 | TTL: 64, 165 | Protocol: layers.IPProtocolUDP, 166 | } 167 | 168 | udpHeader := layers.UDP{ 169 | SrcPort: layers.UDPPort(client.Port), 170 | DstPort: layers.UDPPort(server.Port), 171 | } 172 | 173 | payloadLayer := gopacket.Payload(payload) 174 | 175 | udpHeader.SetNetworkLayerForChecksum(&ipHeader) 176 | 177 | gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer) 178 | 179 | return buf.Bytes() 180 | } 181 | 182 | // Handshake performs a Noise-IK handshake with the Server 183 | func Handshake(conn *ipv4.RawConn, privkey Key, server *Server, client *Peer) (sendCipher, recvCipher *auth.CipherState, index uint32, err error) { 184 | // we generate index on the client side 185 | indexBytes := make([]byte, 4) 186 | rand.Read(indexBytes) 187 | index = binary.BigEndian.Uint32(indexBytes) 188 | 189 | config, err := auth.NewConfig(privkey, server.Pubkey) 190 | if err != nil { 191 | return 192 | } 193 | 194 | handshake, err := noise.NewHandshakeState(config) 195 | if err != nil { 196 | return 197 | } 198 | 199 | header := append([]byte{PacketHandshakeInit}, indexBytes...) 200 | 201 | timestamp := make([]byte, 8) 202 | binary.BigEndian.PutUint64(timestamp, uint64(time.Now().UnixNano())) 203 | 204 | packet, _, _, err := handshake.WriteMessage(header, timestamp) 205 | if err != nil { 206 | return 207 | } 208 | err = SendPacket(packet, conn, server, client) 209 | if err != nil { 210 | return 211 | } 212 | 213 | response, n, err := RecvPacket(conn, server, client) 214 | if err != nil { 215 | return 216 | } 217 | response = response[EmptyUDPSize:n] 218 | packetType := response[0] 219 | response = response[1:] 220 | 221 | if packetType != PacketHandshakeResp { 222 | err = ErrPacketType 223 | return 224 | } 225 | index = binary.BigEndian.Uint32(response[:4]) 226 | response = response[4:] 227 | 228 | _, send, recv, err := handshake.ReadMessage(nil, response) 229 | // we use our own implementation for manual nonce control 230 | sendCipher = auth.NewCipherState(send.Cipher()) 231 | recvCipher = auth.NewCipherState(recv.Cipher()) 232 | 233 | server.LastHandshake = time.Now() 234 | 235 | return 236 | } 237 | 238 | // SendPacket sends packet to the Server 239 | func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *Peer) error { 240 | fullPacket := MakePacket(packet, server, client) 241 | _, err := conn.WriteToIP(fullPacket, server.Addr) 242 | return err 243 | } 244 | 245 | // SendDataPacket encrypts and sends packet to the Server 246 | func SendDataPacket(cipher *auth.CipherState, index uint32, data []byte, conn *ipv4.RawConn, server *Server, client *Peer) error { 247 | indexBytes := make([]byte, 4) 248 | binary.BigEndian.PutUint32(indexBytes, index) 249 | 250 | nonceBytes := make([]byte, 8) 251 | binary.BigEndian.PutUint64(nonceBytes, cipher.Nonce()) 252 | // println("sending nonce:", cipher.Nonce()) 253 | 254 | header := append([]byte{PacketData}, indexBytes...) 255 | header = append(header, nonceBytes...) 256 | 257 | packet := cipher.Encrypt(header, nil, data) 258 | 259 | return SendPacket(packet, conn, server, client) 260 | } 261 | 262 | // RecvPacket recieves a UDP packet from server 263 | func RecvPacket(conn *ipv4.RawConn, server *Server, client *Peer) ([]byte, int, error) { 264 | err := conn.SetReadDeadline(time.Now().Add(timeout)) 265 | if err != nil { 266 | return nil, 0, err 267 | } 268 | // TODO add length field to packet 269 | // this will allow us to use a growable buffer, or to read only when needed 270 | response := make([]byte, 4096) 271 | n, err := conn.Read(response) 272 | if err != nil { 273 | return nil, n, err 274 | } 275 | return response, n, nil 276 | } 277 | 278 | // RecvDataPacket recieves a UDP packet from server 279 | func RecvDataPacket(cipher *auth.CipherState, conn *ipv4.RawConn, server *Server, client *Peer) (body, header []byte, packetType byte, n int, err error) { 280 | response, n, err := RecvPacket(conn, server, client) 281 | if err != nil { 282 | return 283 | } 284 | header = response[:EmptyUDPSize] 285 | response = response[EmptyUDPSize:n] 286 | // println(hex.Dump(response)) 287 | 288 | packetType = response[0] 289 | response = response[1:] 290 | 291 | nonce := binary.BigEndian.Uint64(response[:8]) 292 | response = response[8:] 293 | cipher.SetNonce(nonce) 294 | // println("recving nonce:", nonce) 295 | 296 | body, err = cipher.Decrypt(nil, nil, response) 297 | if err != nil { 298 | return 299 | } 300 | 301 | // now that we're authenticated, see if the nonce is valid 302 | // the sliding window contains a generous 1000 packets, that should hold up 303 | // with plenty of peers. 304 | if !cipher.CheckNonce(nonce) { 305 | err = ErrNonce 306 | body = nil 307 | } 308 | 309 | return 310 | } 311 | 312 | // ParseResponse takes a response packet and parses it into an IP and port. 313 | // There's no error checking, we assume that data passed in is valid 314 | func ParseResponse(response []byte) (net.IP, uint16) { 315 | var ip net.IP 316 | var port uint16 317 | // packet := gopacket.NewPacket(response, layers.LayerTypeIPv4, gopacket.DecodeOptions{ 318 | // Lazy: true, 319 | // NoCopy: true, 320 | // }) 321 | // if packet.TransportLayer().LayerType() != layers.LayerTypeUDP { 322 | // return nil, 0 323 | // } 324 | // payload := packet.ApplicationLayer().LayerContents() 325 | 326 | // data := bytes.NewBuffer(payload) 327 | // // fmt.Println("Layer payload:\n", hex.Dump(data.Bytes())) 328 | 329 | // binary.Read(data, binary.BigEndian, &ipv4Slice) 330 | // ip = net.IP(ipv4Slice) 331 | // binary.Read(data, binary.BigEndian, &port) 332 | // // fmt.Println("ip:", ip.String(), "port:", port) 333 | ip = net.IP(response[:4]) 334 | port = binary.BigEndian.Uint16(response[4:6]) 335 | return ip, port 336 | } 337 | 338 | func testBPF(peers []Peer, client *Peer, server *Server, rawConn *ipv4.RawConn) { 339 | payload := make([]byte, 64) 340 | copy(payload[0:32], client.Pubkey[:]) 341 | 342 | response := make([]byte, 4096) 343 | 344 | // goroutine to read replies 345 | go func() { 346 | for { 347 | n, err := rawConn.Read(response) 348 | if err != nil { 349 | if err, ok := err.(net.Error); ok && err.Timeout() { 350 | fmt.Println("\nConnection to", server.Hostname, "timed out.") 351 | continue 352 | } 353 | fmt.Println("\nError receiving packet:", err) 354 | continue 355 | } 356 | // fmt.Println(n-28, "bytes read") 357 | if n != 28 && n != 28+6 { 358 | srcIP, srcPort, dstPort := parseForBPF(response) 359 | fmt.Println("\nInvalid response of", n, "bytes") 360 | fmt.Println("SRC IP:", srcIP, "\tEXPECTED:", server.Addr.IP) 361 | fmt.Println("SRC PORT:", srcPort, "\tEXPECTED:", server.Port) 362 | fmt.Println("DST PORT:", dstPort, "\tEXPECTED:", client.Port) 363 | fmt.Println() 364 | // fmt.Println(hex.Dump(response[:n])) 365 | } else { 366 | fmt.Print(".") 367 | } 368 | } 369 | }() 370 | 371 | // send packets on the main goroutine 372 | for { 373 | for _, peer := range peers { 374 | copy(payload[32:64], peer.Pubkey[:]) 375 | 376 | packet := MakePacket(payload, server, client) 377 | _, err := rawConn.WriteToIP(packet, server.Addr) 378 | if err != nil { 379 | log.Println("\nError sending packet:", err) 380 | continue 381 | } 382 | } 383 | } 384 | } 385 | 386 | func parseForBPF(response []byte) (srcIP net.IP, srcPort uint16, dstPort uint16) { 387 | srcIP = make([]byte, 4) 388 | srcIPBytes := bytes.NewBuffer(response[12:16]) 389 | srcPortBytes := bytes.NewBuffer(response[20:22]) 390 | dstPortBytes := bytes.NewBuffer(response[22:24]) 391 | 392 | binary.Read(srcIPBytes, binary.BigEndian, &srcIP) 393 | binary.Read(srcPortBytes, binary.BigEndian, &srcPort) 394 | binary.Read(dstPortBytes, binary.BigEndian, &dstPort) 395 | return 396 | } 397 | --------------------------------------------------------------------------------