├── .github └── workflows │ └── release.yml ├── .gitignore ├── LICENSE ├── README.md ├── cmd └── dwgd.go ├── commander.go ├── config.go ├── development ├── README.md ├── Vagrantfile └── e2e-tests │ ├── common.sh │ ├── test_ifname_mode.sh │ └── test_pubkey_mode.sh ├── driver.go ├── driver_test.go ├── dwgd.go ├── entities.go ├── entities_test.go ├── go.mod ├── go.sum ├── listener.go ├── log.go ├── migrations ├── 0000.sql └── 0001.sql ├── rootless.go └── systemd └── dwgd.service /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release dwgd 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v[0-9]+.[0-9]+.[0-9]+" 7 | 8 | jobs: 9 | release: 10 | runs-on: ubuntu-latest 11 | name: Release dwgd tarball 12 | strategy: 13 | matrix: 14 | include: 15 | - arch: amd64 16 | cc: gcc 17 | - arch: arm64 18 | cc: aarch64-linux-gnu-gcc 19 | env: 20 | GOOS: linux 21 | GOARCH: ${{ matrix.arch }} 22 | TAG: ${{ github.ref_name }} 23 | CC: ${{ matrix.cc }} 24 | 25 | steps: 26 | - name: Checkout code 27 | uses: actions/checkout@v3 28 | 29 | - name: Setup Go 30 | uses: actions/setup-go@v3 31 | with: 32 | go-version-file: 'go.mod' 33 | cache: true 34 | 35 | - id: release 36 | uses: bruceadams/get-release@v1.3.2 37 | env: 38 | GITHUB_TOKEN: ${{ github.token }} 39 | 40 | - name: Install cross-compilers 41 | run: | 42 | sudo apt-get update 43 | sudo apt-get install -y gcc-aarch64-linux-gnu gcc-arm-linux-gnueabihf gcc-arm-linux-gnueabi 44 | 45 | - name: Build binary 46 | run: | 47 | rm -rf dist 48 | mkdir -p dist/systemd 49 | CGO_ENABLED=1 go build -tags 'osusergo,netgo,static,' -ldflags '-linkmode=external -extldflags "-static" -X main.Version=${{ env.TAG }}' -o dist/dwgd cmd/dwgd.go 50 | cp systemd/* dist/systemd/ 51 | tar -czvf dwgd-${{ env.TAG }}-${{ env.GOOS }}-${{ env.GOARCH }}.tar.gz -C dist . 52 | 53 | - name: Upload release tarball 54 | uses: actions/upload-release-asset@v1.0.1 55 | env: 56 | GITHUB_TOKEN: ${{ github.token }} 57 | with: 58 | upload_url: ${{ steps.release.outputs.upload_url }} 59 | asset_path: ./dwgd-${{ env.TAG }}-${{ env.GOOS }}-${{ env.GOARCH }}.tar.gz 60 | asset_name: dwgd-${{ env.TAG }}-${{ env.GOOS }}-${{ env.GOARCH }}.tar.gz 61 | asset_content_type: application/gzip 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dwgd 2 | .vagrant 3 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Leonardo Mosciatti 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dwgd: Docker WireGuard Driver 2 | 3 | **dwgd** is a Docker plugin that let your containers connect to a WireGuard 4 | network. 5 | This is achieved by [moving a WireGuard network interface](https://www.wireguard.com/netns/) 6 | from `dwgd` running namespace into the designated container namespace. 7 | 8 | ## Usage 9 | 10 | ### 1. Start the daemon 11 | 12 | Start the `dwgd` daemon: 13 | ``` 14 | $ sudo dwgd -d /var/lib/dwgd.db 15 | [...] 16 | ``` 17 | 18 | ### 2. Create the docker network 19 | 20 | Depending on which [driver specific options](https://docs.docker.com/reference/cli/docker/network/create/#options) 21 | (`-o`) you pass during the network creation phase, you can select two modes: 22 | 23 | - [ifname mode](#ifname-mode): if you want to connect to a WireGuard 24 | interface that's living in the **same host** as the one you are running your 25 | containers on; 26 | 27 | - [pubkey mode](#pubkey-mode): if you want to connect to a WireGuard interface 28 | that's living in a **different host** as the one you are running your 29 | containers on. 30 | 31 | 32 | #### Ifname mode 33 | 34 | In this mode, the name of a **local** WireGuard interface is passed as an option. 35 | `dwgd` will create a new WireGuard interface that will peer to the one you 36 | passed and hand it to the containers. 37 | 38 | Options that you need to pass: 39 | 40 | - `dwgd.ifname`: the name of the **local** interface; 41 | - `dwgd.seed`: secret seed that will be used to generate public and private keys 42 | by SHA256 hashing the `{IP, seed}` couple. 43 | 44 | ``` 45 | docker network create \ 46 | --driver=dwgd \ 47 | -o dwgd.ifname=wg0 \ 48 | -o dwgd.seed=supersecretseed \ 49 | --subnet=10.0.0.0/24 \ 50 | --gateway=10.0.0.1 \ 51 | dwgd-net 52 | ``` 53 | 54 | #### Pubkey mode 55 | 56 | In this mode, an endpoint and a public key for a WireGuard peer to which 57 | containers should connect to are passed as arguments. 58 | 59 | 60 | **Note** 61 | 62 | Please note that you will likely need to modify manually the configuration of 63 | the remote WireGuard peer by adding each container as a peer. 64 | 65 | This is doable because public and private keys are deterministically generated 66 | by hashing the `{IP, seed}` couple. 67 | 68 | You can generate the public key for an `{IP, seed}` couple using the following 69 | command: 70 | 71 | ``` 72 | $ dwgd pubkey -s supersecretseed -i 10.0.0.2 73 | oKetpvdq/I/c7hTW6/AtQPqVlSzgx3q2ClWCx/OXS00= 74 | ``` 75 | 76 | Options that you need to pass: 77 | 78 | - `dwgd.pubkey`: the public key of the remote WireGuard interface; 79 | - `dwgd.seed`: secret seed that will be used to generate public and private keys 80 | by SHA256 hashing the `{IP, seed}` couple; 81 | - `dwgd.endpoint`: the endpoint of the WireGuard peer you want your docker 82 | containers to connect to. 83 | 84 | Create the docker network with the same seed you used to generate the public 85 | key: 86 | ``` 87 | $ docker network create \ 88 | --driver=dwgd \ 89 | -o dwgd.endpoint=example.com:51820 \ 90 | -o dwgd.seed=supersecretseed \ 91 | -o dwgd.pubkey="your remote WireGuard peer's public key" \ 92 | --subnet=10.0.0.0/24 \ 93 | --gateway=10.0.0.1 \ 94 | dwgd_net 95 | ``` 96 | 97 | ### 3. Start a container 98 | 99 | Note that the IP must be set manually. 100 | 101 | ``` 102 | $ docker run -it --rm --network=dwgd_net --ip=10.0.0.2 busybox 103 | / # ip a 104 | 1: lo: mtu 65536 qdisc noqueue qlen 1000 105 | link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00 106 | inet 127.0.0.1/8 scope host lo 107 | valid_lft forever preferred_lft forever 108 | 5: wg0: mtu 1420 qdisc noqueue qlen 1000 109 | link/[65534] 110 | inet 10.0.0.2/24 brd 10.0.0.255 scope global wg0 111 | valid_lft forever preferred_lft forever 112 | / # ping 10.0.0.1 113 | PING 10.0.0.1 (10.0.0.1) 56(84) bytes of data. 114 | 64 bytes from 10.0.0.1: icmp_seq=1 ttl=54 time=9.98 ms 115 | 64 bytes from 10.0.0.1: icmp_seq=2 ttl=54 time=8.65 ms 116 | 64 bytes from 10.0.0.1: icmp_seq=3 ttl=54 time=8.34 ms 117 | ^C 118 | --- 10.0.0.1 ping statistics --- 119 | 3 packets transmitted, 3 received, 0% packet loss, time 2003ms 120 | rtt min/avg/max/mdev = 8.343/8.990/9.976/0.708 ms 121 | ``` 122 | 123 | ## Installation 124 | 125 | This software has been tested in a Linux machine with Debian 12, but I guess it 126 | could work on any reasonably recent Linux system that respects the dependencies. 127 | 128 | After cloning the repository you can build the binary and optionally install 129 | the systemd unit. 130 | ``` 131 | $ go build -o /usr/bin/dwgd ./cmd/dwgd.go 132 | $ chmod +x /usr/bin/dwgd 133 | $ install systemd/* /etc/systemd/system/ 134 | ``` 135 | 136 | ### Dependencies 137 | 138 | You need to have WireGuard installed on your system and the `iproute2` package: 139 | `dwgd` uses the `ip` command to create and delete the WireGuard interfaces. 140 | 141 | You will also need the `nsenter` binary if you want `dwgd` to work with docker 142 | rootless. 143 | 144 | ## Development 145 | 146 | Please refer to [the development directory](development/README.md). 147 | 148 | ## Credits 149 | 150 | This is a rewrite of the proof of concept presented in [this great article](https://www.bestov.io/blog/using-wireguard-as-the-network-for-a-docker-container). -------------------------------------------------------------------------------- /cmd/dwgd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "net" 6 | "os" 7 | "os/signal" 8 | "syscall" 9 | 10 | "github.com/leomos/dwgd" 11 | ) 12 | 13 | var Version string 14 | 15 | var cfg = dwgd.NewConfig() 16 | 17 | func init() { 18 | flag.StringVar(&cfg.Db, "d", cfg.Db, "dwgd db path") 19 | flag.BoolVar(&cfg.Verbose, "v", cfg.Verbose, "verbose mode") 20 | flag.BoolVar(&cfg.Rootless, "r", cfg.Rootless, "run in rootless compatibility mode") 21 | } 22 | 23 | var versionFlag = flag.Bool("version", false, "print the version") 24 | 25 | var pubkeyCmd = flag.NewFlagSet("pubkey", flag.ExitOnError) 26 | var ipFlag = pubkeyCmd.String("i", "", "IP to generate public key") 27 | var seedFlag = pubkeyCmd.String("s", "", "seed to generate public key") 28 | 29 | func pubkey(args []string) { 30 | pubkeyCmd.Parse(args) 31 | 32 | seed := *seedFlag 33 | ip := *ipFlag 34 | 35 | if seed == "" { 36 | dwgd.EventsLog.Println("seed is required") 37 | pubkeyCmd.Usage() 38 | os.Exit(1) 39 | } 40 | if ip == "" { 41 | dwgd.EventsLog.Println("ip is required") 42 | os.Exit(1) 43 | } 44 | 45 | privkey := dwgd.GeneratePrivateKey([]byte(seed), net.ParseIP(ip)) 46 | 47 | dwgd.EventsLog.Printf("%s\n", privkey.PublicKey().String()) 48 | os.Exit(0) 49 | } 50 | 51 | func main() { 52 | if len(os.Args) >= 2 { 53 | switch os.Args[1] { 54 | case "pubkey": 55 | pubkey(os.Args[2:]) 56 | } 57 | } 58 | 59 | signalCh := make(chan os.Signal, 1) 60 | signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM) 61 | 62 | flag.Parse() 63 | 64 | if cfg.Db == "" { 65 | cfg.Db = ":memory:" 66 | } 67 | 68 | if cfg.Verbose { 69 | dwgd.TraceLog.SetOutput(os.Stderr) 70 | } 71 | 72 | version := *versionFlag 73 | if version { 74 | if Version != "" { 75 | dwgd.EventsLog.Println(Version) 76 | } else { 77 | dwgd.EventsLog.Println("(unknown)") 78 | } 79 | os.Exit(0) 80 | } 81 | 82 | dwgd.TraceLog.Printf("Running with the following configuration: %+v\n", cfg) 83 | plugin, err := dwgd.NewDwgd(cfg) 84 | if err != nil { 85 | dwgd.DiagnosticsLog.Fatalf("Couldn't initialize plugin: %s\n", err) 86 | } 87 | err = plugin.Start() 88 | if err != nil { 89 | dwgd.DiagnosticsLog.Fatalf("Couldn't start plugin: %s\n", err) 90 | } 91 | 92 | sig := <-signalCh 93 | dwgd.DiagnosticsLog.Printf("Received signal: %s", sig.String()) 94 | signal.Stop(signalCh) 95 | 96 | err = plugin.Stop() 97 | if err != nil { 98 | dwgd.DiagnosticsLog.Printf("Couldn't stop plugin: %s\n", err) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /commander.go: -------------------------------------------------------------------------------- 1 | package dwgd 2 | 3 | import ( 4 | "io/fs" 5 | "os" 6 | "os/exec" 7 | ) 8 | 9 | // commander abstracts the os and os/exec stdlib packages. 10 | // This is needed to mock in unit tests. 11 | type commander interface { 12 | // os 13 | Chmod(name string, mode fs.FileMode) error 14 | MkdirAll(name string, perm fs.FileMode) error 15 | ReadFile(name string) ([]byte, error) 16 | ReadDir(name string) ([]fs.DirEntry, error) 17 | Remove(name string) error 18 | Symlink(oldname string, newname string) error 19 | // os/exec 20 | LookPath(file string) (string, error) 21 | Run(name string, arg ...string) error 22 | } 23 | 24 | type execCommander struct{} 25 | 26 | func (e *execCommander) Chmod(name string, mode fs.FileMode) error { 27 | return os.Chmod(name, mode) 28 | } 29 | 30 | func (e *execCommander) MkdirAll(path string, perm fs.FileMode) error { 31 | return os.MkdirAll(path, perm) 32 | } 33 | 34 | func (e *execCommander) ReadDir(name string) ([]fs.DirEntry, error) { 35 | return os.ReadDir(name) 36 | } 37 | 38 | func (e *execCommander) ReadFile(name string) ([]byte, error) { 39 | return os.ReadFile(name) 40 | } 41 | 42 | func (e *execCommander) Remove(name string) error { 43 | return os.Remove(name) 44 | } 45 | 46 | func (e *execCommander) Symlink(oldname string, newname string) error { 47 | return os.Symlink(oldname, newname) 48 | } 49 | 50 | func (e *execCommander) LookPath(file string) (string, error) { 51 | return exec.LookPath(file) 52 | } 53 | 54 | func (e *execCommander) Run(name string, arg ...string) error { 55 | cmd := exec.Command(name, arg...) 56 | return cmd.Run() 57 | } 58 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package dwgd 2 | 3 | // A Config represents the configuration of an instance of a dwgd driver. 4 | type Config struct { 5 | Db string // path to the database 6 | Verbose bool // whether to print debug logs or not 7 | Rootless bool // whether to run in rootless compatibility mode or not 8 | } 9 | 10 | func NewConfig() *Config { 11 | return &Config{ 12 | Db: "/var/lib/dwgd.db", 13 | Verbose: false, 14 | Rootless: true, 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /development/README.md: -------------------------------------------------------------------------------- 1 | # Development 2 | 3 | This folder contains some utilities that can aid the development of `dwgd`: 4 | - `e2e-tests/` contains tests that check the whole `dwgd` lifecycle by creating 5 | the necessary resources (WireGuard interface and docker network), sending a ping 6 | from the container and finally removing everything. The tests can be run like: 7 | `sudo ./development/e2e-tests/test_ifname_mode.sh` and everything should be `OK`. 8 | - `Vagrantfile` a simple Vagrant box that has everything it's needed to run the 9 | `e2e-tests`. 10 | 11 | ## Developing on local machine 12 | 13 | You can develop on your own machine by compiling `dwgd`, creating a WireGuard network and starting `dwgd`: 14 | 15 | ```sh 16 | go build ./cmd/dwgd.go 17 | # create server keys 18 | SERVER_PRIVATE_KEY=$(wg genkey) 19 | SERVER_PUBLIC_KEY=$(echo $SERVER_PRIVATE_KEY | wg pubkey) 20 | # create new dwgd0 wireguard interface 21 | sudo ip link add dwgd0 type wireguard 22 | echo $SERVER_PRIVATE_KEY | sudo wg set dwgd0 private-key /dev/fd/0 listen-port 51820 23 | sudo ip address add 10.0.0.1/24 dev dwgd0 24 | # bring interface up 25 | sudo ip link set up dev dwgd0 26 | # generate your container's public key with a specific seed 27 | CLIENT_PUBLIC_KEY=$(./dwgd pubkey -i 10.0.0.2 -s supersecretseed) 28 | sudo wg set dwgd0 peer $CLIENT_PUBLIC_KEY allowed-ips 10.0.0.2/32 29 | # run dwgd driver 30 | sudo ./dwgd -v & 31 | # create docker network with the previously set server public key and seed 32 | docker network create --driver=dwgd -o dwgd.endpoint=localhost:51820 -o dwgd.seed=supersecretseed -o dwgd.pubkey=$SERVER_PUBLIC_KEY --subnet="10.0.0.0/24" --gateway=10.0.0.1 dwgd-net 33 | # run your container 34 | docker run -it --rm --network=dwgd-net --ip=10.0.0.2 busybox 35 | ``` -------------------------------------------------------------------------------- /development/Vagrantfile: -------------------------------------------------------------------------------- 1 | # -*- mode: ruby -*- 2 | # vi: set ft=ruby : 3 | 4 | Vagrant.configure("2") do |config| 5 | config.vm.box = "debian/bookworm64" 6 | 7 | config.vm.hostname = "dwgd-box" 8 | 9 | config.vm.provision "shell", inline: <<-SHELL 10 | apt-get update 11 | apt-get install -y docker.io wireguard 12 | usermod -aG docker vagrant 13 | SHELL 14 | end 15 | -------------------------------------------------------------------------------- /development/e2e-tests/common.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # from: https://stackoverflow.com/questions/5947742/how-to-change-the-output-color-of-echo-in-linux 4 | RED='\033[0;31m' 5 | GREEN='\033[0;32m' 6 | NC='\033[0m' # No Color 7 | 8 | # from: https://tldp.org/LDP/abs/html/debugging.html#ASSERT 9 | assert() { 10 | E_PARAM_ERR=98 11 | E_ASSERT_FAILED=99 12 | 13 | if [ -z "$2" ]; then 14 | return $E_PARAM_ERR 15 | fi 16 | 17 | assertion=$1 18 | error_message=$2 19 | 20 | if [ ! $assertion ]; then 21 | echo -e "${RED}KO${NC} $error_message" 22 | exit $E_ASSERT_FAILED 23 | else 24 | echo -e "${GREEN}OK${NC}" 25 | return 26 | fi 27 | } 28 | 29 | # from: https://cedwards.xyz/defer-for-shell/ 30 | DEFER= 31 | defer() { 32 | DEFER="$*; ${DEFER}" 33 | trap "{ $DEFER }" EXIT 34 | } 35 | 36 | # from: https://stackoverflow.com/questions/59895/how-do-i-get-the-directory-where-a-bash-script-is-located-from-within-the-script 37 | SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) 38 | 39 | REPO_DIR=${SCRIPT_DIR}/../.. 40 | DWGD=${REPO_DIR}/dwgd 41 | TEST_DIR=/tmp/dwgd/run 42 | RESULT_DIR=/tmp/dwgd/results 43 | START_DATE=$(date +%+4Y%m%d_%H%M%S) 44 | 45 | #******* 46 | # dwgd specific definitions 47 | #******* 48 | DWGD_DB_FILE=${TEST_DIR}/dwgd.db 49 | DWGD_PID_FILE=${TEST_DIR}/dwgd.pid 50 | DWGD_STDOUT_FILE=${TEST_DIR}/dwgd.stdout 51 | DWGD_STDERR_FILE=${TEST_DIR}/dwgd.stderr 52 | 53 | # from: https://stackoverflow.com/questions/692000/how-do-i-write-standard-error-to-a-file-while-using-tee-with-a-pipe 54 | dup_stds_to_test_env() { 55 | exec 1> >(tee $TEST_DIR/stdout.out) 2> >(tee $TEST_DIR/stderr.out >&2) 56 | } 57 | 58 | setup_test_env() { 59 | rm -rf $TEST_DIR 60 | mkdir -p $TEST_DIR 61 | dup_stds_to_test_env 62 | systemctl stop dwgd 63 | $DWGD -v -d $DWGD_DB_FILE >$DWGD_STDOUT_FILE 2>$DWGD_STDERR_FILE & 64 | pid=$! 65 | echo $pid >$DWGD_PID_FILE 66 | } 67 | 68 | teardown_test_env() { 69 | kill $(cat ${DWGD_PID_FILE}) 70 | mkdir -p $RESULT_DIR/run-$START_DATE 71 | cp -r $TEST_DIR/* $RESULT_DIR/run-$START_DATE 72 | rm -r $TEST_DIR 73 | } 74 | 75 | #******* 76 | # network to which the container will connect to 77 | #******* 78 | NETWORK_IFNAME=dwgd0 79 | NETWORK_PRIVATE_KEY=$(wg genkey) 80 | NETWORK_PUBLIC_KEY=$(echo $NETWORK_PRIVATE_KEY | wg pubkey) 81 | NETWORK_LISTEN_PORT=51820 82 | NETWORK_IP="10.0.0.1" 83 | NETWORK_CIDR="24" 84 | NETWORK_SEED="supersecretseed" 85 | 86 | create_wireguard_interface() { 87 | ip link add name $NETWORK_IFNAME type wireguard 88 | echo $NETWORK_PRIVATE_KEY | wg set $NETWORK_IFNAME \ 89 | private-key /dev/fd/0 \ 90 | listen-port $NETWORK_LISTEN_PORT 91 | ip address add $NETWORK_IP/$NETWORK_CIDR dev $NETWORK_IFNAME 92 | ip link set up dev $NETWORK_IFNAME 93 | } 94 | 95 | remove_wireguard_interface() { 96 | ip link del dev $NETWORK_IFNAME 97 | } 98 | 99 | #******* 100 | # client and docker definitions 101 | #******* 102 | CLIENT_IP="10.0.0.2" 103 | CLIENT_PUBLIC_KEY=$(${DWGD} pubkey -i ${CLIENT_IP} -s ${NETWORK_SEED}) 104 | DOCKER_NETWORK_NAME="dwgd-net" 105 | -------------------------------------------------------------------------------- /development/e2e-tests/test_ifname_mode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) 4 | 5 | source ${SCRIPT_DIR}/common.sh 6 | 7 | setup_test_env 8 | defer teardown_test_env 9 | 10 | create_wireguard_interface 11 | defer remove_wireguard_interface 12 | 13 | docker network create \ 14 | --driver=dwgd \ 15 | -o dwgd.ifname=$NETWORK_IFNAME \ 16 | -o dwgd.seed=$NETWORK_SEED \ 17 | --subnet="${NETWORK_IP}/${NETWORK_CIDR}" \ 18 | --gateway=$NETWORK_IP \ 19 | dwgd-net 20 | assert "$? -eq 0" "Could not create network ${DOCKER_NETWORK_NAME}" 21 | 22 | docker run \ 23 | -it \ 24 | --rm \ 25 | --network=$DOCKER_NETWORK_NAME \ 26 | --ip=$CLIENT_IP \ 27 | busybox \ 28 | ping -c 3 $NETWORK_IP 29 | assert "$? -eq 0" "Could not ping ${NETWORK_IP}" 30 | 31 | docker network rm $DOCKER_NETWORK_NAME 32 | assert "$? -eq 0" "Could not remove network ${DOCKER_NETWORK_NAME}" 33 | -------------------------------------------------------------------------------- /development/e2e-tests/test_pubkey_mode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) 4 | 5 | source ${SCRIPT_DIR}/common.sh 6 | 7 | setup_test_env 8 | defer teardown_test_env 9 | 10 | create_wireguard_interface 11 | defer remove_wireguard_interface 12 | 13 | wg set $NETWORK_IFNAME peer $CLIENT_PUBLIC_KEY allowed-ips $CLIENT_IP/32 14 | assert "$? -eq 0" "Could not create wireguard interface ${NETWORK_IFNAME}" 15 | 16 | docker network create \ 17 | --driver=dwgd \ 18 | -o dwgd.pubkey=$NETWORK_PUBLIC_KEY \ 19 | -o dwgd.endpoint=localhost:$NETWORK_LISTEN_PORT \ 20 | -o dwgd.seed=$NETWORK_SEED \ 21 | --subnet="${NETWORK_IP}/${NETWORK_CIDR}" \ 22 | --gateway=$NETWORK_IP \ 23 | dwgd-net 24 | assert "$? -eq 0" "Could not create network ${DOCKER_NETWORK_NAME}" 25 | 26 | docker run \ 27 | -it \ 28 | --rm \ 29 | --network=$DOCKER_NETWORK_NAME \ 30 | --ip=$CLIENT_IP \ 31 | busybox \ 32 | ping -c 3 $NETWORK_IP 33 | assert "$? -eq 0" "Could not ping ${NETWORK_IP}" 34 | 35 | docker network rm $DOCKER_NETWORK_NAME 36 | assert "$? -eq 0" "Could not remove network ${DOCKER_NETWORK_NAME}" 37 | -------------------------------------------------------------------------------- /driver.go: -------------------------------------------------------------------------------- 1 | package dwgd 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "os" 8 | 9 | "github.com/docker/go-plugins-helpers/network" 10 | _ "github.com/mattn/go-sqlite3" 11 | "golang.zx2c4.com/wireguard/wgctrl" 12 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 13 | ) 14 | 15 | type wgController interface { 16 | Device(name string) (*wgtypes.Device, error) 17 | ConfigureDevice(name string, cfg wgtypes.Config) error 18 | } 19 | 20 | // Docker WireGuard Driver 21 | type Driver struct { 22 | network.Driver 23 | 24 | c commander 25 | wgc wgController 26 | s *Storage 27 | } 28 | 29 | func NewDriver(dbPath string, c commander, wgc wgController) (*Driver, error) { 30 | if c == nil { 31 | c = &execCommander{} 32 | } 33 | 34 | path, err := c.LookPath("ip") 35 | if err != nil { 36 | TraceLog.Printf("Couldn't find 'ip' utility: %s", err) 37 | } else { 38 | TraceLog.Printf("Using 'ip' utility at the following path: %s", path) 39 | } 40 | 41 | if wgc == nil { 42 | wgc, err = wgctrl.New() 43 | if err != nil { 44 | return nil, err 45 | } 46 | } 47 | 48 | s := &Storage{} 49 | err = s.Open(dbPath) 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | return &Driver{ 55 | c: c, 56 | wgc: wgc, 57 | s: s, 58 | }, nil 59 | } 60 | 61 | func (d *Driver) Close() error { 62 | return d.s.Close() 63 | } 64 | 65 | func (d *Driver) GetCapabilities() (*network.CapabilitiesResponse, error) { 66 | TraceLog.Printf("GetCapabilities\n") 67 | return &network.CapabilitiesResponse{Scope: network.LocalScope, ConnectivityScope: network.LocalScope}, nil 68 | } 69 | 70 | func (d *Driver) CreateNetwork(r *network.CreateNetworkRequest) error { 71 | TraceLog.Printf("CreateNetwork: %+v\n", Jsonify(r)) 72 | var err error 73 | 74 | n := &Network{} 75 | m := r.Options["com.docker.network.generic"].(map[string]interface{}) 76 | 77 | // The following two ifs are used to discern whether we are working in 78 | // ifname mode or pubkey mode. 79 | // By default we expect to work in pubkey mode, which is why if the ifname 80 | // parameter is not present we do not return an error. 81 | var iface *wgtypes.Device 82 | ifname, ok := m["dwgd.ifname"].(string) 83 | if !ok { 84 | n.ifname = "" 85 | } else { 86 | iface, err = d.wgc.Device(ifname) 87 | if errors.Is(err, os.ErrNotExist) { 88 | TraceLog.Printf("Interface %s not recognized\n", ifname) 89 | return err 90 | } 91 | TraceLog.Printf("Using %s as the WireGuard server interface\n", iface.Name) 92 | n.ifname = iface.Name 93 | } 94 | 95 | if iface != nil { 96 | n.pubkey = iface.PublicKey 97 | } else { 98 | payload, ok := m["dwgd.pubkey"].(string) 99 | if !ok { 100 | return fmt.Errorf("dwgd.pubkey option missing") 101 | } 102 | n.pubkey, err = wgtypes.ParseKey(payload) 103 | if err != nil { 104 | return err 105 | } 106 | } 107 | 108 | // From this point on we get all the other parameters needed for both modes. 109 | endpoint, ok := m["dwgd.endpoint"].(string) 110 | if !ok { 111 | if iface != nil { 112 | endpoint = fmt.Sprintf("localhost:%d", iface.ListenPort) 113 | } else { 114 | return fmt.Errorf("dwgd.endpoint option missing") 115 | } 116 | } 117 | n.endpoint, err = net.ResolveUDPAddr("udp", endpoint) 118 | if err != nil { 119 | return err 120 | } 121 | 122 | seed, ok := m["dwgd.seed"].(string) 123 | if !ok { 124 | return fmt.Errorf("dwgd.seed option missing") 125 | } 126 | n.seed = []byte(seed) 127 | 128 | route, ok := m["dwgd.route"].(string) 129 | if !ok { 130 | route = "" 131 | } 132 | n.route = route 133 | 134 | n.id = r.NetworkID 135 | return d.s.AddNetwork(n) 136 | } 137 | 138 | func (d *Driver) DeleteNetwork(r *network.DeleteNetworkRequest) error { 139 | TraceLog.Printf("DeleteNetwork: %+v\n", Jsonify(r)) 140 | return d.s.RemoveNetwork(r.NetworkID) 141 | } 142 | 143 | func (d *Driver) CreateEndpoint(r *network.CreateEndpointRequest) (*network.CreateEndpointResponse, error) { 144 | TraceLog.Printf("CreateEndpoint: %+v\n", Jsonify(r)) 145 | 146 | n, err := d.s.GetNetwork(r.NetworkID) 147 | if err != nil { 148 | return nil, err 149 | } 150 | if n == nil { 151 | return nil, fmt.Errorf("NetworkID %s not found", r.NetworkID) 152 | } 153 | 154 | c, err := d.s.GetClient(r.EndpointID) 155 | if err != nil { 156 | return nil, err 157 | } 158 | if c != nil { 159 | return nil, fmt.Errorf("EndpointID %s already exists", r.EndpointID) 160 | } 161 | 162 | ip, _, err := net.ParseCIDR(r.Interface.Address) 163 | if err != nil { 164 | return nil, err 165 | } 166 | 167 | endpointIdMaxLen := 12 168 | if len(r.EndpointID) < 12 { 169 | endpointIdMaxLen = len(r.EndpointID) 170 | } 171 | c = &Client{ 172 | id: r.EndpointID, 173 | ip: ip, 174 | ifname: "wg-" + r.EndpointID[:endpointIdMaxLen], 175 | network: n, 176 | } 177 | 178 | err = d.s.AddClient(c) 179 | if err != nil { 180 | return nil, err 181 | } 182 | 183 | return &network.CreateEndpointResponse{}, nil 184 | } 185 | 186 | func (d *Driver) DeleteEndpoint(r *network.DeleteEndpointRequest) error { 187 | TraceLog.Printf("DeleteEndpoint: %+v\n", Jsonify(r)) 188 | c, err := d.s.GetClient(r.EndpointID) 189 | if err != nil { 190 | return err 191 | } 192 | if c == nil { 193 | return fmt.Errorf("EndpointID %s not found", r.EndpointID) 194 | } 195 | 196 | if err := d.c.Run("ip", "link", "delete", c.ifname); err != nil { 197 | return err 198 | } 199 | 200 | return d.s.RemoveClient(r.EndpointID) 201 | } 202 | 203 | func (d *Driver) EndpointInfo(r *network.InfoRequest) (*network.InfoResponse, error) { 204 | TraceLog.Printf("EndpointInfo: %+v\n", Jsonify(r)) 205 | return &network.InfoResponse{Value: make(map[string]string)}, nil 206 | } 207 | 208 | func (d *Driver) Join(r *network.JoinRequest) (*network.JoinResponse, error) { 209 | TraceLog.Printf("Join: %+v\n", Jsonify(r)) 210 | 211 | c, err := d.s.GetClient(r.EndpointID) 212 | if err != nil { 213 | return nil, err 214 | } 215 | if c == nil { 216 | return nil, fmt.Errorf("EndpointID %s not found", r.EndpointID) 217 | } 218 | 219 | if err := d.c.Run("ip", "link", "add", "name", c.ifname, "type", "wireguard"); err != nil { 220 | return nil, err 221 | } 222 | 223 | cfg := c.Config() 224 | 225 | err = d.wgc.ConfigureDevice(c.ifname, cfg) 226 | if err != nil { 227 | return nil, err 228 | } 229 | 230 | if c.network.ifname != "" { 231 | TraceLog.Printf("Adding peer to: %s\n", c.network.ifname) 232 | iface, err := d.wgc.Device(c.network.ifname) 233 | if err != nil { 234 | return nil, err 235 | } 236 | 237 | peers := make([]wgtypes.PeerConfig, 1) 238 | peers[0] = c.PeerConfig() 239 | 240 | newNetworkIfaceCfg := wgtypes.Config{ 241 | PrivateKey: &iface.PrivateKey, 242 | ListenPort: &iface.ListenPort, 243 | FirewallMark: &iface.FirewallMark, 244 | ReplacePeers: false, 245 | Peers: peers, 246 | } 247 | TraceLog.Printf("Updating configuration for %s:\n%+v\n", iface.Name, newNetworkIfaceCfg) 248 | 249 | err = d.wgc.ConfigureDevice(iface.Name, newNetworkIfaceCfg) 250 | if err != nil { 251 | return nil, err 252 | } 253 | } 254 | 255 | err = moveToRootlessNamespaceIfNecessary(d.c, r.SandboxKey, c.ifname) 256 | if err != nil { 257 | return nil, err 258 | } 259 | 260 | staticRoutes := make([]*network.StaticRoute, 0) 261 | if c.network.route != "" { 262 | staticRoutes = append(staticRoutes, &network.StaticRoute{ 263 | Destination: c.network.route, 264 | RouteType: 1, 265 | }) 266 | } 267 | 268 | return &network.JoinResponse{ 269 | InterfaceName: network.InterfaceName{ 270 | SrcName: c.ifname, 271 | DstPrefix: "wg", 272 | }, 273 | StaticRoutes: staticRoutes, 274 | DisableGatewayService: true, 275 | }, nil 276 | } 277 | 278 | func (d *Driver) Leave(r *network.LeaveRequest) error { 279 | TraceLog.Printf("Leave: %+v\n", Jsonify(r)) 280 | 281 | c, err := d.s.GetClient(r.EndpointID) 282 | if err != nil { 283 | return err 284 | } 285 | if c == nil { 286 | return fmt.Errorf("EndpointID %s not found", r.EndpointID) 287 | } 288 | 289 | if c.network.ifname != "" { 290 | TraceLog.Printf("Removing peer from: %s\n", c.network.ifname) 291 | iface, err := d.wgc.Device(c.network.ifname) 292 | if err != nil { 293 | return err 294 | } 295 | 296 | peers := make([]wgtypes.PeerConfig, 1) 297 | clientPeer := c.PeerConfig() 298 | clientPeer.Remove = true 299 | peers[0] = clientPeer 300 | 301 | newNetworkIfaceCfg := wgtypes.Config{ 302 | PrivateKey: &iface.PrivateKey, 303 | ListenPort: &iface.ListenPort, 304 | FirewallMark: &iface.FirewallMark, 305 | ReplacePeers: false, 306 | Peers: peers, 307 | } 308 | TraceLog.Printf("Updating configuration for %s:\n%+v\n", iface.Name, Jsonify(newNetworkIfaceCfg)) 309 | 310 | err = d.wgc.ConfigureDevice(iface.Name, newNetworkIfaceCfg) 311 | if err != nil { 312 | return err 313 | } 314 | } 315 | 316 | return nil 317 | } 318 | -------------------------------------------------------------------------------- /driver_test.go: -------------------------------------------------------------------------------- 1 | package dwgd 2 | 3 | import ( 4 | "fmt" 5 | "io/fs" 6 | "testing" 7 | 8 | "github.com/docker/go-plugins-helpers/network" 9 | "github.com/google/go-cmp/cmp" 10 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 11 | ) 12 | 13 | func DeviceFixture() *wgtypes.Device { 14 | network := NetworkFixture() 15 | return &wgtypes.Device{ 16 | Name: "dwgd0", 17 | ListenPort: network.endpoint.Port, 18 | PublicKey: network.pubkey, 19 | } 20 | } 21 | 22 | type testCommander struct { 23 | ChmodFunc func(name string, mode fs.FileMode) error 24 | MkdirAllFunc func(path string, perm fs.FileMode) error 25 | ReadDirFunc func(name string) ([]fs.DirEntry, error) 26 | ReadFileFunc func(name string) ([]byte, error) 27 | RemoveFunc func(name string) error 28 | SymlinkFunc func(oldname string, newname string) error 29 | LookPathFunc func(file string) (string, error) 30 | RunFunc func(name string, arg ...string) error 31 | RunHistory [][]string 32 | } 33 | 34 | func (t *testCommander) Chmod(name string, mode fs.FileMode) error { 35 | return t.ChmodFunc(name, mode) 36 | } 37 | 38 | func (t *testCommander) MkdirAll(path string, perm fs.FileMode) error { 39 | return t.MkdirAllFunc(path, perm) 40 | } 41 | 42 | func (t *testCommander) ReadDir(name string) ([]fs.DirEntry, error) { 43 | return t.ReadDirFunc(name) 44 | } 45 | 46 | func (t *testCommander) ReadFile(name string) ([]byte, error) { 47 | return t.ReadFileFunc(name) 48 | } 49 | 50 | func (t *testCommander) Remove(name string) error { 51 | return t.RemoveFunc(name) 52 | } 53 | 54 | func (t *testCommander) Symlink(oldname string, newname string) error { 55 | return t.SymlinkFunc(oldname, newname) 56 | } 57 | 58 | func (t *testCommander) LookPath(file string) (string, error) { 59 | return t.LookPathFunc(file) 60 | } 61 | 62 | func (t *testCommander) Run(name string, arg ...string) error { 63 | return t.RunFunc(name, arg...) 64 | } 65 | 66 | func CommanderFixture() *testCommander { 67 | t := &testCommander{} 68 | t.ChmodFunc = func(name string, mode fs.FileMode) error { 69 | return nil 70 | } 71 | t.MkdirAllFunc = func(path string, perm fs.FileMode) error { 72 | return nil 73 | } 74 | t.ReadDirFunc = func(name string) ([]fs.DirEntry, error) { 75 | return []fs.DirEntry{}, nil 76 | } 77 | t.ReadFileFunc = func(name string) ([]byte, error) { 78 | return []byte{}, nil 79 | } 80 | t.RemoveFunc = func(name string) error { 81 | return nil 82 | } 83 | t.SymlinkFunc = func(oldname, newname string) error { 84 | return nil 85 | } 86 | t.LookPathFunc = func(file string) (string, error) { 87 | return file, nil 88 | } 89 | t.RunFunc = func(name string, arg ...string) error { 90 | if t.RunHistory == nil { 91 | t.RunHistory = make([][]string, 0) 92 | } 93 | fullCommand := append([]string{name}, arg...) 94 | t.RunHistory = append(t.RunHistory, fullCommand) 95 | return nil 96 | } 97 | return t 98 | } 99 | 100 | type testWgController struct { 101 | ConfigureDeviceFunc func(name string, cfg wgtypes.Config) error 102 | DeviceFunc func(name string) (*wgtypes.Device, error) 103 | Devices map[string]*wgtypes.Device 104 | } 105 | 106 | // ConfigureDevice implements wgController. 107 | func (t *testWgController) ConfigureDevice(name string, cfg wgtypes.Config) error { 108 | return t.ConfigureDeviceFunc(name, cfg) 109 | } 110 | 111 | // Device implements wgController. 112 | func (t *testWgController) Device(name string) (*wgtypes.Device, error) { 113 | return t.DeviceFunc(name) 114 | } 115 | 116 | func WgControllerFixture() *testWgController { 117 | wgc := &testWgController{ 118 | Devices: make(map[string]*wgtypes.Device), 119 | } 120 | 121 | wgc.ConfigureDeviceFunc = func(name string, cfg wgtypes.Config) error { 122 | return nil 123 | } 124 | wgc.DeviceFunc = func(name string) (*wgtypes.Device, error) { 125 | df := DeviceFixture() 126 | if name != df.Name { 127 | return nil, fmt.Errorf("device %s does not exist", name) 128 | } 129 | return df, nil 130 | } 131 | return wgc 132 | } 133 | 134 | func TestDriver(t *testing.T) { 135 | d, err := NewDriver(DbPathFixture(), CommanderFixture(), WgControllerFixture()) 136 | if err != nil { 137 | t.Fatal(err) 138 | } 139 | err = d.Close() 140 | if err != nil { 141 | t.Fatal(err) 142 | } 143 | } 144 | 145 | func MustCreateNetwork(t *testing.T, d *Driver, ifnameMode bool) *Network { 146 | net := NetworkFixture() 147 | options := map[string]interface{}{ 148 | "dwgd.seed": string(net.seed), 149 | "dwgd.endpoint": net.endpoint.String(), 150 | "dwgd.route": net.route, 151 | } 152 | if ifnameMode { 153 | options["dwgd.ifname"] = net.ifname 154 | } else { 155 | options["dwgd.pubkey"] = net.pubkey.String() 156 | net.ifname = "" 157 | } 158 | 159 | err := d.CreateNetwork(&network.CreateNetworkRequest{ 160 | NetworkID: net.id, 161 | Options: map[string]interface{}{ 162 | "com.docker.network.generic": options, 163 | }, 164 | }) 165 | if err != nil { 166 | t.Fatal(err) 167 | } 168 | 169 | other, err := d.s.GetNetwork(net.id) 170 | if err != nil { 171 | t.Fatal(err) 172 | } 173 | if !cmp.Equal(net, other, cmp.AllowUnexported(Network{})) { 174 | t.Fatalf("mismatch: %#v != %#v", net, other) 175 | } 176 | 177 | return other 178 | } 179 | 180 | func MustCreateEndpoint(t *testing.T, d *Driver) *Client { 181 | net := NetworkFixture() 182 | client := ClientFixture(net) 183 | _, err := d.CreateEndpoint(&network.CreateEndpointRequest{ 184 | NetworkID: net.id, 185 | EndpointID: client.id, 186 | Interface: &network.EndpointInterface{ 187 | Address: fmt.Sprintf("%s/32", client.ip.String()), 188 | }, 189 | }) 190 | if err != nil { 191 | t.Fatal(err) 192 | } 193 | 194 | other, err := d.s.GetClient(client.id) 195 | if err != nil { 196 | t.Fatal(err) 197 | } 198 | if !cmp.Equal(client, other, cmp.AllowUnexported(Network{}), cmp.AllowUnexported(Client{})) { 199 | t.Fatalf("mismatch: %#v != %#v", net, other) 200 | } 201 | 202 | return other 203 | } 204 | 205 | func TestDriver_CreateNetwork(t *testing.T) { 206 | t.Run("ifname mode", func(t *testing.T) { 207 | d, err := NewDriver(DbPathFixture(), CommanderFixture(), WgControllerFixture()) 208 | if err != nil { 209 | t.Fatal(err) 210 | } 211 | defer d.Close() 212 | MustCreateNetwork(t, d, true) 213 | }) 214 | 215 | t.Run("pubkey mode", func(t *testing.T) { 216 | d, err := NewDriver(DbPathFixture(), CommanderFixture(), WgControllerFixture()) 217 | if err != nil { 218 | t.Fatal(err) 219 | } 220 | defer d.Close() 221 | MustCreateNetwork(t, d, false) 222 | }) 223 | } 224 | 225 | func TestDriver_DeleteNetwork(t *testing.T) { 226 | t.Run("ifname mode", func(t *testing.T) { 227 | d, err := NewDriver(DbPathFixture(), CommanderFixture(), WgControllerFixture()) 228 | if err != nil { 229 | t.Fatal(err) 230 | } 231 | defer d.Close() 232 | 233 | net := MustCreateNetwork(t, d, true) 234 | 235 | err = d.DeleteNetwork(&network.DeleteNetworkRequest{ 236 | NetworkID: net.id, 237 | }) 238 | if err != nil { 239 | t.Fatal(err) 240 | } 241 | 242 | n, err := d.s.GetNetwork(net.id) 243 | if err != nil { 244 | t.Fatal(err) 245 | } 246 | if n != nil { 247 | t.Fatalf("mismatch: nil != %#v", n) 248 | } 249 | }) 250 | 251 | t.Run("pubkey mode", func(t *testing.T) { 252 | d, err := NewDriver(DbPathFixture(), CommanderFixture(), WgControllerFixture()) 253 | if err != nil { 254 | t.Fatal(err) 255 | } 256 | defer d.Close() 257 | 258 | net := MustCreateNetwork(t, d, false) 259 | 260 | err = d.DeleteNetwork(&network.DeleteNetworkRequest{ 261 | NetworkID: net.id, 262 | }) 263 | if err != nil { 264 | t.Fatal(err) 265 | } 266 | 267 | n, err := d.s.GetNetwork(net.id) 268 | if err != nil { 269 | t.Fatal(err) 270 | } 271 | if n != nil { 272 | t.Fatalf("mismatch: nil != %#v", n) 273 | } 274 | }) 275 | } 276 | 277 | func TestDriver_CreateEndpoint(t *testing.T) { 278 | d, err := NewDriver(DbPathFixture(), CommanderFixture(), WgControllerFixture()) 279 | if err != nil { 280 | t.Fatal(err) 281 | } 282 | 283 | MustCreateNetwork(t, d, true) 284 | MustCreateEndpoint(t, d) 285 | } 286 | 287 | func TestDriver_DeleteEndpoint(t *testing.T) { 288 | d, err := NewDriver(DbPathFixture(), CommanderFixture(), WgControllerFixture()) 289 | if err != nil { 290 | t.Fatal(err) 291 | } 292 | 293 | net := MustCreateNetwork(t, d, true) 294 | client := MustCreateEndpoint(t, d) 295 | 296 | err = d.DeleteEndpoint(&network.DeleteEndpointRequest{ 297 | NetworkID: net.id, 298 | EndpointID: client.id, 299 | }) 300 | if err != nil { 301 | t.Fatal(err) 302 | } 303 | 304 | other, err := d.s.GetClient(client.id) 305 | if err != nil { 306 | t.Fatal(err) 307 | } 308 | if other != nil { 309 | t.Fatalf("mismatch: nil != %#v", other) 310 | } 311 | } 312 | 313 | func TestDriver_Join(t *testing.T) { 314 | t.Run("non rootless", func(t *testing.T) { 315 | tc := CommanderFixture() 316 | d, err := NewDriver(DbPathFixture(), tc, WgControllerFixture()) 317 | if err != nil { 318 | t.Fatal(err) 319 | } 320 | 321 | net := MustCreateNetwork(t, d, true) 322 | client := MustCreateEndpoint(t, d) 323 | 324 | _, err = d.Join(&network.JoinRequest{ 325 | NetworkID: net.id, 326 | EndpointID: client.id, 327 | SandboxKey: "/foo/bar", 328 | }) 329 | if err != nil { 330 | t.Fatal(err) 331 | } 332 | 333 | expectedHistory := [][]string{ 334 | {"ip", "link", "add", "name", client.ifname, "type", "wireguard"}, 335 | } 336 | if !cmp.Equal(tc.RunHistory, expectedHistory) { 337 | t.Fatalf("mismatch: %#v != %#v", tc.RunHistory, expectedHistory) 338 | } 339 | }) 340 | 341 | t.Run("rootless", func(t *testing.T) { 342 | tc := CommanderFixture() 343 | tc.ReadFileFunc = func(name string) ([]byte, error) { return []byte("1000"), nil } 344 | 345 | d, err := NewDriver(DbPathFixture(), tc, WgControllerFixture()) 346 | if err != nil { 347 | t.Fatal(err) 348 | } 349 | 350 | net := MustCreateNetwork(t, d, true) 351 | client := MustCreateEndpoint(t, d) 352 | 353 | _, err = d.Join(&network.JoinRequest{ 354 | NetworkID: net.id, 355 | EndpointID: client.id, 356 | SandboxKey: "/run/user/1000", 357 | }) 358 | if err != nil { 359 | t.Fatal(err) 360 | } 361 | 362 | expectedHistory := [][]string{ 363 | {"ip", "link", "add", "name", client.ifname, "type", "wireguard"}, 364 | {"ip", "link", "set", client.ifname, "netns", "1000"}, 365 | } 366 | if !cmp.Equal(tc.RunHistory, expectedHistory) { 367 | t.Fatalf("mismatch: %#v != %#v", tc.RunHistory, expectedHistory) 368 | } 369 | }) 370 | } 371 | 372 | func TestDriver_Leave(t *testing.T) { 373 | tc := CommanderFixture() 374 | wgc := WgControllerFixture() 375 | 376 | d, err := NewDriver(DbPathFixture(), tc, wgc) 377 | if err != nil { 378 | t.Fatal(err) 379 | } 380 | 381 | net := MustCreateNetwork(t, d, true) 382 | client := MustCreateEndpoint(t, d) 383 | 384 | _, err = d.Join(&network.JoinRequest{ 385 | NetworkID: net.id, 386 | EndpointID: client.id, 387 | SandboxKey: "/foo/bar", 388 | }) 389 | if err != nil { 390 | t.Fatal(err) 391 | } 392 | 393 | err = d.Leave(&network.LeaveRequest{ 394 | NetworkID: net.id, 395 | EndpointID: client.id, 396 | }) 397 | if err != nil { 398 | t.Fatal(err) 399 | } 400 | 401 | expectedHistory := [][]string{ 402 | {"ip", "link", "add", "name", client.ifname, "type", "wireguard"}, 403 | } 404 | if !cmp.Equal(tc.RunHistory, expectedHistory) { 405 | t.Fatalf("mismatch: %#v != %#v", tc.RunHistory, expectedHistory) 406 | } 407 | } 408 | -------------------------------------------------------------------------------- /dwgd.go: -------------------------------------------------------------------------------- 1 | package dwgd 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/docker/go-plugins-helpers/network" 7 | ) 8 | 9 | type Dwgd struct { 10 | driver *Driver 11 | handler *network.Handler 12 | listener net.Listener 13 | symlinker *RootlessSymlinker 14 | } 15 | 16 | func NewDwgd(cfg *Config) (*Dwgd, error) { 17 | driver, err := NewDriver(cfg.Db, nil, nil) 18 | if err != nil { 19 | return nil, err 20 | } 21 | 22 | handler := network.NewHandler(driver) 23 | 24 | listener, err := NewUnixListener(nil) 25 | if err != nil { 26 | return nil, err 27 | } 28 | 29 | var symlinker *RootlessSymlinker 30 | if cfg.Rootless { 31 | symlinker, err = NewRootlessSymlinker(nil) 32 | if err != nil { 33 | return nil, err 34 | } 35 | } 36 | 37 | return &Dwgd{ 38 | driver: driver, 39 | handler: handler, 40 | listener: listener, 41 | symlinker: symlinker, 42 | }, nil 43 | } 44 | 45 | func (d *Dwgd) Start() error { 46 | go func() { 47 | err := d.handler.Serve(d.listener) 48 | if err != nil { 49 | TraceLog.Printf("Couldn't serve on unix socket: %s\n", err) 50 | } 51 | }() 52 | 53 | if d.symlinker != nil { 54 | go func() { 55 | err := d.symlinker.Start() 56 | if err != nil { 57 | TraceLog.Printf("Couldn't start symlinker: %s\n", err) 58 | } 59 | }() 60 | } 61 | 62 | return nil 63 | } 64 | 65 | func (d *Dwgd) Stop() error { 66 | TraceLog.Println("Closing driver") 67 | err := d.driver.Close() 68 | if err != nil { 69 | TraceLog.Printf("Error during driver close: %s\n", err) 70 | } 71 | 72 | TraceLog.Println("Closing listener") 73 | err = d.listener.Close() 74 | if err != nil { 75 | TraceLog.Printf("Error during listener close: %s\n", err) 76 | } 77 | 78 | if d.symlinker != nil { 79 | TraceLog.Println("Closing symlinker") 80 | err := d.symlinker.Stop() 81 | if err != nil { 82 | TraceLog.Printf("Error during symlinker close: %s\n", err) 83 | } 84 | } else { 85 | TraceLog.Println("Symlinker not set, skipping closing") 86 | } 87 | 88 | return nil 89 | } 90 | -------------------------------------------------------------------------------- /entities.go: -------------------------------------------------------------------------------- 1 | package dwgd 2 | 3 | import ( 4 | "crypto/sha256" 5 | "database/sql" 6 | "embed" 7 | "errors" 8 | "fmt" 9 | "io/fs" 10 | "net" 11 | "sort" 12 | "time" 13 | 14 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 15 | ) 16 | 17 | type Network struct { 18 | id string 19 | endpoint *net.UDPAddr 20 | seed []byte 21 | pubkey wgtypes.Key 22 | route string 23 | ifname string 24 | } 25 | 26 | func (n *Network) PeerConfig() wgtypes.PeerConfig { 27 | keepalive := 25 * time.Second 28 | 29 | _, ipnet, _ := net.ParseCIDR("0.0.0.0/0") 30 | allowedIPs := []net.IPNet{*ipnet} 31 | 32 | return wgtypes.PeerConfig{ 33 | Endpoint: n.endpoint, 34 | PublicKey: n.pubkey, 35 | PersistentKeepaliveInterval: &keepalive, 36 | AllowedIPs: allowedIPs, 37 | ReplaceAllowedIPs: true, 38 | } 39 | } 40 | 41 | type Client struct { 42 | id string 43 | ip net.IP 44 | ifname string 45 | network *Network 46 | } 47 | 48 | func (c *Client) Config() wgtypes.Config { 49 | privkey := GeneratePrivateKey(c.network.seed, c.ip) 50 | 51 | peers := make([]wgtypes.PeerConfig, 1) 52 | peers[0] = c.network.PeerConfig() 53 | 54 | return wgtypes.Config{ 55 | PrivateKey: privkey, 56 | Peers: peers, 57 | } 58 | } 59 | 60 | func (c *Client) PeerConfig() wgtypes.PeerConfig { 61 | keepalive := 25 * time.Second 62 | 63 | ipnet := net.IPNet{ 64 | IP: c.ip, 65 | Mask: []byte{255, 255, 255, 255}, 66 | } 67 | allowedIPs := []net.IPNet{ipnet} 68 | 69 | privkey := GeneratePrivateKey(c.network.seed, c.ip) 70 | 71 | return wgtypes.PeerConfig{ 72 | PublicKey: privkey.PublicKey(), 73 | Remove: false, 74 | UpdateOnly: false, 75 | PresharedKey: nil, 76 | Endpoint: nil, 77 | PersistentKeepaliveInterval: &keepalive, 78 | ReplaceAllowedIPs: true, 79 | AllowedIPs: allowedIPs, 80 | } 81 | } 82 | 83 | func GeneratePrivateKey(seed []byte, ip net.IP) *wgtypes.Key { 84 | h := sha256.New() 85 | h.Write(seed) 86 | h.Write(ip) 87 | 88 | // since the size of a SHA256 checksum is 32 bytes by default, 89 | // wgtypes.NewKey cannot return error 90 | priv, _ := wgtypes.NewKey(h.Sum(nil)) 91 | 92 | // Modify random bytes using algorithm described at: 93 | // https://cr.yp.to/ecdh.html. 94 | priv[0] &= 248 95 | priv[31] &= 127 96 | priv[31] |= 64 97 | 98 | return &priv 99 | } 100 | 101 | type Storage struct { 102 | db *sql.DB 103 | } 104 | 105 | func (s *Storage) Open(path string) error { 106 | db, err := sql.Open("sqlite3", path) 107 | if err != nil { 108 | return err 109 | } 110 | s.db = db 111 | 112 | // Enable foreign key checks. 113 | if _, err := db.Exec(`PRAGMA foreign_keys = ON;`); err != nil { 114 | return fmt.Errorf("foreign keys pragma: %w", err) 115 | } 116 | 117 | if err := s.migrate(); err != nil { 118 | return fmt.Errorf("migrate: %w", err) 119 | } 120 | 121 | return err 122 | } 123 | 124 | func (s *Storage) Close() error { 125 | return s.db.Close() 126 | } 127 | 128 | //go:embed migrations/*.sql 129 | var migrationFS embed.FS 130 | 131 | // migrate sets up migration tracking and executes pending migration files. 132 | // 133 | // Migration files are embedded in the sqlite/migration folder and are executed 134 | // in lexigraphical order. 135 | // 136 | // Once a migration is run, its name is stored in the 'migrations' table so it 137 | // is not re-executed. Migrations run in a transaction to prevent partial 138 | // migrations. 139 | func (s *Storage) migrate() error { 140 | // Ensure the 'migrations' table exists so we don't duplicate migrations. 141 | if _, err := s.db.Exec(`CREATE TABLE IF NOT EXISTS migrations (name TEXT PRIMARY KEY);`); err != nil { 142 | return fmt.Errorf("cannot create migrations table: %w", err) 143 | } 144 | 145 | // Read migration files from our embedded file system. 146 | // This uses Go 1.16's 'embed' package. 147 | names, err := fs.Glob(migrationFS, "migrations/*.sql") 148 | if err != nil { 149 | return err 150 | } 151 | sort.Strings(names) 152 | 153 | // Loop over all migration files and execute them in order. 154 | for _, name := range names { 155 | if err := s.migrateFile(name); err != nil { 156 | return fmt.Errorf("migration error: name=%q err=%w", name, err) 157 | } 158 | } 159 | return nil 160 | } 161 | 162 | // migrate runs a single migration file within a transaction. On success, the 163 | // migration file name is saved to the "migrations" table to prevent re-running. 164 | func (s *Storage) migrateFile(name string) error { 165 | tx, err := s.db.Begin() 166 | if err != nil { 167 | return err 168 | } 169 | defer tx.Rollback() 170 | 171 | // Ensure migration has not already been run. 172 | var n int 173 | if err := tx.QueryRow(`SELECT COUNT(*) FROM migrations WHERE name = ?`, name).Scan(&n); err != nil { 174 | return err 175 | } else if n != 0 { 176 | return nil // already run migration, skip 177 | } 178 | 179 | // Read and execute migration file. 180 | if buf, err := fs.ReadFile(migrationFS, name); err != nil { 181 | return err 182 | } else if _, err := tx.Exec(string(buf)); err != nil { 183 | return err 184 | } 185 | 186 | // Insert record into migrations to prevent re-running migration. 187 | if _, err := tx.Exec(`INSERT INTO migrations (name) VALUES (?)`, name); err != nil { 188 | return err 189 | } 190 | 191 | return tx.Commit() 192 | } 193 | 194 | func (s *Storage) AddNetwork(n *Network) error { 195 | tx, err := s.db.Begin() 196 | if err != nil { 197 | return err 198 | } 199 | defer tx.Rollback() 200 | 201 | stm, err := s.db.Prepare("INSERT INTO network(id, endpoint, seed, pubkey, route, ifname) VALUES(?, ?, ?, ?, ?, ?)") 202 | if err != nil { 203 | return err 204 | } 205 | defer stm.Close() 206 | 207 | r, err := stm.Exec(n.id, n.endpoint.String(), n.seed, n.pubkey[:], n.route, n.ifname) 208 | if err != nil { 209 | return err 210 | } 211 | 212 | num, err := r.RowsAffected() 213 | if err != nil { 214 | return err 215 | } 216 | if num != 1 { 217 | return fmt.Errorf("number of inserted rows: %d is not 1", num) 218 | } 219 | 220 | return tx.Commit() 221 | } 222 | 223 | func (s *Storage) RemoveNetwork(id string) error { 224 | tx, err := s.db.Begin() 225 | if err != nil { 226 | return err 227 | } 228 | defer tx.Rollback() 229 | 230 | stm, err := tx.Prepare("DELETE FROM network WHERE id = ?") 231 | if err != nil { 232 | return err 233 | } 234 | defer stm.Close() 235 | 236 | r, err := stm.Exec(id) 237 | if err != nil { 238 | return err 239 | } 240 | 241 | num, err := r.RowsAffected() 242 | if err != nil { 243 | return err 244 | } 245 | if num != 1 { 246 | return fmt.Errorf("number of deleted rows: %d is not 1", num) 247 | } 248 | 249 | return tx.Commit() 250 | } 251 | 252 | func (s *Storage) GetNetwork(id string) (*Network, error) { 253 | tx, err := s.db.Begin() 254 | if err != nil { 255 | return nil, err 256 | } 257 | defer tx.Rollback() 258 | 259 | stmt, err := tx.Prepare("SELECT id, endpoint, seed, pubkey, route, ifname FROM network WHERE id = ?") 260 | if err != nil { 261 | return nil, err 262 | } 263 | defer stmt.Close() 264 | 265 | n := &Network{} 266 | var endpoint string 267 | var pubkey []byte 268 | 269 | err = stmt.QueryRow(id).Scan(&n.id, &endpoint, &n.seed, &pubkey, &n.route, &n.ifname) 270 | if errors.Is(err, sql.ErrNoRows) { 271 | return nil, nil 272 | } 273 | if err != nil { 274 | return nil, err 275 | } 276 | n.endpoint, err = net.ResolveUDPAddr("udp", endpoint) 277 | if err != nil { 278 | return nil, err 279 | } 280 | n.pubkey, err = wgtypes.NewKey(pubkey) 281 | if err != nil { 282 | return nil, err 283 | } 284 | 285 | return n, nil 286 | } 287 | 288 | func (s *Storage) AddClient(c *Client) error { 289 | tx, err := s.db.Begin() 290 | if err != nil { 291 | return err 292 | } 293 | defer tx.Rollback() 294 | 295 | stm, err := tx.Prepare("INSERT INTO client(id, network_id, ip, ifname) VALUES(?, ?, ?, ?)") 296 | if err != nil { 297 | return err 298 | } 299 | defer stm.Close() 300 | 301 | r, err := stm.Exec(c.id, c.network.id, c.ip.String(), c.ifname) 302 | if err != nil { 303 | return err 304 | } 305 | 306 | num, err := r.RowsAffected() 307 | if err != nil { 308 | return err 309 | } 310 | if num != 1 { 311 | return fmt.Errorf("number of inserted rows: %d is not 1", num) 312 | } 313 | 314 | return tx.Commit() 315 | } 316 | 317 | func (s *Storage) RemoveClient(id string) error { 318 | tx, err := s.db.Begin() 319 | if err != nil { 320 | return err 321 | } 322 | defer tx.Rollback() 323 | 324 | stm, err := tx.Prepare("DELETE FROM client WHERE id = ?") 325 | if err != nil { 326 | return err 327 | } 328 | defer stm.Close() 329 | 330 | r, err := stm.Exec(id) 331 | if err != nil { 332 | return err 333 | } 334 | 335 | num, err := r.RowsAffected() 336 | if err != nil { 337 | return err 338 | } 339 | if num != 1 { 340 | return fmt.Errorf("number of deleted rows: %d is not 1", num) 341 | } 342 | 343 | return tx.Commit() 344 | } 345 | 346 | func (s *Storage) GetClient(id string) (*Client, error) { 347 | q := ` 348 | SELECT 349 | client.id, 350 | client.network_id, 351 | client.ip, 352 | client.ifname, 353 | network.endpoint, 354 | network.seed, 355 | network.pubkey, 356 | network.route, 357 | network.ifname 358 | FROM 359 | client 360 | INNER JOIN network 361 | ON client.network_id = network.id 362 | WHERE client.id = ? 363 | ` 364 | tx, err := s.db.Begin() 365 | if err != nil { 366 | return nil, err 367 | } 368 | defer tx.Rollback() 369 | 370 | stmt, err := tx.Prepare(q) 371 | if err != nil { 372 | return nil, err 373 | } 374 | defer stmt.Close() 375 | 376 | c := &Client{} 377 | c.network = &Network{} 378 | var endpoint string 379 | var ip string 380 | var pubkey []byte 381 | err = stmt.QueryRow(id).Scan(&c.id, &c.network.id, &ip, &c.ifname, &endpoint, &c.network.seed, &pubkey, &c.network.route, &c.network.ifname) 382 | if errors.Is(err, sql.ErrNoRows) { 383 | return nil, nil 384 | } 385 | if err != nil { 386 | return nil, err 387 | } 388 | c.ip = net.ParseIP(ip) 389 | c.network.endpoint, err = net.ResolveUDPAddr("udp", endpoint) 390 | if err != nil { 391 | return nil, err 392 | } 393 | c.network.pubkey, err = wgtypes.NewKey(pubkey) 394 | if err != nil { 395 | return nil, err 396 | } 397 | 398 | return c, nil 399 | } 400 | -------------------------------------------------------------------------------- /entities_test.go: -------------------------------------------------------------------------------- 1 | package dwgd 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "net" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/google/go-cmp/cmp" 12 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 13 | ) 14 | 15 | func DbPathFixture() string { 16 | r := rand.Int31n(math.MaxInt32) 17 | return fmt.Sprintf("file:test.%d?mode=memory&cache=shared", r) 18 | } 19 | 20 | func NetworkFixture() *Network { 21 | endpoint, _ := net.ResolveUDPAddr("udp", "localhost:51820") 22 | pubkey, _ := wgtypes.ParseKey("BR1A+UneCu1FVBW/zPI/UVKA4gcNMUroj72LwFMMUUs=") 23 | network := &Network{ 24 | id: "n1", 25 | endpoint: endpoint, 26 | seed: []byte("supersecretseed"), 27 | pubkey: pubkey, 28 | route: "0.0.0.0/0", 29 | ifname: "dwgd0", 30 | } 31 | return network 32 | } 33 | 34 | func ClientFixture(network *Network) *Client { 35 | client := &Client{ 36 | id: "c1", 37 | ip: []byte{10, 0, 0, 2}, 38 | ifname: "wg-c1", 39 | network: network, 40 | } 41 | return client 42 | } 43 | 44 | func MustOpenDB(t *testing.T) *Storage { 45 | t.Helper() 46 | 47 | s := &Storage{} 48 | err := s.Open(DbPathFixture()) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | return s 54 | } 55 | 56 | func MustCloseDB(t *testing.T, s *Storage) { 57 | t.Helper() 58 | 59 | err := s.Close() 60 | if err != nil { 61 | t.Fatal(err) 62 | } 63 | } 64 | 65 | func MustExistNetwork(t *testing.T, s *Storage, n *Network) { 66 | t.Helper() 67 | 68 | err := s.AddNetwork(n) 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | } 73 | 74 | func TestStorage(t *testing.T) { 75 | s := MustOpenDB(t) 76 | MustCloseDB(t, s) 77 | } 78 | 79 | func TestStorage_Network(t *testing.T) { 80 | network := NetworkFixture() 81 | 82 | t.Run("AddNetwork", func(t *testing.T) { 83 | s := MustOpenDB(t) 84 | defer MustCloseDB(t, s) 85 | 86 | err := s.AddNetwork(network) 87 | if err != nil { 88 | t.Fatal(err) 89 | } 90 | 91 | other, err := s.GetNetwork(network.id) 92 | if err != nil { 93 | t.Fatal(err) 94 | } 95 | if !cmp.Equal(network, other, cmp.AllowUnexported(Network{})) { 96 | t.Fatalf("mismatch: %#v != %#v", network, other) 97 | } 98 | }) 99 | 100 | t.Run("RemoveNetwork", func(t *testing.T) { 101 | s := MustOpenDB(t) 102 | defer MustCloseDB(t, s) 103 | err := s.AddNetwork(network) 104 | if err != nil { 105 | t.Fatal(err) 106 | } 107 | err = s.RemoveNetwork(network.id) 108 | if err != nil { 109 | t.Fatal(err) 110 | } 111 | 112 | other, err := s.GetNetwork(network.id) 113 | if err != nil { 114 | t.Fatal(err) 115 | } 116 | if other != nil { 117 | t.Fatalf("mismatch: nil != %#v", other) 118 | } 119 | }) 120 | } 121 | 122 | func TestStorage_Client(t *testing.T) { 123 | network := NetworkFixture() 124 | client := ClientFixture(network) 125 | 126 | t.Run("AddClientNotExistingNetwork", func(t *testing.T) { 127 | s := MustOpenDB(t) 128 | defer MustCloseDB(t, s) 129 | 130 | err := s.AddClient(client) 131 | if !strings.Contains(err.Error(), "FOREIGN KEY constraint failed") { 132 | t.Fatal(err) 133 | } 134 | }) 135 | 136 | t.Run("AddClient", func(t *testing.T) { 137 | s := MustOpenDB(t) 138 | defer MustCloseDB(t, s) 139 | MustExistNetwork(t, s, network) 140 | 141 | err := s.AddClient(client) 142 | if err != nil { 143 | t.Fatal(err) 144 | } 145 | 146 | other, err := s.GetClient(client.id) 147 | if err != nil { 148 | t.Fatal(err) 149 | } 150 | if !cmp.Equal(client, other, cmp.AllowUnexported(Client{}), cmp.AllowUnexported(Network{})) { 151 | t.Fatalf("mismatch: %#v != %#v", client, other) 152 | } 153 | }) 154 | 155 | t.Run("RemoveClient", func(t *testing.T) { 156 | s := MustOpenDB(t) 157 | defer MustCloseDB(t, s) 158 | MustExistNetwork(t, s, network) 159 | 160 | err := s.AddClient(client) 161 | if err != nil { 162 | t.Fatal(err) 163 | } 164 | err = s.RemoveClient(client.id) 165 | if err != nil { 166 | t.Fatal(err) 167 | } 168 | other, err := s.GetClient(client.id) 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | if other != nil { 173 | t.Fatalf("mismatch: nil != %#v", other) 174 | } 175 | }) 176 | 177 | t.Run("RemoveNetwork", func(t *testing.T) { 178 | // removing the network associated with the client should remove the 179 | // client too 180 | 181 | s := MustOpenDB(t) 182 | defer MustCloseDB(t, s) 183 | MustExistNetwork(t, s, network) 184 | 185 | err := s.AddClient(client) 186 | if err != nil { 187 | t.Fatal(err) 188 | } 189 | 190 | err = s.RemoveNetwork(network.id) 191 | if err != nil { 192 | t.Fatal(err) 193 | } 194 | 195 | other, err := s.GetClient(client.id) 196 | if err != nil { 197 | t.Fatal(err) 198 | } 199 | if other != nil { 200 | t.Fatalf("mismatch: nil != %#v", other) 201 | } 202 | }) 203 | } 204 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/leomos/dwgd 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/docker/go-connections v0.4.0 7 | github.com/docker/go-plugins-helpers v0.0.0-20211224144127-6eecb7beb651 8 | github.com/google/go-cmp v0.6.0 9 | github.com/illarion/gonotify/v2 v2.0.0 10 | github.com/mattn/go-sqlite3 v1.14.16 11 | golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 12 | ) 13 | 14 | require ( 15 | github.com/Microsoft/go-winio v0.6.0 // indirect 16 | github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect 17 | github.com/josharian/native v1.1.0 // indirect 18 | github.com/mdlayher/genetlink v1.3.2 // indirect 19 | github.com/mdlayher/netlink v1.7.2 // indirect 20 | github.com/mdlayher/socket v0.4.1 // indirect 21 | golang.org/x/crypto v0.8.0 // indirect 22 | golang.org/x/mod v0.7.0 // indirect 23 | golang.org/x/net v0.9.0 // indirect 24 | golang.org/x/sync v0.1.0 // indirect 25 | golang.org/x/sys v0.7.0 // indirect 26 | golang.org/x/tools v0.5.0 // indirect 27 | golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect 28 | ) 29 | 30 | // replace github.com/leomos/dwgd => ./dwgd 31 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2yDvg= 2 | github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= 3 | github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= 4 | github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= 5 | github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= 6 | github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= 7 | github.com/docker/go-plugins-helpers v0.0.0-20211224144127-6eecb7beb651 h1:YcvzLmdrP/b8kLAGJ8GT7bdncgCAiWxJZIlt84D+RJg= 8 | github.com/docker/go-plugins-helpers v0.0.0-20211224144127-6eecb7beb651/go.mod h1:LFyLie6XcDbyKGeVK6bHe+9aJTYCxWLBg5IrJZOaXKA= 9 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 10 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 11 | github.com/illarion/gonotify/v2 v2.0.0 h1:KNbALXt1hm3SmHNFUrYLoRsXxKfegH9XRNRbb6xxLZs= 12 | github.com/illarion/gonotify/v2 v2.0.0/go.mod h1:38oIJTgFqupkEydkkClkbL6i5lXV/bxdH9do5TALPEE= 13 | github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= 14 | github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= 15 | github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= 16 | github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= 17 | github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= 18 | github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= 19 | github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= 20 | github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= 21 | github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= 22 | github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= 23 | github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= 24 | golang.org/x/crypto v0.8.0 h1:pd9TJtTueMTVQXzk8E2XESSMQDj/U7OUu0PqJqPXQjQ= 25 | golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= 26 | golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= 27 | golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 28 | golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= 29 | golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= 30 | golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= 31 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 32 | golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= 33 | golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 34 | golang.org/x/tools v0.5.0 h1:+bSpV5HIeWkuvgaMfI3UmKRThoTA5ODJTUd8T17NO+4= 35 | golang.org/x/tools v0.5.0/go.mod h1:N+Kgy78s5I24c24dU8OfWNEotWjutIs8SnJvn5IDq+k= 36 | golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= 37 | golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= 38 | golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= 39 | golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= 40 | -------------------------------------------------------------------------------- /listener.go: -------------------------------------------------------------------------------- 1 | package dwgd 2 | 3 | import ( 4 | "net" 5 | "path" 6 | 7 | "github.com/docker/go-connections/sockets" 8 | ) 9 | 10 | const ( 11 | dockerPluginSockDir = "/run/docker/plugins" 12 | dwgdRunDir = "/run/dwgd" 13 | dwgdSockName = "dwgd.sock" 14 | ) 15 | 16 | type UnixListener struct { 17 | sock net.Listener 18 | c commander 19 | } 20 | 21 | func (u *UnixListener) Accept() (net.Conn, error) { 22 | return u.sock.Accept() 23 | } 24 | 25 | func (u *UnixListener) Close() error { 26 | err := u.sock.Close() 27 | if err != nil { 28 | return err 29 | } 30 | 31 | u.c.Remove(path.Join(dockerPluginSockDir, dwgdSockName)) 32 | u.c.Remove(path.Join(dwgdRunDir, dwgdSockName)) 33 | 34 | return nil 35 | } 36 | 37 | func (u *UnixListener) Addr() net.Addr { 38 | return u.sock.Addr() 39 | } 40 | 41 | func NewUnixListener(c commander) (net.Listener, error) { 42 | if c == nil { 43 | c = &execCommander{} 44 | } 45 | 46 | if err := c.MkdirAll(dwgdRunDir, 0777); err != nil { 47 | return nil, err 48 | } 49 | 50 | if err := c.MkdirAll(dockerPluginSockDir, 0755); err != nil { 51 | return nil, err 52 | } 53 | 54 | fullDwgdSockPath := path.Join(dwgdRunDir, dwgdSockName) 55 | listener, err := sockets.NewUnixSocket(fullDwgdSockPath, 0) 56 | if err != nil { 57 | return nil, err 58 | } 59 | if err := c.Chmod(fullDwgdSockPath, 0777); err != nil { 60 | return nil, err 61 | } 62 | 63 | dockerPluginSockPath := path.Join(dockerPluginSockDir, dwgdSockName) 64 | err = c.Symlink(fullDwgdSockPath, dockerPluginSockPath) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | return &UnixListener{ 70 | sock: listener, 71 | c: c, 72 | }, nil 73 | } 74 | -------------------------------------------------------------------------------- /log.go: -------------------------------------------------------------------------------- 1 | package dwgd 2 | 3 | import ( 4 | "encoding/json" 5 | "log" 6 | "os" 7 | ) 8 | 9 | // Used for everything that can be considered a "result" 10 | // and should be printed to standard output 11 | var EventsLog = log.New(os.Stdout, "", log.Lmsgprefix) 12 | 13 | // Used for messages that can give the user a context of 14 | // what the software is doing 15 | var DiagnosticsLog = log.New(os.Stderr, "", log.LstdFlags|log.LUTC) 16 | 17 | // Used for very detailed messages, should not be used 18 | // in a production environment. 19 | // Disabled by default. 20 | var TraceLog = log.New(&EmptyWriter{}, "", log.LstdFlags|log.LUTC) 21 | 22 | type EmptyWriter struct{} 23 | 24 | func (e *EmptyWriter) Write(p []byte) (n int, err error) { 25 | return len(p), nil 26 | } 27 | 28 | func Jsonify(data interface{}) string { 29 | j, err := json.Marshal(data) 30 | if err != nil { 31 | return "" 32 | } 33 | return string(j) 34 | } 35 | -------------------------------------------------------------------------------- /migrations/0000.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS network ( 2 | id TEXT PRIMARY KEY, 3 | endpoint TEXT, 4 | seed BLOB, 5 | pubkey BLOB[32], 6 | route TEXT 7 | ); 8 | 9 | CREATE TABLE IF NOT EXISTS client ( 10 | id TEXT PRIMARY KEY, 11 | network_id TEXT, 12 | ip TEXT, 13 | ifname TEXT, 14 | 15 | FOREIGN KEY(network_id) REFERENCES network(id) ON DELETE CASCADE 16 | ); 17 | -------------------------------------------------------------------------------- /migrations/0001.sql: -------------------------------------------------------------------------------- 1 | ALTER TABLE network ADD COLUMN ifname TEXT; -------------------------------------------------------------------------------- /rootless.go: -------------------------------------------------------------------------------- 1 | package dwgd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "math" 7 | "path" 8 | "regexp" 9 | "strconv" 10 | "time" 11 | 12 | "github.com/illarion/gonotify/v2" 13 | ) 14 | 15 | const ( 16 | xdgRuntimeRoot = "/run/user/" 17 | dockerPidFileName = "docker.pid" 18 | ) 19 | 20 | var ( 21 | userXdgRuntimeDirRegex = regexp.MustCompile(xdgRuntimeRoot + `\d+`) 22 | ) 23 | 24 | func moveToRootlessNamespaceIfNecessary(c commander, sandboxKey string, ifname string) error { 25 | match := userXdgRuntimeDirRegex.FindString(sandboxKey) 26 | if match == "" { 27 | return nil 28 | } 29 | 30 | data, err := c.ReadFile(path.Join(match, dockerPidFileName)) 31 | if err != nil { 32 | return err 33 | } 34 | 35 | pid, err := strconv.Atoi(string(data)) 36 | if err != nil { 37 | return err 38 | } 39 | 40 | TraceLog.Printf("Moving %s to rootless namespace with PID %d\n", ifname, pid) 41 | if err := c.Run("ip", "link", "set", ifname, "netns", fmt.Sprint(pid)); err != nil { 42 | return err 43 | } 44 | 45 | return nil 46 | } 47 | 48 | // returns (pid, socket path, error) 49 | func generateSockSymlinkFromDockerPidFile(c commander, dockerPidFileFullPath string) (int, string, error) { 50 | data, err := c.ReadFile(dockerPidFileFullPath) 51 | if err != nil { 52 | return 0, "", err 53 | } 54 | 55 | pid, err := strconv.Atoi(string(data)) 56 | if err != nil { 57 | return 0, "", err 58 | } 59 | 60 | fullDwgdSockPath := path.Join(dwgdRunDir, dwgdSockName) 61 | dockerPluginSockPath := path.Join(dockerPluginSockDir, dwgdSockName) 62 | if err := c.Run("nsenter", "-U", "-n", "-m", "-t", fmt.Sprint(pid), "ln", "-s", "-f", fullDwgdSockPath, dockerPluginSockPath); err != nil { 63 | TraceLog.Printf("Couldn't create symlink on rootless ns (PID: %d): %s\n", pid, err) 64 | return 0, "", err 65 | } 66 | 67 | TraceLog.Printf("Created symlink for namespace with PID %d\n", pid) 68 | return pid, dockerPluginSockPath, nil 69 | } 70 | 71 | type RootlessSymlinker struct { 72 | c commander 73 | socketSymlinkPerNs map[int]string 74 | stopCh chan int 75 | inotify *gonotify.Inotify 76 | } 77 | 78 | func NewRootlessSymlinker(c commander) (*RootlessSymlinker, error) { 79 | if c == nil { 80 | c = &execCommander{} 81 | } 82 | 83 | path, err := c.LookPath("nsenter") 84 | if err != nil { 85 | TraceLog.Printf("Couldn't find 'nsenter' utility: %s", err) 86 | return nil, err 87 | } else { 88 | TraceLog.Printf("Using 'nsenter' utility at the following path: %s", path) 89 | } 90 | 91 | return &RootlessSymlinker{ 92 | c: c, 93 | socketSymlinkPerNs: make(map[int]string), 94 | stopCh: make(chan int), 95 | }, nil 96 | } 97 | 98 | func (r *RootlessSymlinker) handleEvent(ev gonotify.InotifyEvent) { 99 | if ev.Mask&(gonotify.IN_CREATE|gonotify.IN_ISDIR) != 0 { 100 | if !userXdgRuntimeDirRegex.MatchString(ev.Name) { 101 | return 102 | } 103 | r.inotify.AddWatch(ev.Name, gonotify.IN_CLOSE_WRITE) 104 | } else if ev.Mask&gonotify.IN_CLOSE_WRITE != 0 { 105 | if !userXdgRuntimeDirRegex.MatchString(ev.Name) { 106 | return 107 | } 108 | 109 | TraceLog.Printf("Creating symlink from %s\n", ev.Name) 110 | retries := 5 111 | for i := 0; i < retries; i++ { 112 | pid, sockPath, err := generateSockSymlinkFromDockerPidFile(r.c, ev.Name) 113 | if err == nil { 114 | r.socketSymlinkPerNs[pid] = sockPath 115 | return 116 | } 117 | TraceLog.Printf("Error during creation of socket symlink: %s\n", err) 118 | waitSecs := int64(math.Pow(2, float64(i))) 119 | TraceLog.Printf("[%d/%d] Waiting %d seconds\n", i+1, retries, waitSecs) 120 | time.Sleep(time.Duration(waitSecs) * time.Second) 121 | } 122 | } 123 | 124 | } 125 | 126 | func (r *RootlessSymlinker) Start() error { 127 | // We create a context to handle inotify's lifecyle. 128 | // When the symlinker is stopped we want to stop 129 | // cleanly also the inotify instance. 130 | ctx, cancel := context.WithCancel(context.Background()) 131 | defer cancel() 132 | 133 | inotify, err := gonotify.NewInotify(ctx) 134 | if err != nil { 135 | return err 136 | } 137 | r.inotify = inotify 138 | 139 | // Before starting watching for events we list all the folders 140 | // in the xdgRuntimeRoot: if there already are some instances 141 | // of docker rootless running we can handle those 142 | entries, err := r.c.ReadDir(xdgRuntimeRoot) 143 | if err != nil { 144 | return err 145 | } 146 | for _, entry := range entries { 147 | // If we find any directory whose name is a number 148 | // we assume that it could be a user's XDG_RUNTIME_DIR 149 | // We handle this situation creating a "fake" inotify 150 | // event. 151 | if !entry.IsDir() { 152 | continue 153 | } 154 | fullPath := path.Join(xdgRuntimeRoot, entry.Name()) 155 | isNumber := userXdgRuntimeDirRegex.MatchString(fullPath) 156 | if !isNumber { 157 | continue 158 | } 159 | r.handleEvent(gonotify.InotifyEvent{ 160 | Name: fullPath, 161 | Mask: gonotify.IN_CREATE | gonotify.IN_ISDIR, 162 | }) 163 | 164 | // We also search for a file 165 | // inside the directory and handle a constructed event. 166 | subEntries, err := r.c.ReadDir(fullPath) 167 | if err != nil { 168 | return err 169 | } 170 | for _, subEntry := range subEntries { 171 | if subEntry.Name() == dockerPidFileName { 172 | r.handleEvent(gonotify.InotifyEvent{ 173 | Name: path.Join(fullPath, subEntry.Name()), 174 | Mask: gonotify.IN_CLOSE_WRITE, 175 | }) 176 | } 177 | } 178 | } 179 | 180 | err = r.inotify.AddWatch(xdgRuntimeRoot, gonotify.IN_CREATE|gonotify.IN_ISDIR) 181 | if err != nil { 182 | return err 183 | } 184 | 185 | TraceLog.Println("Starting to listen for events") 186 | for { 187 | raw, err := r.inotify.ReadDeadline(time.Now().Add(time.Millisecond * 200)) 188 | select { 189 | case <-r.stopCh: 190 | return nil 191 | default: 192 | { 193 | if err != nil { 194 | if err == gonotify.TimeoutError { 195 | continue 196 | } 197 | TraceLog.Printf("Error during inotify reading: %s\n", err) 198 | return nil 199 | } 200 | 201 | for _, event := range raw { 202 | r.handleEvent(event) 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | func (r *RootlessSymlinker) Stop() error { 210 | r.stopCh <- 0 211 | close(r.stopCh) 212 | 213 | for pid, path := range r.socketSymlinkPerNs { 214 | if err := r.c.Run("nsenter", "-U", "-n", "-m", "-t", fmt.Sprint(pid), "rm", "-f", path); err != nil { 215 | TraceLog.Printf("Couldn't remove symlink on rootless ns (PID: %d): %s\n", pid, err) 216 | continue 217 | } 218 | } 219 | return nil 220 | } 221 | -------------------------------------------------------------------------------- /systemd/dwgd.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=dwgd 3 | Before=docker.service 4 | After=network.target 5 | Requires=docker.service 6 | 7 | [Service] 8 | ExecStart=/usr/bin/dwgd -d /var/lib/dwgd.db 9 | 10 | [Install] 11 | WantedBy=multi-user.target 12 | --------------------------------------------------------------------------------