├── .gitattributes ├── .github ├── codecov.yml ├── dependabot.yml └── workflows │ ├── codeql-analysis.yml │ ├── go-fuzz.yml │ └── go.yml ├── .gitignore ├── .golangci.yml ├── .vscode ├── dlv-sudo.sh ├── launch.json ├── settings.json └── tasks.json ├── LICENSE ├── PROTOCOL.md ├── README.md ├── TODO.md ├── apply ├── allowed-ips.go ├── allowed-ips_test.go ├── auto-ip.go ├── auto-ip_test.go ├── doc.go ├── health.go ├── health_test.go ├── local-iface.go ├── local-iface_test.go ├── state.go └── state_test.go ├── autopeer ├── autoaddress.go └── autoaddress_test.go ├── cmd ├── cmd.go ├── cmd_generic.go ├── cmd_test.go ├── cmd_unix.go ├── cmd_vnet_test.go └── wirevlink.wgNotRealLongerThanIFNAMSIZ.json ├── config ├── configuration-trust.go ├── configuration-trust_test.go ├── doc.go ├── flags.go ├── flags_test.go ├── peer-config-set.go ├── peer-config-set_test.go ├── peer-config.go ├── peer-data.go ├── peer-data_test.go ├── server-config.go ├── server-config_test.go ├── server-data.go ├── server-data_test.go └── utils_test.go ├── detect ├── router-detect.go └── router-detect_test.go ├── device ├── auto-ip.go ├── auto-ip_test.go ├── config.go ├── device.go └── doc.go ├── fact ├── accumulator.go ├── accumulator_test.go ├── doc.go ├── fact-set.go ├── fact-set_test.go ├── fact.go ├── member-metadata.go ├── member-metadata_test.go ├── parse-signed-group_test.go ├── parse.go ├── parse_test.go ├── testdata │ ├── fuzz │ │ └── FuzzDecodeFrom │ │ │ ├── 0cfd23bfe624f2c78c0930c5fe5ea7b1db43d35285a0920bde0280f4f836b0fe │ │ │ ├── 2bc31826454fd1ee42d6a1449698b0a5d6395c8972eb15eb58843e7e80665d06 │ │ │ ├── 68ef958b5121ee27405ff1e28c59a6e871748abc58724342c6b2590a37ea23f3 │ │ │ ├── 6b3f3671c028200ef9f1dc642c4f7b358cf46817000b2bb2f4f550233df3d467 │ │ │ ├── 7fbe8f1ab90c5846075daef7abd135cdec284f2f1eebd274b8d9cfb024d62df2 │ │ │ ├── 87c7c0b972f6d43523d679b0107adf883ceb97fe527a496b6193389733bd25df │ │ │ ├── 9ee44aada7de1402f44d18fdfc7949aee504f7233cd4610664d31432a5d5f9e2 │ │ │ └── e7eac19df17a16b6835603e9b1763904ecfd2a2308c8e8ea52dbf043052b2f80 │ └── vnet-packets.txt ├── time-scale.go ├── types-signedgroupvalue.go ├── types-signedgroupvalue_test.go ├── types-subjects.go ├── types-values.go ├── types.go └── utils_test.go ├── go.mod ├── go.sum ├── go.work ├── go.work.sum ├── internal ├── WgClient.go ├── channels │ ├── broadcast.go │ ├── doc.go │ ├── filter.go │ ├── process.go │ └── types.go ├── deps.go ├── doc.go ├── lru.go ├── lru_test.go ├── mocks │ ├── .gitignore │ └── doc.go ├── networking │ ├── darwin │ │ └── darwin-ifconfig.go │ ├── doc.go │ ├── host │ │ ├── factory-darwin.go │ │ ├── factory-generic.go │ │ ├── factory-linux.go │ │ └── factory.go │ ├── linux │ │ ├── linux-netlink.go │ │ └── linux-netlink_test.go │ ├── mocks │ │ ├── .gitignore │ │ └── builder.go │ ├── native │ │ ├── constants_darwin.go │ │ ├── constants_generic.go │ │ ├── constants_linux.go │ │ ├── constants_windows.go │ │ ├── go-environment.go │ │ ├── go-environment_test.go │ │ ├── go-interface.go │ │ ├── go-interface_test.go │ │ ├── go-udpconn.go │ │ ├── go-udpconn_test.go │ │ └── utils_test.go │ ├── types.go │ └── vnet │ │ ├── doc.go │ │ ├── host-environment.go │ │ ├── host-wgclient.go │ │ ├── host.go │ │ ├── interface-wrap.go │ │ ├── interface.go │ │ ├── network.go │ │ ├── packet.go │ │ ├── phy.go │ │ ├── smoke_test.go │ │ ├── socket-udpconn.go │ │ ├── socket.go │ │ ├── tunnel.go │ │ ├── util.go │ │ └── world.go ├── testutils │ ├── ci-scaling.go │ ├── ci-scaling_test.go │ ├── doc.go │ ├── facts │ │ └── facts.go │ ├── ip.go │ ├── log.go │ ├── must-bytes.go │ ├── output-capture.go │ ├── paths.go │ └── rand.go └── version.go.in ├── log └── log.go ├── magefiles ├── .gitignore ├── ci.go ├── debug.go ├── dirty.go ├── generate.go ├── go.mod ├── go.sum ├── install.go ├── lint.go ├── magefile.go ├── test.go ├── tools.go ├── vars.go └── versions.go ├── packaging ├── checkinstall │ ├── description-pak │ ├── postinstall-pak │ ├── postremove-pak │ └── preremove-pak ├── deploy.sh ├── wg-go-checkinstall │ └── description-pak ├── wirelink@.service └── wl-quick@.service ├── peerfacts ├── devicefacts.go ├── devicefacts_test.go ├── doc.go ├── peerfacts.go └── peerfacts_test.go ├── server ├── device.go ├── device_test.go ├── errors-generic.go ├── errors-linux.go ├── interface-cache.go ├── peer-config-set.go ├── peer-config-set_test.go ├── peer-config.go ├── peer-config_test.go ├── peer-knowledge.go ├── peer-knowledge_test.go ├── peer-lookup.go ├── received-fact.go ├── udp-inbound.go ├── udp-inbound_test.go ├── udp-outbound.go ├── udp-outbound_test.go ├── udp-server.go ├── udp-server_test.go ├── util.go ├── util_test.go └── utils_test.go ├── signing ├── doc.go ├── sign.go ├── signer.go ├── signer_test.go └── verify.go ├── trust ├── .gitignore ├── composite-trust.go ├── composite-trust_test.go ├── known-peer-trust.go ├── known-peer-trust_test.go ├── route-based-trust.go ├── route-based-trust_test.go ├── trust.go └── trust_test.go ├── util ├── bytereader.go ├── decodable.go ├── decodable_test.go ├── doc.go ├── errors.go ├── errors_test.go ├── ip.go ├── ip_test.go ├── must.go ├── net.go ├── slice.go ├── ternary.go ├── ternary_test.go ├── time.go └── time_test.go ├── wirelink.go └── wirelink.test.json /.gitattributes: -------------------------------------------------------------------------------- 1 | *.go text eol=lf 2 | version.go.in text eol=lf 3 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: false 3 | 4 | coverage: 5 | status: 6 | project: 7 | default: 8 | target: auto 9 | # don't fail for tiny decreases in project-level coverage, rely on 10 | # patch coverage to maintain overall levels 11 | threshold: 1% 12 | # precision: 2 13 | # round: down 14 | # range: "70...100" 15 | 16 | # parsers: 17 | # gcov: 18 | # branch_detection: 19 | # conditional: yes 20 | # loop: yes 21 | # method: no 22 | # macro: no 23 | 24 | # comment: 25 | # layout: "reach,diff,flags,tree" 26 | # behavior: default 27 | # require_changes: no 28 | 29 | ignore: 30 | # mocks are test code, don't care about coverage there 31 | # go reports coverage names package qualified, important for the ignore 32 | - ".*/internal/mocks/" 33 | - ".*/internal/networking/mocks/" 34 | - ".*_test.go$" 35 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | directory: "/" 5 | groups: 6 | go-everything: 7 | patterns: 8 | - "*" 9 | schedule: 10 | interval: "weekly" 11 | day: sunday 12 | - package-ecosystem: github-actions 13 | directory: "/" 14 | schedule: 15 | interval: "weekly" 16 | day: sunday 17 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | name: "CodeQL" 7 | 8 | on: 9 | push: 10 | branches: [main] 11 | pull_request: 12 | schedule: 13 | - cron: '0 16 * * 0' 14 | 15 | concurrency: 16 | group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | analyze: 21 | name: Analyze 22 | runs-on: ubuntu-latest 23 | 24 | strategy: 25 | fail-fast: false 26 | 27 | steps: 28 | - name: Checkout code 29 | uses: actions/checkout@v4 30 | with: 31 | # We must fetch at least the immediate parents so that if this is 32 | # a pull request then we can checkout the head. 33 | # fetch more than that so git describe --tags will work 34 | fetch-depth: 20 35 | 36 | - name: Set up Go 1.24 37 | uses: actions/setup-go@v5 38 | with: 39 | go-version: "^1.24" 40 | id: go 41 | 42 | - name: Get tags to make version.go 43 | run: | 44 | go tool mage deepen 45 | 46 | # Initializes the CodeQL tools for scanning. 47 | - name: Initialize CodeQL 48 | uses: github/codeql-action/init@v3 49 | with: 50 | languages: go 51 | # If you wish to specify custom queries, you can do so here or in a config file. 52 | # By default, queries listed here will override any specified in a config file. 53 | # Prefix the list here with "+" to use these queries and those in the config file. 54 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | #- name: Autobuild 59 | # uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 https://git.io/JvXDl 63 | 64 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 65 | # and modify them (or add more) to build your code if your project 66 | # uses a compiled language 67 | 68 | - run: | 69 | go tool mage compile 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v3 73 | env: 74 | CODEQL_EXTRACTOR_GO_BUILD_COMMAND: "go tool mage compile" 75 | -------------------------------------------------------------------------------- /.github/workflows/go-fuzz.yml: -------------------------------------------------------------------------------- 1 | name: Go Fuzz 2 | on: [push] 3 | concurrency: 4 | group: fuzz-${{ github.workflow }}-${{ github.head_ref || github.ref }} 5 | cancel-in-progress: true 6 | jobs: 7 | fuzz: 8 | name: fuzz 9 | strategy: 10 | fail-fast: false 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Check out code 14 | uses: actions/checkout@v4 15 | with: 16 | fetch-depth: 20 17 | 18 | - name: Set up Go 1.24 19 | uses: actions/setup-go@v5 20 | with: 21 | go-version: "^1.24" 22 | id: go 23 | 24 | - name: Get tags to make version.go 25 | run: | 26 | go tool mage deepen 27 | 28 | - name: Run fuzz tests 29 | run: go tool mage test:fuzz 30 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | on: [push] 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} 5 | cancel-in-progress: true 6 | jobs: 7 | build: 8 | name: Build 9 | strategy: 10 | matrix: 11 | os: 12 | - ubuntu-22.04 13 | - ubuntu-24.04 14 | - ubuntu-24.04-arm 15 | - macos-13 16 | - macos-14 17 | # macos-15 is arm 18 | - macos-15 19 | - windows-latest 20 | fail-fast: false 21 | runs-on: ${{ matrix.os }} 22 | steps: 23 | - name: Check out code 24 | uses: actions/checkout@v4 25 | with: 26 | fetch-depth: 20 27 | 28 | - name: Set up Go 1.24 29 | uses: actions/setup-go@v5 30 | with: 31 | go-version: "^1.24" 32 | id: go 33 | 34 | - name: Get tags to make version.go 35 | run: | 36 | go tool mage deepen 37 | 38 | - name: Get dependencies 39 | run: | 40 | go mod download 41 | go mod verify 42 | git diff --exit-code 43 | - name: Verify tidy 44 | run: | 45 | go mod tidy 46 | ( cd magefiles && go mod tidy ) 47 | go work sync 48 | git diff --exit-code 49 | 50 | - name: Build 51 | run: | 52 | go tool mage compile 53 | git diff --exit-code 54 | 55 | - name: golangci-lint 56 | uses: golangci/golangci-lint-action@v8 57 | with: 58 | version: latest 59 | - name: govulncheck 60 | run: go tool mage lint:vulncheck 61 | 62 | - name: Test with coverage 63 | # building the coverage output runs the tests 64 | run: go tool mage test:coverCI 65 | - name: Generate coverage report 66 | # even if tests fail, generate the report 67 | if: ${{ always() }} 68 | run: go tool mage test:coverHTML 69 | - name: Upload tests to codecov 70 | # upload the coverage report to codecov 71 | if: ${{ always() }} 72 | uses: codecov/test-results-action@v1 73 | with: 74 | token: ${{ secrets.CODECOV_TOKEN }} 75 | - name: Upload coverage (codecov) 76 | if: ${{ always() }} 77 | uses: codecov/codecov-action@v5 78 | with: 79 | token: ${{ secrets.CODECOV_TOKEN }} 80 | disable_search: true 81 | files: ./coverage.out 82 | fail_ci_if_error: true 83 | - name: Upload coverage (artifact) 84 | if: ${{ always() }} 85 | uses: actions/upload-artifact@v4.6.2 86 | with: 87 | name: coverage-${{ matrix.os }}.html 88 | path: ./coverage.html 89 | 90 | - name: Test with race detector 91 | # combination of macos runners being slow and race detector messing with 92 | # timing makes this combination too flaky for now 93 | if: "!startsWith(matrix.os, 'macos-')" 94 | run: go tool mage test:goRace 95 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, build with `go test -c` 9 | *.test 10 | # debug binary from vscode 11 | __debug_bin 12 | 13 | # Output of the go coverage tool, specifically when used with LiteIDE 14 | *.out 15 | 16 | # binaries 17 | wirelink 18 | wirelink-cross-* 19 | 20 | packaging/*checkinstall/*.deb 21 | # everything in doc-pak/ is copied from elsewhere 22 | packaging/*checkinstall/doc-pak/* 23 | 24 | coverage.out 25 | coverage.html 26 | junit.xml 27 | 28 | version.go 29 | 30 | # try folder has temp test code 31 | try/ 32 | 33 | # macos cruft 34 | .DS_Store 35 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | linters: 3 | enable: 4 | - gocyclo 5 | - misspell 6 | - prealloc 7 | - revive 8 | - staticcheck 9 | - unconvert 10 | - unparam 11 | settings: 12 | gocyclo: 13 | min-complexity: 20 14 | revive: 15 | rules: 16 | - name: superfluous-else 17 | disabled: true 18 | exclusions: 19 | generated: lax 20 | rules: 21 | - path: "(.+)\\.go$" 22 | # errcheck: Almost all programs ignore errors on these functions and in most cases it's ok 23 | text: Error return value of .((os\.)?Std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*printf?|os\.(Unsetenv|Setenv)). is not checked 24 | linters: 25 | - errcheck 26 | paths: 27 | - try 28 | - third_party$ 29 | - builtin$ 30 | - examples$ 31 | issues: 32 | max-same-issues: 10 33 | formatters: 34 | enable: 35 | - gofmt 36 | - gofumpt 37 | - goimports 38 | settings: 39 | gofumpt: 40 | extra-rules: true 41 | exclusions: 42 | generated: lax 43 | paths: 44 | - try 45 | - third_party$ 46 | - builtin$ 47 | - examples$ 48 | -------------------------------------------------------------------------------- /.vscode/dlv-sudo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | if ! which dlv ; then 3 | export PATH="${GOPATH}/bin:$PATH" 4 | fi 5 | if [ "$WIRELINK_DEBUG_AS_ROOT" = "true" ]; then 6 | # sudo may not obey "our" $PATH, so need to look up the binary ourselves 7 | exec sudo "$(which dlv)" --only-same-user=false "$@" 8 | else 9 | exec dlv "$@" 10 | fi 11 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "Run test iface", 6 | "type": "go", 7 | "request": "launch", 8 | "mode": "auto", 9 | "program": "${workspaceFolder}", 10 | "env": { 11 | "__DEBUG_BIN_CONFIG_PATH": ".", 12 | "__DEBUG_BIN_CHATTY": "true" 13 | }, 14 | "args": [ 15 | "--iface=test", 16 | "--dump" 17 | ] 18 | }, 19 | { 20 | "name": "Run for real", 21 | "type": "go", 22 | "request": "launch", 23 | "mode": "exec", 24 | "preLaunchTask": "build app", 25 | "program": "${workspaceFolder}/wirelink", 26 | "env": { 27 | "WIRELINK_DEBUG_AS_ROOT": "true" 28 | }, 29 | "args": [ 30 | "--iface=wg0", 31 | "--debug" 32 | ] 33 | }, 34 | { 35 | "name": "Test current package", 36 | "type": "go", 37 | "request": "launch", 38 | "mode": "test", 39 | "program": "${fileDirname}", 40 | }, 41 | { 42 | "name": "Debug mage", 43 | "type":"go", 44 | "request":"launch", 45 | "mode":"exec", 46 | "preLaunchTask": "build mage build", 47 | "program":"${workspaceFolder}/magefiles/build", 48 | "env": {}, 49 | "args": [ 50 | "-v", 51 | "${input:mageTarget}", 52 | ] 53 | } 54 | ], 55 | "inputs": [ 56 | { 57 | "type": "promptString", 58 | "id": "mageTarget", 59 | "description": "Mage target", 60 | } 61 | ] 62 | } 63 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cSpell.allowCompoundWords": true, 3 | "cSpell.words": [ 4 | "aips", 5 | "Cidr", 6 | "codecov", 7 | "Decodable", 8 | "deconfigure", 9 | "deconfigured", 10 | "deconfiguring", 11 | "dedupe", 12 | "deserialization", 13 | "deserialize", 14 | "deserialized", 15 | "ENOKEY", 16 | "enqueues", 17 | "Equalf", 18 | "Falsef", 19 | "fuzzer", 20 | "GOBIN", 21 | "gocyclo", 22 | "goimports", 23 | "golangci", 24 | "golint", 25 | "GOPATH", 26 | "gopls", 27 | "govulncheck", 28 | "Iface", 29 | "Ifaces", 30 | "laddr", 31 | "Newf", 32 | "NICs", 33 | "nolint", 34 | "pcfg", 35 | "pflag", 36 | "pflags", 37 | "SGFs", 38 | "stretchr", 39 | "stringifying", 40 | "systemctl", 41 | "systemd", 42 | "testutils", 43 | "Typef", 44 | "uint", 45 | "unencrypted", 46 | "Unmarshaller", 47 | "unmarshalling", 48 | "unregister", 49 | "unregisters", 50 | "Unsetenv", 51 | "Untrusted", 52 | "uuid", 53 | "uvarint", 54 | "vcfg", 55 | "veth", 56 | "vishvananda", 57 | "vnet", 58 | "vulncheck", 59 | "wgctrl", 60 | "wgtypes", 61 | "wirevlink", 62 | "Wrapf", 63 | "xyzzy" 64 | ], 65 | "go.alternateTools": { 66 | "dlv": "${workspaceFolder}/.vscode/dlv-sudo.sh" 67 | }, 68 | "go.generateTestsFlags": [ 69 | "-template=testify" 70 | ], 71 | "go.testFlags": ["-short"], 72 | "go.lintTool": "golangci-lint", 73 | "go.lintOnSave": "workspace", 74 | "go.coverOnSave": true, 75 | "go.vetOnSave": "workspace", 76 | "go.testTimeout": "10s", 77 | "go.testOnSave": true, 78 | "go.useLanguageServer": true, 79 | "[go]": { 80 | "editor.insertSpaces": false, 81 | "editor.formatOnSave": true, 82 | // work around https://github.com/golang/vscode-go/issues/691 83 | "editor.formatOnSaveMode": "file", 84 | "editor.codeActionsOnSave": { 85 | "source.organizeImports": "explicit" 86 | } 87 | }, 88 | "go.formatTool": "gofumpt", 89 | "files.insertFinalNewline": true, 90 | "files.trimFinalNewlines": true 91 | } 92 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "2.0.0", 3 | "echoCommand": true, 4 | "tasks": [ 5 | { 6 | "label": "build", 7 | "type": "shell", 8 | "command": "go", 9 | "args": [ 10 | "build", 11 | "-v", 12 | "./..." 13 | ], 14 | "group": { 15 | "kind": "build", 16 | "isDefault": true 17 | }, 18 | "problemMatcher": "$go" 19 | }, 20 | { 21 | "label": "build app", 22 | "type": "shell", 23 | "command": "go", 24 | "args": [ 25 | "build", 26 | "-v", 27 | "." 28 | ], 29 | "group": "build", 30 | "problemMatcher": "$go" 31 | }, 32 | { 33 | "label": "test", 34 | "type": "shell", 35 | "command": "go", 36 | "args": [ 37 | "test", 38 | "./..." 39 | ], 40 | "group": { 41 | "kind": "test", 42 | "isDefault": true 43 | } 44 | }, 45 | { 46 | "label": "build mage build", 47 | "type": "shell", 48 | "command": "go", 49 | "args": [ 50 | "tool", 51 | "mage", 52 | "preDebug", 53 | ] 54 | } 55 | ] 56 | } 57 | -------------------------------------------------------------------------------- /apply/allowed-ips.go: -------------------------------------------------------------------------------- 1 | package apply 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/fastcat/wirelink/autopeer" 7 | "github.com/fastcat/wirelink/fact" 8 | "github.com/fastcat/wirelink/log" 9 | "github.com/fastcat/wirelink/util" 10 | 11 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 12 | ) 13 | 14 | type allowedIPFlag int 15 | 16 | const ( 17 | aipNone allowedIPFlag = 0 18 | aipCurrent = 1 << 0 19 | aipAdding = 1 << 1 20 | aipValid = 1 << 2 21 | // aipDeleting = 1 << 3 22 | 23 | aipAlreadyMask = aipCurrent | aipAdding 24 | // trust values others started adding, e.g. auto ip 25 | aipRebuildMask = aipAdding | aipValid 26 | ) 27 | 28 | func ipNetKey(ipNet net.IPNet) string { 29 | return string(util.MustBytes(fact.IPNetValue{IPNet: ipNet}.MarshalBinary())) 30 | } 31 | 32 | func fvKey(value fact.Value) string { 33 | return string(util.MustBytes(value.MarshalBinary())) 34 | } 35 | 36 | func keyIPNet(key string) net.IPNet { 37 | v := &fact.IPNetValue{} 38 | err := v.UnmarshalBinary([]byte(key)) 39 | // this should never happen 40 | if err != nil { 41 | panic(err) 42 | } 43 | return v.IPNet 44 | } 45 | 46 | // EnsureAllowedIPs updates the device config if needed to add all the 47 | // AllowedIPs from the facts to the peer. This assumes that facts have already 48 | // been filtered to be just the trusted ones. 49 | func EnsureAllowedIPs( 50 | peer *wgtypes.Peer, 51 | facts []*fact.Fact, 52 | cfg *wgtypes.PeerConfig, 53 | allowDeconfigure bool, 54 | ) *wgtypes.PeerConfig { 55 | aipFlags := make(map[string]allowedIPFlag) 56 | for _, aip := range peer.AllowedIPs { 57 | aipFlags[ipNetKey(aip)] |= aipCurrent 58 | } 59 | if cfg != nil { 60 | for _, aip := range cfg.AllowedIPs { 61 | aipFlags[ipNetKey(aip)] |= aipAdding 62 | } 63 | } 64 | // autoaddr is always valid 65 | aipFlags[ipNetKey(autopeer.AutoAddressNet(peer.PublicKey))] |= aipValid 66 | 67 | for _, f := range facts { 68 | switch f.Attribute { 69 | case fact.AttributeAllowedCidrV4, fact.AttributeAllowedCidrV6: 70 | key := fvKey(f.Value) 71 | aipFlags[key] |= aipValid 72 | if aipFlags[key]&aipAlreadyMask != aipNone { 73 | continue 74 | } 75 | if ipn, ok := f.Value.(*fact.IPNetValue); ok { 76 | if cfg == nil { 77 | cfg = &wgtypes.PeerConfig{PublicKey: peer.PublicKey} 78 | } 79 | cfg.AllowedIPs = append(cfg.AllowedIPs, ipn.IPNet) 80 | aipFlags[key] |= aipAdding 81 | } else { 82 | log.Error("AIP Fact has wrong value type: %v => %T: %v", f.Attribute, f.Value, f.Value) 83 | } 84 | } 85 | } 86 | 87 | if allowDeconfigure { 88 | replace := false 89 | for _, f := range aipFlags { 90 | if f&(aipCurrent|aipValid) == aipCurrent { 91 | // peer has a current AIP that it should not 92 | // we need to convert the config to a _replace_ mode 93 | replace = true 94 | break 95 | } 96 | } 97 | if replace { 98 | // rebuild 99 | if cfg == nil { 100 | cfg = &wgtypes.PeerConfig{PublicKey: peer.PublicKey} 101 | } 102 | cfg.ReplaceAllowedIPs = true 103 | cfg.AllowedIPs = nil 104 | for k, f := range aipFlags { 105 | // add anything that was already being added, or determined to be valid to have (present or not) 106 | if f&aipRebuildMask == aipNone { 107 | continue 108 | } 109 | ipn := keyIPNet(k) 110 | cfg.AllowedIPs = append(cfg.AllowedIPs, ipn) 111 | } 112 | } 113 | } 114 | 115 | return cfg 116 | } 117 | -------------------------------------------------------------------------------- /apply/auto-ip.go: -------------------------------------------------------------------------------- 1 | package apply 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/fastcat/wirelink/autopeer" 7 | 8 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 9 | ) 10 | 11 | func hasAutoIP(autoaddr net.IP, aips []net.IPNet) bool { 12 | for _, aip := range aips { 13 | if aip.IP.Equal(autoaddr) { 14 | ones, bits := aip.Mask.Size() 15 | if ones == 8*net.IPv6len && bits == 8*net.IPv6len { 16 | return true 17 | } 18 | } 19 | } 20 | return false 21 | } 22 | 23 | // EnsurePeerAutoIP ensures that the config (if any) for the given peer key includes 24 | // its automatic IPv6-LL address. 25 | func EnsurePeerAutoIP(peer *wgtypes.Peer, cfg *wgtypes.PeerConfig) (peerConfig *wgtypes.PeerConfig, added bool) { 26 | autoaddr := autopeer.AutoAddress(peer.PublicKey) 27 | hasNow := hasAutoIP(autoaddr, peer.AllowedIPs) 28 | var alreadyAdding bool 29 | var rebuilding bool 30 | if cfg != nil { 31 | alreadyAdding = hasAutoIP(autoaddr, cfg.AllowedIPs) 32 | rebuilding = cfg.ReplaceAllowedIPs 33 | } 34 | // we can skip this if we're already adding it, or if we have it and aren't rebuilding 35 | if alreadyAdding || hasNow && !rebuilding { 36 | return cfg, false 37 | } 38 | 39 | if cfg == nil { 40 | cfg = &wgtypes.PeerConfig{ 41 | PublicKey: peer.PublicKey, 42 | } 43 | } 44 | 45 | cfg.AllowedIPs = append(cfg.AllowedIPs, net.IPNet{ 46 | IP: autoaddr, 47 | Mask: net.CIDRMask(8*net.IPv6len, 8*net.IPv6len), 48 | }) 49 | // don't "say" we added it (for logging purposes) if we are "re-adding" it 50 | // as part of a rebuild 51 | return cfg, !hasNow 52 | } 53 | 54 | // OnlyAutoIP configures a peer to have _only_ its IPv6-LL IP in its AllowedIPs 55 | // it returns whether a change was attempted and any error that happens 56 | func OnlyAutoIP(peer *wgtypes.Peer, cfg *wgtypes.PeerConfig) *wgtypes.PeerConfig { 57 | autoaddr := autopeer.AutoAddress(peer.PublicKey) 58 | // don't bother checking for not needing a change, just always set it up 59 | if cfg != nil && cfg.ReplaceAllowedIPs && len(cfg.AllowedIPs) == 1 && hasAutoIP(autoaddr, cfg.AllowedIPs) { 60 | // already set to apply this config 61 | return cfg 62 | } 63 | if len(peer.AllowedIPs) == 1 && hasAutoIP(autoaddr, peer.AllowedIPs) { 64 | // if peer is already configured properly and config isn't going to change it, do nothing 65 | if cfg == nil { 66 | return cfg 67 | } 68 | if len(cfg.AllowedIPs) == 0 && !cfg.ReplaceAllowedIPs { 69 | return cfg 70 | } 71 | } 72 | 73 | // peer either isn't setup right, or config is set to change it to something else 74 | if cfg == nil { 75 | cfg = &wgtypes.PeerConfig{PublicKey: peer.PublicKey} 76 | } 77 | cfg.ReplaceAllowedIPs = true 78 | cfg.AllowedIPs = []net.IPNet{{ 79 | IP: autoaddr, 80 | Mask: net.CIDRMask(8*net.IPv6len, 8*net.IPv6len), 81 | }} 82 | 83 | return cfg 84 | } 85 | -------------------------------------------------------------------------------- /apply/doc.go: -------------------------------------------------------------------------------- 1 | // Package apply contains code for applying changes to network interfaces and 2 | // wireguard configurations. 3 | package apply 4 | -------------------------------------------------------------------------------- /apply/health.go: -------------------------------------------------------------------------------- 1 | package apply 2 | 3 | import ( 4 | "time" 5 | 6 | "golang.zx2c4.com/wireguard/device" 7 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 8 | ) 9 | 10 | // TODO: the timing constants here should be moved somewhere more general 11 | 12 | // HealthHysteresisBandaid is an extra delay to add before considering a peer 13 | // unhealthy, based on as-yet undiagnosed observations of handshakes not 14 | // refreshing as often as documentation seems to suggest they should 15 | const HealthHysteresisBandaid = 30 * time.Second 16 | 17 | // HandshakeValidityBase is the base amount of time we think a handshake should be valid for, 18 | // without accounting for tolerances 19 | const HandshakeValidityBase = device.RekeyAfterTime + 20 | device.RekeyTimeout + 21 | device.KeepaliveTimeout + 22 | device.RekeyTimeoutJitterMaxMs*time.Millisecond 23 | 24 | // HandshakeValidity is how long we thing a handshake should be valid for, 25 | // including tolerances 26 | const HandshakeValidity = HandshakeValidityBase + HealthHysteresisBandaid 27 | 28 | // isHealthy checks the state of a peer to see if connectivity to it is probably 29 | // healthy (and thus we shouldn't change its config), or if it is unhealthy and 30 | // we should consider updating its config to try to find a working setup. 31 | // note that this is separate from being "Alive", which means that we have heard 32 | // fact packet(s) from it recently 33 | func isHealthy(state *PeerConfigState, peer *wgtypes.Peer) bool { 34 | // if the peer doesn't have an endpoint, it's not healthy 35 | if peer.Endpoint == nil { 36 | return false 37 | } 38 | // if the peer handshake is still valid, the peer is healthy 39 | if peer.LastHandshakeTime.Add(HandshakeValidity).After(time.Now()) { 40 | return true 41 | } 42 | // if the peer handshake has moved forwards since we last saw it, probably healthy 43 | if state != nil && peer.LastHandshakeTime.After(state.lastHandshake) { 44 | return true 45 | } 46 | return false 47 | } 48 | 49 | // IsHandshakeHealthy returns whether the handshake looks recent enough that the 50 | // peer is likely to be in communication. 51 | func IsHandshakeHealthy(lastHandshake time.Time) bool { 52 | return lastHandshake.Add(HandshakeValidity).After(time.Now()) 53 | } 54 | -------------------------------------------------------------------------------- /apply/health_test.go: -------------------------------------------------------------------------------- 1 | package apply 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | 10 | "github.com/fastcat/wirelink/internal/testutils" 11 | 12 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 13 | ) 14 | 15 | func Test_isHealthy(t *testing.T) { 16 | now := time.Now() 17 | then := now.Add(time.Duration(-1-rand.Intn(30)) * time.Second) 18 | longAgo := now.Add(-HandshakeValidity) 19 | longLongAgo := then.Add(-HandshakeValidity) 20 | 21 | type args struct { 22 | state *PeerConfigState 23 | peer *wgtypes.Peer 24 | } 25 | tests := []struct { 26 | name string 27 | args args 28 | want bool 29 | }{ 30 | { 31 | "no endpoint", 32 | args{ 33 | peer: &wgtypes.Peer{}, 34 | }, 35 | false, 36 | }, 37 | { 38 | "fresh handshake", 39 | args{ 40 | peer: &wgtypes.Peer{ 41 | Endpoint: testutils.RandUDP4Addr(t), 42 | LastHandshakeTime: now, 43 | }, 44 | }, 45 | true, 46 | }, 47 | { 48 | "changed stale handshake", 49 | args{ 50 | peer: &wgtypes.Peer{ 51 | Endpoint: testutils.RandUDP4Addr(t), 52 | LastHandshakeTime: longAgo, 53 | }, 54 | state: &PeerConfigState{ 55 | lastHandshake: longLongAgo, 56 | }, 57 | }, 58 | true, 59 | }, 60 | { 61 | "stable stale handshake", 62 | args{ 63 | peer: &wgtypes.Peer{ 64 | Endpoint: testutils.RandUDP4Addr(t), 65 | LastHandshakeTime: longAgo, 66 | }, 67 | state: &PeerConfigState{ 68 | lastHandshake: longAgo, 69 | }, 70 | }, 71 | false, 72 | }, 73 | } 74 | for _, tt := range tests { 75 | t.Run(tt.name, func(t *testing.T) { 76 | got := isHealthy(tt.args.state, tt.args.peer) 77 | t.Logf("deltas: %v, %v, %v", now.Sub(then), now.Sub(longAgo), now.Sub(longLongAgo)) 78 | assert.Equal(t, tt.want, got) 79 | }) 80 | } 81 | } 82 | 83 | func TestIsHandshakeHealthy(t *testing.T) { 84 | now := time.Now() 85 | then := now.Add(time.Duration(rand.Int63n(30)) * time.Second) 86 | longAgo := now.Add(-HandshakeValidity) 87 | 88 | type args struct { 89 | lastHandshake time.Time 90 | } 91 | tests := []struct { 92 | name string 93 | args args 94 | want bool 95 | }{ 96 | {"very fresh", args{now}, true}, 97 | {"fresh", args{then}, true}, 98 | {"stale", args{longAgo}, false}, 99 | } 100 | for _, tt := range tests { 101 | t.Run(tt.name, func(t *testing.T) { 102 | got := IsHandshakeHealthy(tt.args.lastHandshake) 103 | assert.Equal(t, tt.want, got) 104 | }) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /apply/local-iface.go: -------------------------------------------------------------------------------- 1 | package apply 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/fastcat/wirelink/autopeer" 8 | "github.com/fastcat/wirelink/internal/networking" 9 | "github.com/fastcat/wirelink/log" 10 | "github.com/fastcat/wirelink/util" 11 | 12 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 13 | ) 14 | 15 | // EnsureLocalAutoIP makes sure that the automatic IPv6 link-local IP is 16 | // present on the interface that matches the device 17 | // It returns whether it had to add it, and if any errors happened 18 | func EnsureLocalAutoIP(env networking.Environment, dev *wgtypes.Device) (bool, error) { 19 | iface, err := env.InterfaceByName(dev.Name) 20 | if err != nil { 21 | return false, fmt.Errorf("unable to get interface info for %s: %w", dev.Name, err) 22 | } 23 | addrs, err := iface.Addrs() 24 | if err != nil { 25 | return false, fmt.Errorf("unable to get addresses for %s: %w", dev.Name, err) 26 | } 27 | 28 | autoaddr := autopeer.AutoAddress(dev.PublicKey) 29 | for _, addr := range addrs { 30 | if util.IsIPv6LLMatch(autoaddr, &addr, true) { 31 | return false, nil 32 | } 33 | } 34 | 35 | err = iface.AddAddr(net.IPNet{ 36 | IP: autoaddr, 37 | Mask: net.CIDRMask(4*net.IPv6len, 8*net.IPv6len), 38 | }) 39 | if err != nil { 40 | return false, fmt.Errorf("unable to add %v to %s: %w", autoaddr, dev.Name, err) 41 | } 42 | 43 | log.Debug("Added local IPv6-LL %v to %s", autoaddr, dev.Name) 44 | 45 | return true, nil 46 | } 47 | -------------------------------------------------------------------------------- /apply/local-iface_test.go: -------------------------------------------------------------------------------- 1 | package apply 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "net" 7 | "testing" 8 | 9 | "github.com/fastcat/wirelink/autopeer" 10 | "github.com/fastcat/wirelink/internal/networking/mocks" 11 | "github.com/fastcat/wirelink/internal/testutils" 12 | 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | 16 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 17 | ) 18 | 19 | func TestEnsureLocalAutoIP(t *testing.T) { 20 | in1 := fmt.Sprintf("wg%d", rand.Int31()) 21 | k1 := testutils.MustKey(t) 22 | 23 | type args struct { 24 | env func(*testing.T) *mocks.Environment 25 | dev *wgtypes.Device 26 | } 27 | tests := []struct { 28 | name string 29 | args args 30 | want bool 31 | wantErr bool 32 | }{ 33 | { 34 | "already configured", 35 | args{ 36 | env: func(t *testing.T) *mocks.Environment { 37 | ret := &mocks.Environment{} 38 | ret.WithSimpleInterfaces(map[string]net.IPNet{ 39 | in1: { 40 | IP: autopeer.AutoAddress(k1), 41 | Mask: net.CIDRMask(64, 128), 42 | }, 43 | }) 44 | return ret 45 | }, 46 | dev: &wgtypes.Device{ 47 | Name: in1, 48 | PublicKey: k1, 49 | }, 50 | }, 51 | false, 52 | false, 53 | }, 54 | { 55 | "do configure", 56 | args{ 57 | env: func(t *testing.T) *mocks.Environment { 58 | ret := &mocks.Environment{} 59 | ii := ret.WithSimpleInterfaces(map[string]net.IPNet{ 60 | in1: testutils.RandIPNet(t, net.IPv4len, nil, nil, 24), 61 | }) 62 | ii[in1].On("AddAddr", net.IPNet{ 63 | IP: autopeer.AutoAddress(k1), 64 | Mask: net.CIDRMask(64, 128), 65 | }).Return(nil) 66 | return ret 67 | }, 68 | dev: &wgtypes.Device{ 69 | Name: in1, 70 | PublicKey: k1, 71 | }, 72 | }, 73 | true, 74 | false, 75 | }, 76 | } 77 | for _, tt := range tests { 78 | t.Run(tt.name, func(t *testing.T) { 79 | env := tt.args.env(t) 80 | env.Test(t) 81 | got, err := EnsureLocalAutoIP(env, tt.args.dev) 82 | if tt.wantErr { 83 | require.NotNil(t, err) 84 | } else { 85 | require.Nil(t, err) 86 | } 87 | assert.Equal(t, tt.want, got) 88 | env.AssertExpectations(t) 89 | }) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /autopeer/autoaddress.go: -------------------------------------------------------------------------------- 1 | // Package autopeer provides code to compute a peer's automatic IPv6-LL address 2 | // derived from its public key. 3 | package autopeer 4 | 5 | import ( 6 | "crypto/sha1" 7 | "net" 8 | 9 | "github.com/fastcat/wirelink/internal" 10 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 11 | ) 12 | 13 | // autoAddress computes the IPv6 link-local address that should be assigned to 14 | // peer based on its public key 15 | func autoAddress(key wgtypes.Key) net.IP { 16 | keySum := sha1.Sum(key[:]) 17 | ip := make(net.IP, 16) 18 | copy(ip[0:2], []byte{0xfe, 0x80}) 19 | copy(ip[8:], keySum[:8]) 20 | return ip 21 | } 22 | 23 | // TODO: this should be done with awareness of the number of peers we're going 24 | // to have 25 | var aaMemo = internal.Memoize(50, autoAddress) 26 | 27 | // AutoAddress returns the IPv6 link-local address that should be assigned to 28 | // peer based on its public key 29 | func AutoAddress(key wgtypes.Key) net.IP { 30 | return aaMemo(key) 31 | } 32 | 33 | // AutoAddressNet returns the peer's AutoAddress with a /128 netmask 34 | func AutoAddressNet(key wgtypes.Key) net.IPNet { 35 | return net.IPNet{ 36 | IP: AutoAddress(key), 37 | Mask: net.CIDRMask(8*net.IPv6len, 8*net.IPv6len), 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /autopeer/autoaddress_test.go: -------------------------------------------------------------------------------- 1 | package autopeer 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 10 | ) 11 | 12 | func TestAutoAddressNet(t *testing.T) { 13 | type args struct { 14 | key wgtypes.Key 15 | } 16 | ka := func(ks string) args { 17 | k, err := wgtypes.ParseKey(ks) 18 | require.Nil(t, err) 19 | return args{k} 20 | } 21 | np := func(ip string) net.IPNet { 22 | _, n, err := net.ParseCIDR(ip) 23 | require.Nil(t, err) 24 | return *n 25 | } 26 | tests := []struct { 27 | name string 28 | args args 29 | want net.IPNet 30 | }{ 31 | {"1", ka("6X/iz1GyW9euj9JIdP7PUl14eoWyoQiAa+BDTB38GhE="), np("fe80::afec:ee83:716b:51ac/128")}, 32 | } 33 | for _, tt := range tests { 34 | t.Run(tt.name, func(t *testing.T) { 35 | got := AutoAddressNet(tt.args.key) 36 | assert.Equal(t, tt.want, got) 37 | }) 38 | } 39 | } 40 | 41 | func Benchmark_autoAddress(b *testing.B) { 42 | k, err := wgtypes.GeneratePrivateKey() 43 | require.NoError(b, err) 44 | b.ResetTimer() 45 | 46 | for i := 0; i < b.N; i++ { 47 | autoAddress(k) 48 | } 49 | } 50 | 51 | func BenchmarkAutoAddress(b *testing.B) { 52 | k, err := wgtypes.GeneratePrivateKey() 53 | require.NoError(b, err) 54 | // seed the cache 55 | AutoAddress(k) 56 | b.ResetTimer() 57 | 58 | for i := 0; i < b.N; i++ { 59 | AutoAddress(k) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /cmd/cmd.go: -------------------------------------------------------------------------------- 1 | // Package cmd provides the main implementation of the wirelink command line. 2 | package cmd 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | 11 | "github.com/fastcat/wirelink/config" 12 | "github.com/fastcat/wirelink/internal" 13 | "github.com/fastcat/wirelink/internal/networking" 14 | "github.com/fastcat/wirelink/log" 15 | "github.com/fastcat/wirelink/server" 16 | ) 17 | 18 | // WirelinkCmd represents an instance of the app command line 19 | type WirelinkCmd struct { 20 | args []string 21 | wgc internal.WgClient 22 | Config *config.Server 23 | Server *server.LinkServer 24 | signals chan os.Signal 25 | } 26 | 27 | // New creates a new command instance using the given os.Args value 28 | func New(args []string) *WirelinkCmd { 29 | ret := &WirelinkCmd{ 30 | args: args, 31 | } 32 | 33 | return ret 34 | } 35 | 36 | // Init prepares the command instance 37 | func (w *WirelinkCmd) Init(env networking.Environment) error { 38 | var err error 39 | w.wgc, err = env.NewWgClient() 40 | if err != nil { 41 | return fmt.Errorf("unable to initialize wgctrl: %w", err) 42 | } 43 | 44 | flags, vcfg := config.Init(w.args) 45 | var configData *config.ServerData 46 | if configData, err = config.Parse(flags, vcfg, w.args); err != nil { 47 | return fmt.Errorf("unable to parse configuration: %w", err) 48 | } 49 | // configData comes back nil if we ran --help or --version 50 | if configData == nil { 51 | return nil 52 | } 53 | 54 | if w.Config, err = configData.Parse(vcfg, w.wgc); err != nil { 55 | flags.Usage() 56 | return fmt.Errorf("unable to load configuration: %w", err) 57 | } 58 | if w.Config == nil { 59 | // config dump was requested 60 | return nil 61 | } 62 | 63 | w.Server, err = server.Create(env, w.wgc, w.Config) 64 | if err != nil { 65 | return fmt.Errorf("unable to create server for interface %s: %w", w.Config.Iface, err) 66 | } 67 | 68 | return nil 69 | } 70 | 71 | // Run invokes the server 72 | func (w *WirelinkCmd) Run() error { 73 | defer w.Server.Close() 74 | err := w.Server.Start() 75 | if err != nil { 76 | return fmt.Errorf("unable to start server for interface %s: %w", w.Config.Iface, err) 77 | } 78 | 79 | w.signals = make(chan os.Signal, 5) 80 | w.Server.AddHandler(func(ctx context.Context) error { 81 | signal.Notify(w.signals, syscall.SIGINT, syscall.SIGTERM) 82 | w.addPlatformSignalHandlers() 83 | for { 84 | select { 85 | case sig := <-w.signals: 86 | if sig == syscall.SIGINT || sig == syscall.SIGTERM { 87 | log.Info("Received signal %v, stopping", sig) 88 | // this will just initiate the shutdown, not block waiting for it 89 | w.Server.RequestStop() 90 | 91 | // also give platform handler an opportunity to do things 92 | w.handlePlatformSignal(sig) 93 | } else if !w.handlePlatformSignal(sig) { 94 | log.Error("Received unexpected signal %v, ignoring", sig) 95 | } 96 | case <-ctx.Done(): 97 | return nil 98 | } 99 | } 100 | }) 101 | 102 | log.Info("Server running: %s", w.Server.Describe()) 103 | 104 | // server.Close is handled by defer above 105 | return w.Server.Wait() 106 | } 107 | -------------------------------------------------------------------------------- /cmd/cmd_generic.go: -------------------------------------------------------------------------------- 1 | //go:build js || nacl || plan9 || windows || zos 2 | // +build js nacl plan9 windows zos 3 | 4 | package cmd 5 | 6 | import "os" 7 | 8 | // no SIGUSR1 support here, so these are no-ops 9 | 10 | func (w *WirelinkCmd) addPlatformSignalHandlers() { 11 | } 12 | 13 | func (w *WirelinkCmd) handlePlatformSignal(os.Signal) bool { 14 | return false 15 | } 16 | 17 | func (w *WirelinkCmd) sendPrintRequestSignal() { 18 | } 19 | 20 | // who knows if this is correct, just copying the linux value here 21 | const platformIFNAMSIZ = 16 22 | -------------------------------------------------------------------------------- /cmd/cmd_unix.go: -------------------------------------------------------------------------------- 1 | //go:build !js && !nacl && !plan9 && !windows && !zos 2 | // +build !js,!nacl,!plan9,!windows,!zos 3 | 4 | package cmd 5 | 6 | import ( 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | ) 11 | 12 | func (w *WirelinkCmd) addPlatformSignalHandlers() { 13 | signal.Notify(w.signals, syscall.SIGUSR1) 14 | } 15 | 16 | func (w *WirelinkCmd) handlePlatformSignal(sig os.Signal) bool { 17 | if sig == syscall.SIGUSR1 { 18 | w.Server.RequestPrint(false) 19 | return true 20 | } 21 | return false 22 | } 23 | 24 | func (w *WirelinkCmd) sendPrintRequestSignal() { 25 | w.signals <- syscall.SIGUSR1 26 | } 27 | 28 | const platformIFNAMSIZ = syscall.IFNAMSIZ 29 | -------------------------------------------------------------------------------- /cmd/wirevlink.wgNotRealLongerThanIFNAMSIZ.json: -------------------------------------------------------------------------------- 1 | { 2 | "peers": [ 3 | { 4 | "PublicKey": "invalidKey" 5 | } 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /config/configuration-trust.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/fastcat/wirelink/autopeer" 7 | "github.com/fastcat/wirelink/fact" 8 | "github.com/fastcat/wirelink/log" 9 | "github.com/fastcat/wirelink/trust" 10 | "github.com/fastcat/wirelink/util" 11 | 12 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 13 | ) 14 | 15 | type configEvaluator struct { 16 | Peers 17 | // peerIPs map[wgtypes.Key]net.IP 18 | ipToPeer map[[net.IPv6len]byte]wgtypes.Key 19 | } 20 | 21 | var _ trust.Evaluator = &configEvaluator{} 22 | 23 | // CreateTrustEvaluator maps a peer config map into an evaluator that returns the 24 | // configured trust levels 25 | func CreateTrustEvaluator(peers Peers) trust.Evaluator { 26 | ret := &configEvaluator{ 27 | Peers: peers, 28 | ipToPeer: make(map[[net.IPv6len]byte]wgtypes.Key, len(peers)), 29 | // peerIPs: make(map[wgtypes.Key]net.IP, len(peers)), 30 | } 31 | for peer := range peers { 32 | pip := autopeer.AutoAddress(peer) 33 | ret.ipToPeer[util.IPToBytes(pip)] = peer 34 | // ret.peerIPs[peer] = pip 35 | } 36 | return ret 37 | } 38 | 39 | func (c *configEvaluator) IsKnown(_ fact.Subject) bool { 40 | // peers are never known to the config evaluator, only trusted 41 | return false 42 | } 43 | 44 | // TrustLevel looks up the fact's source IP in the list of known peers' 45 | // IPv6-LL addresses, and returns the configured trust level for that peer, 46 | // if found and configured 47 | func (c *configEvaluator) TrustLevel(_ *fact.Fact, source net.UDPAddr) *trust.Level { 48 | // we evaluate the trust level based on the _source_, not the _subject_ 49 | // source port evaluation is left to route-based-trust 50 | pk, ok := c.ipToPeer[util.IPToBytes(source.IP)] 51 | if !ok { 52 | // having valid peers in the config is fine 53 | log.Debug("No configured peer found for source: %v", source) 54 | return nil 55 | } 56 | pc, ok := c.Peers[pk] 57 | if !ok { 58 | log.Error("WAT: no configuration for recognized source %v = %v", source, pk) 59 | return nil 60 | } 61 | if pc.Trust == nil { 62 | return nil 63 | } 64 | // make a copy so caller can't modify it 65 | ret := *pc.Trust 66 | return &ret 67 | } 68 | -------------------------------------------------------------------------------- /config/configuration-trust_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | 7 | "github.com/fastcat/wirelink/autopeer" 8 | "github.com/fastcat/wirelink/fact" 9 | "github.com/fastcat/wirelink/internal/testutils" 10 | "github.com/fastcat/wirelink/trust" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func Test_configEvaluator_TrustLevel(t *testing.T) { 16 | k1 := testutils.MustKey(t) 17 | k2 := testutils.MustKey(t) 18 | k1u := &net.UDPAddr{IP: autopeer.AutoAddress(k1)} 19 | 20 | u1 := testutils.RandUDP4Addr(t) 21 | 22 | // type fields struct { 23 | // Peers Peers 24 | // ipToPeer map[[net.IPv6len]byte]wgtypes.Key 25 | // } 26 | type args struct { 27 | f *fact.Fact 28 | source net.UDPAddr 29 | } 30 | tests := []struct { 31 | name string 32 | // fields fields 33 | peers Peers 34 | args args 35 | want *trust.Level 36 | }{ 37 | { 38 | "no known peers", 39 | Peers{}, 40 | args{source: *u1}, 41 | nil, 42 | }, 43 | { 44 | "known peer un-configured", 45 | Peers{k1: &Peer{}}, 46 | args{source: *k1u}, 47 | nil, 48 | }, 49 | { 50 | "known peer configured", 51 | Peers{k1: &Peer{Trust: trust.Ptr(trust.Membership)}}, 52 | args{source: *k1u}, 53 | trust.Ptr(trust.Membership), 54 | }, 55 | { 56 | "unknown peer", 57 | Peers{k2: &Peer{Trust: trust.Ptr(trust.Membership)}}, 58 | args{source: *k1u}, 59 | nil, 60 | }, 61 | } 62 | for _, tt := range tests { 63 | t.Run(tt.name, func(t *testing.T) { 64 | // c := &configEvaluator{ 65 | // Peers: tt.fields.Peers, 66 | // ipToPeer: tt.fields.ipToPeer, 67 | // } 68 | c := CreateTrustEvaluator(tt.peers) 69 | got := c.TrustLevel(tt.args.f, tt.args.source) 70 | assert.Equal(t, tt.want, got) 71 | }) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /config/doc.go: -------------------------------------------------------------------------------- 1 | // Package config provides code for working with both the wirelink command line 2 | // arguments and its configuration file. 3 | package config 4 | -------------------------------------------------------------------------------- /config/flags_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "os" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/fastcat/wirelink/internal" 11 | "github.com/fastcat/wirelink/internal/testutils" 12 | 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func TestParse(t *testing.T) { 18 | fakeName := "wirelink_test" 19 | envArg := func(arg, value string) []string { 20 | return []string{ 21 | strings.ToUpper(fmt.Sprintf("%s_%s", fakeName, arg)), 22 | value, 23 | } 24 | } 25 | wgIface := fmt.Sprintf("wgFake%d", rand.Int()) 26 | configPath := "./testdata/" 27 | 28 | tests := []struct { 29 | name string 30 | args []string 31 | env [][]string 32 | wantRet *ServerData 33 | outputCheck func(*testing.T, []byte) 34 | assertion require.ErrorAssertionFunc 35 | }{ 36 | { 37 | "empty", 38 | nil, 39 | nil, 40 | &ServerData{Iface: "wg0"}, 41 | nil, 42 | require.NoError, 43 | }, 44 | { 45 | "arg iface", 46 | []string{"--iface", wgIface}, 47 | nil, 48 | &ServerData{Iface: wgIface}, 49 | nil, 50 | require.NoError, 51 | }, 52 | { 53 | "env iface", 54 | nil, 55 | [][]string{envArg("iface", wgIface)}, 56 | &ServerData{Iface: wgIface}, 57 | nil, 58 | require.NoError, 59 | }, 60 | { 61 | "help", 62 | []string{"--help"}, 63 | nil, 64 | nil, 65 | func(t *testing.T, output []byte) { 66 | text := string(output) 67 | assert.Contains(t, text, "Usage of "+fakeName) 68 | assert.Contains(t, text, internal.Version) 69 | }, 70 | require.NoError, 71 | }, 72 | { 73 | "bogus arg", 74 | []string{fmt.Sprintf("--garbage-%d", rand.Int())}, 75 | nil, 76 | nil, 77 | // TODO: want to validate output here, but it goes to stderr and so far only capturing stdout 78 | nil, 79 | require.Error, 80 | }, 81 | { 82 | "router", 83 | []string{"--router"}, 84 | nil, 85 | &ServerData{Iface: "wg0", Router: boolPtr(true)}, 86 | nil, 87 | require.NoError, 88 | }, 89 | { 90 | "router=true", 91 | []string{"--router=true"}, 92 | nil, 93 | &ServerData{Iface: "wg0", Router: boolPtr(true)}, 94 | nil, 95 | require.NoError, 96 | }, 97 | { 98 | "router=false", 99 | []string{"--router=false"}, 100 | nil, 101 | &ServerData{Iface: "wg0", Router: boolPtr(false)}, 102 | nil, 103 | require.NoError, 104 | }, 105 | // TODO: more tests 106 | } 107 | for _, tt := range tests { 108 | t.Run(tt.name, func(t *testing.T) { 109 | // make sure it doesn't use any real configs on the system 110 | configPathKey := fmt.Sprintf("%s_%s", strings.ToUpper(fakeName), "CONFIG_PATH") 111 | os.Setenv(configPathKey, configPath) 112 | defer os.Unsetenv(configPathKey) 113 | if tt.wantRet != nil { 114 | tt.wantRet.ConfigPath = configPath 115 | } 116 | 117 | // set per-test env, clear it at the end 118 | for _, envPair := range tt.env { 119 | os.Setenv(envPair[0], envPair[1]) 120 | defer os.Unsetenv(envPair[0]) 121 | } 122 | 123 | var gotRet *ServerData 124 | var err error 125 | outputData := testutils.CaptureOutput(t, func() { 126 | args := append([]string{fakeName}, tt.args...) 127 | flags, vcfg := Init(args) 128 | gotRet, err = Parse(flags, vcfg, args) 129 | }) 130 | 131 | tt.assertion(t, err) 132 | assert.Equal(t, tt.wantRet, gotRet) 133 | 134 | if tt.outputCheck != nil { 135 | tt.outputCheck(t, outputData) 136 | } 137 | }) 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /config/peer-config-set.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/fastcat/wirelink/trust" 7 | 8 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 9 | ) 10 | 11 | // Peers represents a set of peer configs, with handy access methods that avoid 12 | // boiler plate for peers that are not configured 13 | type Peers map[wgtypes.Key]*Peer 14 | 15 | // Has checks if the given peer key has any configuration data 16 | func (p Peers) Has(peer wgtypes.Key) bool { 17 | val, ok := p[peer] 18 | return ok && val != nil 19 | } 20 | 21 | // Name returns the name of the peer, if configured, or else the empty string 22 | func (p Peers) Name(peer wgtypes.Key) string { 23 | if config := p[peer]; config != nil { 24 | return config.Name 25 | } 26 | return "" 27 | } 28 | 29 | // Trust returns the configured trust level (if present and valid) or else the 30 | // provided default 31 | func (p Peers) Trust(peer wgtypes.Key, def trust.Level) trust.Level { 32 | if config, ok := p[peer]; ok && config.Trust != nil { 33 | return *config.Trust 34 | } 35 | // else fall through 36 | return def 37 | } 38 | 39 | // AnyTrustedAt returns whether any peer is configured with a trust level of 40 | // at least the given level 41 | func (p Peers) AnyTrustedAt(level trust.Level) bool { 42 | for _, c := range p { 43 | if c.Trust != nil && *c.Trust >= level { 44 | return true 45 | } 46 | } 47 | return false 48 | } 49 | 50 | // IsFactExchanger returns true if the peer is configured as a FactExchanger 51 | func (p Peers) IsFactExchanger(peer wgtypes.Key) bool { 52 | config, ok := p[peer] 53 | return ok && config.FactExchanger 54 | } 55 | 56 | // IsBasic returns true if the peer is explicitly configured as a basic peer, 57 | // or false otherwise 58 | func (p Peers) IsBasic(peer wgtypes.Key) bool { 59 | config, ok := p[peer] 60 | return ok && config.Basic 61 | } 62 | 63 | // AllowedIPs returns the array of AllowedIPs explicitly configured for the peer, if any 64 | func (p Peers) AllowedIPs(peer wgtypes.Key) []net.IPNet { 65 | if config, ok := p[peer]; ok { 66 | return config.AllowedIPs 67 | } 68 | return nil 69 | } 70 | 71 | // Endpoints returns the array of Endpoints explicitly configured for the peer, if any 72 | func (p Peers) Endpoints(peer wgtypes.Key) []PeerEndpoint { 73 | if config, ok := p[peer]; ok { 74 | return config.Endpoints 75 | } 76 | return nil 77 | } 78 | -------------------------------------------------------------------------------- /config/peer-config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/fastcat/wirelink/trust" 8 | ) 9 | 10 | // PeerEndpoint represents a single endpoint (possibly by hostname) for a peer 11 | type PeerEndpoint struct { 12 | // Host may be either an IP or a hostname 13 | Host string 14 | Port int 15 | } 16 | 17 | // Peer represents the parsed info about a peer read from the config file 18 | type Peer struct { 19 | Name string 20 | Trust *trust.Level 21 | FactExchanger bool 22 | Endpoints []PeerEndpoint 23 | AllowedIPs []net.IPNet 24 | Basic bool 25 | } 26 | 27 | func (p *Peer) String() string { 28 | trustStr := "nil" 29 | if p.Trust != nil { 30 | trustStr = p.Trust.String() 31 | } 32 | return fmt.Sprintf("{Name:%s Trust:%s Exch:%v EPs:%d AIPs:%d B:%t}", 33 | p.Name, trustStr, p.FactExchanger, len(p.Endpoints), len(p.AllowedIPs), p.Basic) 34 | } 35 | -------------------------------------------------------------------------------- /config/peer-data.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/fastcat/wirelink/trust" 8 | 9 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 10 | ) 11 | 12 | // PeerData represents the raw data to configure a peer read from the config file 13 | type PeerData struct { 14 | PublicKey string 15 | Name string 16 | Trust string 17 | FactExchanger bool 18 | Endpoints []string 19 | AllowedIPs []string 20 | Basic bool 21 | } 22 | 23 | // Parse validates the info in the PeerData and returns the parsed tuple + error 24 | func (p *PeerData) Parse() (key wgtypes.Key, peer Peer, err error) { 25 | if key, err = wgtypes.ParseKey(p.PublicKey); err != nil { 26 | return 27 | } 28 | peer.Name = p.Name 29 | if p.Trust != "" { 30 | val, ok := trust.Values[p.Trust] 31 | if !ok { 32 | err = fmt.Errorf("invalid trust level '%s'", p.Trust) 33 | return 34 | } 35 | peer.Trust = &val 36 | } 37 | peer.FactExchanger = p.FactExchanger 38 | peer.Endpoints = make([]PeerEndpoint, 0, len(p.Endpoints)) 39 | // we don't do the DNS resolution here because we want it to refresh 40 | // periodically, esp. if we move across a split horizon boundary 41 | // we do want to validate the host/port split however 42 | for _, ep := range p.Endpoints { 43 | var host, portString string 44 | var port int 45 | if host, portString, err = net.SplitHostPort(ep); err != nil { 46 | err = fmt.Errorf("bad endpoint '%s' for '%s'='%s': %w", ep, p.PublicKey, p.Name, err) 47 | return 48 | } 49 | if port, err = net.LookupPort("udp", portString); err != nil { 50 | err = fmt.Errorf("bad endpoint port in '%s' for '%s'='%s': %w", ep, p.PublicKey, p.Name, err) 51 | return 52 | } 53 | 54 | // try to resolve the host, ignoring DNS errors and just looking for parse errors 55 | // NOTE: it's actually really hard to get anything other than a DNSError out of LookupIP 56 | // anything that doesn't parse as an IP is more or less assumed to be a hostname, 57 | // even if it is not actually valid as such (e.g. all numbers), and then a lookup attempted, 58 | // and if the lookup fails, we get an DNS error 59 | _, err = net.LookupIP(host) 60 | if _, ok := err.(*net.DNSError); ok { 61 | // ignore DNS errors ... which actually ends up as basically everything except for a 62 | // parse error for giving the empty string 63 | err = nil 64 | } else if err != nil { 65 | // this branch is very hard, if not impossible, to reach in a test or in the real world 66 | err = fmt.Errorf("bad endpoint host in '%s' for '%s'='%s': %w", ep, p.PublicKey, p.Name, err) 67 | return 68 | } 69 | 70 | // TODO: can validate host portion is syntactically valid: do a lookup and 71 | // ignore host not found errors 72 | 73 | peer.Endpoints = append(peer.Endpoints, PeerEndpoint{ 74 | Host: host, 75 | Port: port, 76 | }) 77 | } 78 | 79 | peer.AllowedIPs = make([]net.IPNet, 0, len(p.AllowedIPs)) 80 | for _, aip := range p.AllowedIPs { 81 | var ipn *net.IPNet 82 | _, ipn, err = net.ParseCIDR(aip) 83 | if err != nil { 84 | err = fmt.Errorf("bad AllowedIP '%s' for '%s'='%s': %w", aip, p.PublicKey, p.Name, err) 85 | return 86 | } 87 | // NOTE: we don't need to run ipn.IP through ipn.Mask, as ParseCIDR does that for us 88 | // we want to do that so config can contain minor harmless mistakes, or also use the 89 | // masked-out bits as documentation for the peer's own/primary IP within that network 90 | // ipn is returned by reference, should never be returned nil 91 | peer.AllowedIPs = append(peer.AllowedIPs, *ipn) 92 | } 93 | 94 | peer.Basic = p.Basic 95 | 96 | return 97 | } 98 | -------------------------------------------------------------------------------- /config/server-config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "path/filepath" 5 | 6 | "github.com/fastcat/wirelink/log" 7 | ) 8 | 9 | // Server describes the configuration for the server, after parsing from various sources 10 | type Server struct { 11 | Iface string 12 | Port int 13 | Chatty bool 14 | 15 | AutoDetectRouter bool 16 | IsRouterNow bool 17 | 18 | ReportIfaces []string 19 | HideIfaces []string 20 | 21 | Peers Peers 22 | 23 | Debug bool 24 | } 25 | 26 | // ShouldReportIface checks a given local network interface name against the config 27 | // for whether we should tell other peers about our configuration on it 28 | func (s *Server) ShouldReportIface(name string) bool { 29 | // TODO: report any broken globs found here _once_ (startup checks can't detect all broken globs) 30 | 31 | // Never tell peers about the wireguard interface itself 32 | if name == s.Iface { 33 | return false 34 | } 35 | 36 | // MUST NOT match any excludes 37 | for _, glob := range s.HideIfaces { 38 | if matched, err := filepath.Match(glob, name); matched && err == nil { 39 | log.Debug("Hiding iface '%s' because it matches exclude '%s'\n", name, glob) 40 | return false 41 | } 42 | } 43 | if len(s.ReportIfaces) == 0 { 44 | log.Debug("Including iface '%s' because no includes are configured", name) 45 | return true 46 | } 47 | // if any includes are specified, name MUST match one of them 48 | for _, glob := range s.ReportIfaces { 49 | if matched, err := filepath.Match(glob, name); matched && err == nil { 50 | log.Debug("Including iface '%s' because it matches include '%s'\n", name, glob) 51 | return true 52 | } 53 | } 54 | log.Debug("Hiding iface '%s' because it doesn't match any includes\n", name) 55 | return false 56 | } 57 | -------------------------------------------------------------------------------- /config/server-config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestServer_ShouldReportIface(t *testing.T) { 12 | ifName := func(prefix string) string { 13 | return fmt.Sprintf("%s%d", prefix, rand.Int31()) 14 | } 15 | self := ifName("wg") 16 | matchEth := []string{"eth*"} 17 | 18 | type fields struct { 19 | Iface string 20 | ReportIfaces []string 21 | HideIfaces []string 22 | } 23 | type args struct { 24 | name string 25 | } 26 | tests := []struct { 27 | name string 28 | fields fields 29 | args args 30 | want bool 31 | }{ 32 | {"self", fields{Iface: self}, args{self}, false}, 33 | {"self included", fields{Iface: self, ReportIfaces: []string{self}}, args{self}, false}, 34 | {"default", fields{Iface: self}, args{ifName("eth")}, true}, 35 | {"included", fields{Iface: self, ReportIfaces: matchEth}, args{ifName("eth")}, true}, 36 | {"excluded", fields{Iface: self, HideIfaces: matchEth}, args{ifName("eth")}, false}, 37 | {"exclude priority", fields{Iface: self, ReportIfaces: matchEth, HideIfaces: matchEth}, args{ifName("eth")}, false}, 38 | {"not included", fields{Iface: self, ReportIfaces: matchEth}, args{ifName("wl")}, false}, 39 | } 40 | for _, tt := range tests { 41 | t.Run(tt.name, func(t *testing.T) { 42 | s := &Server{ 43 | Iface: tt.fields.Iface, 44 | ReportIfaces: tt.fields.ReportIfaces, 45 | HideIfaces: tt.fields.HideIfaces, 46 | } 47 | got := s.ShouldReportIface(tt.args.name) 48 | assert.Equal(t, tt.want, got) 49 | }) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /config/utils_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "math/rand" 5 | ) 6 | 7 | func letter() rune { 8 | return 'a' + int32(rand.Intn(26)) 9 | } 10 | 11 | func boolean() bool { 12 | return rand.Intn(2) == 1 13 | } 14 | 15 | func boolPtr(value bool) *bool { 16 | return &value 17 | } 18 | -------------------------------------------------------------------------------- /detect/router-detect.go: -------------------------------------------------------------------------------- 1 | // Package detect provides utilities to detect if the local system is a router, 2 | // or if some remote peer is one. 3 | package detect 4 | 5 | import ( 6 | "github.com/fastcat/wirelink/log" 7 | "github.com/fastcat/wirelink/util" 8 | 9 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 10 | ) 11 | 12 | // IsPeerRouter considers a router to be a peer that has a global unicast allowed 13 | // IP with a CIDR mask less than the full IP 14 | func IsPeerRouter(peer *wgtypes.Peer) bool { 15 | for _, aip := range peer.AllowedIPs { 16 | // don't ignore private (RFC1918-ish) subnets here, as they are a valid 17 | // thing to be a router for 18 | if !aip.IP.IsGlobalUnicast() { 19 | continue 20 | } 21 | apiNorm := util.NormalizeIP(aip.IP) 22 | ones, size := aip.Mask.Size() 23 | if len(apiNorm)*8 == size && ones < size { 24 | return true 25 | } 26 | } 27 | return false 28 | } 29 | 30 | // IsDeviceRouter tries to detect whether the local device is a router for other peers. 31 | // Currently it does this by assuming that it is a router if and only if nobody else is. 32 | // TODO: try to check local networking config for signs of routing configuration 33 | func IsDeviceRouter(dev *wgtypes.Device) bool { 34 | otherRouters := false 35 | for _, p := range dev.Peers { 36 | if IsPeerRouter(&p) { 37 | log.Debug("Router autodetect: found router peer %v", p.PublicKey) 38 | otherRouters = true 39 | break 40 | } 41 | } 42 | 43 | return !otherRouters 44 | } 45 | -------------------------------------------------------------------------------- /detect/router-detect_test.go: -------------------------------------------------------------------------------- 1 | package detect 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | 7 | "github.com/fastcat/wirelink/internal/testutils" 8 | 9 | "github.com/stretchr/testify/assert" 10 | 11 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 12 | ) 13 | 14 | func TestIsPeerRouter(t *testing.T) { 15 | type args struct { 16 | peer *wgtypes.Peer 17 | } 18 | tests := []struct { 19 | name string 20 | args args 21 | want bool 22 | }{ 23 | {"empty", args{&wgtypes.Peer{}}, false}, 24 | { 25 | "non-routable v4 net", 26 | args{&wgtypes.Peer{AllowedIPs: []net.IPNet{ 27 | testutils.RandIPNet(t, net.IPv4len, []byte{169, 254}, nil, 24), 28 | }}}, 29 | false, 30 | }, 31 | { 32 | "non-routable v4 host", 33 | args{&wgtypes.Peer{AllowedIPs: []net.IPNet{ 34 | testutils.RandIPNet(t, net.IPv4len, []byte{169, 254}, nil, 32), 35 | }}}, 36 | false, 37 | }, 38 | { 39 | "routable v4 net", 40 | args{&wgtypes.Peer{AllowedIPs: []net.IPNet{ 41 | testutils.RandIPNet(t, net.IPv4len, []byte{100}, nil, 24), 42 | }}}, 43 | true, 44 | }, 45 | { 46 | "routable v4 host", 47 | args{&wgtypes.Peer{AllowedIPs: []net.IPNet{ 48 | testutils.RandIPNet(t, net.IPv4len, []byte{100}, nil, 32), 49 | }}}, 50 | false, 51 | }, 52 | { 53 | "non-routable v6 net", 54 | args{&wgtypes.Peer{AllowedIPs: []net.IPNet{ 55 | testutils.RandIPNet(t, net.IPv6len, []byte{0xfe, 0x80}, nil, 64), 56 | }}}, 57 | false, 58 | }, 59 | { 60 | "non-routable v6 host", 61 | args{&wgtypes.Peer{AllowedIPs: []net.IPNet{ 62 | testutils.RandIPNet(t, net.IPv6len, []byte{0xfe, 0x80}, nil, 128), 63 | }}}, 64 | false, 65 | }, 66 | { 67 | "routable v6 net", 68 | args{&wgtypes.Peer{AllowedIPs: []net.IPNet{ 69 | testutils.RandIPNet(t, net.IPv6len, []byte{0x20}, nil, 64), 70 | }}}, 71 | true, 72 | }, 73 | { 74 | "routable v6 host", 75 | args{&wgtypes.Peer{AllowedIPs: []net.IPNet{ 76 | testutils.RandIPNet(t, net.IPv4len, []byte{0x20}, nil, 128), 77 | }}}, 78 | false, 79 | }, 80 | } 81 | for _, tt := range tests { 82 | t.Run(tt.name, func(t *testing.T) { 83 | got := IsPeerRouter(tt.args.peer) 84 | assert.Equal(t, tt.want, got) 85 | }) 86 | } 87 | } 88 | 89 | func TestIsDeviceRouter(t *testing.T) { 90 | router := func() wgtypes.Peer { 91 | return wgtypes.Peer{ 92 | AllowedIPs: []net.IPNet{ 93 | testutils.RandIPNet(t, net.IPv4len, []byte{100}, nil, 24), 94 | }, 95 | } 96 | } 97 | leaf := func() wgtypes.Peer { 98 | return wgtypes.Peer{ 99 | AllowedIPs: []net.IPNet{ 100 | testutils.RandIPNet(t, net.IPv4len, []byte{100}, nil, 32), 101 | }, 102 | } 103 | } 104 | dev := func(peers ...wgtypes.Peer) *wgtypes.Device { 105 | return &wgtypes.Device{ 106 | Peers: peers, 107 | } 108 | } 109 | type args struct { 110 | dev *wgtypes.Device 111 | } 112 | tests := []struct { 113 | name string 114 | args args 115 | want bool 116 | }{ 117 | {"empty", args{dev()}, true}, 118 | {"other leaf", args{dev(leaf())}, true}, 119 | {"other router", args{dev(leaf(), router())}, false}, 120 | } 121 | for _, tt := range tests { 122 | t.Run(tt.name, func(t *testing.T) { 123 | got := IsDeviceRouter(tt.args.dev) 124 | assert.Equal(t, tt.want, got) 125 | }) 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /device/auto-ip.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/fastcat/wirelink/apply" 7 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 8 | ) 9 | 10 | // EnsurePeersAutoIP updates the config of the device, if needed, to ensure all 11 | // peers have their IPv6-LL IP listed in their AllowedIPs. 12 | // It returns the number of peers modified and any error that happens 13 | func (d *Device) EnsurePeersAutoIP() (int, error) { 14 | state, err := d.State() 15 | if err != nil { 16 | return 0, err 17 | } 18 | 19 | var cfg wgtypes.Config 20 | for _, peer := range state.Peers { 21 | pcfg, _ := apply.EnsurePeerAutoIP(&peer, nil) 22 | if pcfg != nil { 23 | cfg.Peers = append(cfg.Peers, *pcfg) 24 | } 25 | } 26 | 27 | if len(cfg.Peers) == 0 { 28 | return 0, nil 29 | } 30 | 31 | err = d.ConfigureDevice(cfg) 32 | if err != nil { 33 | return 0, fmt.Errorf( 34 | "unable to configure %s with %d new peer IPv6-LL AllowedIPs: %w", d.iface, len(cfg.Peers), err) 35 | } 36 | 37 | return len(cfg.Peers), nil 38 | } 39 | -------------------------------------------------------------------------------- /device/config.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 4 | 5 | // ConfigureDevice wraps the underlying wgctrl method 6 | func (d *Device) ConfigureDevice(cfg wgtypes.Config) error { 7 | d.mu.Lock() 8 | defer d.mu.Unlock() 9 | 10 | d.dirty = true 11 | err := d.ctrl.ConfigureDevice(d.iface, cfg) 12 | return err 13 | } 14 | -------------------------------------------------------------------------------- /device/device.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/fastcat/wirelink/internal" 7 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 8 | ) 9 | 10 | // Device wraps a WgClient and an associated device name 11 | type Device struct { 12 | mu sync.Mutex 13 | ctrl internal.WgClient 14 | iface string 15 | ownCtrl bool 16 | 17 | state *wgtypes.Device 18 | dirty bool 19 | lastErr error 20 | } 21 | 22 | // New creates a new device from an already open control interface. The Close() 23 | // method of the returned Device will be a no-op. 24 | func New( 25 | ctrl internal.WgClient, 26 | iface string, 27 | ) (*Device, error) { 28 | dev := &Device{ctrl: ctrl, iface: iface} 29 | if err := dev.read(); err != nil { 30 | return nil, err 31 | } 32 | return dev, nil 33 | } 34 | 35 | // Take is the same as New, except that it takes ownership of the control 36 | // interface if initialization is successful. 37 | func Take( 38 | ctrl internal.WgClient, 39 | iface string, 40 | ) (*Device, error) { 41 | dev := &Device{ctrl: ctrl, iface: iface, ownCtrl: true} 42 | if err := dev.read(); err != nil { 43 | return nil, err 44 | } 45 | return dev, nil 46 | } 47 | 48 | // TODO: Open() 49 | 50 | // Close closes the underlying control interface, if the Device owns it. It is 51 | // not safe to call Close if other goroutines using the device are active. 52 | func (d *Device) Close() error { 53 | d.mu.Lock() 54 | defer d.mu.Unlock() 55 | if d.ownCtrl { 56 | if err := d.ctrl.Close(); err != nil { 57 | return err 58 | } 59 | d.ctrl = nil 60 | } 61 | return nil 62 | } 63 | 64 | // read updates the internal cache of the device state. It must be called with 65 | // the state mutex locked, or from some context such as initialization where it 66 | // is not possible for another goroutine to be accessing the device 67 | // concurrently. 68 | func (d *Device) read() error { 69 | s, err := d.ctrl.Device(d.iface) 70 | d.lastErr = err 71 | if err != nil { 72 | d.dirty = true 73 | return err 74 | } 75 | d.state = s 76 | d.dirty = false 77 | return nil 78 | } 79 | 80 | // State gets the current device state. It will attempt to refresh it if dirty. 81 | // If refresh fails, it will return the last state along with the refresh error. 82 | // The returned state will never be nil, because New/Open will refuse to create 83 | // a Device if they cannot read an initial state. 84 | func (d *Device) State() (*wgtypes.Device, error) { 85 | d.mu.Lock() 86 | defer d.mu.Unlock() 87 | if d.dirty { 88 | // ignore the error, if the read fails we just return the last state 89 | _ = d.read() 90 | } 91 | return d.state, d.lastErr 92 | } 93 | 94 | // Refresh is like State(), but always re-reads the device state, even if it 95 | // isn't dirty. 96 | // func (d *Device) Refresh() (*wgtypes.Device, error) { 97 | // d.mu.Lock() 98 | // defer d.mu.Unlock() 99 | // _ = d.read() 100 | // return d.state, d.lastErr 101 | // } 102 | 103 | // Dirty marks the device state dirty forcing the next read to refresh the data. 104 | func (d *Device) Dirty() { 105 | // for tests 106 | if d == nil { 107 | return 108 | } 109 | 110 | d.mu.Lock() 111 | d.dirty = true 112 | d.mu.Unlock() 113 | } 114 | -------------------------------------------------------------------------------- /device/doc.go: -------------------------------------------------------------------------------- 1 | // Package device wraps the wgctrl library with concurrency safety and dirty 2 | // tracking to simplify interactions from the wirelink server. 3 | package device 4 | -------------------------------------------------------------------------------- /fact/accumulator.go: -------------------------------------------------------------------------------- 1 | package fact 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/fastcat/wirelink/signing" 8 | 9 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 10 | ) 11 | 12 | // GroupAccumulator is a helper to aggregate individual facts into (signed) 13 | // groups of a max size 14 | type GroupAccumulator struct { 15 | maxGroupLen int 16 | groups [][]byte 17 | now time.Time 18 | } 19 | 20 | // NewAccumulator initializes a new GroupAccumulator with a given max inner 21 | // size per group. 22 | func NewAccumulator(maxGroupLen int, now time.Time) *GroupAccumulator { 23 | return &GroupAccumulator{ 24 | maxGroupLen: maxGroupLen, 25 | groups: make([][]byte, 1), 26 | now: now, 27 | } 28 | } 29 | 30 | // AddFact appends the given fact into the accumulator 31 | func (ga *GroupAccumulator) AddFact(f *Fact) error { 32 | b, err := f.MarshalBinaryNow(ga.now) 33 | if err != nil { 34 | return fmt.Errorf("unable to convert fact to packet bytes: %w", err) 35 | } 36 | lgi := len(ga.groups) - 1 37 | lg := ga.groups[lgi] 38 | if len(lg)+len(b) > ga.maxGroupLen { 39 | // make another group 40 | ga.groups = append(ga.groups, b) 41 | } else { 42 | ga.groups[lgi] = append(lg, b...) 43 | } 44 | return nil 45 | } 46 | 47 | // AddFactIfRoom conditionally adds the fact if and only if it won't result in 48 | // creating a new group 49 | func (ga *GroupAccumulator) AddFactIfRoom(f *Fact) (added bool, err error) { 50 | b, err := f.MarshalBinaryNow(ga.now) 51 | if err != nil { 52 | return false, fmt.Errorf("unable to convert fact to packet bytes: %w", err) 53 | } 54 | lgi := len(ga.groups) - 1 55 | lg := ga.groups[lgi] 56 | if len(lg)+len(b) > ga.maxGroupLen || len(lg) == 0 { 57 | return false, nil 58 | } 59 | ga.groups[lgi] = append(lg, b...) 60 | return true, nil 61 | } 62 | 63 | // MakeSignedGroups converts all the accumulated facts into SignedGroups of no 64 | // more than the specified max inner size. 65 | func (ga *GroupAccumulator) MakeSignedGroups( 66 | s *signing.Signer, 67 | recipient *wgtypes.Key, 68 | ) ([]*Fact, error) { 69 | ret := make([]*Fact, 0, len(ga.groups)) 70 | subject := PeerSubject{Key: s.PublicKey} 71 | for _, g := range ga.groups { 72 | if len(g) == 0 { 73 | continue 74 | } 75 | // TODO: have signer cache shared key 76 | nonce, tag, err := s.SignFor(g, recipient) 77 | if err != nil { 78 | return nil, fmt.Errorf("unable to sign group data: %w", err) 79 | } 80 | value := SignedGroupValue{ 81 | Nonce: nonce, 82 | Tag: tag, 83 | InnerBytes: g, 84 | } 85 | ret = append(ret, &Fact{ 86 | Attribute: AttributeSignedGroup, 87 | // zero time will turn into a TTL of zero 88 | Expires: time.Time{}, 89 | Subject: &subject, 90 | Value: &value, 91 | }) 92 | } 93 | return ret, nil 94 | } 95 | -------------------------------------------------------------------------------- /fact/accumulator_test.go: -------------------------------------------------------------------------------- 1 | package fact 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/google/uuid" 8 | 9 | "github.com/fastcat/wirelink/internal/testutils" 10 | "github.com/fastcat/wirelink/signing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestAccumulatorLimits(t *testing.T) { 17 | ef, ep := mustMockAlivePacket(t, nil, nil) 18 | 19 | a := NewAccumulator(len(ep)*4-1, time.Now()) 20 | 21 | for i := 0; i < 4; i++ { 22 | err := a.AddFact(ef) 23 | require.Nil(t, err) 24 | } 25 | 26 | require.Len(t, a.groups, 2, "Should split into 2 groups") 27 | assert.Len(t, a.groups[0], len(ep)*3, "Should have 3 packets in group 1") 28 | assert.Len(t, a.groups[1], len(ep), "Should have 1 packet in group 2") 29 | } 30 | 31 | func TestAccumulatorSigning(t *testing.T) { 32 | ef, ep := mustMockAlivePacket(t, nil, nil) 33 | 34 | a := NewAccumulator(len(ep)*4-1, time.Now()) 35 | for i := 0; i < 4; i++ { 36 | err := a.AddFact(ef) 37 | require.Nil(t, err) 38 | } 39 | 40 | priv, signer := testutils.MustKeyPair(t) 41 | _, pub := testutils.MustKeyPair(t) 42 | 43 | s := signing.New(priv) 44 | 45 | facts, err := a.MakeSignedGroups(s, &pub) 46 | require.Nil(t, err) 47 | 48 | require.Len(t, facts, 2, "Should have two SignedGroupValues") 49 | 50 | for i, sf := range facts { 51 | assert.Equal(t, AttributeSignedGroup, sf.Attribute, "Signing output should be SignedGroups") 52 | assert.IsType(t, &PeerSubject{}, sf.Subject) 53 | // the subject must be the public key of the signer, _not the recipient_ 54 | assert.Equal(t, signer, sf.Subject.(*PeerSubject).Key) 55 | assert.False(t, sf.Expires.After(time.Now()), "Expiration should be <= now") 56 | require.IsType(t, &SignedGroupValue{}, sf.Value, "SG Value should be an SGV") 57 | sgv := sf.Value.(*SignedGroupValue) 58 | 59 | // signing checks are handled elsewhere 60 | switch i { 61 | case 0: 62 | assert.Len(t, sgv.InnerBytes, len(ep)*3, "Should have 3 facts in first packet") 63 | case 1: 64 | assert.Len(t, sgv.InnerBytes, len(ep)*1, "Should have 1 fact in second packet") 65 | default: 66 | require.FailNow(t, "WAT?!") 67 | } 68 | } 69 | } 70 | 71 | func TestGroupAccumulator_AddFactIfRoom_OneByteTooSmall(t *testing.T) { 72 | f := &Fact{ 73 | Attribute: AttributeAlive, 74 | Subject: &PeerSubject{testutils.MustKey(t)}, 75 | Value: &UUIDValue{uuid.Must(uuid.NewRandom())}, 76 | Expires: time.Now(), 77 | } 78 | b, err := f.MarshalBinary() 79 | require.Nil(t, err) 80 | ga := NewAccumulator(len(b)*3-1, time.Now()) 81 | 82 | err = ga.AddFact(f) 83 | require.Nil(t, err) 84 | 85 | added, err := ga.AddFactIfRoom(f) 86 | require.Nil(t, err) 87 | assert.True(t, added) 88 | 89 | added, err = ga.AddFactIfRoom(f) 90 | require.Nil(t, err) 91 | assert.False(t, added) 92 | } 93 | 94 | func TestGroupAccumulator_AddFactIfRoom_JustRight(t *testing.T) { 95 | f := &Fact{ 96 | Attribute: AttributeAlive, 97 | Subject: &PeerSubject{testutils.MustKey(t)}, 98 | Value: &UUIDValue{uuid.Must(uuid.NewRandom())}, 99 | Expires: time.Now(), 100 | } 101 | b, err := f.MarshalBinary() 102 | require.Nil(t, err) 103 | ga := NewAccumulator(len(b)*2, time.Now()) 104 | 105 | err = ga.AddFact(f) 106 | require.Nil(t, err) 107 | 108 | added, err := ga.AddFactIfRoom(f) 109 | require.Nil(t, err) 110 | assert.True(t, added) 111 | } 112 | -------------------------------------------------------------------------------- /fact/doc.go: -------------------------------------------------------------------------------- 1 | // Package fact provides the core code for representing facts, and their 2 | // serialization and deserialization. 3 | package fact 4 | -------------------------------------------------------------------------------- /fact/fact.go: -------------------------------------------------------------------------------- 1 | package fact 2 | 3 | import ( 4 | "bytes" 5 | "encoding" 6 | "encoding/binary" 7 | "fmt" 8 | "math" 9 | "time" 10 | 11 | "github.com/fastcat/wirelink/util" 12 | ) 13 | 14 | // fact types, denoted as attributes of a subject 15 | const ( 16 | AttributeUnknown Attribute = 0 17 | AttributeAlive Attribute = '!' 18 | AttributeEndpointV4 Attribute = 'e' 19 | AttributeEndpointV6 Attribute = 'E' 20 | AttributeAllowedCidrV4 Attribute = 'a' 21 | AttributeAllowedCidrV6 Attribute = 'A' 22 | AttributeMember Attribute = 'm' 23 | AttributeMemberMetadata Attribute = 'M' 24 | // A signed group is a bit different from other facts 25 | // in this case, the subject is actually the source, 26 | // and the value is a signed aggregate of other facts. 27 | AttributeSignedGroup Attribute = 'S' 28 | ) 29 | 30 | // Fact represents a single piece of information about a subject, with an 31 | // associated expiration time 32 | type Fact struct { 33 | encoding.BinaryMarshaler 34 | util.Decodable 35 | 36 | Attribute Attribute 37 | Expires time.Time 38 | Subject Subject 39 | Value Value 40 | } 41 | 42 | func (f *Fact) String() string { 43 | return f.FancyString(func(s Subject) string { return s.String() }, time.Now()) 44 | } 45 | 46 | // FancyString formats the fact as a string using a custom helper to format 47 | // the subject, most commonly to replace peer keys with names 48 | func (f *Fact) FancyString( 49 | subjectFormatter func(s Subject) string, 50 | now time.Time, 51 | ) string { 52 | if f == nil { 53 | return fmt.Sprintf("%v", nil) 54 | } 55 | return fmt.Sprintf( 56 | "{a:%c s:%s v:%s ttl:%.3f}", 57 | f.Attribute, 58 | subjectFormatter(f.Subject), 59 | f.Value, 60 | f.Expires.Sub(now).Seconds(), 61 | ) 62 | } 63 | 64 | // MarshalBinary serializes a Fact to its on-wire format 65 | func (f *Fact) MarshalBinary() ([]byte, error) { 66 | return f.MarshalBinaryNow(time.Now()) 67 | } 68 | 69 | // MarshalBinaryNow is like MarshalBinary, except it uses a provided value of 70 | // `now` so that the output is deterministic 71 | func (f *Fact) MarshalBinaryNow(now time.Time) ([]byte, error) { 72 | var buf bytes.Buffer 73 | var tmp [binary.MaxVarintLen64]byte 74 | var tmpLen int 75 | 76 | buf.WriteByte(byte(f.Attribute)) 77 | 78 | ttl := f.Expires.Sub(now) / timeScale 79 | // clamp ttl to uint16 range 80 | // TODO: warn if we somehow get outside this range 81 | if ttl < 0 { 82 | ttl = 0 83 | } else if ttl > math.MaxUint16 { 84 | ttl = math.MaxUint16 85 | } 86 | tmpLen = binary.PutUvarint(tmp[:], uint64(ttl)) 87 | if n, err := buf.Write(tmp[0:tmpLen]); err != nil || n != tmpLen { 88 | return buf.Bytes(), util.WrapOrNewf(err, "failed to write ttl bytes, wrote %d of %d", n, tmpLen) 89 | } 90 | 91 | // these should never return errors, but ... 92 | 93 | subjectData, err := f.Subject.MarshalBinary() 94 | if err != nil { 95 | return buf.Bytes(), fmt.Errorf("failed to marshal Subject: %w", err) 96 | } 97 | if n, err := buf.Write(subjectData); err != nil || n != len(subjectData) { 98 | return buf.Bytes(), util.WrapOrNewf(err, "failed to write subject to buffer, wrote %d of %d", n, len(subjectData)) 99 | } 100 | 101 | valueData, err := f.Value.MarshalBinary() 102 | if err != nil { 103 | return buf.Bytes(), fmt.Errorf("failed to marshal Value: %w", err) 104 | } 105 | if n, err := buf.Write(valueData); err != nil || n != len(valueData) { 106 | return buf.Bytes(), util.WrapOrNewf(err, "failed to write Value to buffer, wrote %d of %d", n, len(valueData)) 107 | } 108 | 109 | return buf.Bytes(), nil 110 | } 111 | -------------------------------------------------------------------------------- /fact/parse.go: -------------------------------------------------------------------------------- 1 | package fact 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "math" 7 | "net" 8 | "time" 9 | 10 | "github.com/fastcat/wirelink/util" 11 | ) 12 | 13 | // a decodeHinter is expected to initialize the Subject and Value fields of the 14 | // given Fact to the correct types, and return the expected length (in bytes) 15 | // of the encoded value, or -1 if that length is unknown or variable. 16 | type decodeHinter = func(*Fact) (valueLength int) 17 | 18 | // decodeHints provides a lookup table for how to decode each valid attribute value 19 | var decodeHints = map[Attribute]decodeHinter{ 20 | AttributeAlive: func(f *Fact) int { 21 | // Modern ping packet with boot id embedded in value 22 | f.Subject = &PeerSubject{} 23 | f.Value = &UUIDValue{} 24 | return uuidLen 25 | }, 26 | AttributeEndpointV4: func(f *Fact) int { 27 | f.Subject = &PeerSubject{} 28 | f.Value = &IPPortValue{} 29 | return net.IPv4len + 2 30 | }, 31 | AttributeEndpointV6: func(f *Fact) int { 32 | f.Subject = &PeerSubject{} 33 | f.Value = &IPPortValue{} 34 | return net.IPv6len + 2 35 | }, 36 | AttributeAllowedCidrV4: func(f *Fact) int { 37 | f.Subject = &PeerSubject{} 38 | f.Value = &IPNetValue{} 39 | return net.IPv4len + 1 40 | }, 41 | AttributeAllowedCidrV6: func(f *Fact) int { 42 | f.Subject = &PeerSubject{} 43 | f.Value = &IPNetValue{} 44 | return net.IPv6len + 1 45 | }, 46 | 47 | AttributeMember: func(f *Fact) int { 48 | // member attrs don't have a value 49 | f.Subject = &PeerSubject{} 50 | f.Value = &EmptyValue{} 51 | return 0 52 | }, 53 | AttributeMemberMetadata: func(f *Fact) int { 54 | f.Subject = &PeerSubject{} 55 | f.Value = &MemberMetadata{} 56 | return 0 57 | }, 58 | 59 | AttributeSignedGroup: func(f *Fact) int { 60 | f.Subject = &PeerSubject{} 61 | f.Value = &SignedGroupValue{} 62 | // this is a variable length, expected to consume everything until EOF 63 | return -1 64 | }, 65 | } 66 | 67 | // DecodeFrom implements Decodable 68 | func (f *Fact) DecodeFrom(_ int, now time.Time, reader util.ByteReader) error { 69 | var err error 70 | 71 | attrByte, err := reader.ReadByte() 72 | if err != nil { 73 | return fmt.Errorf("unable to read attribute byte from packet: %w", err) 74 | } 75 | f.Attribute = Attribute(attrByte) 76 | 77 | hinter, ok := decodeHints[f.Attribute] 78 | if !ok { 79 | if f.Attribute == AttributeUnknown { 80 | // AttributeUnknown used to be used for ping packets, this has been removed 81 | return fmt.Errorf("legacy AttributeUnknown ping packet not supported") 82 | } 83 | return fmt.Errorf("unrecognized attribute 0x%02x", byte(f.Attribute)) 84 | } 85 | 86 | ttl, err := binary.ReadUvarint(reader) 87 | if err != nil { 88 | return fmt.Errorf("unable to read ttl from packet: %w", err) 89 | } 90 | // clamp TTL to valid range 91 | if ttl > math.MaxUint16 { 92 | return fmt.Errorf("received TTL outside range: %v", ttl) 93 | } 94 | f.Expires = now.Add(time.Duration(ttl) * timeScale) 95 | 96 | valueLength := hinter(f) 97 | 98 | err = f.Subject.DecodeFrom(0, reader) 99 | if err != nil { 100 | return fmt.Errorf("failed to unmarshal fact subject from packet for %v: %w", f.Attribute, err) 101 | } 102 | err = f.Value.DecodeFrom(valueLength, reader) 103 | if err != nil { 104 | return fmt.Errorf("failed to unmarshal fact value from packet for %v: %w", f.Attribute, err) 105 | } 106 | 107 | return nil 108 | } 109 | -------------------------------------------------------------------------------- /fact/testdata/fuzz/FuzzDecodeFrom/0cfd23bfe624f2c78c0930c5fe5ea7b1db43d35285a0920bde0280f4f836b0fe: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | []byte("M000000000000000000000000000000000\x140\xf9՛\xe3\xc6\xf5\xde\xde\xde\x01000000000") 3 | -------------------------------------------------------------------------------- /fact/testdata/fuzz/FuzzDecodeFrom/2bc31826454fd1ee42d6a1449698b0a5d6395c8972eb15eb58843e7e80665d06: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | []byte("E000000000000000000000000000000000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff000000") 3 | -------------------------------------------------------------------------------- /fact/testdata/fuzz/FuzzDecodeFrom/68ef958b5121ee27405ff1e28c59a6e871748abc58724342c6b2590a37ea23f3: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | []byte("M000000000000000000000000000000000\x80\x00") 3 | -------------------------------------------------------------------------------- /fact/testdata/fuzz/FuzzDecodeFrom/6b3f3671c028200ef9f1dc642c4f7b358cf46817000b2bb2f4f550233df3d467: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | []byte("a0000000000000000000000000000000000000000000000000") 3 | -------------------------------------------------------------------------------- /fact/testdata/fuzz/FuzzDecodeFrom/7fbe8f1ab90c5846075daef7abd135cdec284f2f1eebd274b8d9cfb024d62df2: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | []byte("!\xbe\x00000000000000000000000000000000000000000000000000") 3 | -------------------------------------------------------------------------------- /fact/testdata/fuzz/FuzzDecodeFrom/87c7c0b972f6d43523d679b0107adf883ceb97fe527a496b6193389733bd25df: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | []byte("A000000000000000000000000000000000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff00000") 3 | -------------------------------------------------------------------------------- /fact/testdata/fuzz/FuzzDecodeFrom/9ee44aada7de1402f44d18fdfc7949aee504f7233cd4610664d31432a5d5f9e2: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | []byte("M000000000000000000000000000000000\x040\x00 \x00") 3 | -------------------------------------------------------------------------------- /fact/testdata/fuzz/FuzzDecodeFrom/e7eac19df17a16b6835603e9b1763904ecfd2a2308c8e8ea52dbf043052b2f80: -------------------------------------------------------------------------------- 1 | go test fuzz v1 2 | []byte("M00000000000000000000000000000000\xf0\xf0\xf0\xf0\xf0\xf0\xd50") 3 | -------------------------------------------------------------------------------- /fact/time-scale.go: -------------------------------------------------------------------------------- 1 | package fact 2 | 3 | import "time" 4 | 5 | // timeScale is the quantum of measurement for serializing and parsing facts onto the wire. 6 | // It is only present as a mutable parameter for use in tests where we want to run scenarios 7 | // faster than normal realtime would permit. Changing this on a live service is a breaking 8 | // change to the wire protocol. 9 | var timeScale = time.Second 10 | 11 | // ScaleExpirationQuantumForTests reconfigures how the fact TTL is represented on the wire to permit 12 | // faster than normal tests 13 | func ScaleExpirationQuantumForTests(factor uint) { 14 | if factor < 1 || factor > 1000 { 15 | panic("Test time scale must be in the range [1, 1000]") 16 | } 17 | timeScale = time.Second / time.Duration(factor) 18 | } 19 | -------------------------------------------------------------------------------- /fact/types-subjects.go: -------------------------------------------------------------------------------- 1 | package fact 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | 7 | "github.com/fastcat/wirelink/util" 8 | 9 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 10 | ) 11 | 12 | // PeerSubject is a subject that is a peer identified via its public key 13 | type PeerSubject struct { 14 | wgtypes.Key 15 | } 16 | 17 | // MarshalBinary implements encoding.BinaryMarshaler 18 | func (s *PeerSubject) MarshalBinary() ([]byte, error) { 19 | return s.Key[:], nil 20 | } 21 | 22 | // UnmarshalBinary implements BinaryUnmarshaler 23 | func (s *PeerSubject) UnmarshalBinary(data []byte) error { 24 | if len(data) != wgtypes.KeyLen { 25 | return fmt.Errorf("data len wrong for peer subject") 26 | } 27 | copy(s.Key[:], data) 28 | return nil 29 | } 30 | 31 | // DecodeFrom implements Decodable 32 | func (s *PeerSubject) DecodeFrom(_ int, reader io.Reader) error { 33 | return util.DecodeFrom(s, wgtypes.KeyLen, reader) 34 | } 35 | 36 | // IsSubject implements Subject 37 | func (s *PeerSubject) IsSubject() {} 38 | 39 | // *PeerSubject must implement Subject 40 | // we do this with the pointer because if we do it with the struct, the pointer 41 | // matches too, and that confuses things, and critically because unmarshalling 42 | // and decoding require mutation of the value 43 | var _ Subject = &PeerSubject{} 44 | -------------------------------------------------------------------------------- /fact/types.go: -------------------------------------------------------------------------------- 1 | package fact 2 | 3 | import ( 4 | "encoding" 5 | "fmt" 6 | 7 | "github.com/fastcat/wirelink/util" 8 | ) 9 | 10 | // Subject is the subject of a Fact 11 | type Subject interface { 12 | fmt.Stringer 13 | encoding.BinaryMarshaler 14 | util.Decodable 15 | // IsSubject tags Subjects as semantically different from Values 16 | IsSubject() 17 | } 18 | 19 | // Value represents the value of a Fact 20 | type Value interface { 21 | fmt.Stringer 22 | encoding.BinaryMarshaler 23 | util.Decodable 24 | } 25 | 26 | // Attribute is a byte identifying what aspect of a Subject a Fact describes 27 | type Attribute byte 28 | -------------------------------------------------------------------------------- /fact/utils_test.go: -------------------------------------------------------------------------------- 1 | package fact 2 | 3 | import ( 4 | "bytes" 5 | "math/rand" 6 | "net" 7 | "testing" 8 | "time" 9 | 10 | "github.com/google/uuid" 11 | 12 | "github.com/stretchr/testify/require" 13 | 14 | "github.com/fastcat/wirelink/internal/testutils" 15 | 16 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 17 | ) 18 | 19 | func mustMockAlivePacket(t require.TestingT, subject *wgtypes.Key, id *uuid.UUID) (*Fact, []byte) { 20 | if subject == nil { 21 | sk := testutils.MustKey(t) 22 | subject = &sk 23 | } 24 | if id == nil { 25 | u := uuid.Must(uuid.NewRandom()) 26 | id = &u 27 | } 28 | return mustSerialize(t, &Fact{ 29 | Attribute: AttributeAlive, 30 | Subject: &PeerSubject{Key: *subject}, 31 | Expires: time.Time{}, 32 | Value: &UUIDValue{UUID: *id}, 33 | }) 34 | } 35 | 36 | func mustMockAllowedV4Packet(t require.TestingT, subject *wgtypes.Key) (*Fact, []byte) { 37 | if subject == nil { 38 | sk := testutils.MustKey(t) 39 | subject = &sk 40 | } 41 | return mustSerialize(t, &Fact{ 42 | Attribute: AttributeAllowedCidrV4, 43 | Subject: &PeerSubject{Key: *subject}, 44 | Expires: time.Time{}, 45 | Value: &IPNetValue{makeIPNet(t)}, 46 | }) 47 | } 48 | 49 | // TODO: share with package apply 50 | func makeIPNet(t require.TestingT) net.IPNet { 51 | return net.IPNet{ 52 | IP: testutils.MustRandBytes(t, make([]byte, net.IPv4len)), 53 | Mask: net.CIDRMask(1+rand.Intn(8*net.IPv4len), 8*net.IPv4len), 54 | } 55 | } 56 | 57 | func mustSerialize(t require.TestingT, f *Fact) (*Fact, []byte) { 58 | p, err := f.MarshalBinary() 59 | require.Nil(t, err) 60 | return f, p 61 | } 62 | 63 | func mustDeserialize(t *testing.T, p []byte, now time.Time) (f *Fact) { 64 | f = &Fact{} 65 | err := f.DecodeFrom(len(p), now, bytes.NewReader(p)) 66 | require.Nil(t, err) 67 | // to help verify data races, randomize the input buffer after its consumed, 68 | // so that any code that hangs onto it will show clear test failures 69 | testutils.MustRandBytes(t, p) 70 | return 71 | } 72 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fastcat/wirelink 2 | 3 | go 1.24.0 4 | 5 | require ( 6 | github.com/google/uuid v1.6.0 7 | github.com/hashicorp/golang-lru/arc/v2 v2.0.7 8 | github.com/spf13/pflag v1.0.6 9 | github.com/spf13/viper v1.20.1 10 | github.com/stretchr/testify v1.10.0 11 | github.com/vektra/mockery/v2 v2.53.4 12 | github.com/vishvananda/netlink v1.3.1 13 | golang.org/x/crypto v0.38.0 14 | golang.org/x/sync v0.14.0 15 | golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 16 | golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 17 | ) 18 | 19 | require ( 20 | github.com/chigopher/pathlib v0.19.1 // indirect 21 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 22 | github.com/fsnotify/fsnotify v1.8.0 // indirect 23 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 24 | github.com/google/go-cmp v0.7.0 // indirect 25 | github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect 26 | github.com/huandu/xstrings v1.4.0 // indirect 27 | github.com/iancoleman/strcase v0.3.0 // indirect 28 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 29 | github.com/jinzhu/copier v0.4.0 // indirect 30 | github.com/josharian/native v1.1.0 // indirect 31 | github.com/mattn/go-colorable v0.1.14 // indirect 32 | github.com/mattn/go-isatty v0.0.20 // indirect 33 | github.com/mdlayher/genetlink v1.3.2 // indirect 34 | github.com/mdlayher/netlink v1.7.2 // indirect 35 | github.com/mdlayher/socket v0.5.1 // indirect 36 | github.com/mitchellh/go-homedir v1.1.0 // indirect 37 | github.com/mitchellh/mapstructure v1.5.0 // indirect 38 | github.com/pelletier/go-toml/v2 v2.2.3 // indirect 39 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 40 | github.com/rogpeppe/go-internal v1.13.1 // indirect 41 | github.com/rs/zerolog v1.34.0 // indirect 42 | github.com/sagikazarmark/locafero v0.7.0 // indirect 43 | github.com/sourcegraph/conc v0.3.0 // indirect 44 | github.com/spf13/afero v1.12.0 // indirect 45 | github.com/spf13/cast v1.7.1 // indirect 46 | github.com/spf13/cobra v1.8.1 // indirect 47 | github.com/stretchr/objx v0.5.2 // indirect 48 | github.com/subosito/gotenv v1.6.0 // indirect 49 | github.com/vishvananda/netns v0.0.5 // indirect 50 | go.uber.org/multierr v1.11.0 // indirect 51 | golang.org/x/mod v0.24.0 // indirect 52 | golang.org/x/net v0.39.0 // indirect 53 | golang.org/x/sys v0.33.0 // indirect 54 | golang.org/x/term v0.32.0 // indirect 55 | golang.org/x/text v0.25.0 // indirect 56 | golang.org/x/time v0.11.0 // indirect 57 | golang.org/x/tools v0.32.0 // indirect 58 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect 59 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect 60 | gopkg.in/yaml.v3 v3.0.1 // indirect 61 | ) 62 | -------------------------------------------------------------------------------- /go.work: -------------------------------------------------------------------------------- 1 | go 1.24.0 2 | 3 | toolchain go1.24.3 4 | 5 | use ( 6 | . 7 | ./magefiles 8 | ) 9 | -------------------------------------------------------------------------------- /internal/WgClient.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "io" 5 | 6 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 7 | ) 8 | 9 | // WgClient is a copy-pasta of wginternal.Client because we can't import that 10 | type WgClient interface { 11 | io.Closer 12 | Devices() ([]*wgtypes.Device, error) 13 | Device(name string) (*wgtypes.Device, error) 14 | ConfigureDevice(name string, cfg wgtypes.Config) error 15 | } 16 | -------------------------------------------------------------------------------- /internal/channels/broadcast.go: -------------------------------------------------------------------------------- 1 | package channels 2 | 3 | // Broadcast relays all the messages received on input to outputs. It stops when 4 | // input is closed, and closes outputs at that point. For efficient 5 | // functionality, the outputs should be buffered. 6 | func Broadcast[T any](input <-chan T, outputs ...chan<- T) { 7 | for _, output := range outputs { 8 | defer close(output) 9 | } 10 | 11 | for chunk := range input { 12 | for _, output := range outputs { 13 | output <- chunk 14 | } 15 | } 16 | } 17 | 18 | // Broadcaster wraps Broadcast to simplify errgroup setup 19 | func Broadcaster[T any](input <-chan T, outputs ...chan<- T) errGroupFunc { 20 | return func() error { 21 | Broadcast(input, outputs...) 22 | return nil 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /internal/channels/doc.go: -------------------------------------------------------------------------------- 1 | // Package channels contains various generic helpers for working with channels 2 | package channels 3 | -------------------------------------------------------------------------------- /internal/channels/filter.go: -------------------------------------------------------------------------------- 1 | package channels 2 | 3 | // Filter reads each item from input and passes it to filter. If filter returns 4 | // an error, it stops and returns that, otherwise it forwards filter's result to 5 | // output. When input ends, output is closed and Filter returns nil. 6 | func Filter[T, U any]( 7 | input <-chan T, 8 | filter func(T) (U, error), 9 | output chan<- U, 10 | ) error { 11 | defer close(output) 12 | for item := range input { 13 | out, err := filter(item) 14 | if err != nil { 15 | return err 16 | } 17 | output <- out 18 | } 19 | return nil 20 | } 21 | 22 | // FilterMany is like Filter, but allows the filter to return many output items 23 | // for each input item. 24 | func FilterMany[T, U any]( 25 | input <-chan T, 26 | filter func(T) ([]U, error), 27 | output chan<- U, 28 | ) error { 29 | defer close(output) 30 | for item := range input { 31 | out, err := filter(item) 32 | if err != nil { 33 | return err 34 | } 35 | for _, o := range out { 36 | output <- o 37 | } 38 | } 39 | return nil 40 | } 41 | 42 | // Filterer wraps Filter for easy errgroup setup 43 | func Filterer[T, U any]( 44 | input <-chan T, 45 | filter func(T) (U, error), 46 | output chan<- U, 47 | ) errGroupFunc { 48 | return func() error { 49 | return Filter(input, filter, output) 50 | } 51 | } 52 | 53 | // FiltererMany wraps FilterMany for easy errgroup setup 54 | func FiltererMany[T, U any]( 55 | input <-chan T, 56 | filter func(T) ([]U, error), 57 | output chan<- U, 58 | ) errGroupFunc { 59 | return func() error { 60 | return FilterMany(input, filter, output) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /internal/channels/process.go: -------------------------------------------------------------------------------- 1 | package channels 2 | 3 | // Process reads each item from input and passes it to action. If filter returns 4 | // an error, it stops and returns that, otherwise it returns nil when input is 5 | // closed. 6 | func Process[T any](input <-chan T, action func(T) error) error { 7 | for item := range input { 8 | if err := action(item); err != nil { 9 | return err 10 | } 11 | } 12 | return nil 13 | } 14 | 15 | // Processor wraps Process to simplify errgroup setup 16 | func Processor[T any](input <-chan T, action func(T) error) errGroupFunc { 17 | return func() error { 18 | return Process(input, action) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /internal/channels/types.go: -------------------------------------------------------------------------------- 1 | package channels 2 | 3 | type errGroupFunc = func() error 4 | -------------------------------------------------------------------------------- /internal/deps.go: -------------------------------------------------------------------------------- 1 | // this file never builds, it just exists to keep tools we need for code 2 | // generation present in go.mod 3 | 4 | //go:build never 5 | // +build never 6 | 7 | package internal 8 | 9 | import ( 10 | _ "github.com/vektra/mockery/v2" 11 | ) 12 | -------------------------------------------------------------------------------- /internal/doc.go: -------------------------------------------------------------------------------- 1 | // Package internal has bits and bobs for the internal implementation details 2 | // of wirelink, esp. test utilities, mocks, and abstractions. 3 | package internal 4 | -------------------------------------------------------------------------------- /internal/lru.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "github.com/hashicorp/golang-lru/arc/v2" 5 | ) 6 | 7 | // Memoize uses a cache to memoize the given function. 8 | func Memoize[K comparable, V any](size int, f func(K) V) func(K) V { 9 | m, err := arc.NewARC[K, V](size) 10 | if err != nil { 11 | panic(err) 12 | } 13 | 14 | return func(k K) V { 15 | if v, ok := m.Get(k); ok { 16 | return v 17 | } 18 | v := f(k) 19 | m.Add(k, v) 20 | return v 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /internal/lru_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | -------------------------------------------------------------------------------- /internal/mocks/.gitignore: -------------------------------------------------------------------------------- 1 | WgClient.go 2 | -------------------------------------------------------------------------------- /internal/mocks/doc.go: -------------------------------------------------------------------------------- 1 | // Package mocks has just test files for use by other packages' tests 2 | package mocks 3 | 4 | //go:generate go run github.com/vektra/mockery/v2 --dir ../ --output ./ --name WgClient 5 | -------------------------------------------------------------------------------- /internal/networking/darwin/darwin-ifconfig.go: -------------------------------------------------------------------------------- 1 | // Package darwin provides an implementation of networking.Environment for the 2 | // host darwin (macOS) system, leveraging the Go native package, and then 3 | // filling in the gaps by exececuting command line tools such as ifconfig. 4 | package darwin 5 | 6 | import ( 7 | "fmt" 8 | "net" 9 | "os/exec" 10 | 11 | "github.com/fastcat/wirelink/internal/networking" 12 | "github.com/fastcat/wirelink/internal/networking/native" 13 | "github.com/fastcat/wirelink/log" 14 | ) 15 | 16 | // CreateDarwin makes an environment for the host using ifconfig 17 | func CreateDarwin() (networking.Environment, error) { 18 | return &darwinEnvironment{}, nil 19 | } 20 | 21 | type darwinEnvironment struct { 22 | native.GoEnvironment 23 | } 24 | 25 | var _ networking.Environment = (*darwinEnvironment)(nil) 26 | 27 | func (e *darwinEnvironment) Interfaces() ([]networking.Interface, error) { 28 | ifaces, err := e.GoEnvironment.Interfaces() 29 | if err != nil { 30 | return nil, err 31 | } 32 | ret := make([]networking.Interface, len(ifaces)) 33 | for i := range ifaces { 34 | // TODO: may be faster to fetch all links and join them? 35 | ret[i] = e.interfaceFromGo(ifaces[i].(*native.GoInterface)) 36 | } 37 | return ret, nil 38 | } 39 | 40 | func (e *darwinEnvironment) InterfaceByName(name string) (networking.Interface, error) { 41 | iface, err := e.GoEnvironment.InterfaceByName(name) 42 | if err != nil { 43 | return nil, err 44 | } 45 | return e.interfaceFromGo(iface.(*native.GoInterface)), nil 46 | } 47 | 48 | func (e *darwinEnvironment) interfaceFromGo(iface *native.GoInterface) *darwinInterface { 49 | return &darwinInterface{*iface, e} 50 | } 51 | 52 | func (e *darwinEnvironment) Close() error { 53 | return e.GoEnvironment.Close() 54 | } 55 | 56 | type darwinInterface struct { 57 | native.GoInterface 58 | env *darwinEnvironment 59 | } 60 | 61 | var _ networking.Interface = (*darwinInterface)(nil) 62 | 63 | func (i *darwinInterface) AddAddr(addr net.IPNet) error { 64 | family := "inet6" 65 | if addr.IP.To4() != nil { 66 | family = "inet" 67 | } 68 | // probably have to run as root because of this 69 | cmd := exec.Command("ifconfig", i.Name(), family, addr.String(), "alias") 70 | output, err := cmd.CombinedOutput() 71 | if err != nil { 72 | return fmt.Errorf("unable to add %v to %s: %s: %w", addr, i.Name(), string(output), err) 73 | } 74 | log.Debug("ifconfig results: %s", string(output)) 75 | return nil 76 | } 77 | -------------------------------------------------------------------------------- /internal/networking/doc.go: -------------------------------------------------------------------------------- 1 | // Package networking provides abstraction layers for accessing networking 2 | // resources both across platform specific details, and for virtual / mock 3 | // configurations for testing 4 | package networking 5 | -------------------------------------------------------------------------------- /internal/networking/host/factory-darwin.go: -------------------------------------------------------------------------------- 1 | //go:build darwin 2 | // +build darwin 3 | 4 | package host 5 | 6 | import ( 7 | "github.com/fastcat/wirelink/internal/networking" 8 | "github.com/fastcat/wirelink/internal/networking/darwin" 9 | ) 10 | 11 | // CreateHost creates the default Environment implementation for the host OS 12 | func CreateHost() (networking.Environment, error) { 13 | return darwin.CreateDarwin() 14 | } 15 | -------------------------------------------------------------------------------- /internal/networking/host/factory-generic.go: -------------------------------------------------------------------------------- 1 | //go:build !linux && !darwin 2 | // +build !linux,!darwin 3 | 4 | package host 5 | 6 | import ( 7 | "github.com/fastcat/wirelink/internal/networking" 8 | "github.com/fastcat/wirelink/internal/networking/native" 9 | ) 10 | 11 | // CreateHost creates the default Environment implementation for the host OS 12 | func CreateHost() (networking.Environment, error) { 13 | return &native.GoEnvironment{}, nil 14 | } 15 | -------------------------------------------------------------------------------- /internal/networking/host/factory-linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | package host 5 | 6 | import ( 7 | "github.com/fastcat/wirelink/internal/networking" 8 | "github.com/fastcat/wirelink/internal/networking/linux" 9 | ) 10 | 11 | // CreateHost creates the default Environment implementation for the host OS 12 | func CreateHost() (networking.Environment, error) { 13 | return linux.CreateLinux() 14 | } 15 | -------------------------------------------------------------------------------- /internal/networking/host/factory.go: -------------------------------------------------------------------------------- 1 | // Package host provides a generic accessor factory to create the appropriate 2 | // platform-specific interface to the host networking APIs. 3 | package host 4 | 5 | import "github.com/fastcat/wirelink/internal/networking" 6 | 7 | // MustCreateHost calls CreateHost and panics if it fails 8 | func MustCreateHost() networking.Environment { 9 | env, err := CreateHost() 10 | if err != nil { 11 | panic(err) 12 | } 13 | return env 14 | } 15 | -------------------------------------------------------------------------------- /internal/networking/linux/linux-netlink.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | // Package linux provides an implementation of networking.Environment for the 5 | // host Linux system, leveraging the Go native package, and then filling in the 6 | // gaps using netlink APIs. 7 | package linux 8 | 9 | import ( 10 | "fmt" 11 | "net" 12 | 13 | "github.com/vishvananda/netlink" 14 | 15 | "github.com/fastcat/wirelink/internal/networking" 16 | "github.com/fastcat/wirelink/internal/networking/native" 17 | ) 18 | 19 | // CreateLinux makes an environment for the host using netlink 20 | func CreateLinux() (networking.Environment, error) { 21 | nlh, err := netlink.NewHandle() 22 | if err != nil { 23 | return nil, fmt.Errorf("unable to make a new netlink handle: %w", err) 24 | } 25 | 26 | ret := &linuxEnvironment{ 27 | nlh: nlh, 28 | } 29 | 30 | return ret, nil 31 | } 32 | 33 | type linuxEnvironment struct { 34 | native.GoEnvironment 35 | nlh *netlink.Handle 36 | } 37 | 38 | // linuxEnvironment implements networking.Environment 39 | var _ networking.Environment = (*linuxEnvironment)(nil) 40 | 41 | // Interfaces implements Environment 42 | func (e *linuxEnvironment) Interfaces() ([]networking.Interface, error) { 43 | ifaces, err := e.GoEnvironment.Interfaces() 44 | if err != nil { 45 | return nil, err 46 | } 47 | ret := make([]networking.Interface, len(ifaces)) 48 | for i := range ifaces { 49 | // TODO: may be faster to fetch all links and join them? 50 | ret[i], err = e.interfaceFromGo(ifaces[i].(*native.GoInterface)) 51 | if err != nil { 52 | return nil, err 53 | } 54 | } 55 | return ret, nil 56 | } 57 | 58 | func (e *linuxEnvironment) InterfaceByName(name string) (networking.Interface, error) { 59 | iface, err := e.GoEnvironment.InterfaceByName(name) 60 | if err != nil { 61 | return nil, err 62 | } 63 | return e.interfaceFromGo(iface.(*native.GoInterface)) 64 | } 65 | 66 | func (e *linuxEnvironment) interfaceFromGo(iface *native.GoInterface) (*linuxInterface, error) { 67 | link, err := e.nlh.LinkByName(iface.Name()) 68 | if err != nil { 69 | return nil, fmt.Errorf("unable to get netlink info for interface %s: %w", iface.Name(), err) 70 | } 71 | return &linuxInterface{*iface, link, e}, nil 72 | } 73 | 74 | func (e *linuxEnvironment) Close() error { 75 | if e.nlh != nil { 76 | e.nlh.Close() 77 | e.nlh = nil 78 | } 79 | return e.GoEnvironment.Close() 80 | } 81 | 82 | type linuxInterface struct { 83 | native.GoInterface 84 | link netlink.Link 85 | env *linuxEnvironment 86 | } 87 | 88 | var _ networking.Interface = (*linuxInterface)(nil) 89 | 90 | func (i *linuxInterface) AddAddr(addr net.IPNet) error { 91 | err := i.env.nlh.AddrAdd(i.link, &netlink.Addr{ 92 | IPNet: &addr, 93 | }) 94 | if err != nil { 95 | return fmt.Errorf("unable to add %v to %s: %w", addr, i.Name(), err) 96 | } 97 | return nil 98 | } 99 | -------------------------------------------------------------------------------- /internal/networking/mocks/.gitignore: -------------------------------------------------------------------------------- 1 | Environment.go 2 | Interface.go 3 | UDPConn.go 4 | -------------------------------------------------------------------------------- /internal/networking/native/constants_darwin.go: -------------------------------------------------------------------------------- 1 | //go:build darwin 2 | // +build darwin 3 | 4 | package native 5 | 6 | const localhostInterfaceName = "lo0" 7 | -------------------------------------------------------------------------------- /internal/networking/native/constants_generic.go: -------------------------------------------------------------------------------- 1 | //go:build !linux && !darwin && !windows 2 | // +build !linux,!darwin,!windows 3 | 4 | package native 5 | 6 | // this is probably wrong 7 | const localhostInterfaceName = "lo" 8 | -------------------------------------------------------------------------------- /internal/networking/native/constants_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | package native 5 | 6 | const localhostInterfaceName = "lo" 7 | -------------------------------------------------------------------------------- /internal/networking/native/constants_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | // +build windows 3 | 4 | package native 5 | 6 | const localhostInterfaceName = "Loopback Pseudo-Interface 1" 7 | -------------------------------------------------------------------------------- /internal/networking/native/go-environment.go: -------------------------------------------------------------------------------- 1 | // Package native provdies common base implementations of the 2 | // networking.Environment and related interfaces, or at least the portions that 3 | // can be implemented using common native Go APIs. 4 | package native 5 | 6 | import ( 7 | "net" 8 | 9 | "github.com/fastcat/wirelink/internal" 10 | "github.com/fastcat/wirelink/internal/networking" 11 | 12 | "golang.zx2c4.com/wireguard/wgctrl" 13 | ) 14 | 15 | // GoEnvironment is a partial implementation of Environment which provides the 16 | // methods and types that the go runtime can answer 17 | type GoEnvironment struct{} 18 | 19 | // GoEnvironment does not fully implement Environment 20 | // var _ networking.Environment = &GoEnvironment{} 21 | 22 | // Interfaces implements Environment 23 | func (e *GoEnvironment) Interfaces() ([]networking.Interface, error) { 24 | ifaces, err := net.Interfaces() 25 | if err != nil { 26 | return nil, err 27 | } 28 | ret := make([]networking.Interface, len(ifaces)) 29 | for i, iface := range ifaces { 30 | ret[i] = &GoInterface{iface} 31 | } 32 | 33 | return ret, nil 34 | } 35 | 36 | // InterfaceByName implements Environment by wrapping net.InterfaceByName 37 | func (e *GoEnvironment) InterfaceByName(name string) (networking.Interface, error) { 38 | iface, err := net.InterfaceByName(name) 39 | if err != nil { 40 | return nil, err 41 | } 42 | return &GoInterface{*iface}, nil 43 | } 44 | 45 | // ListenUDP implements Environment by wrapping net.ListenUDP 46 | func (e *GoEnvironment) ListenUDP(network string, laddr *net.UDPAddr) (networking.UDPConn, error) { 47 | conn, err := net.ListenUDP(network, laddr) 48 | if err != nil { 49 | return nil, err 50 | } 51 | return &GoUDPConn{*conn}, nil 52 | } 53 | 54 | // NewWgClient implements Environment by wrapping wgctrl.New() 55 | func (e *GoEnvironment) NewWgClient() (internal.WgClient, error) { 56 | return wgctrl.New() 57 | } 58 | 59 | // Close implements Environment by doing nothing and returning nil 60 | func (e *GoEnvironment) Close() error { 61 | return nil 62 | } 63 | 64 | var _ networking.Environment = (*GoEnvironment)(nil) 65 | -------------------------------------------------------------------------------- /internal/networking/native/go-environment_test.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | 8 | "golang.zx2c4.com/wireguard/wgctrl" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestGoEnvironment_Interfaces(t *testing.T) { 15 | tests := []struct { 16 | name string 17 | e *GoEnvironment 18 | want []*GoInterface 19 | wantErr bool 20 | wantCheck func(*testing.T, []*GoInterface, error) 21 | }{ 22 | { 23 | "can retrieve interfaces", 24 | &GoEnvironment{}, 25 | nil, 26 | false, 27 | func(t *testing.T, ifaces []*GoInterface, err error) { 28 | assert.True(t, containsIface(ifaces, func(iface *GoInterface) bool { 29 | require.NotNil(t, iface) 30 | return iface.Name() == localhostInterfaceName 31 | }), "Should find a localhost interface") 32 | assert.GreaterOrEqual(t, len(ifaces), 2) 33 | }, 34 | }, 35 | } 36 | for _, tt := range tests { 37 | t.Run(tt.name, func(t *testing.T) { 38 | e := &GoEnvironment{} 39 | got, err := e.Interfaces() 40 | if tt.wantErr { 41 | require.NotNil(t, err) 42 | } else { 43 | require.Nil(t, err) 44 | } 45 | var gotGo []*GoInterface 46 | if got != nil { 47 | gotGo = make([]*GoInterface, len(got)) 48 | for i, iface := range got { 49 | if assert.IsType(t, &GoInterface{}, iface) { 50 | gotGo[i] = iface.(*GoInterface) 51 | } 52 | } 53 | } 54 | if tt.wantCheck != nil { 55 | tt.wantCheck(t, gotGo, err) 56 | } else { 57 | assert.Equal(t, tt.want, got) 58 | } 59 | }) 60 | } 61 | } 62 | 63 | func TestGoEnvironment_InterfaceByName(t *testing.T) { 64 | type args struct { 65 | name string 66 | } 67 | tests := []struct { 68 | name string 69 | e *GoEnvironment 70 | args args 71 | want *GoInterface 72 | wantErr bool 73 | wantCheck func(*testing.T, *GoInterface, error) 74 | }{ 75 | { 76 | "reasonable localhost results", 77 | &GoEnvironment{}, 78 | args{localhostInterfaceName}, 79 | nil, 80 | false, 81 | func(t *testing.T, iface *GoInterface, err error) { 82 | require.NotNil(t, iface) 83 | assert.Equal(t, iface.Name(), localhostInterfaceName) 84 | assert.True(t, iface.IsUp()) 85 | }, 86 | }, 87 | { 88 | "not found error", 89 | &GoEnvironment{}, 90 | args{fmt.Sprintf("xyzzy%d", rand.Int63())}, 91 | nil, 92 | true, 93 | nil, 94 | }, 95 | } 96 | for _, tt := range tests { 97 | t.Run(tt.name, func(t *testing.T) { 98 | e := &GoEnvironment{} 99 | got, err := e.InterfaceByName(tt.args.name) 100 | if tt.wantErr { 101 | require.NotNil(t, err) 102 | } else { 103 | require.Nil(t, err) 104 | } 105 | var gotGo *GoInterface 106 | if got != nil { 107 | gotGo = got.(*GoInterface) 108 | } 109 | if tt.wantCheck != nil { 110 | tt.wantCheck(t, gotGo, err) 111 | } else { 112 | assert.Equal(t, tt.want, gotGo) 113 | } 114 | }) 115 | } 116 | } 117 | 118 | func TestGoEnvironment_NewWgClient(t *testing.T) { 119 | e := &GoEnvironment{} 120 | got, err := e.NewWgClient() 121 | require.NoError(t, err) 122 | if assert.NotNil(t, got) { 123 | defer got.Close() 124 | } 125 | assert.IsType(t, &wgctrl.Client{}, got) 126 | } 127 | -------------------------------------------------------------------------------- /internal/networking/native/go-interface.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/fastcat/wirelink/internal/networking" 7 | "github.com/fastcat/wirelink/log" 8 | ) 9 | 10 | // GoInterface provides as much of Interface as the go runtime can 11 | type GoInterface struct { 12 | net.Interface 13 | } 14 | 15 | // Name implements Interface, gets its name 16 | func (i *GoInterface) Name() string { 17 | return i.Interface.Name 18 | } 19 | 20 | // IsUp implements Interface, checks for FlagUp 21 | func (i *GoInterface) IsUp() bool { 22 | return i.Flags&net.FlagUp == net.FlagUp 23 | } 24 | 25 | // Addrs implements Interface, looks up the IP addresses for the interface 26 | func (i *GoInterface) Addrs() ([]net.IPNet, error) { 27 | addrs, err := i.Interface.Addrs() 28 | if err != nil { 29 | return nil, err 30 | } 31 | ret := make([]net.IPNet, 0, len(addrs)) 32 | for _, a := range addrs { 33 | if ipn, ok := a.(*net.IPNet); ok { 34 | ret = append(ret, *ipn) 35 | } else { 36 | // this should never happen 37 | log.Error("Got a %T from interface.Addrs, not a net.IPNet", a) 38 | } 39 | } 40 | return ret, nil 41 | } 42 | 43 | // AddAddr implements Interface, returns ErrAddAddrUnsupported 44 | func (i *GoInterface) AddAddr(net.IPNet) error { 45 | return networking.ErrAddAddrUnsupported 46 | } 47 | 48 | var _ networking.Interface = (*GoInterface)(nil) 49 | -------------------------------------------------------------------------------- /internal/networking/native/go-interface_test.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | 7 | "github.com/fastcat/wirelink/internal/testutils" 8 | 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestGoInterface_Addrs(t *testing.T) { 14 | type fields struct { 15 | Interface net.Interface 16 | } 17 | tests := []struct { 18 | name string 19 | fields fields 20 | want []net.IPNet 21 | wantErr bool 22 | wantCheck func(*testing.T, []net.IPNet, error) 23 | }{ 24 | { 25 | "localhost", 26 | fields{*mustNetInterface(t)(net.InterfaceByName(localhostInterfaceName))}, 27 | nil, 28 | false, 29 | func(t *testing.T, addrs []net.IPNet, err error) { 30 | assert.True(t, testutils.ContainsIPNet(addrs, func(addr net.IPNet) bool { 31 | ones, bits := addr.Mask.Size() 32 | // check for 127.0.0.1/8 33 | return net.IPv4(127, 0, 0, 1).Equal(addr.IP) && ones == 8 && bits == 32 34 | })) 35 | }, 36 | }, 37 | } 38 | for _, tt := range tests { 39 | t.Run(tt.name, func(t *testing.T) { 40 | i := &GoInterface{ 41 | Interface: tt.fields.Interface, 42 | } 43 | got, err := i.Addrs() 44 | if tt.wantErr { 45 | require.NotNil(t, err) 46 | } else { 47 | require.Nil(t, err) 48 | } 49 | if tt.wantCheck != nil { 50 | tt.wantCheck(t, got, err) 51 | } else { 52 | assert.Equal(t, tt.want, got) 53 | } 54 | }) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /internal/networking/native/go-udpconn.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "time" 8 | 9 | "github.com/fastcat/wirelink/internal/networking" 10 | ) 11 | 12 | // GoUDPConn implements networking.UDPConn by wrapping net.UDPConn 13 | type GoUDPConn struct { 14 | net.UDPConn 15 | } 16 | 17 | // GoUDPConn implements networking.UDPConn 18 | var _ networking.UDPConn = &GoUDPConn{} 19 | 20 | // ReadPackets implements networking.UDPConn. 21 | // TODO: the cancellation context won't be obeyed very well. 22 | // Methodology loosely adapted from: 23 | // https://medium.com/@zombiezen/canceling-i-o-in-go-capn-proto-5ae8c09c5b29 24 | // via https://github.com/golang/go/issues/20280#issue-227074518 25 | // UDP makes this simpler however as partial reads are not a concern 26 | func (c *GoUDPConn) ReadPackets( 27 | ctx context.Context, 28 | maxSize int, 29 | packets chan<- *networking.UDPPacket, 30 | ) error { 31 | defer close(packets) 32 | 33 | ctxDone := ctx.Done() 34 | monitorDone := make(chan struct{}) 35 | readsDone := make(chan struct{}) 36 | buffer := make([]byte, maxSize) 37 | 38 | // start a goroutine to monitor the context channel, and interrupt the read 39 | // whenever it closes 40 | go func() { 41 | defer close(monitorDone) 42 | for { 43 | select { 44 | case <-ctxDone: 45 | //nolint:errcheck // interrupt any ongoing read, don't care about error 46 | c.SetReadDeadline(time.Now()) 47 | return 48 | case <-readsDone: 49 | return 50 | } 51 | } 52 | }() 53 | 54 | READLOOP: 55 | for { 56 | select { 57 | case <-ctxDone: 58 | break READLOOP 59 | default: 60 | deadline, ok := ctx.Deadline() 61 | if ok { 62 | //nolint:errcheck // don't care about error 63 | c.SetReadDeadline(deadline) 64 | } else { 65 | //nolint:errcheck // don't care about error 66 | c.SetReadDeadline(time.Time{}) 67 | } 68 | n, addr, err := c.ReadFromUDP(buffer) 69 | now := time.Now() 70 | 71 | if err != nil { 72 | if errors.Is(err, net.ErrClosed) { 73 | // the socket has been closed, we're done 74 | break READLOOP 75 | } 76 | 77 | if netErr, ok := err.(net.Error); ok && netErr.Timeout() { 78 | // ignore timeouts, generally this is us poking ourselves 79 | continue 80 | } 81 | 82 | // else fall-through will send the error to the caller 83 | } 84 | 85 | var data []byte 86 | if n > 0 { 87 | data = make([]byte, n) 88 | copy(data, buffer[:n]) 89 | } else if err == nil { 90 | // macOS sometimes gives us an empty read response and no error, ignore 91 | // those 92 | continue 93 | } 94 | 95 | packets <- &networking.UDPPacket{ 96 | Time: now, 97 | Addr: addr, 98 | Data: data, 99 | Err: err, 100 | } 101 | } 102 | } 103 | 104 | close(readsDone) 105 | <-monitorDone 106 | 107 | //nolint:errcheck // don't care about error // reset read deadline on the connection to be safe 108 | c.SetReadDeadline(time.Time{}) 109 | 110 | // TODO: should we return ctx.Err() or the network closing error here? 111 | return nil 112 | } 113 | -------------------------------------------------------------------------------- /internal/networking/native/utils_test.go: -------------------------------------------------------------------------------- 1 | package native 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func containsIface(ifaces []*GoInterface, predicate func(*GoInterface) bool) bool { 11 | for _, iface := range ifaces { 12 | if predicate(iface) { 13 | return true 14 | } 15 | } 16 | return false 17 | } 18 | 19 | func mustNetInterface(t *testing.T) func(iface *net.Interface, err error) *net.Interface { 20 | return func(iface *net.Interface, err error) *net.Interface { 21 | require.Nil(t, err) 22 | require.NotNil(t, iface) 23 | return iface 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /internal/networking/types.go: -------------------------------------------------------------------------------- 1 | package networking 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net" 8 | "time" 9 | 10 | "github.com/fastcat/wirelink/internal" 11 | ) 12 | 13 | // Environment represents the top level abstraction of the system's networking 14 | // environment. 15 | type Environment interface { 16 | io.Closer 17 | // Interfaces is typically a wrapper for net.Interfaces() 18 | Interfaces() ([]Interface, error) 19 | // InterfaceByName looks up an interface by its name 20 | InterfaceByName(string) (Interface, error) 21 | 22 | // ListenUDP abstracts net.ListenUDP 23 | ListenUDP(network string, laddr *net.UDPAddr) (UDPConn, error) 24 | 25 | // NewWgClient creates a wireguard client interface for the host 26 | NewWgClient() (internal.WgClient, error) 27 | } 28 | 29 | // Interface represents a single network interface 30 | type Interface interface { 31 | Name() string 32 | IsUp() bool 33 | Addrs() ([]net.IPNet, error) 34 | AddAddr(net.IPNet) error 35 | } 36 | 37 | // ErrAddAddrUnsupported is returned from Interface.AddAddr on platforms where a 38 | // network interface configuration API is not yet supported 39 | var ErrAddAddrUnsupported = errors.New("this platform doesn't support interface configuration yet") 40 | 41 | // UDPConn abstracts net.UDPConn 42 | type UDPConn interface { 43 | io.Closer 44 | 45 | SetReadDeadline(t time.Time) error 46 | SetWriteDeadline(t time.Time) error 47 | 48 | ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) 49 | WriteToUDP(p []byte, addr *net.UDPAddr) (n int, err error) 50 | 51 | // ReadPackets reads packets from the connection until it is either closed, 52 | // or the passed context is cancelled. 53 | // Packets or errors (other than the connection being closed) will be sent 54 | // to the output channel, which will be closed when this routine finishes. 55 | // Closing the connection is always the responsibility of the caller. 56 | ReadPackets( 57 | ctx context.Context, 58 | maxSize int, 59 | output chan<- *UDPPacket, 60 | ) error 61 | } 62 | 63 | // UDPPacket represents a single result from ReadFromUDP, wrapped in a struct 64 | // so that it can be sent on a channel. 65 | type UDPPacket struct { 66 | Time time.Time 67 | Data []byte 68 | Addr *net.UDPAddr 69 | Err error 70 | } 71 | -------------------------------------------------------------------------------- /internal/networking/vnet/doc.go: -------------------------------------------------------------------------------- 1 | // Package vnet provides a virtual (as opposed to mocked) implementation of the 2 | // abstracted UDP networking stack. Multiple virtual hosts can be created with 3 | // network linkages between them to simulate packet flows. 4 | package vnet 5 | -------------------------------------------------------------------------------- /internal/networking/vnet/host-environment.go: -------------------------------------------------------------------------------- 1 | package vnet 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | 7 | "github.com/fastcat/wirelink/internal/networking" 8 | ) 9 | 10 | // hostEnvironment wraps a Host to make it function as a virtual UDP networking 11 | // Environment. 12 | type hostEnvironment struct { 13 | h *Host 14 | } 15 | 16 | var _ networking.Environment = &hostEnvironment{} 17 | 18 | // Wrap provides an Environment view of a Host 19 | func (h *Host) Wrap() networking.Environment { 20 | return &hostEnvironment{h} 21 | } 22 | 23 | // Close implements Environment 24 | func (he *hostEnvironment) Close() error { 25 | h := he.h 26 | if h == nil { 27 | return errors.New("already closed") 28 | } 29 | h.Close() 30 | return nil 31 | } 32 | 33 | // Interfaces implements Environment 34 | func (he *hostEnvironment) Interfaces() ([]networking.Interface, error) { 35 | he.h.m.Lock() 36 | defer he.h.m.Unlock() 37 | ret := make([]networking.Interface, 0, len(he.h.interfaces)) 38 | for _, i := range he.h.interfaces { 39 | ret = append(ret, i.Wrap()) 40 | } 41 | return ret, nil 42 | } 43 | 44 | // InterfaceByName implements Environment 45 | func (he *hostEnvironment) InterfaceByName(name string) (networking.Interface, error) { 46 | he.h.m.Lock() 47 | defer he.h.m.Unlock() 48 | i := he.h.interfaces[name] 49 | if i != nil { 50 | return i.Wrap(), nil 51 | } 52 | return nil, &net.OpError{Op: "route", Net: "ip+net", Err: errors.New("no such network interface")} 53 | } 54 | 55 | // ListenUDP implements Environment 56 | func (he *hostEnvironment) ListenUDP(network string, laddr *net.UDPAddr) (networking.UDPConn, error) { 57 | // TODO: validate network & local address 58 | 59 | // if laddr specifies a zone, try to add the socket to a specific interface 60 | if laddr.Zone != "" { 61 | he.h.m.Lock() 62 | defer he.h.m.Unlock() 63 | i := he.h.interfaces[laddr.Zone] 64 | if i == nil { 65 | // TODO: this error probably isn't quite right 66 | return nil, &net.OpError{Op: "listen", Net: network, Addr: laddr, Err: errors.New("no such network interface")} 67 | } 68 | s := i.AddSocket(laddr) 69 | return s.Connect(), nil 70 | } 71 | 72 | // this one will take its own lock, so we don't take our own on this branch 73 | s := he.h.AddSocket(laddr) 74 | return s.Connect(), nil 75 | } 76 | -------------------------------------------------------------------------------- /internal/networking/vnet/interface-wrap.go: -------------------------------------------------------------------------------- 1 | package vnet 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/fastcat/wirelink/internal/networking" 7 | ) 8 | 9 | type wrappedPhy struct { 10 | i *PhysicalInterface 11 | } 12 | 13 | // Wrap implements Interface 14 | func (i *PhysicalInterface) Wrap() networking.Interface { 15 | return &wrappedPhy{i} 16 | } 17 | 18 | // Name implements Interface 19 | func (i *wrappedPhy) Name() string { 20 | i.i.m.Lock() 21 | defer i.i.m.Unlock() 22 | return i.i.name 23 | } 24 | 25 | // IsUp implements interface 26 | func (i *wrappedPhy) IsUp() bool { 27 | i.i.m.Lock() 28 | defer i.i.m.Unlock() 29 | return i.i.network != nil 30 | } 31 | 32 | // Addrs implements Interface 33 | func (i *wrappedPhy) Addrs() ([]net.IPNet, error) { 34 | i.i.m.Lock() 35 | defer i.i.m.Unlock() 36 | ret := make([]net.IPNet, 0, len(i.i.addrs)) 37 | for _, a := range i.i.addrs { 38 | ret = append(ret, a) 39 | } 40 | return ret, nil 41 | } 42 | 43 | // AddAddr implements Interface 44 | func (i *wrappedPhy) AddAddr(addr net.IPNet) error { 45 | i.i.AddAddr(addr) 46 | return nil 47 | } 48 | 49 | type wrappedTun struct { 50 | t *Tunnel 51 | } 52 | 53 | // Wrap implements Interface 54 | func (t *Tunnel) Wrap() networking.Interface { 55 | return &wrappedTun{t} 56 | } 57 | 58 | // Name implements Interface 59 | func (t *wrappedTun) Name() string { 60 | t.t.m.Lock() 61 | defer t.t.m.Unlock() 62 | return t.t.name 63 | } 64 | 65 | // IsUp implements interface 66 | func (t *wrappedTun) IsUp() bool { 67 | t.t.m.Lock() 68 | defer t.t.m.Unlock() 69 | return t.t.upstream != nil 70 | } 71 | 72 | // Addrs implements Interface 73 | func (t *wrappedTun) Addrs() ([]net.IPNet, error) { 74 | t.t.m.Lock() 75 | defer t.t.m.Unlock() 76 | ret := make([]net.IPNet, 0, len(t.t.addrs)) 77 | for _, a := range t.t.addrs { 78 | ret = append(ret, a) 79 | } 80 | return ret, nil 81 | } 82 | 83 | // AddAddr implements Interface 84 | func (t *wrappedTun) AddAddr(addr net.IPNet) error { 85 | t.t.AddAddr(addr) 86 | return nil 87 | } 88 | -------------------------------------------------------------------------------- /internal/networking/vnet/interface.go: -------------------------------------------------------------------------------- 1 | package vnet 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | 7 | "github.com/fastcat/wirelink/internal/networking" 8 | "github.com/fastcat/wirelink/util" 9 | ) 10 | 11 | // A SocketOwner can send packets 12 | type SocketOwner interface { 13 | OutboundPacket(*Packet) bool 14 | AddSocket(a *net.UDPAddr) *Socket 15 | DelSocket(*Socket) 16 | } 17 | 18 | // An Interface is any network interface, whether physical or virtual, attached 19 | // to a Host, which can be used to send and receive Packets. 20 | type Interface interface { 21 | SocketOwner 22 | Name() string 23 | DetachFromNetwork() 24 | Wrap() networking.Interface 25 | Addrs() []net.IPNet 26 | } 27 | 28 | // BaseInterface handles the common elements of both physical and tunnel 29 | // Interfaces 30 | type BaseInterface struct { 31 | m *sync.Mutex 32 | id string 33 | name string 34 | world *World 35 | host *Host 36 | addrs map[string]net.IPNet 37 | sockets map[string]*Socket 38 | self Interface 39 | } 40 | 41 | // Addrs fetches a list of the currently assigned addresses on the interface 42 | func (i *BaseInterface) Addrs() []net.IPNet { 43 | i.m.Lock() 44 | ret := make([]net.IPNet, 0, len(i.addrs)) 45 | for _, a := range i.addrs { 46 | ret = append(ret, util.CloneIPNet(a)) 47 | } 48 | i.m.Unlock() 49 | return ret 50 | } 51 | 52 | // Name gets the host-local name of the interface 53 | func (i *BaseInterface) Name() string { 54 | return i.name 55 | } 56 | 57 | // AddAddr adds an IP address to the interface on which it can receive packets 58 | // and from which it can send them 59 | func (i *BaseInterface) AddAddr(a net.IPNet) { 60 | i.m.Lock() 61 | defer i.m.Unlock() 62 | // TODO: clone address so caller can't break it 63 | i.addrs[a.String()] = a 64 | } 65 | 66 | // AddSocket creates a new socket on the interface 67 | func (i *BaseInterface) AddSocket(a *net.UDPAddr) *Socket { 68 | ret := &Socket{ 69 | m: &sync.Mutex{}, 70 | sender: i.self, 71 | addr: a, 72 | } 73 | i.m.Lock() 74 | // TODO: validate addr is unique-ish 75 | i.sockets[a.String()] = ret 76 | i.m.Unlock() 77 | // TODO: goroutines to process packets 78 | return ret 79 | } 80 | 81 | // DelSocket unregisters a socket from the interface 82 | func (i *BaseInterface) DelSocket(s *Socket) { 83 | i.m.Lock() 84 | defer i.m.Unlock() 85 | // TODO: verify socket matches map entry 86 | delete(i.sockets, s.addr.String()) 87 | } 88 | 89 | // InboundPacket inspects the packet to see if its destination matches any 90 | // address on the interface and any listening socket, and if so enqueues it 91 | // for that listener 92 | func (i *BaseInterface) InboundPacket(p *Packet) bool { 93 | i.m.Lock() 94 | var rs *Socket 95 | if !destinationAddrMatch(p, i.addrs) { 96 | // not for this interface, not implementing forwarding, so drop it 97 | i.m.Unlock() 98 | return false 99 | } 100 | rs = destinationSocket(p, i.sockets) 101 | h := i.host 102 | i.m.Unlock() 103 | 104 | // if we found an interface-specific socket that matched, send it there, 105 | // else forward it to the host to try non-specific sockets 106 | if rs != nil { 107 | return rs.InboundPacket(p) 108 | } 109 | if h != nil { 110 | return h.InboundPacket(p) 111 | } 112 | return false 113 | } 114 | -------------------------------------------------------------------------------- /internal/networking/vnet/network.go: -------------------------------------------------------------------------------- 1 | package vnet 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // A Network represents a connected region within which packets can pass among 8 | // Interfaces 9 | type Network struct { 10 | m *sync.Mutex 11 | id string 12 | world *World 13 | interfaces map[string]*PhysicalInterface 14 | } 15 | 16 | // EnqueuePacket enqueues a packet to deliver to the network 17 | func (n *Network) EnqueuePacket(p *Packet) bool { 18 | n.m.Lock() 19 | var dest *PhysicalInterface 20 | for _, i := range n.interfaces { 21 | if destinationAddrMatch(p, i.addrs) { 22 | dest = i 23 | break 24 | } 25 | } 26 | n.m.Unlock() 27 | if dest == nil { 28 | return false 29 | } 30 | return dest.InboundPacket(p) 31 | } 32 | -------------------------------------------------------------------------------- /internal/networking/vnet/packet.go: -------------------------------------------------------------------------------- 1 | package vnet 2 | 3 | import "net" 4 | 5 | // A Packet represents a UDP packet traveling on the virtual network 6 | type Packet struct { 7 | src, dest *net.UDPAddr 8 | data []byte 9 | encapsulated *Packet 10 | } 11 | -------------------------------------------------------------------------------- /internal/networking/vnet/phy.go: -------------------------------------------------------------------------------- 1 | package vnet 2 | 3 | import ( 4 | "net" 5 | ) 6 | 7 | // An PhysicalInterface represents a network interface on a Host that is part of 8 | // some World. The interface may be attached and detached from various Networks 9 | // in the World. 10 | type PhysicalInterface struct { 11 | BaseInterface 12 | network *Network 13 | } 14 | 15 | var _ Interface = &PhysicalInterface{} 16 | 17 | // DetachFromNetwork disconnects the interface from its network, if any. 18 | func (i *PhysicalInterface) DetachFromNetwork() { 19 | i.m.Lock() 20 | defer i.m.Unlock() 21 | if i.network == nil { 22 | return 23 | } 24 | i.network.m.Lock() 25 | defer i.network.m.Unlock() 26 | delete(i.network.interfaces, i.id) 27 | i.network = nil 28 | } 29 | 30 | // AttachToNetwork connects this interface to a given network, 31 | // allowing it to send packets to other hosts on the network 32 | func (i *PhysicalInterface) AttachToNetwork(n *Network) { 33 | i.DetachFromNetwork() 34 | 35 | i.m.Lock() 36 | defer i.m.Unlock() 37 | n.m.Lock() 38 | defer n.m.Unlock() 39 | // TODO: validate network is from the same world 40 | i.network = n 41 | n.interfaces[i.id] = i 42 | } 43 | 44 | // OutboundPacket enqueues the packet to be sent out the interface into the 45 | // network, if possible 46 | func (i *PhysicalInterface) OutboundPacket(p *Packet) bool { 47 | i.m.Lock() 48 | // we assume connectivity based on the network, don't really care about ip subnets 49 | if !destinationSubnetMatch(p, i.addrs) { 50 | i.m.Unlock() 51 | return false 52 | } 53 | // TODO: bogon detection (src addr match)? 54 | n := i.network 55 | if n == nil { 56 | i.m.Unlock() 57 | return false 58 | } 59 | 60 | // fixup source addr 61 | if p.src.IP.Equal(net.IPv4zero) || p.src.IP.Equal(net.IPv6zero) { 62 | // TODO: makes assumptions about multiple addrs on interface 63 | for _, addr := range i.addrs { 64 | // TODO: this fails the race detector sometimes, vs the reads in destinationSocket 65 | p.src.IP = addr.IP 66 | break 67 | } 68 | } 69 | 70 | i.m.Unlock() 71 | 72 | return n.EnqueuePacket(p) 73 | } 74 | -------------------------------------------------------------------------------- /internal/networking/vnet/socket.go: -------------------------------------------------------------------------------- 1 | package vnet 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | ) 7 | 8 | // A Socket represents a listening UDP socket which can send and receive 9 | // Packets 10 | type Socket struct { 11 | m *sync.Mutex 12 | sender SocketOwner 13 | addr *net.UDPAddr 14 | rx func(*Packet) bool 15 | // TODO: encapsulation flag to filter allowed packet kinds 16 | } 17 | 18 | // InboundPacket enqueues a packet for the receive listener to process 19 | func (s *Socket) InboundPacket(p *Packet) bool { 20 | s.m.Lock() 21 | rx := s.rx 22 | s.m.Unlock() 23 | 24 | if rx == nil { 25 | return false 26 | } 27 | return rx(p) 28 | } 29 | 30 | // OutboundPacket sends a packet out the socket's interface 31 | func (s *Socket) OutboundPacket(p *Packet) bool { 32 | s.m.Lock() 33 | sender := s.sender 34 | s.m.Unlock() 35 | if sender != nil { 36 | return sender.OutboundPacket(p) 37 | } 38 | return false 39 | } 40 | 41 | // Close shuts down a socket 42 | func (s *Socket) Close() { 43 | s.m.Lock() 44 | s._close() 45 | s.m.Unlock() 46 | } 47 | 48 | // _close does the close work without the lock held 49 | func (s *Socket) _close() { 50 | if s.sender != nil { 51 | s.sender.DelSocket(s) 52 | s.sender = nil 53 | } 54 | s.rx = nil 55 | } 56 | -------------------------------------------------------------------------------- /internal/networking/vnet/util.go: -------------------------------------------------------------------------------- 1 | package vnet 2 | 3 | import "net" 4 | 5 | func subnetMatch(ip net.IP, addrs map[string]net.IPNet) bool { 6 | for _, a := range addrs { 7 | am := a.IP.Mask(a.Mask) 8 | pm := ip.Mask(a.Mask) 9 | if am.Equal(pm) { 10 | return true 11 | } 12 | } 13 | return false 14 | } 15 | 16 | func destinationAddrMatch(p *Packet, addrs map[string]net.IPNet) bool { 17 | ip := p.dest.IP 18 | for _, a := range addrs { 19 | if ip.Equal(a.IP) { 20 | return true 21 | } 22 | } 23 | return false 24 | } 25 | 26 | func destinationSubnetMatch(p *Packet, addrs map[string]net.IPNet) bool { 27 | return subnetMatch(p.dest.IP, addrs) 28 | } 29 | 30 | func sourceSubnetMatch(p *Packet, addrs map[string]net.IPNet) bool { 31 | return subnetMatch(p.src.IP, addrs) 32 | } 33 | 34 | func destinationSocket(p *Packet, sockets map[string]*Socket) *Socket { 35 | for _, s := range sockets { 36 | if s.addr.Port != p.dest.Port { 37 | continue 38 | } 39 | if s.addr.IP.Equal(net.IPv4zero) || s.addr.IP.Equal(net.IPv6zero) || s.addr.IP.Equal(p.dest.IP) { 40 | // TODO: differentiate v4 and v6 any addresses here 41 | return s 42 | } 43 | } 44 | return nil 45 | } 46 | -------------------------------------------------------------------------------- /internal/networking/vnet/world.go: -------------------------------------------------------------------------------- 1 | package vnet 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | ) 7 | 8 | // A World represents a global set of networks, hosts, and their interfaces. 9 | type World struct { 10 | m *sync.Mutex 11 | networks map[string]*Network 12 | hosts map[string]*Host 13 | } 14 | 15 | // NewWorld initializes a new empty world to which hosts and networks can be 16 | // added. 17 | func NewWorld() *World { 18 | ret := &World{ 19 | m: &sync.Mutex{}, 20 | networks: map[string]*Network{}, 21 | hosts: map[string]*Host{}, 22 | } 23 | return ret 24 | } 25 | 26 | // CreateNetwork creates and attaches a new Network with the given id to the 27 | // world 28 | func (w *World) CreateNetwork(id string) *Network { 29 | w.m.Lock() 30 | defer w.m.Unlock() 31 | // TODO: validate id is unique 32 | ret := &Network{ 33 | m: &sync.Mutex{}, 34 | id: id, 35 | world: w, 36 | interfaces: map[string]*PhysicalInterface{}, 37 | } 38 | // TODO: more network initialization 39 | w.networks[id] = ret 40 | return ret 41 | } 42 | 43 | // CreateEmptyHost creates a new Host within the world, with no interfaces 44 | func (w *World) CreateEmptyHost(id string) *Host { 45 | w.m.Lock() 46 | defer w.m.Unlock() 47 | // TODO: validate id is unique 48 | ret := &Host{ 49 | m: &sync.Mutex{}, 50 | id: id, 51 | world: w, 52 | interfaces: map[string]Interface{}, 53 | sockets: map[string]*Socket{}, 54 | } 55 | // TODO: more host initialization 56 | w.hosts[id] = ret 57 | return ret 58 | } 59 | 60 | // CreateHost creates a simple host within the world, with its 'lo' (localhost) 61 | // interface pre-configured 62 | func (w *World) CreateHost(id string) *Host { 63 | h := w.CreateEmptyHost(id) 64 | n := h.AddPhy("lo") 65 | n.AddAddr(net.IPNet{ 66 | IP: net.IPv4(127, 0, 0, 1), 67 | Mask: net.CIDRMask(8, 32), 68 | }) 69 | n.AddAddr(net.IPNet{ 70 | IP: net.IPv6loopback, 71 | Mask: net.CIDRMask(128, 128), 72 | }) 73 | return h 74 | } 75 | -------------------------------------------------------------------------------- /internal/testutils/ci-scaling.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "math" 5 | "os" 6 | "runtime" 7 | "strconv" 8 | "time" 9 | ) 10 | 11 | // this file has runtime-computed constants to compensate for poor performance 12 | // of CI VMs compared to dedicated developer systems, when running performance 13 | // tests 14 | 15 | // empirical measure on the slowest system tested that passes tests 16 | const baseline = 1_000_000_000 17 | 18 | // CIScaleFactor is an approximate scaling factor by which to multiply time 19 | // deadlines in performance-sensitive tests to compensate for the running system 20 | // being slower than the reference system. 21 | var CIScaleFactor int 22 | 23 | // CIScaleFactorDuration is just CIScaleFactor cast to a time.Duration for 24 | // simplicity. 25 | var CIScaleFactorDuration time.Duration 26 | 27 | // CIScaleMs is time.Millisecond * CIScaleFactor 28 | var CIScaleMs time.Duration // nolint:revive,staticcheck // this is like `time.Millisecond` 29 | 30 | func measurePerf(target time.Duration) int { 31 | counter := 0 32 | start := time.Now() 33 | deadline := start.Add(target) 34 | now := start 35 | for ; now.Before(deadline); now = time.Now() { 36 | for i := 0; i < 1000; i++ { 37 | counter += i 38 | } 39 | } 40 | 41 | return int(float64(counter) * float64(target) / float64(now.Sub(start))) 42 | } 43 | 44 | func init() { 45 | CIScaleFactor = 1 46 | if ciScaleStr, ok := os.LookupEnv("CISCALE"); ok { 47 | if ciSFParsed, err := strconv.ParseInt(ciScaleStr, 0, 16); err != nil { 48 | panic(err) 49 | } else { 50 | CIScaleFactor = int(ciSFParsed) 51 | } 52 | } else if runtime.GOOS == "darwin" && os.Getenv("CI") != "" { 53 | // macOS runners seem to be super slow, try to compensate with a 54 | // micro-benchmark to compare vs. a reference system 55 | thisMachine := measurePerf(time.Millisecond) 56 | if thisMachine >= baseline { 57 | CIScaleFactor = 1 58 | } else { 59 | CIScaleFactor = int(math.Ceil(float64(baseline) / float64(thisMachine))) 60 | } 61 | } 62 | CIScaleFactorDuration = time.Duration(CIScaleFactor) 63 | CIScaleMs = time.Duration(CIScaleFactor) * time.Millisecond 64 | } 65 | -------------------------------------------------------------------------------- /internal/testutils/ci-scaling_test.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func Test_measurePerf(t *testing.T) { 9 | for c := 1; c <= 1000; c *= 10 { 10 | x := measurePerf(time.Millisecond * time.Duration(c)) 11 | f := float64(c) * float64(baseline) / float64(x) 12 | t.Logf("%d ms: %d = *%.1f", c, x, f) 13 | if x/c > baseline { 14 | t.Logf(" recommend increasing baseline to %d", x/c) 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /internal/testutils/doc.go: -------------------------------------------------------------------------------- 1 | // Package testutils provides utility code needed by multiple packages' test 2 | // suites, but which should not be referenced in any non-test code. 3 | package testutils 4 | -------------------------------------------------------------------------------- /internal/testutils/ip.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "math/rand" 5 | "net" 6 | "testing" 7 | ) 8 | 9 | // MakeIPv6 helps build IPv6 values in a similar method to how the "::" marker 10 | // in an IPv6 literal works 11 | func MakeIPv6(left, right []byte) net.IP { 12 | ret := make([]byte, net.IPv6len) 13 | copy(ret, left) 14 | copy(ret[net.IPv6len-len(right):], right) 15 | return ret 16 | } 17 | 18 | // MakeIPv6Net uses MakeIPv6 to create a net.IPNet with the built IP and the given 19 | // CIDR mask length 20 | func MakeIPv6Net(left, right []byte, ones int) net.IPNet { 21 | return net.IPNet{ 22 | IP: MakeIPv6(left, right), 23 | Mask: net.CIDRMask(ones, 8*net.IPv6len), 24 | } 25 | } 26 | 27 | // MakeIPv4Net creates a net.IPNet with the given address and CIDR mask length 28 | func MakeIPv4Net(a, b, c, d byte, ones int) net.IPNet { 29 | return net.IPNet{ 30 | IP: net.IPv4(a, b, c, d).To4(), 31 | Mask: net.CIDRMask(ones, 8*net.IPv4len), 32 | } 33 | } 34 | 35 | // RandIPNet generates a random IPNet of the given size, and with optional fixed left/right bytes, 36 | // and with the given CIDR prefix length 37 | func RandIPNet(t *testing.T, size int, left, right []byte, ones int) net.IPNet { 38 | ipBytes := MustRandBytes(t, make([]byte, size)) 39 | if len(left) > 0 { 40 | copy(ipBytes, left) 41 | } 42 | if len(right) > 0 { 43 | copy(ipBytes[net.IPv6len-len(right):], right) 44 | } 45 | return net.IPNet{ 46 | IP: ipBytes, 47 | Mask: net.CIDRMask(ones, 8*size), 48 | } 49 | } 50 | 51 | // RandUDP4Addr generates a random IPv4 UDP address for test purposes 52 | func RandUDP4Addr(t *testing.T) *net.UDPAddr { 53 | return &net.UDPAddr{ 54 | IP: MustRandBytes(t, make([]byte, net.IPv4len)), 55 | Port: rand.Intn(65535), 56 | } 57 | } 58 | 59 | // RandUDP6Addr generates a random IPv6 UDP address for test purposes 60 | func RandUDP6Addr(t *testing.T) *net.UDPAddr { 61 | return &net.UDPAddr{ 62 | IP: MustRandBytes(t, make([]byte, net.IPv6len)), 63 | Port: rand.Intn(65535), 64 | } 65 | } 66 | 67 | // ContainsIPNet runs a predicate across a net.IPNet slice and returns if any match was found 68 | func ContainsIPNet(addrs []net.IPNet, predicate func(net.IPNet) bool) bool { 69 | for _, addr := range addrs { 70 | if predicate(addr) { 71 | return true 72 | } 73 | } 74 | return false 75 | } 76 | -------------------------------------------------------------------------------- /internal/testutils/log.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import "github.com/fastcat/wirelink/log" 4 | 5 | // assume that if the testutils package is imported, we're running a test, 6 | // and so should enable debug logging 7 | func init() { 8 | log.SetDebug(true) 9 | log.Debug("Auto-enabled debug logging for test mode") 10 | } 11 | -------------------------------------------------------------------------------- /internal/testutils/must-bytes.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "github.com/stretchr/testify/require" 5 | 6 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 7 | ) 8 | 9 | // MustKeyPair generates a real pair of private and public keys, 10 | // panicing (via require) if this fails 11 | func MustKeyPair(t require.TestingT) (privateKey, publicKey wgtypes.Key) { 12 | priv, err := wgtypes.GeneratePrivateKey() 13 | require.Nil(t, err) 14 | pub := priv.PublicKey() 15 | return priv, pub 16 | } 17 | 18 | // MustKey uses MustRandBytes to generate a random (not crypto-valid) key 19 | func MustKey(t require.TestingT) (key wgtypes.Key) { 20 | MustRandBytes(t, key[:]) 21 | return 22 | } 23 | 24 | // MustParseKey parses the string version of a wireguard key, panicing via 25 | // require if it fails, returning the parsed key otherwise 26 | func MustParseKey(t require.TestingT, s string) wgtypes.Key { 27 | k, err := wgtypes.ParseKey(s) 28 | require.NoError(t, err) 29 | return k 30 | } 31 | 32 | // MustRandBytes fills the given slice with random bytes using rand.Read 33 | func MustRandBytes(t require.TestingT, data []byte) []byte { 34 | n, err := Rand.Read(data) 35 | require.Nil(t, err) 36 | require.Equal(t, len(data), n) 37 | return data 38 | } 39 | -------------------------------------------------------------------------------- /internal/testutils/output-capture.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "io" 5 | "os" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | // CaptureOutput runs f with os.Stdout redirected to a temp file, 12 | // and then re-reads everything that is written and returns it 13 | func CaptureOutput(t *testing.T, f func()) []byte { 14 | originalOutput := os.Stdout 15 | 16 | tempfile, err := os.CreateTemp(t.TempDir(), "wirelink-test-output-capture") 17 | require.NoError(t, err) 18 | defer tempfile.Close() 19 | defer os.Remove(tempfile.Name()) 20 | 21 | func() { 22 | os.Stdout = tempfile 23 | defer func() { os.Stdout = originalOutput }() 24 | f() 25 | }() 26 | 27 | _, err = tempfile.Seek(0, 0) 28 | require.NoError(t, err) 29 | data, err := io.ReadAll(tempfile) 30 | require.NoError(t, err) 31 | 32 | return data 33 | } 34 | -------------------------------------------------------------------------------- /internal/testutils/paths.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "path" 5 | "runtime" 6 | ) 7 | 8 | // SrcDirectory uses the call stack to compute the directory of the caller's source file. 9 | func SrcDirectory() string { 10 | _, filename, _, _ := runtime.Caller(1) 11 | return path.Dir(filename) 12 | } 13 | -------------------------------------------------------------------------------- /internal/testutils/rand.go: -------------------------------------------------------------------------------- 1 | package testutils 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "time" 7 | ) 8 | 9 | // make sure tests are really random 10 | func init() { 11 | seed := time.Now().UnixNano() 12 | fmt.Printf("Today's seed is %v\n", seed) 13 | Rand = rand.New(rand.NewSource(seed)) 14 | } 15 | 16 | // Rand is a per-run initialized non-crypto RNG 17 | var Rand *rand.Rand 18 | -------------------------------------------------------------------------------- /internal/version.go.in: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | // Version is the version from git describe used to build this release 4 | const Version = "__GIT_VERSION__" 5 | -------------------------------------------------------------------------------- /log/log.go: -------------------------------------------------------------------------------- 1 | // Package log is a bit of a ridiculous package to have, but the built in `log` package 2 | // always writes to stderr, and fancy structured logging isn't what this tool needs 3 | package log 4 | 5 | import ( 6 | "fmt" 7 | "os" 8 | "time" 9 | ) 10 | 11 | // TODO: copy efficient buffer management from core `log` package 12 | 13 | // Info writes a formatted string with an appended newline to Stdout. 14 | // errors are ignored. 15 | func Info(format string, a ...interface{}) { 16 | if format[len(format)-1] != '\n' { 17 | format = format + "\n" 18 | } 19 | if debugEnabled { 20 | // format = time.Now().Format(debugStampFormat) + " " + format 21 | format = debugOffset() + " " + format 22 | } 23 | fmt.Printf(format, a...) 24 | } 25 | 26 | // Error writes a formatted string with an appended newline to Stderr. 27 | // errors are ignored. 28 | func Error(format string, a ...interface{}) { 29 | if format[len(format)-1] != '\n' { 30 | format = format + "\n" 31 | } 32 | if debugEnabled { 33 | // format = time.Now().Format(debugStampFormat) + " " + format 34 | format = debugOffset() + " " + format 35 | } 36 | fmt.Fprintf(os.Stderr, format, a...) 37 | } 38 | 39 | var ( 40 | debugEnabled bool 41 | debugReference time.Time 42 | ) 43 | 44 | // similar to RFC3339 45 | // const debugStampFormat = "2006-01-02T15:04:05.999" 46 | 47 | // SetDebug controls whether Debug does anything 48 | func SetDebug(enabled bool) { 49 | debugEnabled = enabled 50 | if enabled { 51 | debugReference = time.Now() 52 | } 53 | } 54 | 55 | func debugOffset() string { 56 | return time.Since(debugReference).Round(time.Millisecond).String() 57 | } 58 | 59 | // IsDebug returns whether debug logging is enabled 60 | func IsDebug() bool { 61 | return debugEnabled 62 | } 63 | 64 | // Debug writes a formatted string with an appended newline to Stdout, if enabled. 65 | // errors are ignored. 66 | func Debug(format string, a ...interface{}) { 67 | if !debugEnabled { 68 | return 69 | } 70 | if format[len(format)-1] != '\n' { 71 | format = format + "\n" 72 | } 73 | // format = time.Now().Format(debugStampFormat) + " " + format 74 | format = debugOffset() + " " + format 75 | fmt.Printf(format, a...) 76 | } 77 | -------------------------------------------------------------------------------- /magefiles/.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | -------------------------------------------------------------------------------- /magefiles/ci.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "strconv" 6 | 7 | "github.com/magefile/mage/sh" 8 | ) 9 | 10 | func Deepen(ctx context.Context) error { 11 | if err := sh.RunV("git", "fetch", "origin", "+refs/tags/*:refs/tags/*"); err != nil { 12 | return err 13 | } 14 | d := 30 15 | for { 16 | if _, err := getVersions(ctx); err == nil { 17 | break 18 | } 19 | if err := sh.RunV("git", "fetch", "--deepen="+strconv.Itoa(d), "origin", "+refs/tags/*:refs/tags/*"); err != nil { 20 | return err 21 | } 22 | d += 10 23 | } 24 | return nil 25 | } 26 | -------------------------------------------------------------------------------- /magefiles/debug.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "path/filepath" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/magefile/mage/mg" 13 | "github.com/magefile/mage/sh" 14 | ) 15 | 16 | func PreDebug(ctx context.Context) error { 17 | // compile it once to get the intermediate file 18 | if err := sh.RunV("go", "tool", "mage", "-f", "-keep", "-compile", "build"); err != nil { 19 | return err 20 | } 21 | srcs, err := filepath.Glob("magefiles/*.go") 22 | if err != nil { 23 | return err 24 | } 25 | for i := range srcs { 26 | srcs[i] = filepath.Base(srcs[i]) 27 | } 28 | // compile it again for debugging 29 | c := exec.CommandContext(ctx, "go", "build", "-o", "build", "-gcflags=-N -l") 30 | c.Args = append(c.Args, srcs...) 31 | c.Stdin, c.Stdout, c.Stderr = nil, os.Stdout, os.Stderr 32 | c.Dir = "./magefiles" 33 | if mg.Verbose() { 34 | quoted := make([]string, 0, len(c.Args)-1) 35 | for _, a := range c.Args[1:] { 36 | quoted = append(quoted, strconv.Quote(a)) 37 | } 38 | fmt.Printf("exec: %s %s\n", filepath.Base(c.Path), strings.Join(quoted, " ")) 39 | } 40 | if err := c.Run(); err != nil { 41 | return err 42 | } 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /magefiles/dirty.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/magefile/mage/mg" 9 | "github.com/magefile/mage/target" 10 | ) 11 | 12 | type IfDirty struct { 13 | outputs []string 14 | inputs []string 15 | cmd func(context.Context) error 16 | } 17 | 18 | func ifDirty(outputs ...string) *IfDirty { 19 | return &IfDirty{outputs: outputs} 20 | } 21 | func (id *IfDirty) from(inputs ...string) *IfDirty { 22 | id.inputs = append(id.inputs, inputs...) 23 | return id 24 | } 25 | func (id *IfDirty) then(cmd func(context.Context) error) *IfDirty { 26 | id.cmd = cmd 27 | return id 28 | } 29 | 30 | func (id *IfDirty) run(ctx context.Context) error { 31 | anyDirty := false 32 | for _, dst := range id.outputs { 33 | if dirty, err := target.Dir(dst, id.inputs...); err != nil { 34 | return err 35 | } else if dirty { 36 | anyDirty = true 37 | break 38 | } 39 | } 40 | if !anyDirty { 41 | if mg.Verbose() { 42 | fmt.Printf("clean: %s\n", strings.Join(id.outputs, ", ")) 43 | } 44 | return nil 45 | } 46 | return id.cmd(ctx) 47 | } 48 | -------------------------------------------------------------------------------- /magefiles/generate.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "os" 7 | 8 | "github.com/magefile/mage/mg" 9 | "github.com/magefile/mage/sh" 10 | ) 11 | 12 | var generatedSources = []string{ 13 | "internal/version.go", 14 | } 15 | var goGeneratedSources = []string{ 16 | "internal/mocks/WgClient.go", 17 | "trust/mock_Evaluator_test.go", 18 | "internal/networking/mocks/Environment.go", 19 | "internal/networking/mocks/Interface.go", 20 | "internal/networking/mocks/UDPConn.go", 21 | } 22 | var generated = append(append([]string{}, generatedSources...), goGeneratedSources...) 23 | 24 | func Generate(ctx context.Context) error { 25 | runs := make([]any, len(generators)) 26 | for i := range generators { 27 | // can't pass it g.run because those will all look like the same target to 28 | // it and get deduplicated 29 | runs[i] = mg.F(runGen, i) 30 | } 31 | mg.CtxDeps(ctx, runs...) 32 | return nil 33 | } 34 | 35 | func runGen(ctx context.Context, i int) error { 36 | return generators[i].run(ctx) 37 | } 38 | 39 | var generators = []*IfDirty{ 40 | ifDirty("internal/version.go"). 41 | from("internal/version.go.in", ".git/HEAD", ".git/index"). 42 | then(buildVersion), 43 | ifDirty(goGeneratedSources...).then(goGenerate), 44 | } 45 | 46 | func buildVersion(ctx context.Context) error { 47 | in, err := os.ReadFile("internal/version.go.in") 48 | if err != nil { 49 | return err 50 | } 51 | v, err := getVersions(ctx) 52 | if err != nil { 53 | return err 54 | } 55 | out := bytes.ReplaceAll(in, []byte("__GIT_VERSION__"), []byte(v.pkgVerRel)) 56 | if err := safeOverwrite("internal/version.go", out, 0666); err != nil { 57 | return err 58 | } 59 | return nil 60 | } 61 | 62 | func goGenerate(ctx context.Context) error { 63 | return sh.RunV("go", "generate", "./...") 64 | } 65 | 66 | func safeOverwrite(dst string, content []byte, perm os.FileMode) error { 67 | tmp := dst + ".tmp" 68 | if err := os.WriteFile(tmp, content, perm); err != nil { 69 | return err 70 | } 71 | return os.Rename(tmp, dst) 72 | } 73 | -------------------------------------------------------------------------------- /magefiles/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fastcat/wirelink/magefiles 2 | 3 | go 1.24.0 4 | 5 | require github.com/magefile/mage v1.15.0 6 | 7 | require ( 8 | github.com/bitfield/gotestdox v0.2.2 // indirect 9 | github.com/dnephin/pflag v1.0.7 // indirect 10 | github.com/fatih/color v1.18.0 // indirect 11 | github.com/fsnotify/fsnotify v1.8.0 // indirect 12 | github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect 13 | github.com/mattn/go-colorable v0.1.14 // indirect 14 | github.com/mattn/go-isatty v0.0.20 // indirect 15 | github.com/rogpeppe/go-internal v1.13.1 // indirect 16 | golang.org/x/mod v0.24.0 // indirect 17 | golang.org/x/sync v0.14.0 // indirect 18 | golang.org/x/sys v0.33.0 // indirect 19 | golang.org/x/telemetry v0.0.0-20240522233618-39ace7a40ae7 // indirect 20 | golang.org/x/term v0.32.0 // indirect 21 | golang.org/x/text v0.25.0 // indirect 22 | golang.org/x/tools v0.32.0 // indirect 23 | golang.org/x/vuln v1.1.4 // indirect 24 | gotest.tools/gotestsum v1.12.2 // indirect 25 | ) 26 | 27 | tool ( 28 | github.com/magefile/mage 29 | golang.org/x/vuln/cmd/govulncheck 30 | gotest.tools/gotestsum 31 | ) 32 | -------------------------------------------------------------------------------- /magefiles/install.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "os/exec" 8 | "path/filepath" 9 | 10 | "github.com/magefile/mage/mg" 11 | "github.com/magefile/mage/sh" 12 | ) 13 | 14 | func Install(ctx context.Context) error { 15 | mg.CtxDeps(ctx, Generate) 16 | return sh.RunV("go", "install", "-v") 17 | } 18 | 19 | func SysInstall(ctx context.Context) error { 20 | mg.CtxDeps(ctx, Wirelink) 21 | return sysinstall(ctx, "wirelink") 22 | } 23 | 24 | func SysInstallCross(ctx context.Context, arch string) error { 25 | mg.CtxDeps(ctx, mg.F(WirelinkCross, arch)) 26 | return sysinstall(ctx, "wirelink-cross-"+arch) 27 | } 28 | 29 | func sysinstall(ctx context.Context, src string) error { 30 | if err := sh.RunV("install", src, PREFIX+"/bin/wirelink"); err != nil { 31 | return err 32 | } 33 | if err := sh.RunV("install", "-m", "644", "packaging/wirelink@.service", "/lib/systemd/system/"); err != nil { 34 | return err 35 | } 36 | if err := sh.RunV("install", "-m", "644", "packaging/wl-quick@.service", "/lib/systemd/system/"); err != nil { 37 | return err 38 | } 39 | return nil 40 | } 41 | 42 | type Checkinstall mg.Namespace 43 | 44 | func (Checkinstall) Clean(ctx context.Context) error { 45 | for _, pattern := range []string{ 46 | "./packaging/*checkinstall/*.deb", 47 | "./packaging/*checkinstall/doc-pak", 48 | } { 49 | matches, err := filepath.Glob(pattern) 50 | if err != nil { 51 | return err 52 | } 53 | for _, match := range matches { 54 | if err := os.RemoveAll(match); err != nil { 55 | return err 56 | } 57 | } 58 | } 59 | return nil 60 | } 61 | 62 | func (Checkinstall) Prep(ctx context.Context, arch string) error { 63 | mg.CtxDeps(ctx, mg.F(WirelinkCross, arch)) 64 | if err := sh.RunV("go", "mod", "tidy"); err != nil { 65 | return err 66 | } 67 | if err := os.MkdirAll("packaging/wirelink-checkinstall/doc-pak", 0777); err != nil { 68 | return err 69 | } 70 | args := []string{"-m", "644"} 71 | args = append(args, DOCSFILES...) 72 | args = append(args, "./packaging/wirelink-checkinstall/doc-pak/") 73 | if err := sh.RunV("install", args...); err != nil { 74 | return err 75 | } 76 | return nil 77 | } 78 | 79 | func (Checkinstall) Cross(ctx context.Context, arch string) error { 80 | mg.CtxDeps(ctx, mg.F(Checkinstall{}.Prep, arch)) 81 | // mage's sh package doesn't provide a way to override cwd, so we do this raw. 82 | // Some args have extra quoting to work around checkinstall bugs, see: 83 | // https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=785441 84 | vi, err := getVersions(ctx) 85 | if err != nil { 86 | return err 87 | } 88 | cmd := exec.CommandContext(ctx, "fakeroot", 89 | "checkinstall", 90 | "--type=debian", 91 | "--install=no", 92 | "--fstrans=yes", 93 | "--pkgarch="+arch, 94 | "--pkgname=wirelink", 95 | "--pkgversion="+vi.pkgVer, 96 | "--pkgrelease="+vi.pkgRel, 97 | "--pkglicense=AGPL-3", 98 | "--pkggroup=net", 99 | "--pkgsource=https://github.com/fastcat/wirelink", 100 | "--maintainer='Matthew Gabeler-Lee '", 101 | "--requires=wireguard-tools", 102 | "--recommends='wireguard-dkms | wireguard-modules'", 103 | "--reset-uids=yes", 104 | "--backup=no", 105 | // the real command, need to get back to the original directory 106 | "/bin/sh", "-c", 107 | fmt.Sprintf("cd ../../ && %s %s %s", os.Args[0], "sysInstallCross", arch), 108 | ) 109 | cmd.Dir = "./packaging/checkinstall" 110 | cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr 111 | return cmd.Run() 112 | } 113 | -------------------------------------------------------------------------------- /magefiles/lint.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/magefile/mage/mg" 7 | "github.com/magefile/mage/sh" 8 | ) 9 | 10 | func LintAll(ctx context.Context) error { 11 | mg.CtxDeps(ctx, Lint{}.GolangCI, Lint{}.Vulncheck) 12 | return nil 13 | } 14 | 15 | type Lint mg.Namespace 16 | 17 | func (Lint) GolangCI(ctx context.Context) error { 18 | mg.CtxDeps(ctx, Generate) 19 | return sh.RunV("golangci-lint", "run") 20 | } 21 | func (Lint) Fix(ctx context.Context) error { 22 | mg.CtxDeps(ctx, Generate) 23 | return sh.RunV("golangci-lint", "run", "--fix") 24 | } 25 | func (Lint) Vulncheck(ctx context.Context) error { 26 | mg.CtxDeps(ctx, Generate) 27 | return sh.RunV("go", "tool", "govulncheck", "-test", "./...") 28 | } 29 | -------------------------------------------------------------------------------- /magefiles/magefile.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "path/filepath" 7 | "slices" 8 | "strings" 9 | 10 | "go/build" 11 | 12 | "github.com/magefile/mage/mg" 13 | "github.com/magefile/mage/sh" 14 | ) 15 | 16 | var Default = Everything 17 | 18 | func Compile(ctx context.Context) error { 19 | mg.CtxDeps(ctx, Generate) 20 | return sh.RunV("go", "build", "-v", "./...") 21 | } 22 | 23 | func Wirelink(ctx context.Context) error { 24 | mg.CtxDeps(ctx, Generate) 25 | return sh.RunV("go", "build", "-v", ".") 26 | } 27 | func WirelinkCross(ctx context.Context, arch string) error { 28 | mg.CtxDeps(ctx, Generate) 29 | // build these stripped 30 | return sh.RunWithV( 31 | map[string]string{ 32 | "CGO_ENABLED": "0", 33 | "GOARCH": arch, 34 | }, 35 | "go", "build", "-ldflags", "-s -w", "-o", "wirelink-cross-"+arch, "-v", ".", 36 | ) 37 | } 38 | 39 | func Run(ctx context.Context) error { 40 | mg.CtxDeps(ctx, Generate) 41 | return sh.RunV("go", "run", "-exec", "sudo", ".") 42 | } 43 | 44 | func DlvRunReal(ctx context.Context) error { 45 | mg.CtxDeps(ctx, Compile, Wirelink) 46 | return sh.RunV( 47 | "sudo", "$(GOPATH)/bin/dlv", "debug", 48 | "--only-same-user=false", 49 | "--headless", 50 | "--listen=:2345", 51 | "--log", 52 | "--api-version=2", 53 | "--", 54 | "--debug", 55 | "--iface=wg0", 56 | ) 57 | } 58 | 59 | func Everything(ctx context.Context) { 60 | mg.CtxDeps(ctx, 61 | LintAll, 62 | Compile, 63 | Wirelink, 64 | TestDefault, 65 | ) 66 | } 67 | 68 | func Clean(ctx context.Context) error { 69 | mg.CtxDeps(ctx, Checkinstall{}.Clean) 70 | for _, pattern := range []string{ 71 | "./wirelink", 72 | "./wirelink-cross-*", 73 | "./coverage.out", 74 | "./coverage.html", 75 | } { 76 | matches, err := filepath.Glob(pattern) 77 | if err != nil { 78 | return err 79 | } 80 | for _, match := range matches { 81 | if err := os.RemoveAll(match); err != nil { 82 | return err 83 | } 84 | } 85 | } 86 | for _, l := range [][]string{generatedSources, goGeneratedSources} { 87 | for _, fn := range l { 88 | if err := os.RemoveAll(fn); err != nil { 89 | return err 90 | } 91 | if err := os.RemoveAll(fn + ".tmp"); err != nil { 92 | return err 93 | } 94 | } 95 | } 96 | return nil 97 | } 98 | 99 | var Aliases = map[string]any{ 100 | "lint": LintAll, 101 | "test": TestDefault, 102 | } 103 | 104 | func init() { 105 | // make sure GOBIN is in PATH 106 | gobin := os.Getenv("GOBIN") 107 | if gobin == "" { 108 | gopath := os.Getenv("GOPATH") 109 | if gopath == "" { 110 | gopath = build.Default.GOPATH 111 | } 112 | gobin = filepath.Join(gopath, "bin") 113 | } 114 | pathEntries := filepath.SplitList(os.Getenv("PATH")) 115 | if !slices.Contains(pathEntries, gobin) { 116 | pathEntries = append([]string{gobin}, pathEntries...) 117 | os.Setenv("PATH", strings.Join(pathEntries, string(filepath.ListSeparator))) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /magefiles/test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "io/fs" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | 10 | "github.com/magefile/mage/mg" 11 | "github.com/magefile/mage/sh" 12 | ) 13 | 14 | func TestDefault(ctx context.Context) error { 15 | mg.CtxDeps(ctx, LintAll, Test{}.GoRace) 16 | return nil 17 | } 18 | 19 | type Test mg.Namespace 20 | 21 | func (Test) Go(ctx context.Context) error { 22 | mg.CtxDeps(ctx, Generate) 23 | return sh.RunV("go", "test", "-vet=off", "-timeout=20s", "./...") 24 | } 25 | func (Test) GoRace(ctx context.Context) error { 26 | mg.CtxDeps(ctx, Generate) 27 | return sh.RunV("go", "test", "-vet=off", "-timeout=1m", "-race", "./...") 28 | } 29 | 30 | func (Test) Stress(ctx context.Context) error { 31 | mg.CtxDeps(ctx, Test{}.StressGo, Test{}.StressRace) 32 | return nil 33 | } 34 | func (Test) StressGo(ctx context.Context) error { 35 | return sh.RunV("go", "test", "-vet=off", "-short", "-timeout=2m", "-count=1000", "./...") 36 | } 37 | func (Test) StressRace(ctx context.Context) error { 38 | return sh.RunV("go", "test", "-vet=off", "-short", "-timeout=5m", "-race", "-count=1000", "./...") 39 | } 40 | 41 | func (Test) Cover(ctx context.Context) error { 42 | mg.CtxDeps(ctx, Generate) 43 | return sh.RunV("go", "test", "-vet=off", "-timeout=1m", "-covermode=atomic", "-coverpkg=./...", "-coverprofile=coverage.out", "./...") 44 | } 45 | func (Test) CoverCI(ctx context.Context) error { 46 | mg.CtxDeps(ctx, Generate) 47 | return sh.RunV("go", "tool", "gotestsum", 48 | "--format=github-actions", 49 | "--junitfile=junit.xml", 50 | "--junitfile-project-name=wirelink", 51 | "--", 52 | "-vet=off", "-timeout=1m", "-covermode=atomic", "-coverpkg=./...", "-coverprofile=coverage.out", "./...", 53 | ) 54 | } 55 | func (Test) CoverHTML(ctx context.Context) error { 56 | if err := ifDirty("coverage.html").from("coverage.out").then(func(ctx context.Context) error { 57 | return sh.RunV("go", "tool", "cover", "-html=coverage.out", "-o=coverage.html") 58 | }).run(ctx); err != nil { 59 | return err 60 | } 61 | return nil 62 | } 63 | 64 | func (Test) Fuzz(ctx context.Context) error { 65 | mg.CtxDeps(ctx, Generate) 66 | 67 | fuzzerDirs := map[string]bool{} 68 | if err := filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { 69 | if err != nil { 70 | return err 71 | } else if d.IsDir() { 72 | if filepath.Base(path) == ".git" { 73 | return filepath.SkipDir 74 | } 75 | // recurse 76 | return nil 77 | } else if fuzzerDirs[filepath.Dir(path)] { 78 | // already found a fuzzer in this directory, skip it 79 | return filepath.SkipDir 80 | } else if !strings.HasSuffix(path, "_test.go") { 81 | // not a test file 82 | return nil 83 | } 84 | // TODO: stream the file, not just the lines 85 | contents, err := os.ReadFile(path) 86 | if err != nil { 87 | return err 88 | } 89 | for l := range strings.Lines(string(contents)) { 90 | if strings.HasPrefix(l, "func Fuzz") { 91 | fuzzerDirs[filepath.Dir(path)] = true 92 | // done with this dir 93 | return filepath.SkipDir 94 | } 95 | } 96 | return nil 97 | }); err != nil { 98 | return err 99 | } 100 | 101 | deps := make([]any, 0, len(fuzzerDirs)) 102 | for fuzzerDir := range fuzzerDirs { 103 | deps = append(deps, mg.F(Test{}.fuzzDir, fuzzerDir)) 104 | } 105 | mg.CtxDeps(ctx, deps...) 106 | return nil 107 | } 108 | 109 | func (Test) fuzzDir(ctx context.Context, dir string) error { 110 | return sh.RunV("go", "test", "./"+dir, "-fuzz=.*", "-fuzztime=1m") 111 | } 112 | -------------------------------------------------------------------------------- /magefiles/tools.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/magefile/mage/sh" 7 | ) 8 | 9 | var toolsDev = []string{ 10 | "github.com/go-delve/delve/cmd/dlv@latest", 11 | "github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest", 12 | } 13 | 14 | func InstallToolsDev(ctx context.Context) error { 15 | for _, t := range toolsDev { 16 | if err := sh.RunV("go", "install", t); err != nil { 17 | return err 18 | } 19 | } 20 | return nil 21 | } 22 | -------------------------------------------------------------------------------- /magefiles/vars.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "os" 4 | 5 | var PREFIX = "/usr" 6 | var DOCSFILES = []string{"LICENSE", "README.md", "TODO.md"} 7 | 8 | func init() { 9 | if p := os.Getenv("PREFIX"); p != "" { 10 | PREFIX = p 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /magefiles/versions.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "strings" 8 | 9 | "github.com/magefile/mage/sh" 10 | ) 11 | 12 | type versions struct { 13 | pkgVerRel string 14 | pkgVer string 15 | pkgRel string 16 | } 17 | 18 | func getVersions(ctx context.Context) (versions, error) { 19 | out, err := sh.Output("git", "describe", "--long", "--dirty=+") 20 | if err != nil { 21 | return versions{}, err 22 | } 23 | pkgVerRel := strings.TrimPrefix(out, "v") 24 | pkgVer, pkgRel, _ := strings.Cut(pkgVerRel, "-") 25 | return versions{pkgVerRel, pkgVer, pkgRel}, nil 26 | } 27 | 28 | func Info(ctx context.Context) error { 29 | v, err := getVersions(ctx) 30 | if err != nil { 31 | return err 32 | } 33 | fmt.Printf("PKGVERREL=%s\n", v.pkgVerRel) 34 | fmt.Printf("PKGVER=%s\n", v.pkgVer) 35 | fmt.Printf("PKGREL=%s\n", v.pkgRel) 36 | // fmt.Printf("GOPATH=%s\n", /*FIXME*/) 37 | fmt.Printf("PATH=%s\n", os.Getenv("PATH")) 38 | return nil 39 | } 40 | -------------------------------------------------------------------------------- /packaging/checkinstall/description-pak: -------------------------------------------------------------------------------- 1 | wirelink is an experimental tool to enable direct peer-to-peer connections in wireguard 2 | -------------------------------------------------------------------------------- /packaging/checkinstall/postinstall-pak: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ -d /run/systemd/system ]; then 4 | systemctl --system daemon-reload 5 | 6 | # TODO: enable the service on first install 7 | # this requires some debhelper stuff to do right 8 | 9 | # restart any active instances 10 | # technique borrowed from the wireguard package 11 | units=$(systemctl list-units --state=active --plain --no-legend 'wl-quick@*.service' 'wirelink@*.service' | awk '{print $1}') 12 | if [ -n "$units" ]; then 13 | echo "Restarting active wirelink units ($units)..." 14 | systemctl restart $units 15 | fi 16 | fi 17 | 18 | -------------------------------------------------------------------------------- /packaging/checkinstall/postremove-pak: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ -d /run/systemd/system ]; then 4 | systemctl --system daemon-reload 5 | fi 6 | -------------------------------------------------------------------------------- /packaging/checkinstall/preremove-pak: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ -d /run/systemd/system ]; then 4 | # TODO: this bit is dpkg specific, make it work for RPMs too? 5 | if [ "$1" = "remove" ]; then 6 | echo "Stopping wirelink units..." 7 | systemctl --system stop 'wl-quick@*' 'wirelink@*' 8 | fi 9 | fi 10 | -------------------------------------------------------------------------------- /packaging/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -xeuo pipefail 4 | 5 | cd "$(git rev-parse --show-toplevel)" 6 | 7 | go tool mage clean everything checkinstall:cross amd64 checkinstall:cross arm64 8 | 9 | # filter this to avoid trying to re-deploy debs we already built 10 | # assume wirelink is already new, if it fails, we want to know 11 | newdebs=( 12 | ./packaging/checkinstall/wirelink*.deb 13 | ) 14 | for distro in bullseye bookworm jammy noble ; do 15 | # run these with no stdin so we do the index export only once 16 | adddebs release=$distro "${newdebs[@]}" 255 { 24 | return nil, fmt.Errorf("ttl out of range") 25 | } 26 | 27 | expiration := now.Add(ttl) 28 | 29 | addAttr := func(attr fact.Attribute, value fact.Value) { 30 | ret = append(ret, &fact.Fact{ 31 | Attribute: attr, 32 | Subject: &fact.PeerSubject{Key: peer.PublicKey}, 33 | Value: value, 34 | Expires: expiration, 35 | }) 36 | } 37 | 38 | // the endpoint is trustable if the last handshake age is less than the TTL 39 | if peer.Endpoint != nil && peer.LastHandshakeTime.After(now.Add(-apply.HandshakeValidity)) { 40 | if peer.Endpoint.IP.To4() != nil { 41 | addAttr(fact.AttributeEndpointV4, &fact.IPPortValue{IP: peer.Endpoint.IP, Port: peer.Endpoint.Port}) 42 | } else if peer.Endpoint.IP.To16() != nil { 43 | addAttr(fact.AttributeEndpointV6, &fact.IPPortValue{IP: peer.Endpoint.IP, Port: peer.Endpoint.Port}) 44 | } 45 | } 46 | 47 | // don't publish the autoaddress, everyone can figure that out on their own, 48 | // and must already know it in order to receive the data anyways 49 | 50 | if trustLocalAIPs { 51 | for _, peerIP := range peer.AllowedIPs { 52 | if peerIP.IP.To4() != nil { 53 | addAttr(fact.AttributeAllowedCidrV4, &fact.IPNetValue{IPNet: peerIP}) 54 | } else if peerIP.IP.To16() != nil { 55 | // ignore link-local addresses, particularly the auto-generated v6 one 56 | if peerIP.IP[0] == 0xfe && peerIP.IP[1] == 0x80 { 57 | continue 58 | } 59 | addAttr(fact.AttributeAllowedCidrV6, &fact.IPNetValue{IPNet: peerIP}) 60 | } 61 | } 62 | } 63 | 64 | if trustLocalMembership { 65 | // always use MemberMetadata even if we don't have any metadata 66 | addAttr(fact.AttributeMemberMetadata, &fact.MemberMetadata{}) 67 | } 68 | 69 | return ret, nil 70 | } 71 | -------------------------------------------------------------------------------- /server/errors-generic.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | // +build !linux 3 | 4 | package server 5 | 6 | import ( 7 | "os" 8 | "syscall" 9 | ) 10 | 11 | func isSysErrUnreachable(err *os.SyscallError) bool { 12 | // EDESTADDRREQ and ENETUNREACH happen when we have a bad address for 13 | // talking to a peer, whether when inside the tunnel or for the tunnel 14 | // endpoint. EPERM and ENOKEY happens if we have no handshake. Except 15 | // ENOKEY is linux specific, so it's not checked here. 16 | return err.Err == syscall.EDESTADDRREQ || 17 | err.Err == syscall.ENETUNREACH || 18 | err.Err == syscall.EPERM 19 | } 20 | -------------------------------------------------------------------------------- /server/errors-linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | package server 5 | 6 | import ( 7 | "os" 8 | "syscall" 9 | ) 10 | 11 | func isSysErrUnreachable(err *os.SyscallError) bool { 12 | // EDESTADDRREQ and ENETUNREACH happen when we have a bad address for 13 | // talking to a peer, whether when inside the tunnel or for the tunnel 14 | // endpoint. EPERM and ENOKEY happens if we have no handshake. 15 | return err.Err == syscall.EDESTADDRREQ || 16 | err.Err == syscall.ENETUNREACH || 17 | err.Err == syscall.EPERM || 18 | err.Err == syscall.ENOKEY 19 | } 20 | -------------------------------------------------------------------------------- /server/interface-cache.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | 7 | "github.com/fastcat/wirelink/internal/networking" 8 | "github.com/fastcat/wirelink/log" 9 | ) 10 | 11 | type interfaceCache struct { 12 | mu sync.RWMutex 13 | 14 | env networking.Environment 15 | iface string 16 | 17 | // IPNets assigned to the tunnel 18 | tunnelIPNets []net.IPNet 19 | // IPNets for local network interfaces other than the tunnel 20 | hostIPNets []net.IPNet 21 | 22 | dirty bool 23 | } 24 | 25 | func newInterfaceCache(env networking.Environment, iface string) (*interfaceCache, error) { 26 | cache := &interfaceCache{ 27 | env: env, 28 | iface: iface, 29 | } 30 | if err := cache.read(); err != nil { 31 | return nil, err 32 | } 33 | return cache, nil 34 | } 35 | 36 | func (ic *interfaceCache) read() error { 37 | ifaces, err := ic.env.Interfaces() 38 | if err != nil { 39 | log.Error("unable to load network interfaces: %v", err) 40 | // don't abort the caller, continue with whatever info we have from last time 41 | return err 42 | } 43 | ic.tunnelIPNets = ic.tunnelIPNets[:0] 44 | ic.hostIPNets = ic.hostIPNets[:0] 45 | for _, iface := range ifaces { 46 | if addrs, err := iface.Addrs(); err != nil { 47 | log.Error("unable to fetch addresses from %s: %v", iface.Name(), err) 48 | } else if iface.Name() == ic.iface { 49 | ic.tunnelIPNets = append(ic.tunnelIPNets, addrs...) 50 | } else { 51 | ic.hostIPNets = append(ic.hostIPNets, addrs...) 52 | } 53 | } 54 | return nil 55 | } 56 | 57 | // Dirty marks the cached data dirty, so the next query will force a refresh. 58 | func (ic *interfaceCache) Dirty() { 59 | ic.mu.Lock() 60 | ic.dirty = true 61 | ic.mu.Unlock() 62 | } 63 | 64 | // WillTunnel heuristically checks if an IP is likely to route via the tunnel. 65 | // An IP that matches a non-tunnel interface subnet is expected not to tunnel. 66 | // An IP that doesn't match those but does match a tunnel subnet is expected to 67 | // tunnel. Any other IP is expected to not tunnel. 68 | func (ic *interfaceCache) WillTunnel(ip net.IP) bool { 69 | ic.mu.RLock() 70 | if ic.dirty { 71 | ic.mu.RUnlock() 72 | ic.mu.Lock() 73 | _ = ic.read() 74 | ic.mu.Unlock() 75 | ic.mu.RLock() 76 | } 77 | 78 | isTunnel := false 79 | for _, ipn := range ic.tunnelIPNets { 80 | if ipn.Contains(ip) { 81 | isTunnel = true 82 | break 83 | } 84 | } 85 | if !isTunnel { 86 | ic.mu.RUnlock() 87 | return false 88 | } 89 | for _, ipn := range ic.hostIPNets { 90 | if ipn.Contains(ip) { 91 | isTunnel = false 92 | break 93 | } 94 | } 95 | 96 | ic.mu.RUnlock() 97 | return isTunnel 98 | } 99 | -------------------------------------------------------------------------------- /server/peer-config-set.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/fastcat/wirelink/apply" 7 | 8 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 9 | ) 10 | 11 | type peerConfigSet struct { 12 | peerStates map[wgtypes.Key]*apply.PeerConfigState 13 | psm *sync.Mutex 14 | } 15 | 16 | func newPeerConfigSet() *peerConfigSet { 17 | return &peerConfigSet{ 18 | peerStates: make(map[wgtypes.Key]*apply.PeerConfigState), 19 | psm: new(sync.Mutex), 20 | } 21 | } 22 | 23 | func (pcs *peerConfigSet) Trim(keep func(key wgtypes.Key) bool) { 24 | pcs.psm.Lock() 25 | defer pcs.psm.Unlock() 26 | for k := range pcs.peerStates { 27 | if !keep(k) { 28 | delete(pcs.peerStates, k) 29 | } 30 | } 31 | } 32 | 33 | func (pcs *peerConfigSet) Get(key wgtypes.Key) (ret *apply.PeerConfigState, ok bool) { 34 | pcs.psm.Lock() 35 | ret, ok = pcs.peerStates[key] 36 | pcs.psm.Unlock() 37 | return 38 | } 39 | 40 | func (pcs *peerConfigSet) Set(key wgtypes.Key, value *apply.PeerConfigState) { 41 | pcs.psm.Lock() 42 | if value == nil { 43 | delete(pcs.peerStates, key) 44 | } else { 45 | pcs.peerStates[key] = value 46 | } 47 | pcs.psm.Unlock() 48 | } 49 | 50 | func (pcs *peerConfigSet) ForEach(visitor func(key wgtypes.Key, value *apply.PeerConfigState)) { 51 | pcs.psm.Lock() 52 | defer pcs.psm.Unlock() 53 | for k, v := range pcs.peerStates { 54 | visitor(k, v) 55 | } 56 | } 57 | 58 | // Clone makes a deep clone of the receiver 59 | func (pcs *peerConfigSet) Clone() *peerConfigSet { 60 | if pcs == nil { 61 | return nil 62 | } 63 | 64 | pcs.psm.Lock() 65 | defer pcs.psm.Unlock() 66 | 67 | ret := &peerConfigSet{ 68 | peerStates: make(map[wgtypes.Key]*apply.PeerConfigState), 69 | psm: &sync.Mutex{}, 70 | } 71 | for k, v := range pcs.peerStates { 72 | ret.peerStates[k] = v.Clone() 73 | } 74 | return ret 75 | } 76 | -------------------------------------------------------------------------------- /server/peer-config-set_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | "testing" 7 | 8 | "github.com/fastcat/wirelink/apply" 9 | "github.com/fastcat/wirelink/internal/testutils" 10 | 11 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 12 | 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func Test_peerConfigSet_Clone(t *testing.T) { 18 | k1 := testutils.MustKey(t) 19 | 20 | type fields struct { 21 | peerStates map[wgtypes.Key]*apply.PeerConfigState 22 | } 23 | tests := []struct { 24 | name string 25 | fields *fields 26 | }{ 27 | { 28 | "nil", 29 | nil, 30 | }, 31 | { 32 | "empty", 33 | &fields{map[wgtypes.Key]*apply.PeerConfigState{}}, 34 | }, 35 | { 36 | "filled", 37 | &fields{map[wgtypes.Key]*apply.PeerConfigState{ 38 | k1: (*apply.PeerConfigState)(nil).EnsureNotNil(), 39 | }}, 40 | }, 41 | } 42 | for _, tt := range tests { 43 | t.Run(tt.name, func(t *testing.T) { 44 | var pcs *peerConfigSet 45 | if tt.fields != nil { 46 | pcs = &peerConfigSet{ 47 | peerStates: tt.fields.peerStates, 48 | psm: &sync.Mutex{}, 49 | } 50 | } 51 | got := pcs.Clone() 52 | if tt.fields == nil { 53 | assert.Nil(t, got) 54 | return 55 | } 56 | require.NotNil(t, got) 57 | assert.Equal(t, tt.fields.peerStates, got.peerStates) 58 | assert.NotNil(t, got.psm) 59 | // make sure it really is different 60 | assert.False(t, pcs.psm == got.psm) 61 | // can't do == checks on maps 62 | assert.NotEqual(t, reflect.ValueOf(pcs.peerStates).Pointer(), reflect.ValueOf(got.peerStates).Pointer()) 63 | // if we get here, we already know the map keys are equal 64 | for k := range tt.fields.peerStates { 65 | assert.False(t, pcs.peerStates[k] == got.peerStates[k]) 66 | // inner portion of PCS is a separate test 67 | } 68 | }) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /server/peer-lookup.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | 7 | "github.com/fastcat/wirelink/autopeer" 8 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 9 | ) 10 | 11 | type ( 12 | ipBytes = [net.IPv6len]byte 13 | peerLookup struct { 14 | mu sync.RWMutex 15 | ip map[wgtypes.Key]net.IP 16 | p map[ipBytes]wgtypes.Key 17 | // TODO: LRU trimming 18 | } 19 | ) 20 | 21 | func newPeerLookup() *peerLookup { 22 | return &peerLookup{ 23 | ip: map[wgtypes.Key]net.IP{}, 24 | p: map[ipBytes]wgtypes.Key{}, 25 | } 26 | } 27 | 28 | func (pl *peerLookup) GetPeer(ip net.IP) (peer wgtypes.Key, ok bool) { 29 | var k ipBytes 30 | copy(k[:], ip.To16()) 31 | pl.mu.RLock() 32 | peer, ok = pl.p[k] 33 | pl.mu.RUnlock() 34 | return 35 | } 36 | 37 | // func (pl *peerLookup) GetIP(peer wgtypes.Key) net.IP { 38 | // pl.mu.RLock() 39 | // if ip, ok := pl.ip[peer]; ok { 40 | // pl.mu.RUnlock() 41 | // return ip 42 | // } 43 | // pl.mu.RUnlock() 44 | // pl.mu.Lock() 45 | // ip := autopeer.AutoAddress(peer) 46 | // var k ipBytes 47 | // copy(k[:], ip.To16()) 48 | // pl.ip[peer] = ip 49 | // pl.p[k] = peer 50 | // return ip 51 | // } 52 | 53 | func (pl *peerLookup) addPeers(peers ...wgtypes.Peer) { 54 | pl.mu.Lock() 55 | for _, p := range peers { 56 | if _, ok := pl.ip[p.PublicKey]; !ok { 57 | ip := autopeer.AutoAddress(p.PublicKey) 58 | var k ipBytes 59 | copy(k[:], ip.To16()) 60 | pl.ip[p.PublicKey] = ip 61 | pl.p[k] = p.PublicKey 62 | } 63 | } 64 | pl.mu.Unlock() 65 | } 66 | 67 | func (pl *peerLookup) addKeys(peers ...wgtypes.Key) { 68 | pl.mu.Lock() 69 | for _, p := range peers { 70 | if _, ok := pl.ip[p]; !ok { 71 | ip := autopeer.AutoAddress(p) 72 | var k ipBytes 73 | copy(k[:], ip.To16()) 74 | pl.ip[p] = ip 75 | pl.p[k] = p 76 | } 77 | } 78 | pl.mu.Unlock() 79 | } 80 | -------------------------------------------------------------------------------- /server/received-fact.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/fastcat/wirelink/fact" 8 | ) 9 | 10 | // ReceivedFact is a tuple of a fact and its source. 11 | // It is used for the queue of parsed packets received over the network, 12 | // to hold them in a batch before evaluating them for acceptance 13 | type ReceivedFact struct { 14 | fact *fact.Fact 15 | source net.UDPAddr 16 | } 17 | 18 | func (rf *ReceivedFact) String() string { 19 | return fmt.Sprintf("RF{%v <- %v}", rf.fact, rf.source) 20 | } 21 | 22 | var _ fmt.Stringer = &ReceivedFact{} 23 | -------------------------------------------------------------------------------- /server/utils_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net" 5 | "testing" 6 | "time" 7 | 8 | "github.com/fastcat/wirelink/apply" 9 | "github.com/fastcat/wirelink/config" 10 | "github.com/fastcat/wirelink/internal/testutils" 11 | "github.com/stretchr/testify/assert" 12 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 13 | ) 14 | 15 | func makePCS(t *testing.T, healthy, alive, aliveLong bool) *apply.PeerConfigState { 16 | ret := &apply.PeerConfigState{} 17 | now := time.Now() 18 | handshake := now 19 | var until time.Time 20 | if !healthy { 21 | handshake = now.Add(-time.Hour) 22 | } 23 | if alive { 24 | until = now.Add(DefaultFactTTL) 25 | } 26 | if aliveLong { 27 | now = now.Add(-DefaultFactTTL * 2) 28 | } 29 | ret = ret.Update( 30 | &wgtypes.Peer{ 31 | LastHandshakeTime: handshake, 32 | Endpoint: testutils.RandUDP4Addr(t), 33 | }, 34 | "", 35 | alive, 36 | until, 37 | nil, 38 | now, 39 | nil, 40 | false, 41 | ) 42 | // make sure it worked 43 | assert.Equal(t, healthy, ret.IsHealthy()) 44 | assert.Equal(t, alive, ret.IsAlive()) 45 | return ret 46 | } 47 | 48 | type configBuilder config.Server 49 | 50 | func buildConfig(name string) *configBuilder { 51 | ret := &configBuilder{} 52 | ret.Iface = name 53 | return ret 54 | } 55 | 56 | func (c *configBuilder) withPeer(key wgtypes.Key, peer *config.Peer) *configBuilder { 57 | if c == nil { 58 | c = &configBuilder{} 59 | } 60 | if c.Peers == nil { 61 | c.Peers = make(config.Peers) 62 | } 63 | c.Peers[key] = peer 64 | return c 65 | } 66 | 67 | func (c *configBuilder) Build() *config.Server { 68 | return (*config.Server)(c) 69 | } 70 | 71 | func applyMask(ipn net.IPNet) net.IPNet { 72 | return net.IPNet{ 73 | IP: ipn.IP.Mask(ipn.Mask), 74 | Mask: ipn.Mask, 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /signing/doc.go: -------------------------------------------------------------------------------- 1 | // Package signing provides code for signing and verifying signatures using the 2 | // XChaCha20-Poly1305-Curve25519 construction. 3 | package signing 4 | -------------------------------------------------------------------------------- /signing/sign.go: -------------------------------------------------------------------------------- 1 | package signing 2 | 3 | import ( 4 | "crypto/rand" 5 | "fmt" 6 | 7 | "golang.org/x/crypto/chacha20poly1305" 8 | 9 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 10 | ) 11 | 12 | // SignFor makes a signature to send data to a given peer 13 | func (s *Signer) SignFor( 14 | data []byte, 15 | peer *wgtypes.Key, 16 | ) ( 17 | nonce [chacha20poly1305.NonceSizeX]byte, 18 | tag [chacha20poly1305.Overhead]byte, 19 | err error, 20 | ) { 21 | sk, err := s.sharedKey(peer) 22 | if err != nil { 23 | return 24 | } 25 | if _, err = rand.Read(nonce[:]); err != nil { 26 | return 27 | } 28 | cipher, err := chacha20poly1305.NewX(sk[:]) 29 | if err != nil { 30 | return 31 | } 32 | out := cipher.Seal(nil, nonce[:], nil, data) 33 | if len(out) != len(tag) { 34 | err = fmt.Errorf("unexpected output length %d from AEAD, expected %d", len(out), len(tag)) 35 | return 36 | } 37 | copy(tag[:], out[:]) 38 | return 39 | } 40 | -------------------------------------------------------------------------------- /signing/signer.go: -------------------------------------------------------------------------------- 1 | package signing 2 | 3 | import ( 4 | "golang.org/x/crypto/curve25519" 5 | 6 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 7 | ) 8 | 9 | // Signer represents a helper that does signing and verification 10 | type Signer struct { 11 | privateKey wgtypes.Key 12 | PublicKey wgtypes.Key 13 | } 14 | 15 | // New creates a new Signer using the given private key 16 | func New(privateKey wgtypes.Key) *Signer { 17 | return &Signer{ 18 | privateKey: privateKey, 19 | PublicKey: privateKey.PublicKey(), 20 | } 21 | } 22 | 23 | func (s *Signer) sharedKey(peer *wgtypes.Key) ([]byte, error) { 24 | return curve25519.X25519(s.privateKey[:], peer[:]) 25 | } 26 | -------------------------------------------------------------------------------- /signing/signer_test.go: -------------------------------------------------------------------------------- 1 | package signing 2 | 3 | import ( 4 | "crypto/rand" 5 | "testing" 6 | 7 | "github.com/fastcat/wirelink/internal/testutils" 8 | "github.com/stretchr/testify/assert" 9 | 10 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 11 | 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestSignAndVerify(t *testing.T) { 16 | key1, pubkey1 := testutils.MustKeyPair(t) 17 | key2, pubkey2 := testutils.MustKeyPair(t) 18 | 19 | signer1 := New(key1) 20 | signer2 := New(key2) 21 | 22 | sk1, err := signer1.sharedKey(&pubkey2) 23 | require.NoError(t, err) 24 | sk2, err := signer2.sharedKey(&pubkey1) 25 | require.NoError(t, err) 26 | assert.Equal(t, sk1, sk2, "Shared keys should compute equal") 27 | 28 | // this is a "random" value 29 | const dataLen = 87 30 | data := make([]byte, dataLen) 31 | _, err = rand.Read(data) 32 | if err != nil { 33 | t.Fatal("Unable to generate random data") 34 | } 35 | 36 | nonce, tag, err := signer1.SignFor(data, &pubkey2) 37 | if err != nil { 38 | t.Errorf("Failed to sign: %v", err) 39 | } 40 | 41 | valid, err := signer2.VerifyFrom(nonce, tag, data, &pubkey1) 42 | if err != nil { 43 | t.Errorf("Failed to verify: %v", err) 44 | } 45 | if !valid { 46 | t.Error("Signed data didn't validate") 47 | } 48 | 49 | // verify it fails if we muck with a byte 50 | tag[12]++ 51 | 52 | valid, err = signer2.VerifyFrom(nonce, tag, data, &pubkey1) 53 | if valid || err == nil { 54 | t.Errorf("Incorrectly validated corrupted data") 55 | } 56 | } 57 | 58 | func TestSignErrors(t *testing.T) { 59 | var badKey wgtypes.Key 60 | var badPub wgtypes.Key 61 | goodKey, _ := testutils.MustKeyPair(t) 62 | signer := New(badKey) 63 | _, err := signer.sharedKey(&badPub) 64 | assert.Error(t, err) 65 | signer = New(goodKey) 66 | _, err = signer.sharedKey(&badPub) 67 | assert.Error(t, err) 68 | } 69 | -------------------------------------------------------------------------------- /signing/verify.go: -------------------------------------------------------------------------------- 1 | package signing 2 | 3 | import ( 4 | "crypto/cipher" 5 | 6 | "golang.org/x/crypto/chacha20poly1305" 7 | 8 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 9 | ) 10 | 11 | // VerifyFrom checks the signature on received data from a given peer 12 | func (s *Signer) VerifyFrom( 13 | nonce [chacha20poly1305.NonceSizeX]byte, 14 | tag [chacha20poly1305.Overhead]byte, 15 | data []byte, 16 | peer *wgtypes.Key, 17 | ) ( 18 | valid bool, 19 | err error, 20 | ) { 21 | sk, err := s.sharedKey(peer) 22 | if err != nil { 23 | return 24 | } 25 | var cipher cipher.AEAD 26 | cipher, err = chacha20poly1305.NewX(sk[:]) 27 | if err != nil { 28 | return 29 | } 30 | _, err = cipher.Open(nil, nonce[:], tag[:], data) 31 | if err != nil { 32 | return false, err 33 | } 34 | valid = true 35 | return 36 | } 37 | -------------------------------------------------------------------------------- /trust/.gitignore: -------------------------------------------------------------------------------- 1 | mock_*.go 2 | -------------------------------------------------------------------------------- /trust/composite-trust.go: -------------------------------------------------------------------------------- 1 | package trust 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | 7 | "github.com/fastcat/wirelink/fact" 8 | ) 9 | 10 | // CompositeMode is an enum for how a composite evaluator combines the results 11 | // of its member evaluators 12 | type CompositeMode int 13 | 14 | const ( 15 | // FirstOnly composites return the trust level from the first evaluator that 16 | // knows the subject 17 | FirstOnly CompositeMode = iota 18 | // LeastPermission composites return the lowest trust level from the evaluators 19 | // that know the subject 20 | LeastPermission 21 | // MostPermission composites return the highest trust level from the evaluators 22 | // that known the subject 23 | MostPermission 24 | ) 25 | 26 | func (cm CompositeMode) String() string { 27 | switch cm { 28 | case FirstOnly: 29 | return "FirstOnly" 30 | case LeastPermission: 31 | return "LeastPermission" 32 | case MostPermission: 33 | return "MostPermission" 34 | default: 35 | return fmt.Sprintf("UNKNOWN(%d)", cm) 36 | } 37 | } 38 | 39 | // CreateComposite generates an evaluator which combines the results of others 40 | // using the specified mode 41 | func CreateComposite(mode CompositeMode, evaluators ...Evaluator) Evaluator { 42 | return &composite{ 43 | mode: mode, 44 | inner: evaluators, 45 | } 46 | } 47 | 48 | type composite struct { 49 | mode CompositeMode 50 | inner []Evaluator 51 | } 52 | 53 | // *composite should implement Evaluator 54 | var _ Evaluator = &composite{} 55 | 56 | func (c *composite) IsKnown(subject fact.Subject) bool { 57 | for _, e := range c.inner { 58 | if e.IsKnown(subject) { 59 | return true 60 | } 61 | } 62 | return false 63 | } 64 | 65 | func (c *composite) TrustLevel(fact *fact.Fact, source net.UDPAddr) (ret *Level) { 66 | for _, e := range c.inner { 67 | // IsKnown is orthogonal to TrustLevel, don't check it here 68 | l := e.TrustLevel(fact, source) 69 | if l == nil { 70 | continue 71 | } else if c.mode == FirstOnly { 72 | return l 73 | } else if c.mode == LeastPermission && (ret == nil || *l < *ret) { 74 | ret = l 75 | } else if c.mode == MostPermission && (ret == nil || *l > *ret) { 76 | ret = l 77 | } 78 | } 79 | return 80 | } 81 | -------------------------------------------------------------------------------- /trust/known-peer-trust.go: -------------------------------------------------------------------------------- 1 | package trust 2 | 3 | import ( 4 | "bytes" 5 | "net" 6 | 7 | "github.com/fastcat/wirelink/autopeer" 8 | "github.com/fastcat/wirelink/fact" 9 | "github.com/fastcat/wirelink/util" 10 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 11 | ) 12 | 13 | // CreateKnownPeerTrust creates a trust Evaluator for the given set of peers, 14 | // where a known peer is allowed to tell us Endpoint facts, but not register new 15 | // peers. 16 | func CreateKnownPeerTrust(peers []wgtypes.Peer) Evaluator { 17 | // TODO: share some of the data structures here with RouteBasedTrust 18 | ret := knownPeerTrust{ 19 | peersByIP: make(map[[net.IPv6len]byte]*peerWithAddr), 20 | peersByKey: make(map[wgtypes.Key]*peerWithAddr), 21 | } 22 | for i := range peers { 23 | a := autopeer.AutoAddress(peers[i].PublicKey) 24 | // need to take the address of the array element not a local var 25 | pwa := peerWithAddr{ 26 | peer: &peers[i], 27 | ip: a, 28 | } 29 | ret.peersByIP[util.IPToBytes(a)] = &pwa 30 | ret.peersByKey[peers[i].PublicKey] = &pwa 31 | } 32 | return &ret 33 | } 34 | 35 | type knownPeerTrust struct { 36 | peersByIP map[[net.IPv6len]byte]*peerWithAddr 37 | peersByKey map[wgtypes.Key]*peerWithAddr 38 | } 39 | 40 | // *routeBasedTrust should implement TrustEvaluator 41 | var _ Evaluator = &routeBasedTrust{} 42 | 43 | func (rbt *knownPeerTrust) TrustLevel(f *fact.Fact, source net.UDPAddr) *Level { 44 | ps, ok := f.Subject.(*fact.PeerSubject) 45 | // we only look at PeerSubject facts for this model 46 | if !ok { 47 | return nil 48 | } 49 | 50 | peer, ok := rbt.peersByIP[util.IPToBytes(source.IP)] 51 | if !ok { 52 | // strangely unrecognized, suggests a router peer is forwarding packets from 53 | // other peers' IPv6-LL address to us 54 | return nil 55 | } 56 | 57 | // peer is trusted to tell us its own endpoints, but not to tell us its 58 | // AllowedIPs 59 | if bytes.Equal(ps.Key[:], peer.peer.PublicKey[:]) { 60 | ret := Endpoint 61 | return &ret 62 | } 63 | 64 | // actually, known peers are allowed to tell us endpoints for _any_ known peer 65 | ret := Endpoint 66 | return &ret 67 | } 68 | 69 | // IsKnown returns whether the subject is known to us, i.e. whether the peer 70 | // is locally known 71 | func (rbt *knownPeerTrust) IsKnown(s fact.Subject) bool { 72 | ps, ok := s.(*fact.PeerSubject) 73 | // we only look at PeerSubject facts for this model 74 | if !ok { 75 | return false 76 | } 77 | _, ok = rbt.peersByKey[ps.Key] 78 | return ok 79 | } 80 | -------------------------------------------------------------------------------- /trust/known-peer-trust_test.go: -------------------------------------------------------------------------------- 1 | package trust 2 | 3 | import ( 4 | "math/rand" 5 | "net" 6 | "testing" 7 | 8 | "github.com/fastcat/wirelink/autopeer" 9 | "github.com/fastcat/wirelink/fact" 10 | "github.com/fastcat/wirelink/internal/testutils" 11 | "github.com/stretchr/testify/assert" 12 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 13 | ) 14 | 15 | func Test_knownPeerTrust_TrustLevel(t *testing.T) { 16 | k1 := testutils.MustKey(t) 17 | k2 := testutils.MustKey(t) 18 | k3 := testutils.MustKey(t) 19 | 20 | peerAddr := func(k wgtypes.Key) net.UDPAddr { 21 | return net.UDPAddr{IP: autopeer.AutoAddress(k), Port: rand.Intn(65535)} 22 | } 23 | factAbout := func(k wgtypes.Key) *fact.Fact { 24 | return &fact.Fact{ 25 | Subject: &fact.PeerSubject{Key: k}, 26 | } 27 | } 28 | peerList := func(ks ...wgtypes.Key) []wgtypes.Peer { 29 | ret := make([]wgtypes.Peer, len(ks)) 30 | for i, k := range ks { 31 | ret[i] = wgtypes.Peer{PublicKey: k} 32 | } 33 | return ret 34 | } 35 | 36 | // type fields struct { 37 | // peersByIP map[[net.IPv6len]byte]*peerWithAddr 38 | // peersByKey map[wgtypes.Key]*peerWithAddr 39 | // } 40 | type args struct { 41 | f *fact.Fact 42 | source net.UDPAddr 43 | } 44 | mkArgs := func(subject, source wgtypes.Key) args { 45 | return args{ 46 | factAbout(subject), 47 | peerAddr(source), 48 | } 49 | } 50 | tests := []struct { 51 | name string 52 | // fields fields 53 | peers []wgtypes.Peer 54 | args args 55 | want *Level 56 | }{ 57 | { 58 | "unknown", 59 | peerList(k1), 60 | mkArgs(k2, k3), 61 | nil, 62 | }, 63 | { 64 | "known leaf self info", 65 | peerList(k1), 66 | mkArgs(k1, k1), 67 | Ptr(Endpoint), 68 | }, 69 | { 70 | "known leaf other info", 71 | peerList(k1), 72 | mkArgs(k2, k1), 73 | Ptr(Endpoint), 74 | }, 75 | } 76 | for _, tt := range tests { 77 | t.Run(tt.name, func(t *testing.T) { 78 | // rbt := &knownPeerTrust{ 79 | // peersByIP: tt.fields.peersByIP, 80 | // peersByKey: tt.fields.peersByKey, 81 | // } 82 | rbt := CreateKnownPeerTrust(tt.peers) 83 | got := rbt.TrustLevel(tt.args.f, tt.args.source) 84 | assert.Equal(t, tt.want, got, "want %v got %v", tt.want, got) 85 | }) 86 | } 87 | } 88 | 89 | func Test_knownPeerTrust_IsKnown(t *testing.T) { 90 | k1 := testutils.MustKey(t) 91 | k2 := testutils.MustKey(t) 92 | 93 | ksa := func(k wgtypes.Key) struct{ s fact.Subject } { 94 | return struct{ s fact.Subject }{&fact.PeerSubject{Key: k}} 95 | } 96 | 97 | // type fields struct { 98 | // peersByIP map[[net.IPv6len]byte]*peerWithAddr 99 | // peersByKey map[wgtypes.Key]*peerWithAddr 100 | // } 101 | type args struct { 102 | s fact.Subject 103 | } 104 | tests := []struct { 105 | name string 106 | // fields fields 107 | peers []wgtypes.Peer 108 | args args 109 | want bool 110 | }{ 111 | {"known", []wgtypes.Peer{{PublicKey: k1}}, ksa(k1), true}, 112 | {"unknown", []wgtypes.Peer{{PublicKey: k1}}, ksa(k2), false}, 113 | } 114 | for _, tt := range tests { 115 | t.Run(tt.name, func(t *testing.T) { 116 | // rbt := &knownPeerTrust{ 117 | // peersByIP: tt.fields.peersByIP, 118 | // peersByKey: tt.fields.peersByKey, 119 | // } 120 | rbt := CreateKnownPeerTrust(tt.peers) 121 | got := rbt.IsKnown(tt.args.s) 122 | assert.Equal(t, tt.want, got) 123 | }) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /trust/route-based-trust.go: -------------------------------------------------------------------------------- 1 | package trust 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/fastcat/wirelink/autopeer" 7 | "github.com/fastcat/wirelink/detect" 8 | "github.com/fastcat/wirelink/fact" 9 | "github.com/fastcat/wirelink/util" 10 | 11 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 12 | ) 13 | 14 | // CreateRouteBasedTrust creates a trust Evaluator for the given set of peers, 15 | // using the "routers are trusted" model, wherein "routers" (peers with an 16 | // AllowedIP whose CIDR mask is shorter than the IP length) are allowed to 17 | // provide AllowedIPs for other peers. 18 | func CreateRouteBasedTrust(peers []wgtypes.Peer) Evaluator { 19 | ret := routeBasedTrust{ 20 | peersByIP: make(map[[net.IPv6len]byte]*peerWithAddr), 21 | peersByKey: make(map[wgtypes.Key]*peerWithAddr), 22 | } 23 | for i := range peers { 24 | a := autopeer.AutoAddress(peers[i].PublicKey) 25 | // need to take the address of the array element not a local var 26 | pwa := peerWithAddr{ 27 | peer: &peers[i], 28 | ip: a, 29 | } 30 | ret.peersByIP[util.IPToBytes(a)] = &pwa 31 | ret.peersByKey[peers[i].PublicKey] = &pwa 32 | } 33 | return &ret 34 | } 35 | 36 | type peerWithAddr struct { 37 | peer *wgtypes.Peer 38 | ip net.IP 39 | } 40 | 41 | type routeBasedTrust struct { 42 | peersByIP map[[net.IPv6len]byte]*peerWithAddr 43 | peersByKey map[wgtypes.Key]*peerWithAddr 44 | } 45 | 46 | // *routeBasedTrust should implement TrustEvaluator 47 | var _ Evaluator = &routeBasedTrust{} 48 | 49 | func (rbt *routeBasedTrust) TrustLevel(f *fact.Fact, source net.UDPAddr) *Level { 50 | _, ok := f.Subject.(*fact.PeerSubject) 51 | // we only look at PeerSubject facts for this model 52 | if !ok { 53 | return nil 54 | } 55 | 56 | peer, ok := rbt.peersByIP[util.IPToBytes(source.IP)] 57 | if !ok { 58 | // strangely unrecognized, suggests a router peer is forwarding packets from 59 | // other peers' IPv6-LL address to us 60 | return nil 61 | } 62 | 63 | // peers that are allowed to route traffic are trusted to tell us about 64 | // any other peer, as they are inferred to control the network this has 65 | // to come before the peer self check, because the routers need to be 66 | // permitted to tell us their own AllowedIPs, not just others. 67 | // re-evaluating this each time instead of caching it once at startup is 68 | // intentional as peer AIPs that drive this can change 69 | if detect.IsPeerRouter(peer.peer) { 70 | ret := Membership 71 | return &ret 72 | } 73 | 74 | // anything else, delegate to the next trust evaluator in the chain 75 | return nil 76 | } 77 | 78 | // IsKnown returns whether the subject is known to us, i.e. whether the peer 79 | // is locally known 80 | func (rbt *routeBasedTrust) IsKnown(s fact.Subject) bool { 81 | ps, ok := s.(*fact.PeerSubject) 82 | // we only look at PeerSubject facts for this model 83 | if !ok { 84 | return false 85 | } 86 | _, ok = rbt.peersByKey[ps.Key] 87 | return ok 88 | } 89 | -------------------------------------------------------------------------------- /trust/trust_test.go: -------------------------------------------------------------------------------- 1 | package trust 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/fastcat/wirelink/fact" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestShouldAccept(t *testing.T) { 13 | type args struct { 14 | attr fact.Attribute 15 | known bool 16 | level *Level 17 | } 18 | type test struct { 19 | name string 20 | args args 21 | want bool 22 | } 23 | 24 | create := func(name string, attr fact.Attribute, known bool, level Level, want bool) test { 25 | name = fmt.Sprintf("%s(%c,%v,%v)=%v", name, attr, known, level, want) 26 | return test{name, args{attr, known, &level}, want} 27 | } 28 | matrix := func(name string, attrs []fact.Attribute, known bool, levels []Level, want bool) []test { 29 | ret := make([]test, 0, len(attrs)*len(levels)) 30 | for _, attr := range attrs { 31 | for _, level := range levels { 32 | ret = append(ret, create(name, attr, known, level, want)) 33 | } 34 | } 35 | return ret 36 | } 37 | 38 | validAttrs := []fact.Attribute{ 39 | fact.AttributeEndpointV4, 40 | fact.AttributeEndpointV6, 41 | fact.AttributeAllowedCidrV4, 42 | fact.AttributeAllowedCidrV6, 43 | fact.AttributeMember, 44 | fact.AttributeMemberMetadata, 45 | } 46 | invalidAttrs := []fact.Attribute{ 47 | fact.AttributeUnknown, 48 | // alive doesn't go through trust 49 | fact.AttributeAlive, 50 | // signed group is a transport structure and never directly evaluated for trust 51 | fact.AttributeSignedGroup, 52 | } 53 | epAttrs := []fact.Attribute{ 54 | fact.AttributeEndpointV4, 55 | fact.AttributeEndpointV6, 56 | } 57 | aipAttrs := []fact.Attribute{ 58 | fact.AttributeAllowedCidrV4, 59 | fact.AttributeAllowedCidrV6, 60 | } 61 | memberAttr := []fact.Attribute{ 62 | fact.AttributeMember, 63 | fact.AttributeMemberMetadata, 64 | } 65 | allLevels := []Level{Untrusted, Endpoint, AllowedIPs, Membership, DelegateTrust} 66 | 67 | tests := []test{ 68 | {"nil trust", args{fact.AttributeAlive, true, nil}, false}, 69 | } 70 | tests = append(tests, matrix("gigo", invalidAttrs, false, allLevels, false)...) 71 | tests = append(tests, matrix("gigo", invalidAttrs, true, allLevels, false)...) 72 | tests = append(tests, matrix("new peer", validAttrs, false, []Level{Untrusted, Endpoint, AllowedIPs}, false)...) 73 | tests = append(tests, matrix("new peer", validAttrs, false, []Level{Membership, DelegateTrust}, true)...) 74 | tests = append(tests, matrix("ep", epAttrs, true, []Level{Untrusted}, false)...) 75 | tests = append(tests, matrix("ep", epAttrs, true, []Level{Endpoint, AllowedIPs, Membership, DelegateTrust}, true)...) 76 | tests = append(tests, matrix("aip", aipAttrs, true, []Level{Untrusted, Endpoint}, false)...) 77 | tests = append(tests, matrix("aip", aipAttrs, true, []Level{AllowedIPs, Membership, DelegateTrust}, true)...) 78 | tests = append(tests, matrix("member", memberAttr, true, []Level{Untrusted, Endpoint, AllowedIPs}, false)...) 79 | tests = append(tests, matrix("member", memberAttr, true, []Level{Membership, DelegateTrust}, true)...) 80 | 81 | for _, tt := range tests { 82 | t.Run(tt.name, func(t *testing.T) { 83 | got := ShouldAccept(tt.args.attr, tt.args.known, tt.args.level) 84 | assert.Equal(t, tt.want, got) 85 | }) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /util/bytereader.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "io" 4 | 5 | // ByteReader combines io.Reader and io.ByteReader 6 | type ByteReader interface { 7 | io.Reader 8 | io.ByteReader 9 | } 10 | -------------------------------------------------------------------------------- /util/decodable.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bytes" 5 | "encoding" 6 | "fmt" 7 | "io" 8 | ) 9 | 10 | // Decodable is an interface that mimics BinaryUnmarshaller, but sources from 11 | // an io.Reader instead of a slice 12 | type Decodable interface { 13 | // DecodeFrom reads just enough bytes from the reader to deserialize itself 14 | DecodeFrom(lengthHint int, reader io.Reader) error 15 | } 16 | 17 | // DecodeFrom provides an equivalent function to Decodable.DecodeFrom, but for 18 | // types that implement BinaryUnmarshaler and which have a fixed known length, 19 | // e.g. to provide a default implementation for Decodable for such types 20 | func DecodeFrom(value encoding.BinaryUnmarshaler, readLen int, reader io.Reader) error { 21 | var data []byte 22 | switch r := reader.(type) { 23 | case *bytes.Buffer: 24 | data = r.Next(readLen) 25 | if len(data) != readLen { 26 | return fmt.Errorf("unable to read %T: only %d of %d bytes available", value, readLen, len(data)) 27 | } 28 | default: 29 | data = make([]byte, readLen) 30 | n, err := io.ReadFull(r, data) 31 | if err != nil { 32 | return fmt.Errorf("unable to read %T (read %d of %d bytes): %w", value, n, readLen, err) 33 | } 34 | } 35 | err := value.UnmarshalBinary(data) 36 | if err != nil { 37 | return fmt.Errorf("unable to read %T: unmarshal failed: %w", value, err) 38 | } 39 | return nil 40 | } 41 | -------------------------------------------------------------------------------- /util/decodable_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | "math/rand" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | type unmarshal struct { 15 | data []byte 16 | err error 17 | } 18 | 19 | func (u *unmarshal) UnmarshalBinary(data []byte) error { 20 | u.data = data 21 | return u.err 22 | } 23 | 24 | // wrapper around a bytes buffer to prevent identification as one 25 | type alternateBuffer struct { 26 | bytes.Buffer 27 | err error 28 | } 29 | 30 | func (a *alternateBuffer) Read(p []byte) (int, error) { 31 | if a.err != nil { 32 | return 0, a.err 33 | } 34 | return a.Buffer.Read(p) 35 | } 36 | 37 | var seedOffset int64 38 | 39 | func bytesFromSeed(seed int64, len int) []byte { 40 | if seedOffset == 0 { 41 | seedOffset = rand.Int63() 42 | } 43 | ret := make([]byte, len) 44 | r := rand.New(rand.NewSource(seed + seedOffset)) 45 | // rand.Read never fails, no need to check returns 46 | r.Read(ret) 47 | return ret 48 | } 49 | 50 | func bufFromSeed(seed int64, len int) *bytes.Buffer { 51 | return bytes.NewBuffer(bytesFromSeed(seed, len)) 52 | } 53 | 54 | func altFromSeed(seed int64, len int) *alternateBuffer { 55 | return &alternateBuffer{ 56 | Buffer: *bufFromSeed(seed, len), 57 | } 58 | } 59 | 60 | func errFromSeed(seed int64, len int) *alternateBuffer { 61 | return &alternateBuffer{ 62 | Buffer: *bufFromSeed(seed, len), 63 | err: errors.New("Mock buffer read error"), 64 | } 65 | } 66 | 67 | func TestDecodeFrom(t *testing.T) { 68 | readLen := 16 + rand.Intn(16) 69 | type args struct { 70 | value *unmarshal 71 | readLen int 72 | reader io.Reader 73 | } 74 | tests := []struct { 75 | name string 76 | args args 77 | wantValue []byte 78 | wantErr bool 79 | }{ 80 | { 81 | "simple read from byte buffer", 82 | args{ 83 | value: &unmarshal{}, 84 | readLen: readLen, 85 | reader: bufFromSeed(1, readLen), 86 | }, 87 | bytesFromSeed(1, readLen), 88 | false, 89 | }, 90 | { 91 | "simple read from reader", 92 | args{ 93 | value: &unmarshal{}, 94 | readLen: readLen, 95 | reader: altFromSeed(2, readLen), 96 | }, 97 | bytesFromSeed(2, readLen), 98 | false, 99 | }, 100 | { 101 | "read error", 102 | args{ 103 | value: &unmarshal{}, 104 | readLen: readLen, 105 | reader: errFromSeed(3, readLen), 106 | }, 107 | nil, 108 | true, 109 | }, 110 | { 111 | "short read", 112 | args{ 113 | value: &unmarshal{}, 114 | readLen: readLen, 115 | reader: bufFromSeed(4, readLen-1), 116 | }, 117 | nil, 118 | true, 119 | }, 120 | { 121 | "unmarshal error", 122 | args{ 123 | value: &unmarshal{nil, errors.New("Mock fail unmarshal")}, 124 | readLen: readLen, 125 | reader: bufFromSeed(5, readLen), 126 | }, 127 | bytesFromSeed(5, readLen), 128 | true, 129 | }, 130 | } 131 | for _, tt := range tests { 132 | t.Run(tt.name, func(t *testing.T) { 133 | err := DecodeFrom(tt.args.value, tt.args.readLen, tt.args.reader) 134 | if tt.wantErr { 135 | require.NotNil(t, err, "DecodeFrom() error") 136 | } else { 137 | require.Nil(t, err, "DecodeFrom() error") 138 | } 139 | assert.Equal(t, tt.args.value.data, tt.wantValue) 140 | }) 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /util/doc.go: -------------------------------------------------------------------------------- 1 | // Package util provides common utility functions used by multiple portions of 2 | // wirelink, primarily providing wrappers and helpers for core go libraries 3 | // and constructs. 4 | package util 5 | -------------------------------------------------------------------------------- /util/errors.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // WrapOrNewf calls fmt.Errorf with varying format depending on whether err is 8 | // nil, always returning an error value 9 | func WrapOrNewf(err error, format string, args ...interface{}) error { 10 | if err == nil { 11 | return fmt.Errorf(format, args...) 12 | } 13 | args = append(args, err) 14 | return fmt.Errorf(format+": %w", args...) 15 | } 16 | -------------------------------------------------------------------------------- /util/errors_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func stringPtr(value string) *string { 12 | return &value 13 | } 14 | 15 | func TestWrapOrNewf(t *testing.T) { 16 | type args struct { 17 | err error 18 | format string 19 | args []interface{} 20 | } 21 | tests := []struct { 22 | name string 23 | args args 24 | wantError string 25 | wantUnwrap *string 26 | }{ 27 | { 28 | "not wrapping", 29 | args{ 30 | nil, 31 | "not wrapping", 32 | nil, 33 | }, 34 | "not wrapping", 35 | nil, 36 | }, 37 | { 38 | "simple wrapping", 39 | args{ 40 | errors.New("inner"), 41 | "simple wrapping", 42 | nil, 43 | }, 44 | "simple wrapping: inner", 45 | stringPtr("inner"), 46 | }, 47 | } 48 | for _, tt := range tests { 49 | t.Run(tt.name, func(t *testing.T) { 50 | err := WrapOrNewf(tt.args.err, tt.args.format, tt.args.args...) 51 | require.NotNil(t, err) 52 | assert.Equal(t, tt.wantError, err.Error()) 53 | cause := errors.Unwrap(err) 54 | // there may be multiple layers, dig and find the bottom 55 | for next := errors.Unwrap(cause); next != nil; next = errors.Unwrap(cause) { 56 | cause = next 57 | } 58 | if tt.wantUnwrap == nil { 59 | assert.Nil(t, cause) 60 | } else { 61 | require.NotNil(t, cause) 62 | assert.Equal(t, *tt.wantUnwrap, cause.Error()) 63 | } 64 | }) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /util/ip.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "net" 5 | ) 6 | 7 | // NormalizeIP returns a version of the given ip normalized to its underlying 8 | // family, instead of the "always in IPv6 container" format that is often used, 9 | // so IPv4 values will have a length of 4 and IPv6 ones a length of 16 10 | func NormalizeIP(ip net.IP) net.IP { 11 | n := ip.To4() 12 | if n == nil { 13 | n = ip.To16() 14 | } 15 | return n 16 | } 17 | 18 | // IsIPv6LLMatch checks if a given expected IPv6 address matches an actual 19 | // address + mask, checking if the mask is of the expected form. 20 | // The mask is expected to be /128 if local is false, or /64 if it is true 21 | func IsIPv6LLMatch(expected net.IP, actual *net.IPNet, local bool) bool { 22 | expectedOnes := 8 * net.IPv6len 23 | if local { 24 | expectedOnes = 4 * net.IPv6len 25 | } 26 | ones, bits := actual.Mask.Size() 27 | return ones == expectedOnes && bits == 8*net.IPv6len && expected.Equal(actual.IP) 28 | } 29 | 30 | // IPToBytes returns the given IP normalized to a 16 byte array, 31 | // suitable for use as a map key among other things 32 | func IPToBytes(ip net.IP) (ret [net.IPv6len]byte) { 33 | ip = ip.To16() 34 | copy(ret[:], ip) 35 | return 36 | } 37 | 38 | // IsGloballyRoutable checks if an IP address looks routable across the internet 39 | // or not. It will return false for any IP that is not a Global Unicast address, 40 | // and also for certain special reserved subnets that are used within site-level 41 | // domains but are not meant to be routed on the internet. 42 | func IsGloballyRoutable(ip net.IP) bool { 43 | if !ip.IsGlobalUnicast() { 44 | return false 45 | } 46 | if ip.IsPrivate() { 47 | return false 48 | } 49 | if ip4 := ip.To4(); ip4 != nil { 50 | // ignore the CG-NAT subnet 100.64.0.0/10 (https://tools.ietf.org/html/rfc6598) 51 | if ip4[0] == 0x64 && ip4[1] >= 0x40 && ip4[1] <= 0x7f { 52 | return false 53 | } 54 | // TODO: more ranges? 55 | } 56 | if ip6 := ip.To16(); ip6 != nil { 57 | // fc00::/7 ~ unique local addr ipv6 equivalent of 10/8-ish 58 | if ip6[0] == 0xfc || ip6[0] == 0xfd { 59 | return false 60 | } 61 | // fec0::/10 ~ deprecated site-local 62 | if ip6[0] == 0xfe && ip6[1] >= 0xc0 { 63 | return false 64 | } 65 | } 66 | 67 | // TODO: more ipv6 reserved ranges 68 | // https://www.iana.org/assignments/ipv6-address-space/ipv6-address-space.xhtml 69 | return true 70 | } 71 | -------------------------------------------------------------------------------- /util/must.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | // MustBytes is a helper, esp. for BinaryMarshaller, that takes a tuple of 4 | // a byte slice and maybe an error and panics if there is an error, or else 5 | // returns the byte slice 6 | func MustBytes(value []byte, err error) []byte { 7 | if err != nil { 8 | panic(err) 9 | } 10 | return value 11 | } 12 | 13 | // MustByte is a helper, esp. for ByteReader, that takes a tuple of 14 | // a byte and maybe an error and panics if there is an error, or else 15 | // returns the byte 16 | func MustByte(value byte, err error) byte { 17 | if err != nil { 18 | panic(err) 19 | } 20 | return value 21 | } 22 | 23 | // MustInt64 is a helper, esp. for encoding.binary, that takes a tuple of 24 | // an int64 and maybe an error and panics if there is an error, or else 25 | // returns the int64 26 | func MustInt64(value int64, err error) int64 { 27 | if err != nil { 28 | panic(err) 29 | } 30 | return value 31 | } 32 | -------------------------------------------------------------------------------- /util/net.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "net" 5 | "sort" 6 | "strings" 7 | ) 8 | 9 | // UDPEqualIPPort checks if to UDPAddrs are equal in terms of their IP and Port 10 | // fields, but ignoring any Zone value 11 | func UDPEqualIPPort(a, b *net.UDPAddr) bool { 12 | if a == nil && b == nil { 13 | return true 14 | } 15 | if a == nil || b == nil { 16 | return false 17 | } 18 | return a.IP.Equal(b.IP) && a.Port == b.Port 19 | } 20 | 21 | // SortIPNetSlice sorts a slice of IPNets by their string value, returning the 22 | // (modified in place) slice. 23 | // OMG want generics. 24 | func SortIPNetSlice(slice []net.IPNet) []net.IPNet { 25 | sort.Slice(slice, func(i, j int) bool { 26 | return strings.Compare(slice[i].String(), slice[j].String()) < 0 27 | }) 28 | return slice 29 | } 30 | 31 | // CloneIPNet makes a deep copy of the given value 32 | func CloneIPNet(ipn net.IPNet) net.IPNet { 33 | var ret net.IPNet 34 | ret.IP = make(net.IP, len(ipn.IP)) 35 | copy(ret.IP, ipn.IP) 36 | ret.Mask = make(net.IPMask, len(ipn.Mask)) 37 | copy(ret.Mask, ipn.Mask) 38 | return ret 39 | } 40 | 41 | // CloneUDPAddr makes a deep copy of the given address 42 | func CloneUDPAddr(addr *net.UDPAddr) *net.UDPAddr { 43 | if addr == nil { 44 | return nil 45 | } 46 | 47 | ret := &net.UDPAddr{ 48 | Port: addr.Port, 49 | Zone: addr.Zone, 50 | } 51 | ret.IP = make(net.IP, len(addr.IP)) 52 | copy(ret.IP, addr.IP) 53 | return ret 54 | } 55 | -------------------------------------------------------------------------------- /util/slice.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | // CloneBytes returns a new copy of the input data 4 | func CloneBytes(data []byte) []byte { 5 | if data == nil { 6 | return nil 7 | } 8 | ret := make([]byte, len(data)) 9 | copy(ret, data) 10 | return ret 11 | } 12 | -------------------------------------------------------------------------------- /util/ternary.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | // Ternary turns a trivial if/else into a function call 4 | func Ternary(value bool, trueResult, falseResult interface{}) interface{} { 5 | if value { 6 | return trueResult 7 | } 8 | return falseResult 9 | } 10 | -------------------------------------------------------------------------------- /util/ternary_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestTernary(t *testing.T) { 10 | type args struct { 11 | value bool 12 | trueResult interface{} 13 | falseResult interface{} 14 | } 15 | tests := []struct { 16 | name string 17 | args args 18 | want interface{} 19 | }{ 20 | { 21 | "true", 22 | args{true, true, false}, 23 | true, 24 | }, 25 | { 26 | "false", 27 | args{false, true, false}, 28 | false, 29 | }, 30 | } 31 | for _, tt := range tests { 32 | t.Run(tt.name, func(t *testing.T) { 33 | assert.Equal(t, tt.want, Ternary(tt.args.value, tt.args.trueResult, tt.args.falseResult)) 34 | }) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /util/time.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "time" 4 | 5 | // TimeMax is the maximum representable time in go. 6 | // see: https://stackoverflow.com/a/32620397/7649 7 | // see also `time.go` in the runtime 8 | func TimeMax() time.Time { 9 | return time.Unix(1<<63-62135596801, 999999999) 10 | } 11 | -------------------------------------------------------------------------------- /util/time_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_TimeMaxVsNow(t *testing.T) { 11 | timeMax := TimeMax() 12 | now := time.Now() 13 | 14 | assert.True(t, timeMax.After(now)) 15 | assert.True(t, now.Before(timeMax)) 16 | assert.True(t, now.Sub(timeMax) < 255*time.Second) 17 | assert.Less(t, int64(now.Sub(timeMax)), int64(255*time.Second)) 18 | } 19 | -------------------------------------------------------------------------------- /wirelink.go: -------------------------------------------------------------------------------- 1 | // The wirelink command is the main entrypoint. 2 | package main 3 | 4 | import ( 5 | "errors" 6 | "fmt" 7 | "os" 8 | 9 | "github.com/spf13/pflag" 10 | 11 | "github.com/fastcat/wirelink/cmd" 12 | "github.com/fastcat/wirelink/internal/networking/host" 13 | ) 14 | 15 | func main() { 16 | cmd := cmd.New(os.Args) 17 | err := cmd.Init(host.MustCreateHost()) 18 | // don't print on error just because help was requested 19 | if err != nil { 20 | if !errors.Is(err, pflag.ErrHelp) { 21 | fmt.Fprintf(os.Stderr, "Fatal init error: %v\n", err) 22 | defer os.Exit(1) 23 | } 24 | return 25 | } 26 | if cmd.Server == nil { 27 | // --dump or such 28 | return 29 | } 30 | err = cmd.Run() 31 | if err != nil { 32 | fmt.Fprintf(os.Stderr, "Fatal run error: %v\n", err) 33 | defer os.Exit(1) 34 | return 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /wirelink.test.json: -------------------------------------------------------------------------------- 1 | { 2 | "peers": [ 3 | { 4 | "PublicKey": "18dCY1rwCNo4aQaWFccKFfhbqG0FGBxDkCSbyQ3awkw=", 5 | "Name": "test", 6 | "Trust": "Membership", 7 | "AllowedIPs": [ 8 | "100.64.2.0/24" 9 | ], 10 | "Endpoints": [ 11 | "100.64.1.1:1234" 12 | ] 13 | }, 14 | { 15 | "publickey": "6rxPhfwiMXp40mit6bdCytVlTCubK/qYxdo6K+liS2o=" 16 | }, 17 | { 18 | "PublicKey": "vlejBUlxwt5c42bP1Y/W0QiDr0bfZR9iH+i/aanR61s=", 19 | "Trust": null 20 | } 21 | ] 22 | } 23 | --------------------------------------------------------------------------------