├── .dockerignore ├── .github └── workflows │ ├── release.yaml │ └── wgtunnel.yaml ├── .gitignore ├── .golangci.yaml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── buildinfo └── buildinfo.go ├── cmd ├── tunnel │ ├── main.go │ ├── signal_unix.go │ └── signal_windows.go └── tunneld │ ├── main.go │ ├── signal_unix.go │ ├── signal_windows.go │ └── tracing.go ├── compose ├── .env.example ├── .gitignore ├── Makefile ├── caddy │ └── Dockerfile └── docker-compose.yml ├── go.mod ├── go.sum ├── scripts ├── check_unstaged.sh └── version.sh ├── tunneld ├── api.go ├── api_test.go ├── httpapi │ └── httpapi.go ├── httpmw │ ├── limitbody.go │ ├── limitbody_test.go │ └── ratelimit.go ├── options.go ├── options_test.go ├── tunneld.go └── tunneld_test.go └── tunnelsdk ├── api.go ├── client.go └── tunnel.go /.dockerignore: -------------------------------------------------------------------------------- 1 | # Ignore everything 2 | * 3 | 4 | # Allow the tunnel binary 5 | !/build/tunneld 6 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: release 2 | on: 3 | push: 4 | tags: 5 | - "v*" 6 | workflow_dispatch: 7 | inputs: 8 | snapshot: 9 | description: Force a dev version to be generated, implies dry_run. 10 | type: boolean 11 | required: true 12 | dry_run: 13 | description: Perform a dry-run release. 14 | type: boolean 15 | required: true 16 | 17 | permissions: 18 | # Required to publish a release 19 | contents: write 20 | # Necessary to push docker images to ghcr.io. 21 | packages: write 22 | 23 | env: 24 | WGTUNNEL_RELEASE: ${{ github.event.inputs.snapshot && 'false' || 'true' }} 25 | 26 | jobs: 27 | release: 28 | runs-on: ubuntu-latest 29 | steps: 30 | - uses: actions/checkout@v3 31 | with: 32 | fetch-depth: 0 33 | 34 | # If the event that triggered the build was an annotated tag (which our 35 | # tags are supposed to be), actions/checkout has a bug where the tag in 36 | # question is only a lightweight tag and not a full annotated tag. This 37 | # command seems to fix it. 38 | # https://github.com/actions/checkout/issues/290 39 | - name: Fetch git tags 40 | run: git fetch --tags --force 41 | 42 | - name: Docker Login 43 | uses: docker/login-action@v2 44 | with: 45 | registry: ghcr.io 46 | username: ${{ github.actor }} 47 | password: ${{ secrets.GITHUB_TOKEN }} 48 | 49 | - uses: actions/setup-go@v3 50 | with: 51 | go-version: "~1.20" 52 | 53 | - name: Build tunneld and Docker images 54 | id: build 55 | run: | 56 | set -euo pipefail 57 | go mod download 58 | 59 | make clean 60 | make -j build/tunneld build/tunneld.tag 61 | 62 | image_tag=$(cat build/tunneld.tag) 63 | if [[ "$image_tag" == "" ]]; then 64 | echo "No tag found in build/tunneld.tag" 65 | exit 1 66 | fi 67 | 68 | echo "docker_tag=${image_tag}" >> $GITHUB_OUTPUT 69 | 70 | - name: Push Docker image 71 | if: ${{ !github.event.inputs.dry_run && !github.event.inputs.snapshot }} 72 | run: | 73 | set -euxo pipefail 74 | 75 | image_tag="${{ steps.build.outputs.docker_tag }}" 76 | docker push "$image_tag" 77 | 78 | latest_tag="ghcr.io/coder/wgtunnel/tunneld:latest" 79 | docker tag "$image_tag" "$latest_tag" 80 | docker push "$latest_tag" 81 | 82 | - name: ls build 83 | run: ls -lh build 84 | 85 | - name: Publish release 86 | if: ${{ !github.event.inputs.dry_run && !github.event.inputs.snapshot }} 87 | uses: ncipollo/release-action@v1 88 | with: 89 | artifacts: "build/tunneld" 90 | body: "Docker image: `${{ steps.build.outputs.docker_tag }}`" 91 | token: ${{ secrets.GITHUB_TOKEN }} 92 | 93 | - name: Upload artifacts to actions (if dry-run or snapshot) 94 | if: ${{ github.event.inputs.dry_run || github.event.inputs.snapshot }} 95 | uses: actions/upload-artifact@v2 96 | with: 97 | name: release-artifacts 98 | path: | 99 | ./build/tunneld 100 | retention-days: 7 101 | -------------------------------------------------------------------------------- /.github/workflows/wgtunnel.yaml: -------------------------------------------------------------------------------- 1 | name: wgtunnel 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | permissions: 11 | actions: none 12 | checks: none 13 | contents: read 14 | deployments: none 15 | issues: none 16 | packages: none 17 | pull-requests: none 18 | repository-projects: none 19 | security-events: none 20 | statuses: none 21 | 22 | # Cancel in-progress runs for pull requests when developers push additional 23 | # changes. 24 | concurrency: 25 | group: ${{ github.workflow }}-${{ github.ref }} 26 | cancel-in-progress: ${{ github.event_name == 'pull_request' }} 27 | 28 | jobs: 29 | fmt: 30 | runs-on: ubuntu-latest 31 | steps: 32 | - name: Checkout 33 | uses: actions/checkout@v2 34 | - name: Setup Go 35 | uses: actions/setup-go@v3 36 | with: 37 | go-version: "~1.20" 38 | - name: Check for unstaged files 39 | run: ./scripts/check_unstaged.sh 40 | 41 | lint: 42 | runs-on: ubuntu-latest 43 | steps: 44 | - name: Checkout 45 | uses: actions/checkout@v2 46 | - name: Setup Go 47 | uses: actions/setup-go@v3 48 | with: 49 | go-version: "~1.20" 50 | - name: golangci-lint 51 | uses: golangci/golangci-lint-action@v3.2.0 52 | with: 53 | version: v1.51.0 54 | 55 | test: 56 | runs-on: ubuntu-latest 57 | steps: 58 | - name: Checkout 59 | uses: actions/checkout@v2 60 | - name: Setup Go 61 | uses: actions/setup-go@v3 62 | with: 63 | go-version: "~1.20" 64 | - name: Install gotestsum 65 | uses: jaxxstorm/action-install-gh-release@v1.7.1 66 | env: 67 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 68 | with: 69 | repo: gotestyourself/gotestsum 70 | tag: v1.9.0 71 | - name: Test 72 | run: make test 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Go workspace file 15 | go.work 16 | 17 | # Build directory. 18 | build/ 19 | 20 | *.key 21 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | # This is copied from github.com/coder/coder. 2 | # 3 | # Changes: 4 | # - removed ruleguard 5 | 6 | linters-settings: 7 | gocognit: 8 | min-complexity: 46 # Min code complexity (def 30). 9 | 10 | goconst: 11 | min-len: 4 # Min length of string consts (def 3). 12 | min-occurrences: 3 # Min number of const occurrences (def 3). 13 | 14 | gocritic: 15 | enabled-checks: 16 | # - appendAssign 17 | # - appendCombine 18 | - argOrder 19 | # - assignOp 20 | # - badCall 21 | - badCond 22 | - badLock 23 | - badRegexp 24 | - boolExprSimplify 25 | # - builtinShadow 26 | - builtinShadowDecl 27 | - captLocal 28 | - caseOrder 29 | - codegenComment 30 | # - commentedOutCode 31 | - commentedOutImport 32 | - commentFormatting 33 | - defaultCaseOrder 34 | - deferUnlambda 35 | # - deprecatedComment 36 | # - docStub 37 | - dupArg 38 | - dupBranchBody 39 | - dupCase 40 | - dupImport 41 | - dupSubExpr 42 | # - elseif 43 | - emptyFallthrough 44 | # - emptyStringTest 45 | # - equalFold 46 | # - evalOrder 47 | # - exitAfterDefer 48 | # - exposedSyncMutex 49 | # - filepathJoin 50 | - flagDeref 51 | - flagName 52 | - hexLiteral 53 | # - httpNoBody 54 | # - hugeParam 55 | # - ifElseChain 56 | # - importShadow 57 | - indexAlloc 58 | - initClause 59 | - ioutilDeprecated 60 | - mapKey 61 | - methodExprCall 62 | # - nestingReduce 63 | - newDeref 64 | - nilValReturn 65 | # - octalLiteral 66 | - offBy1 67 | # - paramTypeCombine 68 | # - preferStringWriter 69 | # - preferWriteByte 70 | # - ptrToRefParam 71 | # - rangeExprCopy 72 | # - rangeValCopy 73 | - regexpMust 74 | - regexpPattern 75 | # - regexpSimplify 76 | # - ruleguard 77 | - singleCaseSwitch 78 | - sloppyLen 79 | # - sloppyReassign 80 | - sloppyTypeAssert 81 | - sortSlice 82 | - sprintfQuotedString 83 | - sqlQuery 84 | # - stringConcatSimplify 85 | # - stringXbytes 86 | # - suspiciousSorting 87 | - switchTrue 88 | - truncateCmp 89 | - typeAssertChain 90 | # - typeDefFirst 91 | - typeSwitchVar 92 | # - typeUnparen 93 | - underef 94 | # - unlabelStmt 95 | # - unlambda 96 | # - unnamedResult 97 | # - unnecessaryBlock 98 | # - unnecessaryDefer 99 | # - unslice 100 | - valSwap 101 | - weakCond 102 | # - whyNoLint 103 | # - wrapperFunc 104 | # - yodaStyleExpr 105 | 106 | staticcheck: 107 | # https://staticcheck.io/docs/options#checks 108 | # We disable SA1019 because it gets angry about our usage of xerrors. We 109 | # intentionally xerrors because stack frame support didn't make it into the 110 | # stdlib port. 111 | checks: ["all", "-SA1019"] 112 | 113 | goimports: 114 | local-prefixes: coder.com,cdr.dev,go.coder.com,github.com/cdr,github.com/coder 115 | 116 | gocyclo: 117 | min-complexity: 50 118 | 119 | importas: 120 | no-unaliased: true 121 | 122 | misspell: 123 | locale: US 124 | 125 | nestif: 126 | min-complexity: 4 # Min complexity of if statements (def 5, goal 4) 127 | 128 | revive: 129 | # see https://github.com/mgechev/revive#available-rules for details. 130 | ignore-generated-header: true 131 | severity: warning 132 | rules: 133 | - name: atomic 134 | - name: bare-return 135 | - name: blank-imports 136 | - name: bool-literal-in-expr 137 | - name: call-to-gc 138 | - name: confusing-naming 139 | - name: confusing-results 140 | - name: constant-logical-expr 141 | - name: context-as-argument 142 | - name: context-keys-type 143 | - name: deep-exit 144 | - name: defer 145 | - name: dot-imports 146 | - name: duplicated-imports 147 | - name: early-return 148 | - name: empty-block 149 | - name: empty-lines 150 | - name: error-naming 151 | - name: error-return 152 | - name: error-strings 153 | - name: errorf 154 | - name: exported 155 | - name: flag-parameter 156 | - name: get-return 157 | - name: identical-branches 158 | - name: if-return 159 | - name: import-shadowing 160 | - name: increment-decrement 161 | - name: indent-error-flow 162 | # - name: modifies-parameter 163 | - name: modifies-value-receiver 164 | - name: package-comments 165 | - name: range 166 | - name: range-val-address 167 | - name: range-val-in-closure 168 | - name: receiver-naming 169 | - name: redefines-builtin-id 170 | - name: string-of-int 171 | - name: struct-tag 172 | - name: superfluous-else 173 | - name: time-naming 174 | - name: unconditional-recursion 175 | - name: unexported-naming 176 | - name: unexported-return 177 | - name: unhandled-error 178 | - name: unnecessary-stmt 179 | - name: unreachable-code 180 | - name: unused-parameter 181 | - name: unused-receiver 182 | - name: var-declaration 183 | - name: var-naming 184 | - name: waitgroup-by-value 185 | 186 | issues: 187 | # Rules listed here: https://github.com/securego/gosec#available-rules 188 | exclude-rules: 189 | - path: _test\.go 190 | linters: 191 | # We use assertions rather than explicitly checking errors in tests 192 | - errcheck 193 | 194 | fix: true 195 | max-issues-per-linter: 0 196 | max-same-issues: 0 197 | 198 | run: 199 | concurrency: 4 200 | skip-dirs: 201 | - node_modules 202 | skip-files: 203 | - scripts/rules.go 204 | timeout: 5m 205 | 206 | # Over time, add more and more linters from 207 | # https://golangci-lint.run/usage/linters/ as the code improves. 208 | linters: 209 | disable-all: true 210 | enable: 211 | - asciicheck 212 | - bidichk 213 | - bodyclose 214 | - dogsled 215 | - errcheck 216 | - errname 217 | - errorlint 218 | - exportloopref 219 | - forcetypeassert 220 | - gocritic 221 | - gocyclo 222 | - goimports 223 | - gomodguard 224 | - gosec 225 | - gosimple 226 | - govet 227 | - importas 228 | - ineffassign 229 | - makezero 230 | - misspell 231 | - nilnil 232 | - noctx 233 | - paralleltest 234 | - revive 235 | 236 | # These don't work until the following issue is solved. 237 | # https://github.com/golangci/golangci-lint/issues/2649 238 | # - rowserrcheck 239 | # - sqlclosecheck 240 | # - structcheck 241 | # - wastedassign 242 | 243 | - staticcheck 244 | - tenv 245 | # In Go, it's possible for a package to test it's internal functionality 246 | # without testing any exported functions. This is enabled to promote 247 | # decomposing a package before testing it's internals. A function caller 248 | # should be able to test most of the functionality from exported functions. 249 | # 250 | # There are edge-cases to this rule, but they should be carefully considered 251 | # to avoid structural inconsistency. 252 | - testpackage 253 | - tparallel 254 | - typecheck 255 | - unconvert 256 | - unused 257 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM alpine:latest 2 | 3 | ARG WGTUNNEL_VERSION 4 | LABEL \ 5 | org.opencontainers.image.title="wgtunnel" \ 6 | org.opencontainers.image.description="Simple HTTP tunnel over WireGuard." \ 7 | org.opencontainers.image.url="https://github.com/coder/wgtunnel" \ 8 | org.opencontainers.image.source="https://github.com/coder/wgtunnel" \ 9 | org.opencontainers.image.version="$WGTUNNEL_VERSION" 10 | 11 | RUN adduser -D -u 1000 tunneld 12 | USER tunneld 13 | 14 | COPY ./build/tunneld / 15 | 16 | CMD ["/tunneld"] 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 Coder Technologies, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Use a single bash shell for each job, and immediately exit on failure 2 | SHELL := bash 3 | .SHELLFLAGS := -ceu 4 | .ONESHELL: 5 | 6 | # This doesn't work on directories. 7 | # See https://stackoverflow.com/questions/25752543/make-delete-on-error-for-directory-targets 8 | .DELETE_ON_ERROR: 9 | 10 | # Don't print the commands in the file unless you specify VERBOSE. This is 11 | # essentially the same as putting "@" at the start of each line. 12 | ifndef VERBOSE 13 | .SILENT: 14 | endif 15 | 16 | # Create the output directories if they do not exist. 17 | $(shell mkdir -p build) 18 | 19 | VERSION := $(shell ./scripts/version.sh) 20 | 21 | clean: 22 | rm -rf build 23 | .PHONY: clean 24 | 25 | fmt: 26 | go fmt ./... 27 | .PHONY: fmt 28 | 29 | lint: 30 | golangci-lint run 31 | .PHONY: lint 32 | 33 | build: build/tunneld build/tunnel 34 | .PHONY: build 35 | 36 | # build/tunneld and build/tunnel build the Go binary for the current 37 | # architecture. You can change the architecture by setting GOOS and GOARCH 38 | # manually before calling this target. 39 | build/tunneld build/tunnel: build/%: $(shell find . -type f -name '*.go') 40 | CGO_ENABLED=0 go build \ 41 | -o "$@" \ 42 | -tags urfave_cli_no_docs \ 43 | -ldflags "-s -w -X 'github.com/coder/wgtunnel/buildinfo.tag=$(VERSION)'" \ 44 | "./cmd/$*" 45 | 46 | # build/tunneld.tag generates the Docker image for tunneld. 47 | build/tunneld.tag: build/tunneld 48 | # Dev versions contain plus signs which are illegal in Docker tags. 49 | version="$(VERSION)" 50 | tag="ghcr.io/coder/wgtunnel/tunneld:$${version//+/-}" 51 | 52 | docker build \ 53 | --file Dockerfile \ 54 | --build-arg "WGTUNNEL_VERSION=$(VERSION)" \ 55 | --tag "$$tag" \ 56 | . 57 | 58 | echo "$$tag" > "$@" 59 | 60 | test: 61 | go clean -testcache 62 | gotestsum -- -v -short ./... 63 | .PHONY: test 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wgtunnel 2 | 3 | wgtunnel is a simple WireGuard tunnel server. Clients can register themselves 4 | to the server with a single API request (done periodically in the background in 5 | case the server restarts), and then connect to a WireGuard endpoint on the 6 | server over UDP to tunnel. 7 | 8 | Generated URLs are unique and are based on the WireGuard public key. Wildcards 9 | for each tunnel are also semi-supported, using hyphens instead of periods to 10 | allow for TLS. 11 | 12 | This is used by [Coder](https://github.com/coder/coder) to create tunnels for 13 | trial/demo deployments with globally accessible URLs. 14 | 15 | ## Deployment 16 | 17 | Deploy `tunneld` onto your server and configure it with environment variables or 18 | flags. Point the DNS entries `${base_url}` and `*.${base_url}` to the server. If 19 | you want to use HTTPS, setup a proxy such as [Caddy](https://caddyserver.com/) 20 | in front of the server. 21 | 22 | `tunneld` is available on GitHub releases or can be installed with: 23 | 24 | ```console 25 | $ go install github.com/coder/wgtunnel/cmd/tunneld 26 | ``` 27 | 28 | or by running `make build/tunneld`. 29 | 30 | You can also use the Docker image `ghcr.io/coder/wgtunnel/tunneld`. 31 | 32 | ## Usage 33 | 34 | Either use `tunnel` for easy usage from a terminal, or use the `tunnelsdk` 35 | package to initiate a tunnel against the given API server URL. Remember to 36 | store the private key for future tunnel sessions in a safe place, otherwise you 37 | will get a new hostname! 38 | 39 | `tunnel` can be installed with: 40 | 41 | ```console 42 | $ go install github.com/coder/wgtunnel/cmd/tunnel 43 | ``` 44 | 45 | or by running `make build/tunnel`. 46 | 47 | ## License 48 | 49 | Licensed under the MIT license. 50 | -------------------------------------------------------------------------------- /buildinfo/buildinfo.go: -------------------------------------------------------------------------------- 1 | package buildinfo 2 | 3 | import ( 4 | "runtime/debug" 5 | "sync" 6 | "time" 7 | 8 | "golang.org/x/mod/semver" 9 | ) 10 | 11 | var ( 12 | buildInfo *debug.BuildInfo 13 | buildInfoValid bool 14 | readBuildInfo sync.Once 15 | 16 | version string 17 | readVersion sync.Once 18 | 19 | // Injected with ldflags at build! 20 | tag string 21 | ) 22 | 23 | const ( 24 | // develPrefix is prefixed to developer versions of the application. 25 | develPrefix = "v0.0.0-devel" 26 | ) 27 | 28 | // Version returns the semantic version of the build. 29 | // Use golang.org/x/mod/semver to compare versions. 30 | func Version() string { 31 | readVersion.Do(func() { 32 | revision, valid := revision() 33 | if valid { 34 | revision = "+" + revision[:7] 35 | } 36 | if tag == "" { 37 | // This occurs when the tag hasn't been injected, 38 | // like when using "go run". 39 | version = develPrefix + revision 40 | return 41 | } 42 | version = "v" + tag 43 | // The tag must be prefixed with "v" otherwise the 44 | // semver library will return an empty string. 45 | if semver.Build(version) == "" { 46 | version += revision 47 | } 48 | }) 49 | return version 50 | } 51 | 52 | // Time returns when the Git revision was published. 53 | func Time() (time.Time, bool) { 54 | value, valid := find("vcs.time") 55 | if !valid { 56 | return time.Time{}, false 57 | } 58 | parsed, err := time.Parse(time.RFC3339, value) 59 | if err != nil { 60 | panic("couldn't parse time: " + err.Error()) 61 | } 62 | return parsed, true 63 | } 64 | 65 | // revision returns the Git hash of the build. 66 | func revision() (string, bool) { 67 | return find("vcs.revision") 68 | } 69 | 70 | // find panics if a setting with the specific key was not 71 | // found in the build info. 72 | func find(key string) (string, bool) { 73 | readBuildInfo.Do(func() { 74 | buildInfo, buildInfoValid = debug.ReadBuildInfo() 75 | }) 76 | if !buildInfoValid { 77 | panic("couldn't read build info") 78 | } 79 | for _, setting := range buildInfo.Settings { 80 | if setting.Key != key { 81 | continue 82 | } 83 | return setting.Value, true 84 | } 85 | return "", false 86 | } 87 | -------------------------------------------------------------------------------- /cmd/tunnel/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "log" 8 | "net" 9 | "net/url" 10 | "os" 11 | "os/signal" 12 | "time" 13 | 14 | "github.com/urfave/cli/v2" 15 | "golang.org/x/xerrors" 16 | 17 | "cdr.dev/slog" 18 | "cdr.dev/slog/sloggers/sloghuman" 19 | "github.com/coder/wgtunnel/buildinfo" 20 | "github.com/coder/wgtunnel/tunnelsdk" 21 | ) 22 | 23 | func main() { 24 | cli.VersionFlag = &cli.BoolFlag{ 25 | Name: "version", 26 | Aliases: []string{"V"}, 27 | Usage: "Print the version.", 28 | } 29 | 30 | app := &cli.App{ 31 | Name: "tunnel", 32 | Usage: "run a wgtunnel client", 33 | ArgsUsage: "", 34 | Version: buildinfo.Version(), 35 | Flags: []cli.Flag{ 36 | &cli.BoolFlag{ 37 | Name: "verbose", 38 | Aliases: []string{"v"}, 39 | Usage: "Enable verbose logging.", 40 | EnvVars: []string{"TUNNEL_VERBOSE"}, 41 | }, 42 | &cli.StringFlag{ 43 | Name: "api-url", 44 | Usage: "The base URL of the tunnel API.", 45 | EnvVars: []string{"TUNNEL_API_URL"}, 46 | }, 47 | &cli.StringFlag{ 48 | Name: "wireguard-key", 49 | Aliases: []string{"wg-key"}, 50 | Usage: "The private key for the wireguard client. It should be base64 encoded. You must specify this or wireguard-key-file.", 51 | EnvVars: []string{"TUNNEL_WIREGUARD_KEY"}, 52 | }, 53 | &cli.StringFlag{ 54 | Name: "wireguard-key-file", 55 | Aliases: []string{"wg-key-file"}, 56 | Usage: "The file containing the private key for the wireguard client. It should contain a base64 encoded key. The file will be created and populated with a fresh key if it does not exist. You must specify this or wireguard-key.", 57 | EnvVars: []string{"TUNNEL_WIREGUARD_KEY_FILE"}, 58 | }, 59 | }, 60 | Action: runApp, 61 | } 62 | 63 | err := app.Run(os.Args) 64 | if err != nil { 65 | log.Fatal(err) 66 | } 67 | } 68 | 69 | func runApp(ctx *cli.Context) error { 70 | var ( 71 | verbose = ctx.Bool("verbose") 72 | apiURL = ctx.String("api-url") 73 | wireguardKey = ctx.String("wireguard-key") 74 | wireguardKeyFile = ctx.String("wireguard-key-file") 75 | ) 76 | if apiURL == "" { 77 | return xerrors.New("api-url is required. See --help for more information.") 78 | } 79 | if wireguardKey == "" && wireguardKeyFile == "" { 80 | return xerrors.New("wireguard-key or wireguard-key-file is required. See --help for more information.") 81 | } 82 | if wireguardKey != "" && wireguardKeyFile != "" { 83 | return xerrors.New("Only one of wireguard-key or wireguard-key-file can be specified. See --help for more information.") 84 | } 85 | 86 | if ctx.Args().Len() != 1 { 87 | return xerrors.New("exactly one argument (target-address) is required. See --help for more information.") 88 | } 89 | targetAddress := ctx.Args().Get(0) 90 | if targetAddress == "" { 91 | return xerrors.New("target-address is empty") 92 | } 93 | _, _, err := net.SplitHostPort(targetAddress) 94 | if err != nil { 95 | return xerrors.Errorf("target-address %q is not a valid host:port: %w", targetAddress, err) 96 | } 97 | 98 | logger := slog.Make(sloghuman.Sink(os.Stderr)).Leveled(slog.LevelInfo) 99 | if verbose { 100 | logger = logger.Leveled(slog.LevelDebug) 101 | } 102 | 103 | apiURLParsed, err := url.Parse(apiURL) 104 | if err != nil { 105 | return xerrors.Errorf("failed to parse api-url %q: %w", apiURL, err) 106 | } 107 | 108 | if wireguardKeyFile != "" { 109 | fileBytes, err := os.ReadFile(wireguardKeyFile) 110 | if xerrors.Is(err, os.ErrNotExist) { 111 | key, err := tunnelsdk.GeneratePrivateKey() 112 | if err != nil { 113 | return xerrors.Errorf("failed to generate wireguard key: %w", err) 114 | } 115 | 116 | fileBytes = []byte(key.String()) 117 | err = os.WriteFile(wireguardKeyFile, fileBytes, 0600) 118 | if err != nil { 119 | return xerrors.Errorf("failed to write wireguard key to file %q: %w", wireguardKeyFile, err) 120 | } 121 | } else if err != nil { 122 | return xerrors.Errorf("failed to read wireguard-key-file %q: %w", wireguardKeyFile, err) 123 | } 124 | wireguardKey = string(fileBytes) 125 | } 126 | 127 | wireguardKeyParsed, err := tunnelsdk.ParsePrivateKey(wireguardKey) 128 | if err != nil { 129 | return xerrors.Errorf("could not parse wireguard-key or wireguard-key-file: %w", err) 130 | } 131 | 132 | client := tunnelsdk.New(apiURLParsed) 133 | tunnel, err := client.LaunchTunnel(ctx.Context, tunnelsdk.TunnelConfig{ 134 | Log: logger, 135 | PrivateKey: wireguardKeyParsed, 136 | }) 137 | if err != nil { 138 | return xerrors.Errorf("launch tunnel: %w", err) 139 | } 140 | defer func() { 141 | err := tunnel.Close() 142 | if err != nil { 143 | logger.Error(ctx.Context, "close tunnel", slog.Error(err)) 144 | } 145 | }() 146 | 147 | _, _ = fmt.Fprintln(os.Stderr, "Tunnel is ready. You can now connect to one of the following URLs:") 148 | _, _ = fmt.Fprintln(os.Stderr, " -", tunnel.URL.String()) 149 | for _, u := range tunnel.OtherURLs { 150 | _, _ = fmt.Fprintln(os.Stderr, " -", u.String()) 151 | } 152 | 153 | // Start forwarding traffic to/from the tunnel. 154 | go func() { 155 | for { 156 | conn, err := tunnel.Listener.Accept() 157 | if err != nil { 158 | logger.Error(ctx.Context, "close tunnel", slog.Error(err)) 159 | tunnel.Close() 160 | return 161 | } 162 | 163 | go func() { 164 | defer conn.Close() 165 | 166 | dialCtx, dialCancel := context.WithTimeout(ctx.Context, 10*time.Second) 167 | defer dialCancel() 168 | 169 | targetConn, err := (&net.Dialer{}).DialContext(dialCtx, "tcp", targetAddress) 170 | if err != nil { 171 | logger.Warn(ctx.Context, "could not dial target", slog.F("target_address", targetAddress), slog.Error(err)) 172 | return 173 | } 174 | defer targetConn.Close() 175 | 176 | go func() { 177 | _, err := io.Copy(targetConn, conn) 178 | if err != nil && !xerrors.Is(err, io.EOF) { 179 | logger.Warn(ctx.Context, "could not copy from tunnel to target", slog.Error(err)) 180 | } 181 | }() 182 | 183 | _, err = io.Copy(conn, targetConn) 184 | if err != nil && !xerrors.Is(err, io.EOF) { 185 | logger.Warn(ctx.Context, "could not copy from target to tunnel", slog.Error(err)) 186 | } 187 | }() 188 | } 189 | }() 190 | 191 | _, _ = fmt.Printf("\nTunnel is ready! You can now connect to %s\n", tunnel.URL.String()) 192 | 193 | notifyCtx, notifyStop := signal.NotifyContext(ctx.Context, InterruptSignals...) 194 | defer notifyStop() 195 | 196 | select { 197 | case <-notifyCtx.Done(): 198 | _, _ = fmt.Printf("\nClosing tunnel due to signal...\n") 199 | return tunnel.Close() 200 | case <-tunnel.Wait(): 201 | } 202 | 203 | return nil 204 | } 205 | -------------------------------------------------------------------------------- /cmd/tunnel/signal_unix.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package main 4 | 5 | import ( 6 | "os" 7 | "syscall" 8 | ) 9 | 10 | var InterruptSignals = []os.Signal{ 11 | os.Interrupt, 12 | syscall.SIGTERM, 13 | syscall.SIGHUP, 14 | } 15 | -------------------------------------------------------------------------------- /cmd/tunnel/signal_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package main 4 | 5 | import ( 6 | "os" 7 | ) 8 | 9 | var InterruptSignals = []os.Signal{os.Interrupt} 10 | -------------------------------------------------------------------------------- /cmd/tunneld/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "log" 8 | "net/http" 9 | "net/http/pprof" 10 | "net/netip" 11 | "net/url" 12 | "os" 13 | "os/signal" 14 | "time" 15 | 16 | "github.com/urfave/cli/v2" 17 | "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" 18 | "go.opentelemetry.io/otel" 19 | "go.opentelemetry.io/otel/propagation" 20 | sdktrace "go.opentelemetry.io/otel/sdk/trace" 21 | "golang.org/x/sync/errgroup" 22 | "golang.org/x/xerrors" 23 | 24 | "cdr.dev/slog" 25 | "cdr.dev/slog/sloggers/sloghuman" 26 | "github.com/coder/wgtunnel/buildinfo" 27 | "github.com/coder/wgtunnel/tunneld" 28 | "github.com/coder/wgtunnel/tunnelsdk" 29 | ) 30 | 31 | func main() { 32 | cli.VersionFlag = &cli.BoolFlag{ 33 | Name: "version", 34 | Aliases: []string{"V"}, 35 | Usage: "Print the version.", 36 | } 37 | 38 | app := &cli.App{ 39 | Name: "tunneld", 40 | Usage: "run a wgtunnel server", 41 | Version: buildinfo.Version(), 42 | Flags: []cli.Flag{ 43 | &cli.BoolFlag{ 44 | Name: "verbose", 45 | Aliases: []string{"v"}, 46 | Usage: "Enable verbose logging.", 47 | EnvVars: []string{"TUNNELD_VERBOSE"}, 48 | }, 49 | &cli.StringFlag{ 50 | Name: "listen-address", 51 | Aliases: []string{"a"}, 52 | Usage: "HTTP listen address for the API and tunnel traffic.", 53 | Value: "127.0.0.1:8080", 54 | EnvVars: []string{"TUNNELD_LISTEN_ADDRESS"}, 55 | }, 56 | &cli.StringFlag{ 57 | Name: "base-url", 58 | Aliases: []string{"u"}, 59 | Usage: "The base URL to use for the tunnel, including scheme. All tunnels will be subdomains of this hostname.", 60 | EnvVars: []string{"TUNNELD_BASE_URL"}, 61 | }, 62 | &cli.StringFlag{ 63 | Name: "wireguard-endpoint", 64 | Aliases: []string{"wg-endpoint"}, 65 | Usage: "The UDP address advertised to clients that they will connect to for wireguard connections. It should be in the form host:port.", 66 | EnvVars: []string{"TUNNELD_WIREGUARD_ENDPOINT"}, 67 | }, 68 | // Technically a uint16. 69 | &cli.UintFlag{ 70 | Name: "wireguard-port", 71 | Aliases: []string{"wg-port"}, 72 | Usage: "The UDP port that the wireguard server will listen on. It should be the same as the port in wireguard-endpoint.", 73 | EnvVars: []string{"TUNNELD_WIREGUARD_PORT"}, 74 | }, 75 | &cli.StringFlag{ 76 | Name: "wireguard-key", 77 | Aliases: []string{"wg-key"}, 78 | Usage: "The private key for the wireguard server. It should be base64 encoded. You can generate a key with `wg genkey`. Mutually exclusive with wireguard-key-file.", 79 | EnvVars: []string{"TUNNELD_WIREGUARD_KEY"}, 80 | }, 81 | &cli.StringFlag{ 82 | Name: "wireguard-key-file", 83 | Aliases: []string{"wg-key-file"}, 84 | Usage: "The file path containing the private key for the wireguard server. The contents should be base64 encoded. If the file does not exist, a key will be generated for you and written to the file. Mutually exclusive with wireguard-key.", 85 | EnvVars: []string{"TUNNELD_WIREGUARD_KEY_FILE"}, 86 | }, 87 | &cli.IntFlag{ 88 | Name: "wireguard-mtu", 89 | Aliases: []string{"wg-mtu"}, 90 | Usage: "The MTU to use for the wireguard interface.", 91 | Value: tunneld.DefaultWireguardMTU, 92 | EnvVars: []string{"TUNNELD_WIREGUARD_MTU"}, 93 | }, 94 | &cli.StringFlag{ 95 | Name: "wireguard-server-ip", 96 | Aliases: []string{"wg-server-ip"}, 97 | Usage: "The virtual IP address of this server in the wireguard network. Must be an IPv6 address contained within wireguard-network-prefix.", 98 | Value: tunneld.DefaultWireguardServerIP.String(), 99 | EnvVars: []string{"TUNNELD_WIREGUARD_SERVER_IP"}, 100 | }, 101 | &cli.StringFlag{ 102 | Name: "wireguard-network-prefix", 103 | Aliases: []string{"wg-network-prefix"}, 104 | Usage: "The CIDR of the wireguard network. All client IPs will be generated within this network. Must be a IPv6 CIDR and have at least 64 bits available.", 105 | Value: tunneld.DefaultWireguardNetworkPrefix.String(), 106 | EnvVars: []string{"TUNNELD_WIREGUARD_NETWORK_PREFIX"}, 107 | }, 108 | &cli.StringFlag{ 109 | Name: "real-ip-header", 110 | Usage: "Use the given header as the real IP address rather than the remote socket address.", 111 | Value: "", 112 | EnvVars: []string{"TUNNELD_REAL_IP_HEADER"}, 113 | }, 114 | &cli.StringFlag{ 115 | Name: "pprof-listen-address", 116 | Usage: "The address to listen on for pprof. If set to an empty string, pprof will not be enabled.", 117 | Value: "127.0.0.1:6060", 118 | EnvVars: []string{"TUNNELD_PPROF_LISTEN_ADDRESS"}, 119 | }, 120 | &cli.StringFlag{ 121 | Name: "tracing-honeycomb-team", 122 | Usage: "The Honeycomb team ID to send tracing data to. If not specified, tracing will not be shipped anywhere.", 123 | EnvVars: []string{"TUNNELD_TRACING_HONEYCOMB_TEAM"}, 124 | }, 125 | &cli.StringFlag{ 126 | Name: "tracing-instance-id", 127 | Usage: "The instance ID to annotate all traces with that uniquely identifies this deployment.", 128 | EnvVars: []string{"TUNNELD_TRACING_INSTANCE_ID"}, 129 | }, 130 | }, 131 | Action: runApp, 132 | } 133 | 134 | err := app.Run(os.Args) 135 | if err != nil { 136 | log.Fatal(err) 137 | } 138 | } 139 | 140 | func runApp(ctx *cli.Context) error { 141 | var ( 142 | verbose = ctx.Bool("verbose") 143 | listenAddress = ctx.String("listen-address") 144 | baseURL = ctx.String("base-url") 145 | wireguardEndpoint = ctx.String("wireguard-endpoint") 146 | wireguardPort = ctx.Uint("wireguard-port") 147 | wireguardKey = ctx.String("wireguard-key") 148 | wireguardKeyFile = ctx.String("wireguard-key-file") 149 | wireguardMTU = ctx.Int("wireguard-mtu") 150 | wireguardServerIP = ctx.String("wireguard-server-ip") 151 | wireguardNetworkPrefix = ctx.String("wireguard-network-prefix") 152 | realIPHeader = ctx.String("real-ip-header") 153 | pprofListenAddress = ctx.String("pprof-listen-address") 154 | tracingHoneycombTeam = ctx.String("tracing-honeycomb-team") 155 | tracingInstanceID = ctx.String("tracing-instance-id") 156 | ) 157 | if baseURL == "" { 158 | return xerrors.New("base-url is required. See --help for more information.") 159 | } 160 | if wireguardEndpoint == "" { 161 | return xerrors.New("wireguard-endpoint is required. See --help for more information.") 162 | } 163 | if wireguardPort < 1 || wireguardPort > 65535 { 164 | return xerrors.New("wireguard-port is required and must be between 1 and 65535. See --help for more information.") 165 | } 166 | if wireguardKey == "" && wireguardKeyFile == "" { 167 | return xerrors.New("wireguard-key is required. See --help for more information.") 168 | } 169 | if wireguardKey != "" && wireguardKeyFile != "" { 170 | return xerrors.New("wireguard-key and wireguard-key-file are mutually exclusive. See --help for more information.") 171 | } 172 | 173 | logger := slog.Make(sloghuman.Sink(os.Stderr)).Leveled(slog.LevelInfo) 174 | if verbose { 175 | logger = logger.Leveled(slog.LevelDebug) 176 | } 177 | 178 | // Initiate tracing. 179 | var tp *sdktrace.TracerProvider 180 | if tracingHoneycombTeam != "" { 181 | exp, err := newHoneycombExporter(ctx.Context, tracingHoneycombTeam) 182 | if err != nil { 183 | return xerrors.Errorf("create honeycomb telemetry exporter: %w", err) 184 | } 185 | 186 | // Create a new tracer provider with a batch span processor and the otlp 187 | // exporter. 188 | tp := newTraceProvider(exp, tracingInstanceID) 189 | otel.SetTracerProvider(tp) 190 | otel.SetTextMapPropagator( 191 | propagation.NewCompositeTextMapPropagator( 192 | propagation.TraceContext{}, 193 | propagation.Baggage{}, 194 | ), 195 | ) 196 | 197 | defer func() { 198 | // allow time for traces to flush even if command context is canceled 199 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 200 | defer cancel() 201 | _ = tp.Shutdown(ctx) 202 | }() 203 | } 204 | 205 | baseURLParsed, err := url.Parse(baseURL) 206 | if err != nil { 207 | return xerrors.Errorf("could not parse base-url %q: %w", baseURL, err) 208 | } 209 | wireguardServerIPParsed, err := netip.ParseAddr(wireguardServerIP) 210 | if err != nil { 211 | return xerrors.Errorf("could not parse wireguard-server-ip %q: %w", wireguardServerIP, err) 212 | } 213 | wireguardNetworkPrefixParsed, err := netip.ParsePrefix(wireguardNetworkPrefix) 214 | if err != nil { 215 | return xerrors.Errorf("could not parse wireguard-network-prefix %q: %w", wireguardNetworkPrefix, err) 216 | } 217 | 218 | if wireguardKeyFile != "" { 219 | _, err = os.Stat(wireguardKeyFile) 220 | if errors.Is(err, os.ErrNotExist) { 221 | logger.Info(ctx.Context, "generating private key to file", slog.F("path", wireguardKeyFile)) 222 | key, err := tunnelsdk.GeneratePrivateKey() 223 | if err != nil { 224 | return xerrors.Errorf("could not generate private key: %w", err) 225 | } 226 | 227 | err = os.WriteFile(wireguardKeyFile, []byte(key.String()), 0600) 228 | if err != nil { 229 | return xerrors.Errorf("could not write base64-encoded private key to %q: %w", wireguardKeyFile, err) 230 | } 231 | } else if err != nil { 232 | return xerrors.Errorf("could not stat wireguard-key-file %q: %w", wireguardKeyFile, err) 233 | } 234 | 235 | logger.Info(ctx.Context, "reading private key from file", slog.F("path", wireguardKeyFile)) 236 | wireguardKeyBytes, err := os.ReadFile(wireguardKeyFile) 237 | if err != nil { 238 | return xerrors.Errorf("could not read wireguard-key-file %q: %w", wireguardKeyFile, err) 239 | } 240 | wireguardKey = string(wireguardKeyBytes) 241 | } 242 | 243 | wireguardKeyParsed, err := tunnelsdk.ParsePrivateKey(wireguardKey) 244 | if err != nil { 245 | return xerrors.Errorf("could not parse wireguard-key %q: %w", wireguardKey, err) 246 | } 247 | logger.Info(ctx.Context, "parsed private key", slog.F("hash", wireguardKeyParsed.Hash())) 248 | 249 | options := &tunneld.Options{ 250 | BaseURL: baseURLParsed, 251 | WireguardEndpoint: wireguardEndpoint, 252 | WireguardPort: uint16(wireguardPort), 253 | WireguardKey: wireguardKeyParsed, 254 | WireguardMTU: wireguardMTU, 255 | WireguardServerIP: wireguardServerIPParsed, 256 | WireguardNetworkPrefix: wireguardNetworkPrefixParsed, 257 | RealIPHeader: realIPHeader, 258 | } 259 | td, err := tunneld.New(options) 260 | if err != nil { 261 | return xerrors.Errorf("create tunneld.API instance: %w", err) 262 | } 263 | 264 | // ReadHeaderTimeout is purposefully not enabled. It caused some issues with 265 | // websockets over the dev tunnel. 266 | // See: https://github.com/coder/coder/pull/3730 267 | //nolint:gosec 268 | server := &http.Server{ 269 | // These errors are typically noise like "TLS: EOF". Vault does similar: 270 | // https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714 271 | ErrorLog: log.New(io.Discard, "", 0), 272 | Addr: listenAddress, 273 | Handler: td.Router(), 274 | } 275 | if tp != nil { 276 | server.Handler = otelhttp.NewHandler(server.Handler, "tunneld") 277 | } 278 | 279 | // Start the pprof server if requested. 280 | if pprofListenAddress != "" { 281 | var _ = pprof.Handler 282 | go func() { 283 | server := &http.Server{ 284 | // See above for why we discard these errors. 285 | ErrorLog: log.New(io.Discard, "", 0), 286 | ReadHeaderTimeout: 15 * time.Second, 287 | Addr: pprofListenAddress, 288 | Handler: nil, // use pprof 289 | } 290 | 291 | logger.Info(ctx.Context, "starting pprof server", slog.F("listen_address", pprofListenAddress)) 292 | _ = server.ListenAndServe() 293 | }() 294 | } 295 | 296 | eg, egCtx := errgroup.WithContext(ctx.Context) 297 | eg.Go(func() error { 298 | logger.Info(egCtx, "listening for requests", slog.F("listen_address", listenAddress)) 299 | err = server.ListenAndServe() 300 | if err != nil { 301 | return xerrors.Errorf("error in ListenAndServe: %w", err) 302 | } 303 | return nil 304 | }) 305 | 306 | notifyCtx, notifyStop := signal.NotifyContext(ctx.Context, InterruptSignals...) 307 | defer notifyStop() 308 | 309 | eg.Go(func() error { 310 | <-notifyCtx.Done() 311 | logger.Info(egCtx, "shutting down server due to signal") 312 | 313 | shutdownCtx, shutdownCancel := context.WithTimeout(egCtx, 5*time.Second) 314 | defer shutdownCancel() 315 | return server.Shutdown(shutdownCtx) 316 | }) 317 | 318 | return eg.Wait() 319 | } 320 | -------------------------------------------------------------------------------- /cmd/tunneld/signal_unix.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package main 4 | 5 | import ( 6 | "os" 7 | "syscall" 8 | ) 9 | 10 | var InterruptSignals = []os.Signal{ 11 | os.Interrupt, 12 | syscall.SIGTERM, 13 | syscall.SIGHUP, 14 | } 15 | -------------------------------------------------------------------------------- /cmd/tunneld/signal_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package main 4 | 5 | import ( 6 | "os" 7 | ) 8 | 9 | var InterruptSignals = []os.Signal{os.Interrupt} 10 | -------------------------------------------------------------------------------- /cmd/tunneld/tracing.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | 6 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace" 7 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" 8 | "go.opentelemetry.io/otel/sdk/resource" 9 | sdktrace "go.opentelemetry.io/otel/sdk/trace" 10 | semconv "go.opentelemetry.io/otel/semconv/v1.11.0" 11 | "google.golang.org/grpc/credentials" 12 | 13 | "github.com/coder/wgtunnel/buildinfo" 14 | ) 15 | 16 | func newHoneycombExporter(ctx context.Context, teamID string) (*otlptrace.Exporter, error) { 17 | opts := []otlptracegrpc.Option{ 18 | otlptracegrpc.WithEndpoint("api.honeycomb.io:443"), 19 | otlptracegrpc.WithHeaders(map[string]string{ 20 | "x-honeycomb-team": teamID, 21 | }), 22 | otlptracegrpc.WithTLSCredentials(credentials.NewClientTLSFromCert(nil, "")), 23 | } 24 | 25 | client := otlptracegrpc.NewClient(opts...) 26 | return otlptrace.New(ctx, client) 27 | } 28 | 29 | func newTraceProvider(exp *otlptrace.Exporter, instanceID string) *sdktrace.TracerProvider { 30 | rsc := resource.NewWithAttributes( 31 | semconv.SchemaURL, 32 | semconv.ServiceNameKey.String("WireguardTunnel"), 33 | semconv.ServiceInstanceIDKey.String(instanceID), 34 | semconv.ServiceVersionKey.String(buildinfo.Version()), 35 | ) 36 | 37 | return sdktrace.NewTracerProvider( 38 | sdktrace.WithBatcher(exp), 39 | sdktrace.WithResource(rsc), 40 | ) 41 | } 42 | -------------------------------------------------------------------------------- /compose/.env.example: -------------------------------------------------------------------------------- 1 | CLOUDFLARE_TOKEN= 2 | HONEYCOMB_TEAM= 3 | -------------------------------------------------------------------------------- /compose/.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | -------------------------------------------------------------------------------- /compose/Makefile: -------------------------------------------------------------------------------- 1 | # Use a single bash shell for each job, and immediately exit on failure 2 | SHELL := bash 3 | .SHELLFLAGS := -ceu 4 | .ONESHELL: 5 | 6 | # Don't print the commands in the file unless you specify VERBOSE. This is 7 | # essentially the same as putting "@" at the start of each line. 8 | ifndef VERBOSE 9 | .SILENT: 10 | endif 11 | 12 | up: 13 | pushd .. 14 | make -B build 15 | popd 16 | docker compose -p wgtunnel up --build 17 | .PHONY: up 18 | -------------------------------------------------------------------------------- /compose/caddy/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CADDY_VERSION=2.6.4 2 | FROM caddy:${CADDY_VERSION}-builder AS builder 3 | 4 | RUN xcaddy build \ 5 | --with github.com/lucaslorentz/caddy-docker-proxy/v2 \ 6 | --with github.com/caddy-dns/cloudflare 7 | 8 | FROM caddy:${CADDY_VERSION} 9 | 10 | COPY --from=builder /usr/bin/caddy /usr/bin/caddy 11 | 12 | CMD ["caddy", "docker-proxy"] 13 | -------------------------------------------------------------------------------- /compose/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.9" 2 | services: 3 | caddy: 4 | build: ./caddy 5 | ports: 6 | - 8080:80 7 | - 4443:443 8 | environment: 9 | - CADDY_INGRESS_NETWORKS=caddy 10 | networks: 11 | - caddy 12 | volumes: 13 | - /var/run/docker.sock:/var/run/docker.sock 14 | - caddy_data:/data 15 | restart: unless-stopped 16 | 17 | tunnel: 18 | build: .. 19 | restart: always 20 | ports: 21 | - 55551:55551/udp 22 | networks: 23 | - caddy 24 | environment: 25 | TUNNELD_LISTEN_ADDRESS: "0.0.0.0:8080" 26 | TUNNELD_BASE_URL: "https://local.try.coder.app:4443" 27 | TUNNELD_WIREGUARD_ENDPOINT: "local.try.coder.app:55551" 28 | TUNNELD_WIREGUARD_PORT: "55551" 29 | TUNNELD_WIREGUARD_KEY_FILE: "/home/tunneld/wg.key" 30 | TUNNELD_WIREGUARD_MTU: "1280" 31 | TUNNELD_WIREGUARD_SERVER_IP: "fcca::1" 32 | TUNNELD_WIREGUARD_NETWORK_PREFIX: "fcca::/16" 33 | TUNNELD_REAL_IP_HEADER: "X-Forwarded-For" 34 | TUNNELD_PPROF_LISTEN_ADDRESS: "127.0.0.1:6060" 35 | TUNNELD_TRACING_HONEYCOMB_TEAM: "${HONEYCOMB_TEAM}" 36 | TUNNELD_TRACING_INSTANCE_ID: "local" 37 | labels: 38 | caddy: "local.try.coder.app, *.local.try.coder.app" 39 | caddy.reverse_proxy: "{{upstreams 8080}}" 40 | caddy.tls.dns: cloudflare ${CLOUDFLARE_TOKEN} 41 | 42 | networks: 43 | caddy: 44 | external: true 45 | 46 | volumes: 47 | caddy_data: {} 48 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coder/wgtunnel 2 | 3 | go 1.20 4 | 5 | replace github.com/tailscale/wireguard-go => github.com/coder/wireguard-go v0.0.0-20240502122727-a4cb23ac736d 6 | 7 | require ( 8 | cdr.dev/slog v1.6.2-0.20230901043036-3e17d6de9749 9 | github.com/go-chi/chi/v5 v5.0.10 10 | github.com/go-chi/hostrouter v0.2.0 11 | github.com/go-chi/httprate v0.7.4 12 | github.com/riandyrn/otelchi v0.5.1 13 | github.com/stretchr/testify v1.8.4 14 | github.com/tailscale/wireguard-go v0.0.0-20231121184858-cc193a0b3272 15 | github.com/urfave/cli/v2 v2.25.7 16 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0 17 | go.opentelemetry.io/otel v1.18.0 18 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0 19 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.18.0 20 | go.opentelemetry.io/otel/sdk v1.18.0 21 | go.opentelemetry.io/otel/trace v1.18.0 22 | golang.org/x/mod v0.12.0 23 | golang.org/x/sync v0.3.0 24 | golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 25 | golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 26 | google.golang.org/grpc v1.58.3 27 | ) 28 | 29 | require ( 30 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect 31 | github.com/cenkalti/backoff/v4 v4.2.1 // indirect 32 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 33 | github.com/charmbracelet/lipgloss v0.7.1 // indirect 34 | github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect 35 | github.com/davecgh/go-spew v1.1.1 // indirect 36 | github.com/felixge/httpsnoop v1.0.3 // indirect 37 | github.com/go-logr/logr v1.2.4 // indirect 38 | github.com/go-logr/stdr v1.2.2 // indirect 39 | github.com/golang/protobuf v1.5.3 // indirect 40 | github.com/google/btree v1.1.2 // indirect 41 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0 // indirect 42 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect 43 | github.com/mattn/go-isatty v0.0.19 // indirect 44 | github.com/mattn/go-runewidth v0.0.15 // indirect 45 | github.com/muesli/reflow v0.3.0 // indirect 46 | github.com/muesli/termenv v0.15.2 // indirect 47 | github.com/pmezard/go-difflib v1.0.0 // indirect 48 | github.com/rivo/uniseg v0.4.4 // indirect 49 | github.com/russross/blackfriday/v2 v2.1.0 // indirect 50 | github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect 51 | go.opentelemetry.io/contrib v1.19.0 // indirect 52 | go.opentelemetry.io/otel/metric v1.18.0 // indirect 53 | go.opentelemetry.io/proto/otlp v1.0.0 // indirect 54 | golang.org/x/crypto v0.17.0 // indirect 55 | golang.org/x/net v0.17.0 // indirect 56 | golang.org/x/sys v0.15.0 // indirect 57 | golang.org/x/term v0.15.0 // indirect 58 | golang.org/x/text v0.14.0 // indirect 59 | golang.org/x/time v0.3.0 // indirect 60 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect 61 | google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d // indirect 62 | google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect 63 | google.golang.org/protobuf v1.33.0 // indirect 64 | gopkg.in/yaml.v3 v3.0.1 // indirect 65 | gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect 66 | ) 67 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | cdr.dev/slog v1.6.2-0.20230901043036-3e17d6de9749 h1:KGTttdvivQTsOJMbkjAHD8QhERoje0egSZn7hLvHdio= 2 | cdr.dev/slog v1.6.2-0.20230901043036-3e17d6de9749/go.mod h1:NaoTA7KwopCrnaSb0JXTC0PTp/O/Y83Lndnq0OEV3ZQ= 3 | cloud.google.com/go/compute v1.23.0 h1:tP41Zoavr8ptEqaW6j+LQOnyBBhO7OkOMAGrgLopTwY= 4 | cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= 5 | cloud.google.com/go/logging v1.8.1 h1:26skQWPeYhvIasWKm48+Eq7oUqdcdbwsCVwz5Ys0FvU= 6 | cloud.google.com/go/longrunning v0.5.1 h1:Fr7TXftcqTudoyRJa113hyaqlGdiBQkp0Gq7tErFDWI= 7 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= 8 | github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= 9 | github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= 10 | github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= 11 | github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= 12 | github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 13 | github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZpc5Kc1E= 14 | github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c= 15 | github.com/coder/wireguard-go v0.0.0-20240502122727-a4cb23ac736d h1:9bX/NUIgbQN2wDDTIIt/Gn60D1Ff7QjH3VhDAH5dgv0= 16 | github.com/coder/wireguard-go v0.0.0-20240502122727-a4cb23ac736d/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= 17 | github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= 18 | github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= 19 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 20 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 21 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 22 | github.com/felixge/httpsnoop v1.0.2/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= 23 | github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= 24 | github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= 25 | github.com/go-chi/chi/v5 v5.0.0/go.mod h1:BBug9lr0cqtdAhsu6R4AAdvufI0/XBzAQSsUqJpoZOs= 26 | github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= 27 | github.com/go-chi/chi/v5 v5.0.10 h1:rLz5avzKpjqxrYwXNfmjkrYYXOyLJd37pz53UFHC6vk= 28 | github.com/go-chi/chi/v5 v5.0.10/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= 29 | github.com/go-chi/hostrouter v0.2.0 h1:GwC7TZz8+SlJN/tV/aeJgx4F+mI5+sp+5H1PelQUjHM= 30 | github.com/go-chi/hostrouter v0.2.0/go.mod h1:pJ49vWVmtsKRKZivQx0YMYv4h0aX+Gcn6V23Np9Wf1s= 31 | github.com/go-chi/httprate v0.7.4 h1:a2GIjv8he9LRf3712zxxnRdckQCm7I8y8yQhkJ84V6M= 32 | github.com/go-chi/httprate v0.7.4/go.mod h1:6GOYBSwnpra4CQfAKXu8sQZg+nZ0M1g9QnyFvxrAB8A= 33 | github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 34 | github.com/go-logr/logr v1.2.1/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 35 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 36 | github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= 37 | github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 38 | github.com/go-logr/stdr v1.2.0/go.mod h1:YkVgnZu1ZjjL7xTxrfm/LLZBfkhTqSR1ydtm6jTKKwI= 39 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 40 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 41 | github.com/golang/glog v1.1.0 h1:/d3pCKDPWNnvIWe0vVUpNP32qc8U3PDVxySP/y360qE= 42 | github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= 43 | github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= 44 | github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= 45 | github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= 46 | github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= 47 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 48 | github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 49 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 50 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0 h1:RtRsiaGvWxcwd8y3BiRZxsylPT8hLWZ5SPcfI+3IDNk= 51 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0/go.mod h1:TzP6duP4Py2pHLVPPQp42aoYI92+PCrVotyR5e8Vqlk= 52 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 53 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 54 | github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= 55 | github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= 56 | github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= 57 | github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 58 | github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= 59 | github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= 60 | github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= 61 | github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= 62 | github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= 63 | github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= 64 | github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= 65 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 66 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 67 | github.com/riandyrn/otelchi v0.5.1 h1:0/45omeqpP7f/cvdL16GddQBfAEmZvUyl2QzLSE6uYo= 68 | github.com/riandyrn/otelchi v0.5.1/go.mod h1:ZxVxNEl+jQ9uHseRYIxKWRb3OY8YXFEu+EkNiiSNUEA= 69 | github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 70 | github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 71 | github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= 72 | github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= 73 | github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= 74 | github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= 75 | github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 76 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 77 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 78 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 79 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 80 | github.com/urfave/cli/v2 v2.25.7 h1:VAzn5oq403l5pHjc4OhD54+XGO9cdKVL/7lDjF+iKUs= 81 | github.com/urfave/cli/v2 v2.25.7/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= 82 | github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= 83 | github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= 84 | go.opentelemetry.io/contrib v1.0.0/go.mod h1:EH4yDYeNoaTqn/8yCWQmfNB78VHfGX2Jt2bvnvzBlGM= 85 | go.opentelemetry.io/contrib v1.19.0 h1:rnYI7OEPMWFeM4QCqWQ3InMJ0arWMR1i0Cx9A5hcjYM= 86 | go.opentelemetry.io/contrib v1.19.0/go.mod h1:gIzjwWFoGazJmtCaDgViqOSJPde2mCWzv60o0bWPcZs= 87 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0 h1:KfYpVmrjI7JuToy5k8XV3nkapjWx48k4E4JOtVstzQI= 88 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0/go.mod h1:SeQhzAEccGVZVEy7aH87Nh0km+utSpo1pTv6eMMop48= 89 | go.opentelemetry.io/otel v1.3.0/go.mod h1:PWIKzi6JCp7sM0k9yZ43VX+T345uNbAkDKwHVjb2PTs= 90 | go.opentelemetry.io/otel v1.18.0 h1:TgVozPGZ01nHyDZxK5WGPFB9QexeTMXEH7+tIClWfzs= 91 | go.opentelemetry.io/otel v1.18.0/go.mod h1:9lWqYO0Db579XzVuCKFNPDl4s73Voa+zEck3wHaAYQI= 92 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0 h1:IAtl+7gua134xcV3NieDhJHjjOVeJhXAnYf/0hswjUY= 93 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.18.0/go.mod h1:w+pXobnBzh95MNIkeIuAKcHe/Uu/CX2PKIvBP6ipKRA= 94 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.18.0 h1:yE32ay7mJG2leczfREEhoW3VfSZIvHaB+gvVo1o8DQ8= 95 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.18.0/go.mod h1:G17FHPDLt74bCI7tJ4CMitEk4BXTYG4FW6XUpkPBXa4= 96 | go.opentelemetry.io/otel/metric v1.18.0 h1:JwVzw94UYmbx3ej++CwLUQZxEODDj/pOuTCvzhtRrSQ= 97 | go.opentelemetry.io/otel/metric v1.18.0/go.mod h1:nNSpsVDjWGfb7chbRLUNW+PBNdcSTHD4Uu5pfFMOI0k= 98 | go.opentelemetry.io/otel/sdk v1.3.0/go.mod h1:rIo4suHNhQwBIPg9axF8V9CA72Wz2mKF1teNrup8yzs= 99 | go.opentelemetry.io/otel/sdk v1.18.0 h1:e3bAB0wB3MljH38sHzpV/qWrOTCFrdZF2ct9F8rBkcY= 100 | go.opentelemetry.io/otel/sdk v1.18.0/go.mod h1:1RCygWV7plY2KmdskZEDDBs4tJeHG92MdHZIluiYs/M= 101 | go.opentelemetry.io/otel/trace v1.3.0/go.mod h1:c/VDhno8888bvQYmbYLqe41/Ldmr/KKunbvWM4/fEjk= 102 | go.opentelemetry.io/otel/trace v1.18.0 h1:NY+czwbHbmndxojTEKiSMHkG2ClNH2PwmcHrdo0JY10= 103 | go.opentelemetry.io/otel/trace v1.18.0/go.mod h1:T2+SGJGuYZY3bjj5rgh/hN7KIrlpWC5nS8Mjvzckz+0= 104 | go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= 105 | go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= 106 | go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= 107 | golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= 108 | golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= 109 | golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= 110 | golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 111 | golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= 112 | golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= 113 | golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= 114 | golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= 115 | golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 116 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 117 | golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= 118 | golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 119 | golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= 120 | golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= 121 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 122 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 123 | golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= 124 | golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= 125 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 126 | golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= 127 | golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= 128 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= 129 | golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= 130 | golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= 131 | golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= 132 | google.golang.org/genproto v0.0.0-20230803162519-f966b187b2e5 h1:L6iMMGrtzgHsWofoFcihmDEMYeDR9KN/ThbPWGrh++g= 133 | google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d h1:DoPTO70H+bcDXcd39vOqb2viZxgqeBeSGtZ55yZU4/Q= 134 | google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d/go.mod h1:KjSP20unUpOx5kyQUFa7k4OJg0qeJ7DEZflGDu2p6Bk= 135 | google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d h1:uvYuEyMHKNt+lT4K3bN6fGswmK8qSvcreM3BwjDh+y4= 136 | google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M= 137 | google.golang.org/grpc v1.58.3 h1:BjnpXut1btbtgN/6sp+brB2Kbm2LjNXnidYujAVbSoQ= 138 | google.golang.org/grpc v1.58.3/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= 139 | google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= 140 | google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= 141 | google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= 142 | google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= 143 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 144 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 145 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 146 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 147 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 148 | gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= 149 | gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= 150 | -------------------------------------------------------------------------------- /scripts/check_unstaged.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | cd "$(dirname "$0")/.." 6 | 7 | FILES="$(git ls-files --other --modified --exclude-standard)" 8 | if [[ "$FILES" != "" ]]; then 9 | mapfile -t files <<<"$FILES" 10 | 11 | echo 12 | echo "The following files contain unstaged changes:" 13 | echo 14 | for file in "${files[@]}"; do 15 | echo " - $file" 16 | done 17 | 18 | echo 19 | echo "These are the changes:" 20 | echo 21 | for file in "${files[@]}"; do 22 | git --no-pager diff "$file" 1>&2 23 | done 24 | 25 | echo 26 | echo "Unstaged changes, see above for details." 27 | exit 1 28 | fi 29 | -------------------------------------------------------------------------------- /scripts/version.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script generates the version string used by wgtunnel, including for dev 4 | # versions. Note: the version returned by this script will NOT include the "v" 5 | # prefix that is included in the Git tag. 6 | # 7 | # If $WGTUNNEL_RELEASE is set to "true", the returned version will equal the 8 | # current git tag. If the current commit is not tagged, this will fail. 9 | # 10 | # If $WGTUNNEL_RELEASE is not set, the returned version will always be a dev 11 | # version even if the current commit is tagged. 12 | 13 | set -euo pipefail 14 | cd "$(dirname "$0")" 15 | 16 | if [[ "${WGTUNNEL_FORCE_VERSION:-}" != "" ]]; then 17 | echo "$WGTUNNEL_FORCE_VERSION" 18 | exit 0 19 | fi 20 | 21 | last_tag="$(git describe --tags --abbrev=0)" 22 | version="$last_tag" 23 | 24 | # If the HEAD has extra commits since the last tag then we are in a dev version. 25 | # 26 | # Dev versions are denoted by the "-devel+" suffix with a trailing commit short 27 | # SHA. 28 | if [[ "${WGTUNNEL_RELEASE:-}" == *t* ]]; then 29 | # $last_tag will equal `git describe --always` if we currently have the tag 30 | # checked out. 31 | if [[ "$last_tag" != "$(git describe --always)" ]]; then 32 | # make won't exit on $(shell cmd) failures, so we have to kill it :( 33 | if [[ "$(ps -o comm= "$PPID" || true)" == *make* ]]; then 34 | log "ERROR: version.sh: the current commit is not tagged with an annotated tag" 35 | kill "$PPID" || true 36 | exit 1 37 | fi 38 | 39 | error "version.sh: the current commit is not tagged with an annotated tag" 40 | fi 41 | else 42 | version+="-devel+$(git rev-parse --short HEAD)" 43 | fi 44 | 45 | # Remove the "v" prefix. 46 | echo "${version#v}" 47 | -------------------------------------------------------------------------------- /tunneld/api.go: -------------------------------------------------------------------------------- 1 | package tunneld 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "net/http/httputil" 10 | "net/netip" 11 | "net/url" 12 | "strings" 13 | "time" 14 | 15 | "github.com/go-chi/chi/v5" 16 | "github.com/go-chi/hostrouter" 17 | "github.com/riandyrn/otelchi" 18 | "github.com/tailscale/wireguard-go/device" 19 | "go.opentelemetry.io/otel/attribute" 20 | "go.opentelemetry.io/otel/trace" 21 | "golang.org/x/xerrors" 22 | 23 | "github.com/coder/wgtunnel/tunneld/httpapi" 24 | "github.com/coder/wgtunnel/tunneld/httpmw" 25 | "github.com/coder/wgtunnel/tunnelsdk" 26 | ) 27 | 28 | func (api *API) Router() http.Handler { 29 | var ( 30 | hr = hostrouter.New() 31 | apiRouter = chi.NewRouter() 32 | proxyRouter = chi.NewRouter() 33 | unknownRouter = chi.NewRouter() 34 | ) 35 | 36 | hr.Map(api.BaseURL.Host, apiRouter) 37 | hr.Map("*."+api.BaseURL.Host, proxyRouter) 38 | hr.Map("*", unknownRouter) 39 | 40 | proxyRouter.Use( 41 | otelchi.Middleware("proxy"), 42 | httpmw.LimitBody(50*1<<20), // 50MB 43 | ) 44 | proxyRouter.Mount("/", http.HandlerFunc(api.handleTunnel)) 45 | 46 | apiRouter.Use( 47 | otelchi.Middleware("api", otelchi.WithChiRoutes(apiRouter)), 48 | httpmw.LimitBody(1<<20), // 1MB 49 | httpmw.RateLimit(httpmw.RateLimitConfig{ 50 | Log: api.Log.Named("ratelimier"), 51 | Count: 10, 52 | Window: 10 * time.Second, 53 | RealIPHeader: api.Options.RealIPHeader, 54 | }), 55 | ) 56 | 57 | apiRouter.Get("/", func(w http.ResponseWriter, r *http.Request) { 58 | w.WriteHeader(http.StatusOK) 59 | _, _ = w.Write([]byte("https://coder.com")) 60 | }) 61 | apiRouter.Post("/tun", api.postTun) 62 | apiRouter.Post("/api/v2/clients", api.postClients) 63 | 64 | notFound := func(rw http.ResponseWriter, r *http.Request) { 65 | httpapi.Write(r.Context(), rw, http.StatusNotFound, tunnelsdk.Response{ 66 | Message: "Not found.", 67 | }) 68 | } 69 | apiRouter.NotFound(notFound) 70 | unknownRouter.NotFound(notFound) 71 | 72 | return hr 73 | } 74 | 75 | type LegacyPostTunRequest struct { 76 | PublicKey device.NoisePublicKey `json:"public_key"` 77 | } 78 | 79 | type LegacyPostTunResponse struct { 80 | Hostname string `json:"hostname"` 81 | ServerEndpoint string `json:"server_endpoint"` 82 | ServerIP netip.Addr `json:"server_ip"` 83 | ServerPublicKey string `json:"server_public_key"` // hex 84 | ClientIP netip.Addr `json:"client_ip"` 85 | } 86 | 87 | // postTun provides compatibility with the old tunnel client contained in older 88 | // versions of coder/coder. It essentially converts the old request format to a 89 | // newer request, and the newer response to the old response format. 90 | func (api *API) postTun(rw http.ResponseWriter, r *http.Request) { 91 | ctx := r.Context() 92 | 93 | var req LegacyPostTunRequest 94 | if !httpapi.Read(ctx, rw, r, &req) { 95 | return 96 | } 97 | 98 | registerReq := tunnelsdk.ClientRegisterRequest{ 99 | Version: tunnelsdk.TunnelVersion1, 100 | PublicKey: req.PublicKey, 101 | } 102 | 103 | resp, exists, err := api.registerClient(registerReq) 104 | if err != nil { 105 | httpapi.Write(ctx, rw, http.StatusInternalServerError, tunnelsdk.Response{ 106 | Message: "Failed to register client.", 107 | Detail: err.Error(), 108 | }) 109 | return 110 | } 111 | 112 | if len(resp.TunnelURLs) == 0 { 113 | httpapi.Write(ctx, rw, http.StatusInternalServerError, tunnelsdk.Response{ 114 | Message: "No tunnel URLs found.", 115 | }) 116 | return 117 | } 118 | 119 | u, err := url.Parse(resp.TunnelURLs[0]) 120 | if err != nil { 121 | httpapi.Write(ctx, rw, http.StatusInternalServerError, tunnelsdk.Response{ 122 | Message: "Failed to parse tunnel URL.", 123 | Detail: err.Error(), 124 | }) 125 | return 126 | } 127 | 128 | status := http.StatusCreated 129 | if exists { 130 | status = http.StatusOK 131 | } 132 | httpapi.Write(ctx, rw, status, LegacyPostTunResponse{ 133 | Hostname: u.Host, 134 | ServerEndpoint: resp.ServerEndpoint, 135 | ServerIP: resp.ServerIP, 136 | ServerPublicKey: hex.EncodeToString(resp.ServerPublicKey[:]), 137 | ClientIP: resp.ClientIP, 138 | }) 139 | } 140 | 141 | func (api *API) postClients(rw http.ResponseWriter, r *http.Request) { 142 | ctx := r.Context() 143 | 144 | var req tunnelsdk.ClientRegisterRequest 145 | if !httpapi.Read(r.Context(), rw, r, &req) { 146 | return 147 | } 148 | 149 | resp, _, err := api.registerClient(req) 150 | if err != nil { 151 | httpapi.Write(ctx, rw, http.StatusInternalServerError, tunnelsdk.Response{ 152 | Message: "Failed to register client.", 153 | Detail: err.Error(), 154 | }) 155 | return 156 | } 157 | 158 | httpapi.Write(ctx, rw, http.StatusOK, resp) 159 | } 160 | 161 | func (api *API) registerClient(req tunnelsdk.ClientRegisterRequest) (tunnelsdk.ClientRegisterResponse, bool, error) { 162 | if req.Version <= 0 || req.Version > tunnelsdk.TunnelVersionLatest { 163 | req.Version = tunnelsdk.TunnelVersionLatest 164 | } 165 | 166 | ip, urls := api.WireguardPublicKeyToIPAndURLs(req.PublicKey, req.Version) 167 | 168 | api.pkeyCacheMu.Lock() 169 | api.pkeyCache[ip] = cachedPeer{ 170 | key: req.PublicKey, 171 | lastHandshake: time.Now(), 172 | } 173 | api.pkeyCacheMu.Unlock() 174 | 175 | exists := true 176 | if api.wgDevice.LookupPeer(req.PublicKey) == nil { 177 | exists = false 178 | 179 | api.pkeyCacheMu.Lock() 180 | api.pkeyCache[ip] = cachedPeer{ 181 | key: req.PublicKey, 182 | lastHandshake: time.Now(), 183 | } 184 | api.pkeyCacheMu.Unlock() 185 | 186 | err := api.wgDevice.IpcSet(fmt.Sprintf(`public_key=%x 187 | allowed_ip=%s/128`, 188 | req.PublicKey, 189 | ip.String(), 190 | )) 191 | if err != nil { 192 | return tunnelsdk.ClientRegisterResponse{}, false, xerrors.Errorf("register client with wireguard: %w", err) 193 | } 194 | } 195 | 196 | urlsStr := make([]string, len(urls)) 197 | for i, u := range urls { 198 | urlsStr[i] = u.String() 199 | } 200 | 201 | return tunnelsdk.ClientRegisterResponse{ 202 | Version: req.Version, 203 | ReregisterWait: api.PeerRegisterInterval, 204 | TunnelURLs: urlsStr, 205 | ClientIP: ip, 206 | ServerEndpoint: api.WireguardEndpoint, 207 | ServerIP: api.WireguardServerIP, 208 | ServerPublicKey: api.WireguardKey.NoisePublicKey(), 209 | WireguardMTU: api.WireguardMTU, 210 | }, exists, nil 211 | } 212 | 213 | type ipPortKey struct{} 214 | 215 | func (api *API) handleTunnel(rw http.ResponseWriter, r *http.Request) { 216 | ctx := r.Context() 217 | 218 | host := r.Host 219 | subdomain, _ := splitHostname(host) 220 | subdomainParts := strings.Split(subdomain, "-") 221 | user := subdomainParts[len(subdomainParts)-1] 222 | 223 | span := trace.SpanFromContext(ctx) 224 | span.SetAttributes( 225 | attribute.Bool("proxy_request", true), 226 | attribute.String("user", user), 227 | ) 228 | 229 | ip, err := api.HostnameToWireguardIP(user) 230 | if err != nil { 231 | httpapi.Write(ctx, rw, http.StatusBadRequest, tunnelsdk.Response{ 232 | Message: "Invalid tunnel URL.", 233 | Detail: err.Error(), 234 | }) 235 | return 236 | } 237 | 238 | api.pkeyCacheMu.RLock() 239 | pkey, ok := api.pkeyCache[ip] 240 | api.pkeyCacheMu.RUnlock() 241 | 242 | if !ok || time.Since(pkey.lastHandshake) > api.PeerTimeout { 243 | httpapi.Write(ctx, rw, http.StatusBadGateway, tunnelsdk.Response{ 244 | Message: "Peer is not connected.", 245 | Detail: "", 246 | }) 247 | return 248 | } 249 | 250 | // The transport on the reverse proxy uses this ctx value to know which 251 | // IP to dial. See tunneld.go. 252 | ctx = context.WithValue(ctx, ipPortKey{}, netip.AddrPortFrom(ip, tunnelsdk.TunnelPort)) 253 | r = r.WithContext(ctx) 254 | 255 | rp := httputil.ReverseProxy{ 256 | // This can only happen when it fails to dial. 257 | ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { 258 | httpapi.Write(ctx, rw, http.StatusBadGateway, tunnelsdk.Response{ 259 | Message: "Failed to dial peer.", 260 | Detail: err.Error(), 261 | }) 262 | }, 263 | Director: func(rp *http.Request) { 264 | rp.URL.Scheme = "http" 265 | rp.URL.Host = r.Host 266 | rp.Host = r.Host 267 | }, 268 | Transport: api.transport, 269 | } 270 | 271 | rp.ServeHTTP(rw, r) 272 | } 273 | 274 | // splitHostname splits a hostname into the subdomain and the rest of the 275 | // string, stripping any port data and leading/trailing periods. 276 | func splitHostname(hostname string) (subdomain string, rest string) { 277 | hostname = strings.Trim(hostname, ".") 278 | hostnameHost, _, err := net.SplitHostPort(hostname) 279 | if err == nil { 280 | hostname = hostnameHost 281 | } 282 | 283 | parts := strings.SplitN(hostname, ".", 2) 284 | if len(parts) != 2 { 285 | return hostname, "" 286 | } 287 | 288 | return parts[0], parts[1] 289 | } 290 | -------------------------------------------------------------------------------- /tunneld/api_test.go: -------------------------------------------------------------------------------- 1 | package tunneld_test 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | "encoding/json" 7 | "io" 8 | "net/http" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | 15 | "github.com/coder/wgtunnel/tunneld" 16 | "github.com/coder/wgtunnel/tunnelsdk" 17 | ) 18 | 19 | // Test for the compatibility endpoint which allows old tunnels to connect to 20 | // the new server. 21 | func Test_postTun(t *testing.T) { 22 | t.Parallel() 23 | 24 | td, client := createTestTunneld(t, nil) 25 | 26 | key, err := tunnelsdk.GeneratePrivateKey() 27 | require.NoError(t, err) 28 | 29 | expectedIP, expectedURLs := td.WireguardPublicKeyToIPAndURLs(key.NoisePublicKey(), tunnelsdk.TunnelVersion1) 30 | require.Len(t, expectedURLs, 2) 31 | require.Len(t, strings.Split(expectedURLs[0].Host, ".")[0], 32) 32 | expectedHostname := expectedURLs[0].Host 33 | 34 | // First request should return a 201. 35 | resp, err := client.Request(context.Background(), http.MethodPost, "/tun", tunneld.LegacyPostTunRequest{ 36 | PublicKey: key.NoisePublicKey(), 37 | }) 38 | require.NoError(t, err) 39 | defer resp.Body.Close() 40 | require.Equal(t, http.StatusCreated, resp.StatusCode) 41 | 42 | var legacyRes tunneld.LegacyPostTunResponse 43 | require.NoError(t, json.NewDecoder(resp.Body).Decode(&legacyRes)) 44 | require.Equal(t, expectedIP, legacyRes.ClientIP) 45 | require.Equal(t, expectedHostname, legacyRes.Hostname) 46 | 47 | // Register on the new endpoint so we can compare the values to the legacy 48 | // endpoint. 49 | newRes, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{ 50 | Version: tunnelsdk.TunnelVersion1, 51 | PublicKey: key.NoisePublicKey(), 52 | }) 53 | require.NoError(t, err) 54 | require.Equal(t, tunnelsdk.TunnelVersion1, newRes.Version) 55 | 56 | require.Equal(t, legacyRes.ServerEndpoint, newRes.ServerEndpoint) 57 | require.Equal(t, legacyRes.ServerIP, newRes.ServerIP) 58 | require.Equal(t, legacyRes.ServerPublicKey, hex.EncodeToString(newRes.ServerPublicKey[:])) 59 | require.Equal(t, legacyRes.ClientIP, newRes.ClientIP) 60 | 61 | // Second request should return a 200. 62 | resp, err = client.Request(context.Background(), http.MethodPost, "/tun", tunneld.LegacyPostTunRequest{ 63 | PublicKey: key.NoisePublicKey(), 64 | }) 65 | require.NoError(t, err) 66 | defer resp.Body.Close() 67 | require.Equal(t, http.StatusOK, resp.StatusCode) 68 | 69 | var legacyRes2 tunneld.LegacyPostTunResponse 70 | require.NoError(t, json.NewDecoder(resp.Body).Decode(&legacyRes2)) 71 | require.Equal(t, legacyRes, legacyRes2) 72 | } 73 | 74 | func Test_postClients(t *testing.T) { 75 | t.Parallel() 76 | 77 | td, client := createTestTunneld(t, nil) 78 | 79 | key, err := tunnelsdk.GeneratePrivateKey() 80 | require.NoError(t, err) 81 | 82 | expectedIP, expectedURLs := td.WireguardPublicKeyToIPAndURLs(key.NoisePublicKey(), tunnelsdk.TunnelVersion2) 83 | 84 | expectedURLsStr := make([]string, len(expectedURLs)) 85 | for i, u := range expectedURLs { 86 | expectedURLsStr[i] = u.String() 87 | } 88 | 89 | // Register a client. 90 | res, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{ 91 | // No version should default to 2. 92 | PublicKey: key.NoisePublicKey(), 93 | }) 94 | require.NoError(t, err) 95 | 96 | require.Equal(t, tunnelsdk.TunnelVersion2, res.Version) 97 | require.Equal(t, expectedURLsStr, res.TunnelURLs) 98 | require.Equal(t, expectedIP, res.ClientIP) 99 | require.Equal(t, td.WireguardEndpoint, res.ServerEndpoint) 100 | require.Equal(t, td.WireguardServerIP, res.ServerIP) 101 | require.Equal(t, td.WireguardKey.NoisePublicKey(), res.ServerPublicKey) 102 | require.Equal(t, td.WireguardMTU, res.WireguardMTU) 103 | require.Equal(t, td.PeerRegisterInterval, res.ReregisterWait) 104 | 105 | // Register the same client again. 106 | res2, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{ 107 | Version: tunnelsdk.TunnelVersion2, 108 | PublicKey: key.NoisePublicKey(), 109 | }) 110 | require.NoError(t, err) 111 | require.Equal(t, res, res2) 112 | 113 | // Register the same client with the old version. 114 | res3, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{ 115 | Version: tunnelsdk.TunnelVersion1, 116 | PublicKey: key.NoisePublicKey(), 117 | }) 118 | require.NoError(t, err) 119 | 120 | // Should be equal after reversing the URL list. 121 | require.Equal(t, tunnelsdk.TunnelVersion1, res3.Version) 122 | res3.TunnelURLs[0], res3.TunnelURLs[1] = res3.TunnelURLs[1], res3.TunnelURLs[0] 123 | res3.Version = tunnelsdk.TunnelVersion2 124 | require.Equal(t, res, res3) 125 | } 126 | 127 | func Test_getRoot(t *testing.T) { 128 | t.Parallel() 129 | 130 | _, client := createTestTunneld(t, nil) 131 | 132 | res, err := client.Request(context.Background(), http.MethodGet, "/", nil) 133 | require.NoError(t, err) 134 | defer res.Body.Close() 135 | 136 | out, err := io.ReadAll(res.Body) 137 | require.NoError(t, err) 138 | assert.Equal(t, "https://coder.com", string(out)) 139 | } 140 | -------------------------------------------------------------------------------- /tunneld/httpapi/httpapi.go: -------------------------------------------------------------------------------- 1 | package httpapi 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "net/http" 8 | 9 | "github.com/coder/wgtunnel/tunnelsdk" 10 | ) 11 | 12 | // Read decodes JSON from the HTTP request into the value provided. 13 | func Read(ctx context.Context, rw http.ResponseWriter, r *http.Request, value interface{}) bool { 14 | // TODO: tracing 15 | err := json.NewDecoder(r.Body).Decode(value) 16 | if err != nil { 17 | Write(ctx, rw, http.StatusBadRequest, tunnelsdk.Response{ 18 | Message: "Request body must be valid JSON.", 19 | Detail: err.Error(), 20 | }) 21 | return false 22 | } 23 | 24 | return true 25 | } 26 | 27 | // Write outputs the given value as JSON to the response. 28 | func Write(_ context.Context, rw http.ResponseWriter, status int, response interface{}) { 29 | // TODO: tracing 30 | buf := &bytes.Buffer{} 31 | enc := json.NewEncoder(buf) 32 | enc.SetEscapeHTML(true) 33 | 34 | err := enc.Encode(response) 35 | if err != nil { 36 | http.Error(rw, err.Error(), http.StatusInternalServerError) 37 | return 38 | } 39 | 40 | rw.Header().Set("Content-Type", "application/json; charset=utf-8") 41 | rw.WriteHeader(status) 42 | 43 | _, err = rw.Write(buf.Bytes()) 44 | if err != nil { 45 | http.Error(rw, err.Error(), http.StatusInternalServerError) 46 | return 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /tunneld/httpmw/limitbody.go: -------------------------------------------------------------------------------- 1 | package httpmw 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | 7 | "golang.org/x/xerrors" 8 | ) 9 | 10 | var ErrLimitReached = xerrors.Errorf("i/o limit reached") 11 | 12 | // LimitReader is like io.LimitReader except that it returns ErrLimitReached 13 | // when the limit has been reached. 14 | type LimitReader struct { 15 | Limit int64 16 | N int64 17 | R io.Reader 18 | } 19 | 20 | func (l *LimitReader) Reset(n int64) { 21 | l.N = 0 22 | l.Limit = n 23 | } 24 | 25 | func (l *LimitReader) Read(p []byte) (int, error) { 26 | if l.N >= l.Limit { 27 | return 0, ErrLimitReached 28 | } 29 | 30 | if int64(len(p)) > l.Limit-l.N { 31 | p = p[:l.Limit-l.N] 32 | } 33 | 34 | n, err := l.R.Read(p) 35 | l.N += int64(n) 36 | return n, err 37 | } 38 | 39 | type LimitedBody struct { 40 | R *LimitReader 41 | original io.ReadCloser 42 | } 43 | 44 | func (r LimitedBody) Read(p []byte) (n int, err error) { 45 | return r.R.Read(p) 46 | } 47 | 48 | func (r LimitedBody) Close() error { 49 | return r.original.Close() 50 | } 51 | 52 | func SetBodyLimit(r *http.Request, n int64) { 53 | if body, ok := r.Body.(LimitedBody); ok { 54 | body.R.Reset(n) 55 | } else { 56 | r.Body = LimitedBody{ 57 | R: &LimitReader{R: r.Body, Limit: n}, 58 | original: r.Body, 59 | } 60 | } 61 | } 62 | 63 | func LimitBody(n int64) func(h http.Handler) http.Handler { 64 | return func(next http.Handler) http.Handler { 65 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 66 | SetBodyLimit(r, n) 67 | next.ServeHTTP(w, r) 68 | }) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /tunneld/httpmw/limitbody_test.go: -------------------------------------------------------------------------------- 1 | package httpmw_test 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | 12 | "github.com/coder/wgtunnel/tunneld/httpmw" 13 | ) 14 | 15 | func TestLimitBody(t *testing.T) { 16 | t.Parallel() 17 | 18 | tests := []struct { 19 | Name string 20 | Limit int64 21 | Size int 22 | LimitReached bool 23 | }{ 24 | { 25 | Name: "under", 26 | Limit: 1024, 27 | Size: 512, 28 | LimitReached: false, 29 | }, 30 | { 31 | Name: "under-by-one", 32 | Limit: 1024, 33 | Size: 1023, 34 | LimitReached: false, 35 | }, 36 | { 37 | Name: "exact", 38 | Limit: 1024, 39 | Size: 1024, 40 | LimitReached: true, 41 | }, 42 | { 43 | Name: "over", 44 | Limit: 1024, 45 | Size: 1025, 46 | LimitReached: true, 47 | }, 48 | { 49 | Name: "default-under", 50 | Limit: 1 << 20, 51 | Size: 1 << 19, 52 | LimitReached: false, 53 | }, 54 | { 55 | Name: "default-over", 56 | Limit: 1 << 20, 57 | Size: 1<<20 + 1, 58 | LimitReached: true, 59 | }, 60 | } 61 | 62 | for _, test := range tests { 63 | test := test 64 | t.Run(test.Name, func(t *testing.T) { 65 | t.Parallel() 66 | 67 | var buf bytes.Buffer 68 | buf.Grow(test.Size) 69 | for i := 0; i < test.Size; i++ { 70 | err := buf.WriteByte('1') 71 | require.NoError(t, err, "expected to write byte to buffer successfully") 72 | } 73 | 74 | req := httptest.NewRequest("POST", "/", &buf) 75 | middleware := httpmw.LimitBody(test.Limit) 76 | 77 | handlerCalled := false 78 | 79 | nextHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 80 | // Read as much as we can from the body, discarding output 81 | written, err := io.Copy(io.Discard, req.Body) 82 | if test.LimitReached { 83 | require.ErrorIs(t, err, httpmw.ErrLimitReached, "expected stream to return ErrLimitReached") 84 | require.EqualValues(t, test.Limit, written, "expect that the amount of data copied matches the limit") 85 | } else { 86 | require.NoError(t, err, "no error should occur") 87 | require.EqualValues(t, test.Size, written, "expect that the amount of data copied matches the input size") 88 | } 89 | 90 | handlerCalled = true 91 | }) 92 | 93 | middleware(nextHandler).ServeHTTP(httptest.NewRecorder(), req) 94 | require.True(t, handlerCalled, "expected handler to be invoked") 95 | }) 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /tunneld/httpmw/ratelimit.go: -------------------------------------------------------------------------------- 1 | package httpmw 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "net/http" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | "github.com/go-chi/httprate" 12 | 13 | "cdr.dev/slog" 14 | "github.com/coder/wgtunnel/tunneld/httpapi" 15 | "github.com/coder/wgtunnel/tunnelsdk" 16 | ) 17 | 18 | type RateLimitConfig struct { 19 | Log slog.Logger 20 | 21 | // Count of the amount of requests allowed in the Window. If the Count is 22 | // zero, the rate limiter is disabled. 23 | Count int 24 | Window time.Duration 25 | 26 | // RealIPHeader is the header to use to get the real IP address of the 27 | // request. If this is empty, the request's RemoteAddr is used. 28 | RealIPHeader string 29 | } 30 | 31 | // RateLimit returns a handler that limits requests based on IP. 32 | func RateLimit(cfg RateLimitConfig) func(http.Handler) http.Handler { 33 | if cfg.Count <= 0 { 34 | return func(handler http.Handler) http.Handler { 35 | return handler 36 | } 37 | } 38 | 39 | var logMissingHeaderOnce sync.Once 40 | 41 | return httprate.Limit( 42 | cfg.Count, 43 | cfg.Window, 44 | httprate.WithKeyFuncs(func(r *http.Request) (string, error) { 45 | if cfg.RealIPHeader != "" { 46 | val := r.Header.Get(cfg.RealIPHeader) 47 | if val != "" { 48 | val = strings.TrimSpace(strings.Split(val, ",")[0]) 49 | return canonicalizeIP(val), nil 50 | } 51 | 52 | logMissingHeaderOnce.Do(func() { 53 | cfg.Log.Warn(r.Context(), "real IP header not found or invalid on request", slog.F("header", cfg.RealIPHeader), slog.F("value", val)) 54 | }) 55 | } 56 | 57 | return httprate.KeyByIP(r) 58 | }), 59 | httprate.WithLimitHandler(func(rw http.ResponseWriter, r *http.Request) { 60 | httpapi.Write(r.Context(), rw, http.StatusTooManyRequests, tunnelsdk.Response{ 61 | Message: fmt.Sprintf("You've been rate limited for sending more than %v requests in %v.", cfg.Count, cfg.Window), 62 | }) 63 | }), 64 | ) 65 | } 66 | 67 | // canonicalizeIP returns a form of ip suitable for comparison to other IPs. 68 | // For IPv4 addresses, this is simply the whole string. 69 | // For IPv6 addresses, this is the /64 prefix. 70 | // 71 | // This function is taken directly from go-chi/httprate: 72 | // https://github.com/go-chi/httprate/blob/0ea2148d09a46ae62efcad05b70d87418d8e4f43/httprate.go#L111 73 | func canonicalizeIP(ip string) string { 74 | isIPv6 := false 75 | // This is how net.ParseIP decides if an address is IPv6 76 | // https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/ip.go;l=704 77 | for i := 0; !isIPv6 && i < len(ip); i++ { 78 | switch ip[i] { 79 | case '.': 80 | // IPv4 81 | return ip 82 | case ':': 83 | // IPv6 84 | isIPv6 = true 85 | } 86 | } 87 | if !isIPv6 { 88 | // Not an IP address at all 89 | return ip 90 | } 91 | 92 | ipv6 := net.ParseIP(ip) 93 | if ipv6 == nil { 94 | return ip 95 | } 96 | 97 | return ipv6.Mask(net.CIDRMask(64, 128)).String() 98 | } 99 | -------------------------------------------------------------------------------- /tunneld/options.go: -------------------------------------------------------------------------------- 1 | package tunneld 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/base32" 6 | "encoding/hex" 7 | "net" 8 | "net/http" 9 | "net/netip" 10 | "net/url" 11 | "strings" 12 | "time" 13 | 14 | "github.com/tailscale/wireguard-go/device" 15 | "golang.org/x/xerrors" 16 | 17 | "cdr.dev/slog" 18 | "github.com/coder/wgtunnel/tunnelsdk" 19 | ) 20 | 21 | const ( 22 | DefaultWireguardMTU = 1280 23 | DefaultPeerDialTimeout = 10 * time.Second 24 | DefaultPeerPollDuration = 30 * time.Second 25 | DefaultPeerTimeout = 2 * time.Minute 26 | ) 27 | 28 | var ( 29 | DefaultWireguardServerIP = netip.MustParseAddr("fcca::1") 30 | DefaultWireguardNetworkPrefix = netip.MustParsePrefix("fcca::/16") 31 | ) 32 | 33 | var newHostnameEncoder = base32.HexEncoding.WithPadding(base32.NoPadding) 34 | 35 | type Options struct { 36 | Log slog.Logger 37 | 38 | // BaseURL is the base URL to use for the tunnel, including scheme. All 39 | // tunnels will be subdomains of this hostname. 40 | // e.g. "https://tunnel.example.com" will place tunnels at 41 | // "https://xyz.tunnel.example.com" 42 | BaseURL *url.URL 43 | 44 | // WireguardEndpoint is the UDP address advertised to clients that they will 45 | // connect to for wireguard connections. It should be in the form 46 | // "$ip:$port" or "$hostname:$port". 47 | WireguardEndpoint string 48 | // WireguardPort is the UDP port that the wireguard server will listen on. 49 | // It should be the same as the port in WireguardEndpoint. 50 | WireguardPort uint16 51 | // WireguardKey is the private key for the wireguard server. 52 | WireguardKey tunnelsdk.Key 53 | 54 | // WireguardMTU is the MTU to use for the wireguard interface. Defaults to 55 | // 1280. 56 | WireguardMTU int 57 | // WireguardServerIP is the virtual IP address of this server in the 58 | // wireguard network. Must be an IPv6 address contained within 59 | // WireguardNetworkPrefix. Defaults to fcca::1. 60 | WireguardServerIP netip.Addr 61 | // WireguardNetworkPrefix is the CIDR of the wireguard network. All client 62 | // IPs will be generated within this network. Must be a IPv6 CIDR and have 63 | // at least 64 bits of space available. Defaults to fcca::/16. 64 | WireguardNetworkPrefix netip.Prefix 65 | 66 | // RealIPHeader is the header to use for getting a request's IP address. If 67 | // not set, the request's RemoteAddr will be used. 68 | // 69 | // Used for rate limiting. 70 | RealIPHeader string 71 | 72 | // PeerDialTimeout is the timeout for dialing a peer on a request. Defaults 73 | // to 10 seconds. 74 | PeerDialTimeout time.Duration 75 | 76 | // PeerRegisterInterval is how often the clients should re-register. 77 | PeerRegisterInterval time.Duration 78 | 79 | // PeerTimeout is how long the server will wait before removing the peer. 80 | PeerTimeout time.Duration 81 | } 82 | 83 | // Validate checks that the options are valid and populates default values for 84 | // missing fields. 85 | func (options *Options) Validate() error { 86 | if options == nil { 87 | return xerrors.New("options is nil") 88 | } 89 | if options.BaseURL == nil { 90 | return xerrors.New("BaseURL is required") 91 | } 92 | if options.WireguardEndpoint == "" { 93 | return xerrors.New("WireguardEndpoint is required") 94 | } 95 | _, _, err := net.SplitHostPort(options.WireguardEndpoint) 96 | if err != nil { 97 | return xerrors.Errorf("WireguardEndpoint %q is not a valid host:port combination: %w", options.WireguardEndpoint, err) 98 | } 99 | if options.WireguardPort == 0 { 100 | return xerrors.New("WireguardPort is required") 101 | } 102 | if options.WireguardKey.IsZero() { 103 | return xerrors.New("WireguardKey is required") 104 | } 105 | if !options.WireguardKey.IsPrivate() { 106 | return xerrors.New("WireguardKey must be a private key") 107 | } 108 | // Key is parsed and validated when the server is started. 109 | if options.WireguardMTU <= 0 { 110 | options.WireguardMTU = DefaultWireguardMTU 111 | } 112 | if options.WireguardServerIP.BitLen() == 0 { 113 | options.WireguardServerIP = DefaultWireguardServerIP 114 | } 115 | if options.WireguardServerIP.BitLen() != 128 { 116 | return xerrors.New("WireguardServerIP must be an IPv6 address") 117 | } 118 | if options.WireguardNetworkPrefix.Bits() <= 0 { 119 | options.WireguardNetworkPrefix = DefaultWireguardNetworkPrefix 120 | } 121 | if options.WireguardNetworkPrefix.Bits() > 64 { 122 | return xerrors.New("WireguardNetworkPrefix must have at least 64 bits available") 123 | } 124 | if options.WireguardNetworkPrefix.Bits()%8 != 0 { 125 | return xerrors.New("WireguardNetworkPrefix must be a multiple of 8 bits") 126 | } 127 | if !options.WireguardNetworkPrefix.Contains(options.WireguardServerIP) { 128 | return xerrors.New("WireguardServerIP must be contained within WireguardNetworkPrefix") 129 | } 130 | 131 | if options.RealIPHeader != "" { 132 | options.RealIPHeader = http.CanonicalHeaderKey(options.RealIPHeader) 133 | } 134 | 135 | if options.PeerDialTimeout <= 0 { 136 | options.PeerDialTimeout = DefaultPeerDialTimeout 137 | } 138 | if options.PeerRegisterInterval <= 0 { 139 | options.PeerRegisterInterval = DefaultPeerPollDuration 140 | } 141 | if options.PeerTimeout <= 0 { 142 | options.PeerTimeout = DefaultPeerTimeout 143 | } 144 | if options.PeerRegisterInterval >= options.PeerTimeout { 145 | return xerrors.Errorf("PeerRegisterInterval(%s) must be less than PeerTimeout(%s)", 146 | options.PeerRegisterInterval.String(), 147 | options.PeerTimeout.String(), 148 | ) 149 | } 150 | 151 | return nil 152 | } 153 | 154 | // WireguardPublicKeyToIPAndURLs returns the IP address that corresponds to the 155 | // given wireguard public key, as well as all accepted tunnel URLs for the key. 156 | // 157 | // We support an older 32 character format ("old format") and a newer 12 158 | // character format ("good format") which is preferred. The first URL returned 159 | // should be considered "preferred", and all other URLs are provided for 160 | // compatibility with older deployments only. The "good format" is preferred as 161 | // it's shorter to avoid issues with hostname length limits when apps prefixes 162 | // are added to the equation. 163 | // 164 | // "good format": 165 | // 166 | // Take the first 8 bytes of the hash of the public key, and convert to 167 | // base32. 168 | // 169 | // "old format": 170 | // 171 | // Take the network prefix, and create a new address filling the last n bytes 172 | // with the first n bytes of the hash of the public key. Then convert to hex. 173 | func (options *Options) WireguardPublicKeyToIPAndURLs(publicKey device.NoisePublicKey, version tunnelsdk.TunnelVersion) (netip.Addr, []*url.URL) { 174 | var ( 175 | keyHash = sha256.Sum256(publicKey[:]) 176 | addrBytes = options.WireguardNetworkPrefix.Addr().As16() 177 | ) 178 | 179 | // IPv6 address: 180 | // For the IP address, we take the first 64 bits of the network prefix and 181 | // the first 64 bits of the hash of the public key. 182 | copy(addrBytes[8:], keyHash[:8]) 183 | 184 | // Good format: 185 | goodFormatBytes := make([]byte, 8) 186 | copy(goodFormatBytes, keyHash[:8]) 187 | goodFormat := newHostnameEncoder.EncodeToString(goodFormatBytes) 188 | goodFormatURL := *options.BaseURL 189 | goodFormatURL.Host = strings.ToLower(goodFormat) + "." + goodFormatURL.Host 190 | 191 | // Old format: 192 | oldFormatBytes := make([]byte, 16) 193 | copy(oldFormatBytes, addrBytes[:]) 194 | prefixLenBytes := options.WireguardNetworkPrefix.Bits() / 8 195 | copy(oldFormatBytes[prefixLenBytes:], keyHash[:16-prefixLenBytes]) 196 | oldFormat := hex.EncodeToString(oldFormatBytes) 197 | oldFormatURL := *options.BaseURL 198 | oldFormatURL.Host = strings.ToLower(oldFormat) + "." + oldFormatURL.Host 199 | 200 | urls := []*url.URL{&goodFormatURL, &oldFormatURL} 201 | if version == tunnelsdk.TunnelVersion1 { 202 | // Return the old format first for backwards compatibility. 203 | urls = []*url.URL{&oldFormatURL, &goodFormatURL} 204 | } 205 | 206 | return netip.AddrFrom16(addrBytes), urls 207 | } 208 | 209 | // HostnameToWireguardIP returns the wireguard IP address that corresponds to a 210 | // given encoded hostname label as returned by WireguardPublicKeyToIPAndURLs. 211 | func (options *Options) HostnameToWireguardIP(hostname string) (netip.Addr, error) { 212 | var addrLast8Bytes []byte 213 | 214 | if len(hostname) == 32 { 215 | // "Old format": 216 | decoded, err := hex.DecodeString(hostname) 217 | if err != nil { 218 | return netip.Addr{}, xerrors.Errorf("decode old hostname %q as hex: %w", hostname, err) 219 | } 220 | if len(decoded) != 16 { 221 | return netip.Addr{}, xerrors.Errorf("invalid old hostname length: got %d, expected 16", len(decoded)) 222 | } 223 | 224 | // Even though the hostname will have the entire old IP address, we only 225 | // care about the first 8 bytes after the prefix length. 226 | prefixLenBytes := options.WireguardNetworkPrefix.Bits() / 8 227 | addrLast8Bytes = decoded[prefixLenBytes : prefixLenBytes+8] 228 | } else { 229 | // "Good format": 230 | decoded, err := newHostnameEncoder.DecodeString(strings.ToUpper(hostname)) 231 | if err != nil { 232 | return netip.Addr{}, xerrors.Errorf("decode new hostname %q as base32: %w", hostname, err) 233 | } 234 | if len(decoded) != 8 { 235 | return netip.Addr{}, xerrors.Errorf("invalid new hostname length: got %d, expected 8", len(decoded)) 236 | } 237 | 238 | addrLast8Bytes = decoded 239 | } 240 | 241 | if addrLast8Bytes == nil { 242 | return netip.Addr{}, xerrors.Errorf("invalid hostname %q, does not match new or old format", hostname) 243 | } 244 | 245 | addrBytes := options.WireguardNetworkPrefix.Addr().As16() 246 | copy(addrBytes[8:], addrLast8Bytes[:]) 247 | return netip.AddrFrom16(addrBytes), nil 248 | } 249 | -------------------------------------------------------------------------------- /tunneld/options_test.go: -------------------------------------------------------------------------------- 1 | package tunneld_test 2 | 3 | import ( 4 | "fmt" 5 | "net/netip" 6 | "net/url" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/require" 12 | "github.com/tailscale/wireguard-go/device" 13 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 14 | 15 | "github.com/coder/wgtunnel/tunneld" 16 | "github.com/coder/wgtunnel/tunnelsdk" 17 | ) 18 | 19 | func Test_Option(t *testing.T) { 20 | t.Parallel() 21 | 22 | key, err := tunnelsdk.GeneratePrivateKey() 23 | require.NoError(t, err) 24 | 25 | t.Run("Validate", func(t *testing.T) { 26 | t.Parallel() 27 | 28 | t.Run("FullValid", func(t *testing.T) { 29 | t.Parallel() 30 | 31 | o := tunneld.Options{ 32 | BaseURL: &url.URL{ 33 | Scheme: "http", 34 | Host: "localhost", 35 | }, 36 | WireguardEndpoint: "localhost:1234", 37 | WireguardPort: 1234, 38 | WireguardKey: key, 39 | WireguardMTU: tunneld.DefaultWireguardMTU + 1, 40 | WireguardServerIP: netip.MustParseAddr("feed::1"), 41 | WireguardNetworkPrefix: netip.MustParsePrefix("feed::1/64"), 42 | RealIPHeader: "X-Real-Ip", 43 | PeerDialTimeout: 1 * time.Second, 44 | PeerRegisterInterval: time.Second, 45 | PeerTimeout: 2 * time.Second, 46 | } 47 | 48 | clone := o 49 | clone.BaseURL = &url.URL{ 50 | Scheme: o.BaseURL.Scheme, 51 | Host: o.BaseURL.Host, 52 | } 53 | clonePtr := &clone 54 | err := clonePtr.Validate() 55 | require.NoError(t, err) 56 | 57 | // Should not have updated the struct. 58 | require.Equal(t, o, clone) 59 | }) 60 | 61 | t.Run("Valid", func(t *testing.T) { 62 | t.Parallel() 63 | 64 | o := &tunneld.Options{ 65 | BaseURL: &url.URL{ 66 | Scheme: "http", 67 | Host: "localhost", 68 | }, 69 | WireguardEndpoint: "localhost:1234", 70 | WireguardPort: 1234, 71 | WireguardKey: key, 72 | RealIPHeader: "x-real-ip", 73 | } 74 | 75 | err := o.Validate() 76 | require.NoError(t, err) 77 | 78 | require.Equal(t, &url.URL{Scheme: "http", Host: "localhost"}, o.BaseURL) 79 | require.Equal(t, "localhost:1234", o.WireguardEndpoint) 80 | require.EqualValues(t, 1234, o.WireguardPort) 81 | require.Equal(t, key, o.WireguardKey) 82 | require.EqualValues(t, tunneld.DefaultWireguardMTU, o.WireguardMTU) 83 | require.Equal(t, tunneld.DefaultWireguardServerIP, o.WireguardServerIP) 84 | require.Equal(t, tunneld.DefaultWireguardNetworkPrefix, o.WireguardNetworkPrefix) 85 | // should be canonicalized. 86 | require.Equal(t, "X-Real-Ip", o.RealIPHeader) 87 | }) 88 | 89 | t.Run("Invalid", func(t *testing.T) { 90 | t.Parallel() 91 | 92 | t.Run("Nil", func(t *testing.T) { 93 | t.Parallel() 94 | 95 | err := (*tunneld.Options)(nil).Validate() 96 | require.Error(t, err) 97 | require.ErrorContains(t, err, "options is nil") 98 | }) 99 | 100 | t.Run("BaseURL", func(t *testing.T) { 101 | t.Parallel() 102 | 103 | o := &tunneld.Options{ 104 | BaseURL: nil, 105 | WireguardEndpoint: "localhost:1234", 106 | WireguardPort: 1234, 107 | WireguardKey: key, 108 | } 109 | 110 | err := o.Validate() 111 | require.Error(t, err) 112 | require.ErrorContains(t, err, "BaseURL is required") 113 | }) 114 | 115 | t.Run("WireguardEndpoint", func(t *testing.T) { 116 | t.Parallel() 117 | 118 | o := &tunneld.Options{ 119 | BaseURL: &url.URL{ 120 | Scheme: "http", 121 | Host: "localhost", 122 | }, 123 | WireguardEndpoint: "", 124 | WireguardPort: 1234, 125 | WireguardKey: key, 126 | } 127 | 128 | err := o.Validate() 129 | require.Error(t, err) 130 | require.ErrorContains(t, err, "WireguardEndpoint is required") 131 | 132 | o.WireguardEndpoint = "localhost" 133 | 134 | err = o.Validate() 135 | require.Error(t, err) 136 | require.ErrorContains(t, err, "not a valid host:port") 137 | }) 138 | 139 | t.Run("WireguardPort", func(t *testing.T) { 140 | t.Parallel() 141 | 142 | o := &tunneld.Options{ 143 | BaseURL: &url.URL{ 144 | Scheme: "http", 145 | Host: "localhost", 146 | }, 147 | WireguardEndpoint: "localhost:1234", 148 | WireguardPort: 0, 149 | WireguardKey: key, 150 | } 151 | 152 | err := o.Validate() 153 | require.Error(t, err) 154 | require.ErrorContains(t, err, "WireguardPort is required") 155 | }) 156 | 157 | t.Run("WireguardKey", func(t *testing.T) { 158 | t.Parallel() 159 | 160 | o := &tunneld.Options{ 161 | BaseURL: &url.URL{ 162 | Scheme: "http", 163 | Host: "localhost", 164 | }, 165 | WireguardEndpoint: "localhost:1234", 166 | WireguardPort: 1234, 167 | WireguardKey: tunnelsdk.Key{}, 168 | } 169 | 170 | err := o.Validate() 171 | require.Error(t, err) 172 | require.ErrorContains(t, err, "WireguardKey is required") 173 | 174 | o.WireguardKey, err = key.PublicKey() 175 | require.NoError(t, err) 176 | 177 | err = o.Validate() 178 | require.Error(t, err) 179 | require.ErrorContains(t, err, "WireguardKey must be a private key") 180 | }) 181 | 182 | t.Run("WireguardServerIP", func(t *testing.T) { 183 | t.Parallel() 184 | 185 | o := &tunneld.Options{ 186 | BaseURL: &url.URL{ 187 | Scheme: "http", 188 | Host: "localhost", 189 | }, 190 | WireguardEndpoint: "localhost:1234", 191 | WireguardPort: 1234, 192 | WireguardKey: key, 193 | WireguardServerIP: netip.MustParseAddr("127.0.0.1"), 194 | } 195 | 196 | err := o.Validate() 197 | require.Error(t, err) 198 | require.ErrorContains(t, err, "WireguardServerIP must be an IPv6 address") 199 | }) 200 | 201 | t.Run("WireguardNetworkPrefix", func(t *testing.T) { 202 | t.Parallel() 203 | 204 | o := &tunneld.Options{ 205 | BaseURL: &url.URL{ 206 | Scheme: "http", 207 | Host: "localhost", 208 | }, 209 | WireguardEndpoint: "localhost:1234", 210 | WireguardPort: 1234, 211 | WireguardKey: key, 212 | WireguardServerIP: netip.MustParseAddr("feed::1"), 213 | WireguardNetworkPrefix: netip.MustParsePrefix("feed::1/128"), 214 | } 215 | 216 | err := o.Validate() 217 | require.Error(t, err) 218 | require.ErrorContains(t, err, "WireguardNetworkPrefix must have at least 64 bits available") 219 | 220 | o.WireguardServerIP = netip.MustParseAddr("fcca::1") 221 | o.WireguardNetworkPrefix = netip.MustParsePrefix("feed::1/64") 222 | 223 | err = o.Validate() 224 | require.Error(t, err) 225 | require.ErrorContains(t, err, "WireguardServerIP must be contained within WireguardNetworkPrefix") 226 | }) 227 | }) 228 | }) 229 | 230 | t.Run("WireguardPublicKeyToIPAndURLs", func(t *testing.T) { 231 | t.Parallel() 232 | 233 | cases := []struct { 234 | // base64 encoded 235 | key string 236 | ip string 237 | urls []string 238 | }{ 239 | { 240 | key: "8HGwtvNSGqXyO2s7UCW/NtvQM7L5jUL+s76h3qZbeG0=", 241 | ip: "f8bf:98cd:3caf:3e62", 242 | urls: []string{ 243 | "http://v2vphj9slsv64.localhost.com", 244 | "http://fccaf8bf98cd3caf3e6270a5db3140f9.localhost.com", 245 | }, 246 | }, 247 | { 248 | key: "ikEH8jCTwDMpQb7B1SbLi7itzDHJrlLzZtdNmuiLZHo=", 249 | ip: "2150:c2ea:38fe:21f", 250 | urls: []string{ 251 | "http://458c5qhovo11u.localhost.com", 252 | "http://fcca2150c2ea38fe021f76fac00cd533.localhost.com", 253 | }, 254 | }, 255 | { 256 | key: "8yxYMm//sfv27tkSz9itIa/8Ihql+vFRpsvjTSTaYAg=", 257 | ip: "c17e:72e4:c52e:a6c4", 258 | urls: []string{ 259 | "http://o5v75p655qjc8.localhost.com", 260 | "http://fccac17e72e4c52ea6c4fbb4ef809339.localhost.com", 261 | }, 262 | }, 263 | { 264 | key: "Gl7xZzfkCyFTbB+Uejc17GmfbjLy6s8NEZBaJKx/swU=", 265 | ip: "f773:2e08:771d:7a6f", 266 | urls: []string{ 267 | "http://utpis23n3lt6u.localhost.com", 268 | "http://fccaf7732e08771d7a6f6fdcb4a1f367.localhost.com", 269 | }, 270 | }, 271 | { 272 | key: "f8YjkcGgOggYzlIr2KtShY+8ZgR0hIXmJHPjCG8wi2Q=", 273 | ip: "dcf1:4e76:15bd:b2c7", 274 | urls: []string{ 275 | "http://rjokstglnmpce.localhost.com", 276 | "http://fccadcf14e7615bdb2c7638238302374.localhost.com", 277 | }, 278 | }, 279 | { 280 | key: "Q3dubFlwwLnCpQTagjCckb1XLGtViZoBX1qHAZWV2gI=", 281 | ip: "25a2:8a43:2e91:1543", 282 | urls: []string{ 283 | "http://4mh8kgpei4ak6.localhost.com", 284 | "http://fcca25a28a432e9115439264ae85af84.localhost.com", 285 | }, 286 | }, 287 | } 288 | 289 | for i, c := range cases { 290 | i, c := i, c 291 | 292 | pubKey, err := wgtypes.ParseKey(c.key) 293 | require.NoError(t, err) 294 | 295 | t.Run(fmt.Sprintf("Default/%d", i), func(t *testing.T) { 296 | t.Parallel() 297 | 298 | options := &tunneld.Options{ 299 | BaseURL: &url.URL{ 300 | Scheme: "http", 301 | Host: "localhost.com", 302 | }, 303 | WireguardEndpoint: "localhost:1234", 304 | WireguardPort: 1234, 305 | WireguardKey: key, 306 | WireguardServerIP: tunneld.DefaultWireguardServerIP, 307 | WireguardNetworkPrefix: tunneld.DefaultWireguardNetworkPrefix, 308 | } 309 | err := options.Validate() 310 | require.NoError(t, err) 311 | 312 | expectedIP := "fcca::" + c.ip 313 | 314 | ip, urls := options.WireguardPublicKeyToIPAndURLs(device.NoisePublicKey(pubKey), tunnelsdk.TunnelVersion2) 315 | require.Equal(t, expectedIP, ip.String()) 316 | 317 | urlsStr := make([]string, len(urls)) 318 | for i, u := range urls { 319 | urlsStr[i] = u.String() 320 | } 321 | require.Equal(t, c.urls, urlsStr) 322 | 323 | // Try the old version, which should have a reversed URL list. 324 | ip, urls = options.WireguardPublicKeyToIPAndURLs(device.NoisePublicKey(pubKey), tunnelsdk.TunnelVersion1) 325 | require.Equal(t, expectedIP, ip.String()) 326 | 327 | urlsStr = make([]string, len(urls)) 328 | for i, u := range urls { 329 | urlsStr[len(urls)-i-1] = u.String() 330 | } 331 | require.Equal(t, c.urls, urlsStr) 332 | }) 333 | 334 | t.Run(fmt.Sprintf("LongerPrefix/%d", i), func(t *testing.T) { 335 | t.Parallel() 336 | 337 | options := &tunneld.Options{ 338 | BaseURL: &url.URL{ 339 | Scheme: "http", 340 | Host: "localhost.com", 341 | }, 342 | WireguardEndpoint: "localhost:1234", 343 | WireguardPort: 1234, 344 | WireguardKey: key, 345 | WireguardServerIP: netip.MustParseAddr("feed:beef:deaf:deed::1"), 346 | WireguardNetworkPrefix: netip.MustParsePrefix("feed:beef:deaf:deed::1/64"), 347 | } 348 | err := options.Validate() 349 | require.NoError(t, err) 350 | 351 | expectedIP := "feed:beef:deaf:deed:" + c.ip 352 | 353 | // The second URL has a different IP prefix length, so adjust 354 | // accordingly. 355 | expectedURL2, err := url.Parse(c.urls[1]) 356 | require.NoError(t, err) 357 | hostRest := strings.SplitN(expectedURL2.Host, ".", 2)[1] 358 | expectedURL2.Host = "feedbeefdeafdeed" + expectedURL2.Host[4:20] + "." + hostRest 359 | t.Logf("mutated URL %q to %q", c.urls[1], expectedURL2.String()) 360 | expectedURLs := []string{ 361 | c.urls[0], 362 | expectedURL2.String(), 363 | } 364 | 365 | ip, urls := options.WireguardPublicKeyToIPAndURLs(device.NoisePublicKey(pubKey), tunnelsdk.TunnelVersion2) 366 | require.Equal(t, expectedIP, ip.String()) 367 | 368 | urlsStr := make([]string, len(urls)) 369 | for i, u := range urls { 370 | urlsStr[i] = u.String() 371 | } 372 | require.Equal(t, expectedURLs, urlsStr) 373 | 374 | // Try the old version, which should have a reversed URL list. 375 | ip, urls = options.WireguardPublicKeyToIPAndURLs(device.NoisePublicKey(pubKey), tunnelsdk.TunnelVersion1) 376 | require.Equal(t, expectedIP, ip.String()) 377 | 378 | urlsStr = make([]string, len(urls)) 379 | for i, u := range urls { 380 | urlsStr[len(urls)-i-1] = u.String() 381 | } 382 | require.Equal(t, expectedURLs, urlsStr) 383 | }) 384 | } 385 | }) 386 | 387 | t.Run("HostnameToWireguardIP", func(t *testing.T) { 388 | t.Parallel() 389 | 390 | cases := []struct { 391 | hostname string 392 | ip string 393 | errContains string 394 | }{ 395 | // Good format: 396 | { 397 | hostname: "v2vphj9slsv64", 398 | ip: "f8bf:98cd:3caf:3e62", 399 | }, 400 | { 401 | hostname: "458c5qhovo11u", 402 | ip: "2150:c2ea:38fe:21f", 403 | }, 404 | { 405 | hostname: "o5v75p655qjc8", 406 | ip: "c17e:72e4:c52e:a6c4", 407 | }, 408 | { 409 | hostname: "utpis23n3lt6u", 410 | ip: "f773:2e08:771d:7a6f", 411 | }, 412 | { 413 | hostname: "rjokstglnmpce", 414 | ip: "dcf1:4e76:15bd:b2c7", 415 | }, 416 | { 417 | hostname: "4mh8kgpei4ak6", 418 | ip: "25a2:8a43:2e91:1543", 419 | }, 420 | 421 | // Good format errors: 422 | { 423 | hostname: "v2vphj9slsv64.localhost.com", 424 | errContains: "decode new hostname", 425 | }, 426 | { 427 | hostname: "4mh8kgpei4ak64mh8kgpei4ak6", 428 | errContains: "invalid new hostname length", 429 | }, 430 | 431 | // Bad format: 432 | { 433 | hostname: "fccaf8bf98cd3caf3e6270a5db3140f9", 434 | ip: "f8bf:98cd:3caf:3e62", 435 | }, 436 | { 437 | hostname: "fcca2150c2ea38fe021f76fac00cd533", 438 | ip: "2150:c2ea:38fe:21f", 439 | }, 440 | { 441 | hostname: "fccac17e72e4c52ea6c4fbb4ef809339", 442 | ip: "c17e:72e4:c52e:a6c4", 443 | }, 444 | { 445 | hostname: "fccaf7732e08771d7a6f6fdcb4a1f367", 446 | ip: "f773:2e08:771d:7a6f", 447 | }, 448 | { 449 | hostname: "fccadcf14e7615bdb2c7638238302374", 450 | ip: "dcf1:4e76:15bd:b2c7", 451 | }, 452 | { 453 | hostname: "fcca25a28a432e9115439264ae85af84", 454 | ip: "25a2:8a43:2e91:1543", 455 | }, 456 | 457 | // Bad format errors: 458 | { 459 | hostname: "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 460 | errContains: "decode old hostname", 461 | }, 462 | } 463 | 464 | for i, c := range cases { 465 | c := c 466 | 467 | t.Run(fmt.Sprint(i), func(t *testing.T) { 468 | t.Parallel() 469 | 470 | t.Run("Default", func(t *testing.T) { 471 | t.Parallel() 472 | 473 | options := &tunneld.Options{ 474 | BaseURL: &url.URL{ 475 | Scheme: "http", 476 | Host: "localhost.com", 477 | }, 478 | WireguardEndpoint: "localhost:1234", 479 | WireguardPort: 1234, 480 | WireguardKey: key, 481 | WireguardServerIP: tunneld.DefaultWireguardServerIP, 482 | WireguardNetworkPrefix: tunneld.DefaultWireguardNetworkPrefix, 483 | } 484 | err := options.Validate() 485 | require.NoError(t, err) 486 | 487 | ip, err := options.HostnameToWireguardIP(c.hostname) 488 | if c.errContains != "" { 489 | require.Error(t, err) 490 | require.ErrorContains(t, err, c.errContains) 491 | return 492 | } 493 | 494 | require.NoError(t, err) 495 | require.Equal(t, "fcca::"+c.ip, ip.String()) 496 | }) 497 | 498 | t.Run("LongerPrefix", func(t *testing.T) { 499 | t.Parallel() 500 | 501 | options := &tunneld.Options{ 502 | BaseURL: &url.URL{ 503 | Scheme: "http", 504 | Host: "localhost.com", 505 | }, 506 | WireguardEndpoint: "localhost:1234", 507 | WireguardPort: 1234, 508 | WireguardKey: key, 509 | WireguardServerIP: netip.MustParseAddr("feed:beef:deaf:deed::1"), 510 | WireguardNetworkPrefix: netip.MustParsePrefix("feed:beef:deaf:deed::1/64"), 511 | } 512 | err := options.Validate() 513 | require.NoError(t, err) 514 | 515 | // The second hostname has a different IP prefix length, so 516 | // adjust accordingly. 517 | hostname := c.hostname 518 | if len(hostname) == 32 { 519 | hostname = "feedbeefdeafdeed" + hostname[4:20] 520 | t.Logf("mutated hostname %q to %q", c.hostname, hostname) 521 | } 522 | 523 | ip, err := options.HostnameToWireguardIP(hostname) 524 | if c.errContains != "" { 525 | require.Error(t, err) 526 | require.ErrorContains(t, err, c.errContains) 527 | return 528 | } 529 | 530 | require.NoError(t, err) 531 | require.Equal(t, "feed:beef:deaf:deed:"+c.ip, ip.String()) 532 | }) 533 | }) 534 | } 535 | }) 536 | } 537 | -------------------------------------------------------------------------------- /tunneld/tunneld.go: -------------------------------------------------------------------------------- 1 | package tunneld 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | "net/netip" 9 | "sync" 10 | "time" 11 | 12 | "github.com/tailscale/wireguard-go/conn" 13 | "github.com/tailscale/wireguard-go/device" 14 | "github.com/tailscale/wireguard-go/tun/netstack" 15 | "go.opentelemetry.io/otel" 16 | "go.opentelemetry.io/otel/attribute" 17 | "go.opentelemetry.io/otel/codes" 18 | "golang.org/x/xerrors" 19 | ) 20 | 21 | // TODO: add logging to API 22 | type API struct { 23 | *Options 24 | 25 | wgNet *netstack.Net 26 | wgDevice *device.Device 27 | transport *http.Transport 28 | 29 | pkeyCacheMu sync.RWMutex 30 | pkeyCache map[netip.Addr]cachedPeer 31 | } 32 | 33 | type cachedPeer struct { 34 | key device.NoisePublicKey 35 | lastHandshake time.Time 36 | } 37 | 38 | func New(options *Options) (*API, error) { 39 | if options == nil { 40 | options = &Options{} 41 | } 42 | err := options.Validate() 43 | if err != nil { 44 | return nil, xerrors.Errorf("invalid options: %w", err) 45 | } 46 | 47 | // Create the wireguard virtual TUN adapter and netstack. 48 | tun, wgNet, err := netstack.CreateNetTUN( 49 | []netip.Addr{options.WireguardServerIP}, 50 | // We don't do DNS resolution over the netstack, so don't specify any 51 | // DNS servers. 52 | []netip.Addr{}, 53 | options.WireguardMTU, 54 | ) 55 | if err != nil { 56 | return nil, xerrors.Errorf("create wireguard virtual TUN adapter and netstack: %w", err) 57 | } 58 | 59 | // Create, configure and start the wireguard device. 60 | deviceLogger := options.Log.Named("wireguard_device") 61 | dlog := &device.Logger{ 62 | Verbosef: func(format string, args ...interface{}) { 63 | deviceLogger.Debug(context.Background(), fmt.Sprintf(format, args...)) 64 | }, 65 | Errorf: func(format string, args ...interface{}) { 66 | deviceLogger.Error(context.Background(), fmt.Sprintf(format, args...)) 67 | }, 68 | } 69 | dev := device.NewDevice(tun, conn.NewDefaultBind(), dlog) 70 | err = dev.IpcSet(fmt.Sprintf(`private_key=%s 71 | listen_port=%d`, 72 | options.WireguardKey.HexString(), 73 | options.WireguardPort, 74 | )) 75 | if err != nil { 76 | return nil, xerrors.Errorf("configure wireguard device: %w", err) 77 | } 78 | err = dev.Up() 79 | if err != nil { 80 | return nil, xerrors.Errorf("start wireguard device: %w", err) 81 | } 82 | 83 | return &API{ 84 | Options: options, 85 | wgNet: wgNet, 86 | wgDevice: dev, 87 | pkeyCache: make(map[netip.Addr]cachedPeer), 88 | transport: &http.Transport{ 89 | DialContext: func(ctx context.Context, network, addr string) (nc net.Conn, err error) { 90 | ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "(http.Transport).DialContext") 91 | defer span.End() 92 | defer func() { 93 | if err != nil { 94 | span.RecordError(err) 95 | span.SetStatus(codes.Error, err.Error()) 96 | } 97 | }() 98 | 99 | ip := ctx.Value(ipPortKey{}) 100 | if ip == nil { 101 | err = xerrors.New("no ip on context") 102 | return nil, err 103 | } 104 | 105 | ipp, ok := ip.(netip.AddrPort) 106 | if !ok { 107 | err = xerrors.Errorf("ip is incorrect type, got %T", ipp) 108 | return nil, err 109 | } 110 | 111 | span.SetAttributes(attribute.String("wireguard_addr", ipp.Addr().String())) 112 | 113 | dialCtx, dialCancel := context.WithTimeout(ctx, options.PeerDialTimeout) 114 | defer dialCancel() 115 | 116 | nc, err = wgNet.DialContextTCPAddrPort(dialCtx, ipp) 117 | if err != nil { 118 | return nil, err 119 | } 120 | 121 | return nc, nil 122 | }, 123 | ForceAttemptHTTP2: false, 124 | MaxIdleConns: 0, 125 | IdleConnTimeout: 90 * time.Second, 126 | TLSHandshakeTimeout: 10 * time.Second, 127 | ExpectContinueTimeout: 1 * time.Second, 128 | }, 129 | }, nil 130 | } 131 | 132 | func (api *API) Close() error { 133 | // Remove peers before closing to avoid a race condition between dev.Close() 134 | // and the peer goroutines which results in segfault. 135 | api.wgDevice.RemoveAllPeers() 136 | api.wgDevice.Close() 137 | <-api.wgDevice.Wait() 138 | 139 | return nil 140 | } 141 | -------------------------------------------------------------------------------- /tunneld/tunneld_test.go: -------------------------------------------------------------------------------- 1 | package tunneld_test 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "io" 7 | "log" 8 | "net" 9 | "net/http" 10 | "net/http/httptest" 11 | "net/netip" 12 | "net/url" 13 | "reflect" 14 | "strconv" 15 | "sync" 16 | "testing" 17 | "time" 18 | 19 | "github.com/stretchr/testify/assert" 20 | "github.com/stretchr/testify/require" 21 | 22 | "cdr.dev/slog/sloggers/slogtest" 23 | "github.com/coder/wgtunnel/tunneld" 24 | "github.com/coder/wgtunnel/tunnelsdk" 25 | ) 26 | 27 | // TestEndToEnd does an end-to-end tunnel test by creating a tunneld server, a 28 | // client, setting up the tunnel, and then doing a bunch of tests through the 29 | // tunnel to ensure it works. 30 | func TestEndToEnd(t *testing.T) { 31 | t.Parallel() 32 | 33 | td, client := createTestTunneld(t, nil) 34 | require.NotNil(t, td) 35 | 36 | // Start a tunnel. 37 | key, err := tunnelsdk.GeneratePrivateKey() 38 | require.NoError(t, err, "generate private key") 39 | tunnel, err := client.LaunchTunnel(context.Background(), tunnelsdk.TunnelConfig{ 40 | Log: slogtest. 41 | Make(t, &slogtest.Options{IgnoreErrors: true}). 42 | Named("tunnel_client"), 43 | PrivateKey: key, 44 | }) 45 | require.NoError(t, err, "launch tunnel") 46 | defer func() { 47 | _ = tunnel.Close() 48 | <-tunnel.Wait() 49 | }() 50 | 51 | require.NotNil(t, tunnel.URL) 52 | require.Len(t, tunnel.OtherURLs, 1) 53 | require.NotEqual(t, tunnel.URL.String(), tunnel.OtherURLs[0].String()) 54 | 55 | serveTunnel(t, tunnel) 56 | waitForTunnelReady(t, client, tunnel) 57 | 58 | // Make a bunch of requests concurrently. 59 | var wg sync.WaitGroup 60 | for i := 0; i < 1024; i++ { 61 | wg.Add(1) 62 | go func(i int) { 63 | defer wg.Done() 64 | 65 | // Do half of the requests to the primary URL and the other half to 66 | // the other URL (there's only one other URL right now). 67 | u := tunnel.URL 68 | if i%2 == 0 { 69 | u = tunnel.OtherURLs[0] 70 | } 71 | 72 | u, err := u.Parse("/test/" + strconv.Itoa(i)) 73 | if !assert.NoError(t, err) { 74 | return 75 | } 76 | 77 | // Do a third of the requests with a prefix before the hostname. 78 | if i%3 == 0 { 79 | u.Host = "prefix--" + u.Host 80 | } 81 | 82 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 83 | defer cancel() 84 | 85 | res, err := client.Request(ctx, http.MethodGet, u.String(), nil) 86 | if !assert.NoError(t, err) { 87 | return 88 | } 89 | defer res.Body.Close() 90 | assert.Equal(t, http.StatusOK, res.StatusCode) 91 | 92 | body, err := io.ReadAll(res.Body) 93 | if !assert.NoError(t, err) { 94 | return 95 | } 96 | assert.Equal(t, "hello world /test/"+strconv.Itoa(i), string(body)) 97 | }(i) 98 | } 99 | 100 | wg.Wait() 101 | 102 | err = tunnel.Close() 103 | require.NoError(t, err, "close tunnel") 104 | 105 | <-tunnel.Wait() 106 | } 107 | 108 | // This test ensures that wgtunnel is compatible with the old closed-source 109 | // wgtunnel when register requests are made with version 1. 110 | // 111 | // This uses real values generated by the old wgtunnel source code and checks 112 | // that the new wgtunnel can parse them and generates the expected values. 113 | func TestCompatibility(t *testing.T) { 114 | t.Parallel() 115 | 116 | /* 117 | wgKey: mCW7PwpK8iBmyXEFyGk55G24H0IU/AmJf5ZerzA3jGY= 118 | wgPubKey: Y9psPgU9BNRCvjPR93RNghbJUPyVh0LXBTnbHb+0TgU= 119 | publicKeyToV6: fcca:bbaf:8a9b:77f9:3fa9:fa65:7677:155e 120 | v6ToString: fccabbaf8a9b77f93fa9fa657677155e 121 | stringToV6: fcca:bbaf:8a9b:77f9:3fa9:fa65:7677:155e 122 | */ 123 | 124 | clientKey, err := tunnelsdk.ParsePrivateKey("mCW7PwpK8iBmyXEFyGk55G24H0IU/AmJf5ZerzA3jGY=") 125 | require.NoError(t, err) 126 | require.Equal(t, "mCW7PwpK8iBmyXEFyGk55G24H0IU/AmJf5ZerzA3jGY=", clientKey.String()) 127 | 128 | clientPublicKey, err := clientKey.PublicKey() 129 | require.NoError(t, err) 130 | require.Equal(t, "Y9psPgU9BNRCvjPR93RNghbJUPyVh0LXBTnbHb+0TgU=", clientPublicKey.String()) 131 | 132 | t.Run("Default", func(t *testing.T) { 133 | t.Parallel() 134 | 135 | td, client := createTestTunneld(t, &tunneld.Options{ 136 | BaseURL: &url.URL{ 137 | Scheme: "http", 138 | Host: "localhost.com", 139 | }, 140 | WireguardEndpoint: "", // generated automatically 141 | WireguardPort: 0, // generated automatically 142 | WireguardKey: tunnelsdk.Key{}, // generated automatically 143 | WireguardServerIP: tunneld.DefaultWireguardServerIP, 144 | WireguardNetworkPrefix: tunneld.DefaultWireguardNetworkPrefix, 145 | }) 146 | require.NotNil(t, td) 147 | 148 | ip1, urls1 := td.Options.WireguardPublicKeyToIPAndURLs(clientPublicKey.NoisePublicKey(), tunnelsdk.TunnelVersion1) 149 | ip2, urls2 := td.Options.WireguardPublicKeyToIPAndURLs(clientPublicKey.NoisePublicKey(), tunnelsdk.TunnelVersion2) 150 | 151 | // Identical IP address in both formats. This differs from the old 152 | // wgtunnel which uses all 16 bytes of the IP instead of just the prefix 153 | // and 8 bytes of the public key, but old clients don't care about the 154 | // IP anyways. 155 | require.Equal(t, ip1, ip2) 156 | // Swapped order of URLs in the new format. 157 | require.Equal(t, []string{urls1[0].String(), urls1[1].String()}, []string{urls2[1].String(), urls2[0].String()}) 158 | require.Equal(t, "fccabbaf8a9b77f93fa9fa657677155e.localhost.com", urls1[0].Host) 159 | 160 | // Register with the old format. 161 | res, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{ 162 | Version: tunnelsdk.TunnelVersion1, 163 | PublicKey: clientPublicKey.NoisePublicKey(), 164 | }) 165 | require.NoError(t, err) 166 | 167 | require.Equal(t, tunnelsdk.TunnelVersion1, res.Version) 168 | require.Equal(t, "http://fccabbaf8a9b77f93fa9fa657677155e.localhost.com", res.TunnelURLs[0]) 169 | require.Equal(t, ip1, res.ClientIP) 170 | 171 | // Now actually tunnel and check that the URL works. 172 | tunnel, err := client.LaunchTunnel(context.Background(), tunnelsdk.TunnelConfig{ 173 | Log: slogtest.Make(t, &slogtest.Options{ 174 | IgnoreErrors: true, 175 | }), 176 | Version: tunnelsdk.TunnelVersion1, 177 | PrivateKey: clientKey, 178 | }) 179 | require.NoError(t, err) 180 | require.NotNil(t, tunnel) 181 | 182 | serveTunnel(t, tunnel) 183 | waitForTunnelReady(t, client, tunnel) 184 | 185 | // Make a request to the tunnel. 186 | { 187 | u, err := url.Parse(res.TunnelURLs[0]) 188 | require.NoError(t, err) 189 | u.Path = "/test/1" 190 | 191 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 192 | defer cancel() 193 | 194 | res, err := client.Request(ctx, http.MethodGet, u.String(), nil) 195 | if !assert.NoError(t, err) { 196 | return 197 | } 198 | defer res.Body.Close() 199 | assert.Equal(t, http.StatusOK, res.StatusCode) 200 | 201 | body, err := io.ReadAll(res.Body) 202 | if !assert.NoError(t, err) { 203 | return 204 | } 205 | assert.Equal(t, "hello world /test/1", string(body)) 206 | } 207 | }) 208 | 209 | // This test is mostly for completeness, but we don't use the longer prefix 210 | // functionality anyways, and it's not compatible with the old wgtunnel 211 | // implementation. 212 | t.Run("LongerPrefix", func(t *testing.T) { 213 | t.Parallel() 214 | 215 | td, client := createTestTunneld(t, &tunneld.Options{ 216 | BaseURL: &url.URL{ 217 | Scheme: "http", 218 | Host: "localhost.com", 219 | }, 220 | WireguardEndpoint: "", // generated automatically 221 | WireguardPort: 0, // generated automatically 222 | WireguardKey: tunnelsdk.Key{}, // generated automatically 223 | WireguardServerIP: netip.MustParseAddr("feed:beef:deaf:deed::1"), 224 | WireguardNetworkPrefix: netip.MustParsePrefix("feed:beef:deaf:deed::1/64"), 225 | }) 226 | require.NotNil(t, td) 227 | 228 | ip1, urls1 := td.Options.WireguardPublicKeyToIPAndURLs(clientPublicKey.NoisePublicKey(), tunnelsdk.TunnelVersion1) 229 | ip2, urls2 := td.Options.WireguardPublicKeyToIPAndURLs(clientPublicKey.NoisePublicKey(), tunnelsdk.TunnelVersion2) 230 | 231 | // Identical IP address in both formats. This differs from the old 232 | // wgtunnel which uses all 16 bytes of the IP instead of just the prefix 233 | // and 8 bytes of the public key, but old clients don't care about the 234 | // IP anyways. 235 | require.Equal(t, ip1, ip2) 236 | // Swapped order of URLs in the new format. 237 | require.Equal(t, []string{urls1[0].String(), urls1[1].String()}, []string{urls2[1].String(), urls2[0].String()}) 238 | 239 | // For longer prefix, we use the prefix bytes, then the public key 240 | // bytes. We don't do any shifting. 241 | require.Equal(t, "feedbeefdeafdeedbbaf8a9b77f93fa9.localhost.com", urls1[0].Host) 242 | 243 | // Register with the old format. 244 | res, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{ 245 | Version: tunnelsdk.TunnelVersion1, 246 | PublicKey: clientPublicKey.NoisePublicKey(), 247 | }) 248 | require.NoError(t, err) 249 | 250 | require.Equal(t, tunnelsdk.TunnelVersion1, res.Version) 251 | require.Equal(t, "http://feedbeefdeafdeedbbaf8a9b77f93fa9.localhost.com", res.TunnelURLs[0]) 252 | require.Equal(t, ip1, res.ClientIP) 253 | 254 | // Now actually tunnel and check that the URL works. 255 | tunnel, err := client.LaunchTunnel(context.Background(), tunnelsdk.TunnelConfig{ 256 | Log: slogtest.Make(t, &slogtest.Options{ 257 | IgnoreErrors: true, 258 | }), 259 | Version: tunnelsdk.TunnelVersion1, 260 | PrivateKey: clientKey, 261 | }) 262 | require.NoError(t, err) 263 | require.NotNil(t, tunnel) 264 | 265 | serveTunnel(t, tunnel) 266 | waitForTunnelReady(t, client, tunnel) 267 | 268 | // Make a request to the tunnel. 269 | { 270 | u, err := url.Parse(res.TunnelURLs[0]) 271 | require.NoError(t, err) 272 | u.Path = "/test/1" 273 | 274 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 275 | defer cancel() 276 | 277 | res, err := client.Request(ctx, http.MethodGet, u.String(), nil) 278 | if !assert.NoError(t, err) { 279 | return 280 | } 281 | defer res.Body.Close() 282 | assert.Equal(t, http.StatusOK, res.StatusCode) 283 | 284 | body, err := io.ReadAll(res.Body) 285 | if !assert.NoError(t, err) { 286 | return 287 | } 288 | assert.Equal(t, "hello world /test/1", string(body)) 289 | } 290 | }) 291 | } 292 | 293 | func TestTimeout(t *testing.T) { 294 | t.Parallel() 295 | 296 | td, client := createTestTunneld(t, &tunneld.Options{ 297 | BaseURL: &url.URL{ 298 | Scheme: "http", 299 | Host: "localhost.com", 300 | }, 301 | WireguardEndpoint: "", // generated automatically 302 | WireguardPort: 0, // generated automatically 303 | WireguardKey: tunnelsdk.Key{}, // generated automatically 304 | WireguardServerIP: tunneld.DefaultWireguardServerIP, 305 | WireguardNetworkPrefix: tunneld.DefaultWireguardNetworkPrefix, 306 | PeerDialTimeout: time.Second, 307 | }) 308 | require.NotNil(t, td) 309 | 310 | // Start a tunnel. 311 | key, err := tunnelsdk.GeneratePrivateKey() 312 | require.NoError(t, err, "generate private key") 313 | tunnel, err := client.LaunchTunnel(context.Background(), tunnelsdk.TunnelConfig{ 314 | Log: slogtest. 315 | Make(t, &slogtest.Options{IgnoreErrors: true}). 316 | Named("tunnel_client"), 317 | PrivateKey: key, 318 | }) 319 | require.NoError(t, err, "launch tunnel") 320 | 321 | // Close the tunnel. 322 | err = tunnel.Close() 323 | require.NoError(t, err, "close tunnel") 324 | <-tunnel.Wait() 325 | 326 | // Requests should fail in roughly 1 second. 327 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 328 | defer cancel() 329 | 330 | u := *tunnel.URL 331 | u.Path = "/test/1" 332 | 333 | now := time.Now() 334 | res, err := client.Request(ctx, http.MethodGet, u.String(), nil) 335 | require.NoError(t, err) 336 | require.WithinDuration(t, now.Add(time.Second), time.Now(), 2*time.Second) 337 | defer res.Body.Close() 338 | require.Equal(t, http.StatusBadGateway, res.StatusCode) 339 | } 340 | 341 | func TestPeerTimeout(t *testing.T) { 342 | t.Parallel() 343 | 344 | td, client := createTestTunneld(t, &tunneld.Options{ 345 | PeerTimeout: time.Second, 346 | PeerRegisterInterval: 100 * time.Millisecond, 347 | }) 348 | require.NotNil(t, td) 349 | 350 | // Start a tunnel. 351 | key, err := tunnelsdk.GeneratePrivateKey() 352 | require.NoError(t, err, "generate private key") 353 | tunnel, err := client.LaunchTunnel(context.Background(), tunnelsdk.TunnelConfig{ 354 | Log: slogtest. 355 | Make(t, &slogtest.Options{IgnoreErrors: true}). 356 | Named("tunnel_client"), 357 | PrivateKey: key, 358 | }) 359 | require.NoError(t, err, "launch tunnel") 360 | defer func() { 361 | _ = tunnel.Close() 362 | <-tunnel.Wait() 363 | }() 364 | 365 | require.NotNil(t, tunnel.URL) 366 | require.Len(t, tunnel.OtherURLs, 1) 367 | require.NotEqual(t, tunnel.URL.String(), tunnel.OtherURLs[0].String()) 368 | 369 | serveTunnel(t, tunnel) 370 | waitForTunnelReady(t, client, tunnel) 371 | 372 | // Successfully send a request to the peer. 373 | { 374 | u, err := tunnel.URL.Parse("/test/1") 375 | require.NoError(t, err) 376 | 377 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 378 | defer cancel() 379 | 380 | res, err := client.Request(ctx, http.MethodGet, u.String(), nil) 381 | if !assert.NoError(t, err) { 382 | return 383 | } 384 | defer res.Body.Close() 385 | assert.Equal(t, http.StatusOK, res.StatusCode) 386 | 387 | body, err := io.ReadAll(res.Body) 388 | require.NoError(t, err) 389 | require.Equal(t, "hello world /test/1", string(body)) 390 | } 391 | 392 | err = tunnel.Close() 393 | require.NoError(t, err, "close tunnel") 394 | <-tunnel.Wait() 395 | 396 | time.Sleep(td.PeerTimeout) 397 | 398 | // The correct error should be returned after the peer goes away. 399 | { 400 | u, err := tunnel.URL.Parse("/test/1") 401 | require.NoError(t, err) 402 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 403 | defer cancel() 404 | 405 | res, err := client.Request(ctx, http.MethodGet, u.String(), nil) 406 | require.NoError(t, err) 407 | defer res.Body.Close() 408 | 409 | tres := tunnelsdk.Response{} 410 | err = json.NewDecoder(res.Body).Decode(&tres) 411 | require.NoError(t, err) 412 | 413 | require.Equal(t, http.StatusBadGateway, res.StatusCode) 414 | require.Equal(t, "Peer is not connected.", tres.Message) 415 | } 416 | } 417 | 418 | func freeUDPPort(t *testing.T) uint16 { 419 | t.Helper() 420 | 421 | l, err := net.ListenUDP("udp", &net.UDPAddr{ 422 | IP: net.ParseIP("127.0.0.1"), 423 | Port: 0, 424 | }) 425 | require.NoError(t, err, "listen on random UDP port") 426 | 427 | _, port, err := net.SplitHostPort(l.LocalAddr().String()) 428 | require.NoError(t, err, "split host port") 429 | 430 | portUint, err := strconv.ParseUint(port, 10, 16) 431 | require.NoError(t, err, "parse port") 432 | 433 | // This is prone to races, but since we have to tell wireguard to create the 434 | // listener and can't pass in a net.Listener, we have to do this. 435 | err = l.Close() 436 | require.NoError(t, err, "close UDP listener") 437 | 438 | return uint16(portUint) 439 | } 440 | 441 | func createTestTunneld(t *testing.T, options *tunneld.Options) (*tunneld.API, *tunnelsdk.Client) { 442 | t.Helper() 443 | 444 | if options == nil { 445 | options = &tunneld.Options{} 446 | } 447 | if reflect.ValueOf(options.Log).IsZero() { 448 | options.Log = slogtest. 449 | Make(t, &slogtest.Options{IgnoreErrors: true}). 450 | Named("tunneld") 451 | } 452 | 453 | // Set required options if unset. 454 | if options.BaseURL == nil { 455 | options.BaseURL = &url.URL{ 456 | Scheme: "http", 457 | Host: "tunnel.dev", 458 | } 459 | } 460 | if options.WireguardEndpoint == "" && options.WireguardPort == 0 { 461 | port := freeUDPPort(t) 462 | options.WireguardEndpoint = "127.0.0.1:" + strconv.Itoa(int(port)) 463 | options.WireguardPort = port 464 | } 465 | if options.WireguardKey.IsZero() { 466 | key, err := tunnelsdk.GeneratePrivateKey() 467 | require.NoError(t, err, "generate wireguard private key") 468 | options.WireguardKey = key 469 | } 470 | 471 | err := options.Validate() 472 | require.NoError(t, err, "validate options") 473 | 474 | return createTestTunneldNoDefaults(t, options) 475 | } 476 | 477 | func createTestTunneldNoDefaults(t *testing.T, options *tunneld.Options) (*tunneld.API, *tunnelsdk.Client) { 478 | t.Helper() 479 | 480 | td, err := tunneld.New(options) 481 | require.NoError(t, err, "create tunneld") 482 | t.Cleanup(func() { 483 | _ = td.Close() 484 | }) 485 | 486 | srv := httptest.NewServer(td.Router()) 487 | t.Cleanup(srv.Close) 488 | 489 | u, err := url.Parse(srv.URL) 490 | require.NoError(t, err, "parse server URL") 491 | 492 | client := tunnelsdk.New(options.BaseURL) 493 | client.HTTPClient = tunnelHTTPClient(u) 494 | return td, client 495 | } 496 | 497 | func serveTunnel(t *testing.T, tunnel *tunnelsdk.Tunnel) { 498 | t.Helper() 499 | 500 | // Start a basic HTTP server with the listener. 501 | srv := &http.Server{ 502 | // These errors are typically noise like "TLS: EOF". Vault does similar: 503 | // https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714 504 | ErrorLog: log.New(io.Discard, "", 0), 505 | ReadHeaderTimeout: 5 * time.Second, 506 | Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 507 | rw.Header().Set("Content-Type", "text/plain") 508 | rw.WriteHeader(http.StatusOK) 509 | _, _ = rw.Write([]byte("hello world " + r.URL.Path)) 510 | }), 511 | } 512 | 513 | done := make(chan struct{}) 514 | go func() { 515 | defer close(done) 516 | _ = srv.Serve(tunnel.Listener) 517 | }() 518 | t.Cleanup(func() { 519 | _ = srv.Close() 520 | <-done 521 | }) 522 | } 523 | 524 | // tunnelHTTPClient returns a HTTP client that disregards DNS and always 525 | // connects to the tunneld server IP. This is useful for testing connections to 526 | // generated tunnel URLs with custom hostnames that don't resolve. 527 | func tunnelHTTPClient(tunURL *url.URL) *http.Client { 528 | return &http.Client{ 529 | Transport: &http.Transport{ 530 | DisableKeepAlives: true, 531 | DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 532 | return (&net.Dialer{}).DialContext(ctx, network, tunURL.Host) 533 | }, 534 | }, 535 | } 536 | } 537 | 538 | func waitForTunnelReady(t *testing.T, client *tunnelsdk.Client, tunnel *tunnelsdk.Tunnel) { 539 | require.Eventually(t, func() bool { 540 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 541 | defer cancel() 542 | 543 | res, err := client.Request(ctx, http.MethodGet, tunnel.URL.String(), nil) 544 | require.NoError(t, err, "create request") 545 | if err == nil { 546 | _ = res.Body.Close() 547 | } 548 | return err == nil && res.StatusCode == http.StatusOK 549 | }, 15*time.Second, 100*time.Millisecond) 550 | } 551 | -------------------------------------------------------------------------------- /tunnelsdk/api.go: -------------------------------------------------------------------------------- 1 | package tunnelsdk 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "net/http" 7 | "net/netip" 8 | "time" 9 | 10 | "github.com/tailscale/wireguard-go/device" 11 | ) 12 | 13 | type Response struct { 14 | Message string `json:"message"` 15 | Detail string `json:"detail,omitempty"` 16 | } 17 | 18 | type ClientRegisterRequest struct { 19 | Version TunnelVersion `json:"version"` 20 | PublicKey device.NoisePublicKey `json:"public_key"` 21 | } 22 | 23 | type ClientRegisterResponse struct { 24 | Version TunnelVersion `json:"version"` 25 | ReregisterWait time.Duration `json:"reregister_wait"` 26 | // TunnelURLs contains a list of valid URLs that will be forwarded from the 27 | // server to this tunnel client once connected. The first URL is the 28 | // preferred URL, and the other URLs are provided for compatibility 29 | // purposes only. 30 | // 31 | // The order of the URLs changes based on the Version field in the request. 32 | TunnelURLs []string `json:"tunnel_urls"` 33 | ClientIP netip.Addr `json:"client_ip"` 34 | 35 | ServerEndpoint string `json:"server_endpoint"` 36 | ServerIP netip.Addr `json:"server_ip"` 37 | ServerPublicKey device.NoisePublicKey `json:"server_public_key"` 38 | WireguardMTU int `json:"wireguard_mtu"` 39 | } 40 | 41 | func (c *Client) ClientRegister(ctx context.Context, req ClientRegisterRequest) (ClientRegisterResponse, error) { 42 | res, err := c.Request(ctx, http.MethodPost, "/api/v2/clients", req) 43 | if err != nil { 44 | return ClientRegisterResponse{}, err 45 | } 46 | defer res.Body.Close() 47 | if res.StatusCode != http.StatusOK { 48 | return ClientRegisterResponse{}, readBodyAsError(res) 49 | } 50 | 51 | var resp ClientRegisterResponse 52 | return resp, json.NewDecoder(res.Body).Decode(&resp) 53 | } 54 | -------------------------------------------------------------------------------- /tunnelsdk/client.go: -------------------------------------------------------------------------------- 1 | package tunnelsdk 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "mime" 11 | "net/http" 12 | "net/url" 13 | "strings" 14 | 15 | "golang.org/x/xerrors" 16 | ) 17 | 18 | // New creates a tunneld client for the provided URL. 19 | func New(serverURL *url.URL) *Client { 20 | return &Client{ 21 | HTTPClient: &http.Client{}, 22 | URL: serverURL, 23 | } 24 | } 25 | 26 | // Client provides HTTP methods for the tunneld API and a full wireguard tunnel 27 | // client implementation. 28 | type Client struct { 29 | HTTPClient *http.Client 30 | URL *url.URL 31 | } 32 | 33 | // Request performs an HTTP request with the body provided. The caller is 34 | // responsible for closing the response body. 35 | func (c *Client) Request(ctx context.Context, method, path string, body interface{}) (*http.Response, error) { 36 | serverURL, err := c.URL.Parse(path) 37 | if err != nil { 38 | return nil, xerrors.Errorf("parse url: %w", err) 39 | } 40 | 41 | var buf bytes.Buffer 42 | if body != nil { 43 | if data, ok := body.([]byte); ok { 44 | buf = *bytes.NewBuffer(data) 45 | } else { 46 | // Assume JSON if not bytes. 47 | enc := json.NewEncoder(&buf) 48 | enc.SetEscapeHTML(false) 49 | err = enc.Encode(body) 50 | if err != nil { 51 | return nil, xerrors.Errorf("encode body: %w", err) 52 | } 53 | } 54 | } 55 | 56 | req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), &buf) 57 | if err != nil { 58 | return nil, xerrors.Errorf("create request: %w", err) 59 | } 60 | 61 | if body != nil { 62 | req.Header.Set("Content-Type", "application/json") 63 | } 64 | 65 | resp, err := c.HTTPClient.Do(req) 66 | if err != nil { 67 | return nil, xerrors.Errorf("do: %w", err) 68 | } 69 | return resp, err 70 | } 71 | 72 | // readBodyAsError reads the response body as tunnelsdk.Error type for easily 73 | // reading the error message. 74 | func readBodyAsError(res *http.Response) error { 75 | if res == nil { 76 | return xerrors.Errorf("no body returned") 77 | } 78 | defer res.Body.Close() 79 | contentType := res.Header.Get("Content-Type") 80 | 81 | var method, u string 82 | if res.Request != nil { 83 | method = res.Request.Method 84 | if res.Request.URL != nil { 85 | u = res.Request.URL.String() 86 | } 87 | } 88 | 89 | resp, err := io.ReadAll(res.Body) 90 | if err != nil { 91 | return xerrors.Errorf("read body: %w", err) 92 | } 93 | 94 | mimeType, _, err := mime.ParseMediaType(contentType) 95 | if err != nil { 96 | mimeType = strings.TrimSpace(strings.Split(contentType, ";")[0]) 97 | } 98 | if mimeType != "application/json" { 99 | if len(resp) > 1024 { 100 | resp = append(resp[:1024], []byte("...")...) 101 | } 102 | if len(resp) == 0 { 103 | resp = []byte("no response body") 104 | } 105 | return &Error{ 106 | statusCode: res.StatusCode, 107 | Response: Response{ 108 | Message: "unexpected non-JSON response", 109 | Detail: string(resp), 110 | }, 111 | } 112 | } 113 | 114 | var m Response 115 | err = json.NewDecoder(bytes.NewBuffer(resp)).Decode(&m) 116 | if err != nil { 117 | if errors.Is(err, io.EOF) { 118 | return &Error{ 119 | statusCode: res.StatusCode, 120 | Response: Response{ 121 | Message: "empty response body", 122 | }, 123 | } 124 | } 125 | return xerrors.Errorf("decode body: %w", err) 126 | } 127 | if m.Message == "" { 128 | if len(resp) > 1024 { 129 | resp = append(resp[:1024], []byte("...")...) 130 | } 131 | m.Message = fmt.Sprintf("unexpected status code %d, response has no message", res.StatusCode) 132 | m.Detail = string(resp) 133 | } 134 | 135 | return &Error{ 136 | Response: m, 137 | statusCode: res.StatusCode, 138 | method: method, 139 | url: u, 140 | } 141 | } 142 | 143 | // Error represents an unaccepted or invalid request to the API. 144 | type Error struct { 145 | Response 146 | 147 | statusCode int 148 | method string 149 | url string 150 | } 151 | 152 | func (e *Error) StatusCode() int { 153 | return e.statusCode 154 | } 155 | 156 | func (e *Error) Friendly() string { 157 | return e.Message 158 | } 159 | 160 | func (e *Error) Error() string { 161 | var builder strings.Builder 162 | if e.method != "" && e.url != "" { 163 | _, _ = fmt.Fprintf(&builder, "%v %v: ", e.method, e.url) 164 | } 165 | _, _ = fmt.Fprintf(&builder, "unexpected status code %d: %s", e.statusCode, e.Message) 166 | if e.Detail != "" { 167 | _, _ = fmt.Fprintf(&builder, "\n\tError: %s", e.Detail) 168 | } 169 | 170 | return builder.String() 171 | } 172 | -------------------------------------------------------------------------------- /tunnelsdk/tunnel.go: -------------------------------------------------------------------------------- 1 | package tunnelsdk 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "crypto/sha512" 7 | "encoding/hex" 8 | "errors" 9 | "fmt" 10 | "math/big" 11 | "net" 12 | "net/netip" 13 | "net/url" 14 | "time" 15 | 16 | "github.com/tailscale/wireguard-go/conn" 17 | "github.com/tailscale/wireguard-go/device" 18 | "github.com/tailscale/wireguard-go/tun/netstack" 19 | "golang.org/x/xerrors" 20 | "golang.zx2c4.com/wireguard/wgctrl/wgtypes" 21 | 22 | "cdr.dev/slog" 23 | ) 24 | 25 | // TunnelPort is the port in the virtual wireguard network stack that the 26 | // listener is listening on. 27 | const TunnelPort = 8090 28 | 29 | // TunnelVersion is the version of the tunnel URL specification. 30 | type TunnelVersion int 31 | 32 | const ( 33 | // TunnelVersion1 is the "old style" tunnel URL. Each hostname base is 32 34 | // characters long and is base16 (hex) encoded. 35 | TunnelVersion1 TunnelVersion = 1 36 | // TunnelVersion2 is the "new style" tunnel URL. Each hostname base is ~12 37 | // characters long and is base32 encoded. 38 | TunnelVersion2 TunnelVersion = 2 39 | 40 | TunnelVersionLatest = TunnelVersion2 41 | ) 42 | 43 | // Key is a Wireguard private or public key. 44 | type Key struct { 45 | k wgtypes.Key 46 | isPrivate bool 47 | } 48 | 49 | // GenerateWireguardPrivateKey generates a new wireguard private key using 50 | // secure cryptography. The caller should store the key (using key.String()) in 51 | // a safe place like the user's home directory, and use it in the future rather 52 | // than generating a new key each time. 53 | func GeneratePrivateKey() (Key, error) { 54 | key, err := wgtypes.GeneratePrivateKey() 55 | if err != nil { 56 | return Key{}, err 57 | } 58 | 59 | return Key{ 60 | k: key, 61 | isPrivate: true, 62 | }, nil 63 | } 64 | 65 | // ParsePrivateKey parses a private key generated using key.String(). 66 | func ParsePrivateKey(key string) (Key, error) { 67 | k, err := wgtypes.ParseKey(key) 68 | if err != nil { 69 | return Key{}, err 70 | } 71 | 72 | return Key{ 73 | k: k, 74 | // assume it's private, not really any way to tell unfortunately 75 | isPrivate: true, 76 | }, nil 77 | } 78 | 79 | // ParsePublicKey parses a public key generated using key.String(). 80 | func ParsePublicKey(key string) (Key, error) { 81 | k, err := wgtypes.ParseKey(key) 82 | if err != nil { 83 | return Key{}, err 84 | } 85 | 86 | return Key{ 87 | k: k, 88 | isPrivate: false, 89 | }, nil 90 | } 91 | 92 | // FromNoisePrivateKey converts a device.NoisePrivateKey to a Key. 93 | func FromNoisePrivateKey(k device.NoisePrivateKey) Key { 94 | return Key{ 95 | k: wgtypes.Key(k), 96 | isPrivate: true, 97 | } 98 | } 99 | 100 | // FromNoisePublicKey converts a device.NoisePublicKey to a Key. 101 | func FromNoisePublicKey(k device.NoisePublicKey) Key { 102 | return Key{ 103 | k: wgtypes.Key(k), 104 | isPrivate: false, 105 | } 106 | } 107 | 108 | // IsZero returns true if the Key is the zero value. 109 | func (k Key) IsZero() bool { 110 | return k.k == wgtypes.Key{} 111 | } 112 | 113 | // IsPrivate returns true if the key is a private key. 114 | func (k Key) IsPrivate() bool { 115 | return k.isPrivate 116 | } 117 | 118 | // String returns a base64 encoded string representation of the key. 119 | func (k Key) String() string { 120 | return k.k.String() 121 | } 122 | 123 | // HexString returns the hex string representation of the key. 124 | func (k Key) HexString() string { 125 | return hex.EncodeToString(k.k[:]) 126 | } 127 | 128 | // Hash returns the SHA512 hash of the key. 129 | func (k Key) Hash() string { 130 | hash := sha512.Sum512(k.k[:]) 131 | return hex.EncodeToString(hash[:]) 132 | } 133 | 134 | // NoisePrivateKey returns the device.NoisePrivateKey for the key. If the key is 135 | // not a private key, an error is returned. 136 | func (k Key) NoisePrivateKey() (device.NoisePrivateKey, error) { 137 | if !k.isPrivate { 138 | return device.NoisePrivateKey{}, xerrors.Errorf("cannot call key.NoisePrivateKey() on a public key") 139 | } 140 | 141 | return device.NoisePrivateKey(k.k), nil 142 | } 143 | 144 | // NoisePublicKey returns the device.NoisePublicKey for the key. If the key is a 145 | // private key, it is converted to a public key automatically. 146 | func (k Key) NoisePublicKey() device.NoisePublicKey { 147 | if k.isPrivate { 148 | return device.NoisePublicKey(k.k.PublicKey()) 149 | } 150 | 151 | return device.NoisePublicKey(k.k) 152 | } 153 | 154 | // PublicKey returns the public key component of the Wireguard private key. If 155 | // the key is not a private key, an error is returned. 156 | func (k Key) PublicKey() (Key, error) { 157 | if !k.isPrivate { 158 | return k, xerrors.Errorf("cannot call key.PublicKey() on a public key") 159 | } 160 | 161 | return Key{ 162 | k: k.k.PublicKey(), 163 | isPrivate: false, 164 | }, nil 165 | } 166 | 167 | type TunnelConfig struct { 168 | Log slog.Logger 169 | // Version denotes which version of the tunnel URL specification to use. 170 | // Undefined version is treated as the latest version. 171 | Version TunnelVersion 172 | // PrivateKey is the Wireguard private key. You can use GeneratePrivateKey 173 | // to generate a new key. It should be stored in a safe place for future 174 | // tunnel sessions, otherwise you will get a new hostname. 175 | PrivateKey Key 176 | } 177 | 178 | // LaunchTunnel makes a request to the tunneld server to register the client's 179 | // tunnel using the client's public key, then establishes a wireguard connection 180 | // to the server and returns a *Tunnel. Connections can be accepted from 181 | // tunnel.Listener. 182 | func (c *Client) LaunchTunnel(ctx context.Context, cfg TunnelConfig) (*Tunnel, error) { 183 | if cfg.Version == 0 { 184 | cfg.Version = TunnelVersionLatest 185 | } 186 | 187 | pubKey := cfg.PrivateKey.NoisePublicKey() 188 | 189 | res, err := c.ClientRegister(ctx, ClientRegisterRequest{ 190 | Version: cfg.Version, 191 | PublicKey: pubKey, 192 | }) 193 | if err != nil { 194 | return nil, xerrors.Errorf("initial client registration: %w", err) 195 | } 196 | if len(res.TunnelURLs) == 0 { 197 | return nil, xerrors.Errorf("no tunnel urls returned from server") 198 | } 199 | if res.ReregisterWait <= 0 { 200 | return nil, xerrors.Errorf("invalid reregister wait time: %s", res.ReregisterWait) 201 | } 202 | 203 | primaryURL, err := url.Parse(res.TunnelURLs[0]) 204 | if err != nil { 205 | return nil, xerrors.Errorf("parse tunnel url: %w", err) 206 | } 207 | 208 | otherURLs := make([]*url.URL, len(res.TunnelURLs)-1) 209 | for i, u := range res.TunnelURLs[1:] { 210 | otherURLs[i], err = url.Parse(u) 211 | if err != nil { 212 | return nil, xerrors.Errorf("parse tunnel url %d (%q): %w", i, u, err) 213 | } 214 | } 215 | 216 | // Ensure the returned server endpoint from the API is an IP address and not 217 | // a hostname to avoid constant DNS lookups. 218 | host, port, err := net.SplitHostPort(res.ServerEndpoint) 219 | if err != nil { 220 | return nil, xerrors.Errorf("parse server endpoint: %w", err) 221 | } 222 | wgIP, err := net.ResolveIPAddr("ip", host) 223 | if err != nil { 224 | return nil, xerrors.Errorf("resolve endpoint: %w", err) 225 | } 226 | wgEndpoint := net.JoinHostPort(wgIP.String(), port) 227 | 228 | // Start re-registering the client every 30 seconds. 229 | returnedOK := false 230 | tunnelCtx, tunnelCancel := context.WithCancel(context.Background()) 231 | defer func() { 232 | if !returnedOK { 233 | tunnelCancel() 234 | } 235 | }() 236 | go func() { 237 | ticker := time.NewTicker(res.ReregisterWait) 238 | defer ticker.Stop() 239 | 240 | for { 241 | select { 242 | case <-tunnelCtx.Done(): 243 | return 244 | case <-ticker.C: 245 | } 246 | 247 | ctx, cancel := context.WithTimeout(tunnelCtx, 10*time.Second) 248 | res, err := c.ClientRegister(ctx, ClientRegisterRequest{ 249 | PublicKey: pubKey, 250 | }) 251 | if err != nil && !errors.Is(err, context.Canceled) { 252 | cfg.Log.Warn(ctx, "periodically re-register tunnel", slog.Error(err)) 253 | } 254 | 255 | // If we failed to re-register, try again in 30 seconds plus a 256 | // random amount of time between 0 and 30 seconds. 257 | if res.ReregisterWait <= 0 { 258 | res.ReregisterWait = 30 * time.Second 259 | i, err := rand.Int(rand.Reader, big.NewInt(30)) 260 | if err != nil { 261 | i = big.NewInt(30) 262 | } 263 | res.ReregisterWait += time.Duration(i.Int64()) * time.Second 264 | } 265 | 266 | ticker.Reset(res.ReregisterWait) 267 | cancel() 268 | } 269 | }() 270 | 271 | // Create wireguard virtual network stack. 272 | tun, tnet, err := netstack.CreateNetTUN( 273 | []netip.Addr{res.ClientIP}, 274 | // We don't resolve hostnames in the tunnel, so we don't need a DNS 275 | // server. 276 | []netip.Addr{}, 277 | res.WireguardMTU, 278 | ) 279 | if err != nil { 280 | return nil, xerrors.Errorf("create net TUN: %w", err) 281 | } 282 | 283 | // Create wireguard device, configure it and start it. 284 | deviceLogger := cfg.Log.Named("wireguard_device") 285 | dlog := &device.Logger{ 286 | Verbosef: func(format string, args ...any) { 287 | deviceLogger.Debug(ctx, fmt.Sprintf(format, args...)) 288 | }, 289 | Errorf: func(format string, args ...any) { 290 | deviceLogger.Error(ctx, fmt.Sprintf(format, args...)) 291 | }, 292 | } 293 | dev := device.NewDevice(tun, conn.NewDefaultBind(), dlog) 294 | err = dev.IpcSet(fmt.Sprintf(`private_key=%s 295 | public_key=%s 296 | endpoint=%s 297 | persistent_keepalive_interval=21 298 | allowed_ip=%s/128`, 299 | cfg.PrivateKey.HexString(), 300 | hex.EncodeToString(res.ServerPublicKey[:]), 301 | wgEndpoint, 302 | res.ServerIP.String(), 303 | )) 304 | if err != nil { 305 | return nil, xerrors.Errorf("configure wireguard ipc: %w", err) 306 | } 307 | err = dev.Up() 308 | if err != nil { 309 | return nil, xerrors.Errorf("wireguard device up: %w", err) 310 | } 311 | 312 | // Create a listener on the static tunnel port. 313 | wgListen, err := tnet.ListenTCP(&net.TCPAddr{Port: TunnelPort}) 314 | if err != nil { 315 | return nil, xerrors.Errorf("wireguard device listen: %w", err) 316 | } 317 | 318 | closed := make(chan struct{}, 1) 319 | closeFn := func() { 320 | tunnelCancel() 321 | 322 | _ = wgListen.Close() 323 | // Remove peers before closing to avoid a race condition between 324 | // dev.Close() and the peer goroutines which results in segfault. 325 | dev.RemoveAllPeers() 326 | dev.Close() 327 | } 328 | go func() { 329 | defer close(closed) 330 | select { 331 | case <-ctx.Done(): 332 | closeFn() 333 | case <-dev.Wait(): 334 | tunnelCancel() 335 | } 336 | }() 337 | 338 | returnedOK = true 339 | return &Tunnel{ 340 | closeFn: closeFn, 341 | closed: closed, 342 | URL: primaryURL, 343 | OtherURLs: otherURLs, 344 | Listener: wgListen, 345 | }, nil 346 | } 347 | 348 | type Tunnel struct { 349 | closeFn func() 350 | closed <-chan struct{} 351 | URL *url.URL 352 | OtherURLs []*url.URL 353 | Listener net.Listener 354 | } 355 | 356 | func (t *Tunnel) Close() error { 357 | t.closeFn() 358 | return nil 359 | } 360 | 361 | func (t *Tunnel) Wait() <-chan struct{} { 362 | return t.closed 363 | } 364 | --------------------------------------------------------------------------------