├── .github ├── release.yml └── workflows │ ├── tagpr.yml │ └── test.yml ├── .gitignore ├── .goreleaser.yml ├── .tagpr ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── cmd └── wsgate-server │ └── main.go ├── go.mod ├── go.sum ├── internal ├── dumper │ └── dumper.go ├── handler │ ├── handler.go │ └── handler_test.go ├── mapping │ └── mapping.go └── publickey │ ├── publickey.go │ └── publickey_test.go └── sample-map.txt /.github/release.yml: -------------------------------------------------------------------------------- 1 | changelog: 2 | exclude: 3 | labels: 4 | - tagpr 5 | -------------------------------------------------------------------------------- /.github/workflows/tagpr.yml: -------------------------------------------------------------------------------- 1 | name: tagpr 2 | on: 3 | push: 4 | branches: 5 | - master 6 | jobs: 7 | tagpr: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | contents: write 11 | pull-requests: write 12 | issues: write 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - uses: Songmu/tagpr@v1 17 | id: tagpr 18 | env: 19 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 20 | 21 | - name: Setup Go 22 | uses: actions/setup-go@v4 23 | with: 24 | go-version-file: go.mod 25 | if: ${{ steps.tagpr.outputs.tag != ''}} 26 | 27 | - name: Run GoReleaser 28 | uses: goreleaser/goreleaser-action@v5 29 | with: 30 | version: latest 31 | args: release --clean 32 | env: 33 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | CGO_ENABLED: 0 35 | if: ${{ steps.tagpr.outputs.tag != ''}} 36 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - "**" 6 | jobs: 7 | test: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | contents: read 11 | steps: 12 | - name: Checkout 13 | uses: actions/checkout@v4 14 | with: 15 | fetch-depth: 0 16 | 17 | - name: Setup Go 18 | uses: actions/setup-go@v4 19 | with: 20 | go-version-file: go.mod 21 | 22 | - name: test 23 | run: make check 24 | env: 25 | CGO_ENABLED: 0 26 | 27 | - name: Snapshot GoReleaser 28 | uses: goreleaser/goreleaser-action@v5 29 | with: 30 | version: latest 31 | args: build --snapshot 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | /wsgate-server 27 | eg/ 28 | vendor/ 29 | dist/ 30 | *~ 31 | -------------------------------------------------------------------------------- /.goreleaser.yml: -------------------------------------------------------------------------------- 1 | builds: 2 | - binary: wsgate-server 3 | main: ./cmd/wsgate-server/main.go 4 | goos: 5 | - darwin 6 | - linux 7 | goarch: 8 | - amd64 9 | - arm64 10 | ignore: 11 | - goos: darwin 12 | goarch: arm64 13 | env: 14 | - CGO_ENABLED=0 15 | archives: 16 | - format: zip 17 | name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Arch }}" 18 | release: 19 | github: 20 | owner: kazeburo 21 | name: wsgate-server 22 | -------------------------------------------------------------------------------- /.tagpr: -------------------------------------------------------------------------------- 1 | # config file for the tagpr in git config format 2 | # The tagpr generates the initial configuration, which you can rewrite to suit your environment. 3 | # CONFIGURATIONS: 4 | # tagpr.releaseBranch 5 | # Generally, it is "main." It is the branch for releases. The tagpr tracks this branch, 6 | # creates or updates a pull request as a release candidate, or tags when they are merged. 7 | # 8 | # tagpr.versionFile 9 | # Versioning file containing the semantic version needed to be updated at release. 10 | # It will be synchronized with the "git tag". 11 | # Often this is a meta-information file such as gemspec, setup.cfg, package.json, etc. 12 | # Sometimes the source code file, such as version.go or Bar.pm, is used. 13 | # If you do not want to use versioning files but only git tags, specify the "-" string here. 14 | # You can specify multiple version files by comma separated strings. 15 | # 16 | # tagpr.vPrefix 17 | # Flag whether or not v-prefix is added to semver when git tagging. (e.g. v1.2.3 if true) 18 | # This is only a tagging convention, not how it is described in the version file. 19 | # 20 | # tagpr.changelog (Optional) 21 | # Flag whether or not changelog is added or changed during the release. 22 | # 23 | # tagpr.command (Optional) 24 | # Command to change files just before release. 25 | # 26 | # tagpr.template (Optional) 27 | # Pull request template file in go template format 28 | # 29 | # tagpr.templateText (Optional) 30 | # Pull request template text in go template format 31 | # 32 | # tagpr.release (Optional) 33 | # GitHub Release creation behavior after tagging [true, draft, false] 34 | # If this value is not set, the release is to be created. 35 | # 36 | # tagpr.majorLabels (Optional) 37 | # Label of major update targets. Default is [major] 38 | # 39 | # tagpr.minorLabels (Optional) 40 | # Label of minor update targets. Default is [minor] 41 | # 42 | # tagpr.commitPrefix (Optional) 43 | # Prefix of commit message. Default is "[tagpr]" 44 | # 45 | [tagpr] 46 | vPrefix = true 47 | releaseBranch = master 48 | versionFile = Makefile 49 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [v0.4.1](https://github.com/kazeburo/wsgate-server/compare/v0.4.0...v0.4.1) - 2025-04-19 4 | - modernize packages by @kazeburo in https://github.com/kazeburo/wsgate-server/pull/9 5 | - add issue permission by @kazeburo in https://github.com/kazeburo/wsgate-server/pull/11 6 | - replace jwt. remove seq with atomic by @kazeburo in https://github.com/kazeburo/wsgate-server/pull/13 7 | 8 | ## [v0.4.0](https://github.com/kazeburo/wsgate-server/compare/v0.3.1...v0.4.0) - 2019-08-28 9 | - impl: compression by @azonti in https://github.com/kazeburo/wsgate-server/pull/8 10 | 11 | ## [v0.3.1](https://github.com/kazeburo/wsgate-server/compare/v0.3.0...v0.3.1) - 2019-06-07 12 | - reuse buffer by @kazeburo in https://github.com/kazeburo/wsgate-server/pull/7 13 | 14 | ## [v0.3.0](https://github.com/kazeburo/wsgate-server/compare/v0.2.1...v0.3.0) - 2019-05-09 15 | - graceful stop by @kazeburo in https://github.com/kazeburo/wsgate-server/pull/5 16 | - improve verify token by @kazeburo in https://github.com/kazeburo/wsgate-server/pull/6 17 | 18 | ## [v0.2.1](https://github.com/kazeburo/wsgate-server/compare/v0.2.0...v0.2.1) - 2019-04-22 19 | - validate signed method by @kazeburo in https://github.com/kazeburo/wsgate-server/pull/4 20 | 21 | ## [v0.2.0](https://github.com/kazeburo/wsgate-server/compare/v0.1.1...v0.2.0) - 2019-02-05 22 | - use zap by @kazeburo in https://github.com/kazeburo/wsgate-server/pull/1 23 | - devide to 3 pkgs by @kazeburo in https://github.com/kazeburo/wsgate-server/pull/2 24 | - Log tcp dump by @kazeburo in https://github.com/kazeburo/wsgate-server/pull/3 25 | 26 | ## [v0.1.1](https://github.com/kazeburo/wsgate-server/compare/v0.1.0...v0.1.1) - 2018-10-11 27 | 28 | ## [v0.1.0](https://github.com/kazeburo/wsgate-server/compare/v0.0.8...v0.1.0) - 2018-09-26 29 | 30 | ## [v0.0.8](https://github.com/kazeburo/wsgate-server/compare/v0.0.7...v0.0.8) - 2018-09-26 31 | 32 | ## [v0.0.7](https://github.com/kazeburo/wsgate-server/compare/v0.0.6...v0.0.7) - 2018-09-26 33 | 34 | ## [v0.0.6](https://github.com/kazeburo/wsgate-server/compare/v0.0.5...v0.0.6) - 2018-09-26 35 | 36 | ## [v0.0.5](https://github.com/kazeburo/wsgate-server/compare/v0.0.4...v0.0.5) - 2018-09-26 37 | 38 | ## [v0.0.4](https://github.com/kazeburo/wsgate-server/compare/v0.0.3...v0.0.4) - 2018-09-26 39 | 40 | ## [v0.0.3](https://github.com/kazeburo/wsgate-server/compare/v0.0.2...v0.0.3) - 2018-09-26 41 | 42 | ## [v0.0.2](https://github.com/kazeburo/wsgate-server/compare/v0.0.1...v0.0.2) - 2018-09-26 43 | 44 | ## [v0.0.1](https://github.com/kazeburo/wsgate-server/commits/v0.0.1) - 2018-04-06 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Masahiro Nagano 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VERSION=0.4.1 2 | LDFLAGS=-ldflags "-w -s -X main.Version=${VERSION}" 3 | all: wsgate-server 4 | 5 | .PHONY: wsgate-server 6 | 7 | wsgate-server: cmd/wsgate-server/main.go 8 | go build $(LDFLAGS) -o wsgate-server cmd/wsgate-server/main.go 9 | 10 | linux: cmd/wsgate-server/main.go 11 | GOOS=linux GOARCH=amd64 go build $(LDFLAGS) -o wsgate-server cmd/wsgate-server/main.go 12 | 13 | check: 14 | go test -v ./... 15 | 16 | fmt: 17 | go fmt ./... 18 | 19 | clean: 20 | rm -rf wsgate-server 21 | 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wsgate-server - a websocket to tcp proxy/bridge server 2 | 3 | ``` 4 | [client] 5 | | 6 | | TCP 7 | | 8 | [wsgate-client] (https://github.com/kazeburo/wsgate-client) 9 | | 10 | | websocket (wss) 11 | | 12 | [reverse proxy] if required 13 | | 14 | | websocket (ws) 15 | | 16 | [wsgate-server] (https://github.com/kazeburo/wsgate-server) 17 | | 18 | | TCP 19 | | 20 | [server] 21 | ``` 22 | 23 | ## Example 24 | 25 | ### wsgate-server 26 | 27 | map-server.txt 28 | 29 | ``` 30 | mysql,127.0.0.1:3306 31 | ssh,127.0.0.1:22 32 | ``` 33 | run server 34 | 35 | ``` 36 | $ wsgate-server --listen 0.0.0.0:8086 --map map-server.txt 37 | ``` 38 | 39 | ### wsgate-client 40 | 41 | map-client.txt 42 | 43 | ``` 44 | 127.0.0.1:8306,https://example.com/proxy/mysql 45 | 127.0.0.1:8022,https://example.com/proxy/ssh 46 | ``` 47 | 48 | run client server 49 | 50 | ``` 51 | $ wsgate-client --map map-client.txt 52 | ``` 53 | 54 | ### client 55 | 56 | ``` 57 | # mysql 58 | $ mysql -h 127.0.0.1 --port 8306 --user ... 59 | 60 | # ssh 61 | ssh -p 8022 user@127.0.0.1 62 | ``` 63 | 64 | ### Using go-sql-driver/mysql 65 | 66 | It's able to use RegisterDial to connect wsgate-server. 67 | 68 | ``` 69 | mysql.RegisterDial("websocket", func(url string) (net.Conn, error) { 70 | wsURL := strings.Replace(url, "http", "ws", 1) 71 | wsConf, err := websocket.NewConfig(wsURL, url) 72 | if err != nil { 73 | log.Fatalf("NewConfig failed: %v", err) 74 | } 75 | conn, err := websocket.DialConfig(wsConf) 76 | if err != nil { 77 | log.Fatalf("Dial to %q fail: %v", url, err) 78 | } 79 | conn.PayloadType = websocket.BinaryFrame 80 | return conn, err 81 | }) 82 | 83 | db, err := sql.Open("mysql", "yyyy:xxx@websocket(https://example.com/proxy/mysql)/test") 84 | ``` 85 | 86 | ## Usage 87 | 88 | ``` 89 | Usage of ./wsgate-server: 90 | -dial_timeout duration 91 | Dial timeout. (default 10s) 92 | -dump-tcp uint 93 | Dump TCP. 0 = disable, 1 = src to dest, 2 = both 94 | -handshake_timeout duration 95 | Handshake timeout. (default 10s) 96 | -jwt-freshness duration 97 | time in seconds to allow generated jwt tokens (default 1h0m0s) 98 | -listen string 99 | Address to listen to. (default "127.0.0.1:8086") 100 | -map string 101 | path and proxy host mapping file 102 | -public-key string 103 | public key for verifying JWT auth header 104 | -shutdown_timeout duration 105 | timeout to wait for all connections to be closed (default 24h0m0s) 106 | -version 107 | show version 108 | -write_timeout duration 109 | Write timeout. (default 10s) 110 | ``` 111 | -------------------------------------------------------------------------------- /cmd/wsgate-server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "os" 10 | "os/signal" 11 | "runtime" 12 | "sync" 13 | "syscall" 14 | "time" 15 | 16 | "github.com/gorilla/mux" 17 | "github.com/kazeburo/wsgate-server/internal/handler" 18 | "github.com/kazeburo/wsgate-server/internal/mapping" 19 | "github.com/kazeburo/wsgate-server/internal/publickey" 20 | ss "github.com/lestrrat/go-server-starter-listener" 21 | "go.uber.org/zap" 22 | ) 23 | 24 | var ( 25 | // Version wsgate-server version 26 | Version string 27 | showVersion = flag.Bool("version", false, "Show version") 28 | listen = flag.String("listen", "127.0.0.1:8086", "Address to listen to") 29 | handshakeTimeout = flag.Duration("handshake_timeout", 10*time.Second, "Handshake timeout") 30 | dialTimeout = flag.Duration("dial_timeout", 10*time.Second, "Dial timeout") 31 | writeTimeout = flag.Duration("write_timeout", 10*time.Second, "Write timeout") 32 | shutdownTimeout = flag.Duration("shutdown_timeout", 86400*time.Second, "Timeout to wait for all connections to be closed") 33 | enableCompression = flag.Bool("enable_compression", false, "To enable WebSocket Per-Message Compression Extensions (RFC 7692)") 34 | mapFile = flag.String("map", "", "Path and proxy host mapping file") 35 | publicKeyFile = flag.String("public-key", "", "Public key for verifying JWT auth header") 36 | jwtFreshness = flag.Duration("jwt-freshness", 3600*time.Second, "Time in seconds to allow generated jwt tokens") 37 | dumpTCP = flag.Uint("dump-tcp", 0, "Dump TCP. 0 = disable, 1 = src to dest, 2 = both") 38 | ) 39 | 40 | func printVersion() { 41 | fmt.Printf(`wsgate-server %s 42 | Compiler: %s %s 43 | `, 44 | Version, 45 | runtime.Compiler, 46 | runtime.Version()) 47 | } 48 | 49 | func main() { 50 | flag.Parse() 51 | 52 | if *showVersion { 53 | printVersion() 54 | return 55 | } 56 | 57 | logger, _ := zap.NewProduction() 58 | 59 | mp, err := mapping.New(*mapFile, logger) 60 | if err != nil { 61 | logger.Fatal("Failed init mapping", zap.Error(err)) 62 | } 63 | 64 | pk, err := publickey.New(*publicKeyFile, *jwtFreshness, logger) 65 | if err != nil { 66 | logger.Fatal("Failed init publickey", zap.Error(err)) 67 | } 68 | 69 | proxyHandler, err := handler.New( 70 | *handshakeTimeout, 71 | *dialTimeout, 72 | *writeTimeout, 73 | *enableCompression, 74 | mp, 75 | pk, 76 | *dumpTCP, 77 | logger, 78 | ) 79 | if err != nil { 80 | logger.Fatal("Failed init handler", zap.Error(err)) 81 | } 82 | 83 | wg := &sync.WaitGroup{} 84 | defer func() { 85 | c := make(chan struct{}) 86 | go func() { 87 | defer close(c) 88 | wg.Wait() 89 | }() 90 | select { 91 | case <-c: 92 | logger.Info("All connections closed. Shutdown") 93 | return 94 | case <-time.After(*shutdownTimeout): 95 | logger.Info("Timeout, close some connections. Shutdown") 96 | return 97 | } 98 | }() 99 | 100 | m := mux.NewRouter() 101 | m.HandleFunc("/", proxyHandler.Hello()) 102 | m.HandleFunc("/live", proxyHandler.Hello()) 103 | m.HandleFunc("/proxy/{dest}", proxyHandler.Proxy(wg)) 104 | 105 | s := &http.Server{ 106 | Handler: m, 107 | ReadTimeout: 10 * time.Second, 108 | WriteTimeout: 10 * time.Second, 109 | MaxHeaderBytes: 1 << 20, 110 | } 111 | 112 | idleConnsClosed := make(chan struct{}) 113 | go func() { 114 | sigChan := make(chan os.Signal, 1) 115 | signal.Notify(sigChan, syscall.SIGTERM) 116 | <-sigChan 117 | logger.Info("Signal received. Start to shutdown") 118 | ctx, cancel := context.WithTimeout(context.Background(), *shutdownTimeout) 119 | if es := s.Shutdown(ctx); es != nil { 120 | logger.Warn("Shutdown error", zap.Error(es)) 121 | } 122 | cancel() 123 | close(idleConnsClosed) 124 | logger.Info("Waiting for all connections to be closed") 125 | }() 126 | 127 | l, err := ss.NewListener() 128 | if l == nil || err != nil { 129 | // Fallback if not running under Server::Starter 130 | l, err = net.Listen("tcp", *listen) 131 | if err != nil { 132 | logger.Fatal("Failed to listen to port", zap.String("listen", *listen)) 133 | } 134 | } 135 | 136 | if err := s.Serve(l); err != http.ErrServerClosed { 137 | logger.Error("Error in Serve", zap.Error(err)) 138 | } 139 | 140 | <-idleConnsClosed 141 | } 142 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/kazeburo/wsgate-server 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.23.4 6 | 7 | require ( 8 | github.com/gorilla/mux v1.8.1 9 | github.com/gorilla/websocket v1.5.3 10 | github.com/lestrrat/go-server-starter-listener v0.0.0-20150507032651-00dd68592c85 11 | github.com/pkg/errors v0.9.1 12 | go.uber.org/zap v1.27.0 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | gopkg.in/yaml.v3 v3.0.1 // indirect 19 | ) 20 | 21 | require ( 22 | github.com/golang-jwt/jwt/v5 v5.2.2 23 | github.com/stretchr/testify v1.8.1 24 | go.uber.org/multierr v1.11.0 // indirect 25 | golang.org/x/net v0.39.0 26 | ) 27 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= 5 | github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= 6 | github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= 7 | github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= 8 | github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= 9 | github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 10 | github.com/lestrrat/go-server-starter-listener v0.0.0-20150507032651-00dd68592c85 h1:gxayjLkXkf5th4qXa32uk8PngRdBRGOiH+cn2bamRCQ= 11 | github.com/lestrrat/go-server-starter-listener v0.0.0-20150507032651-00dd68592c85/go.mod h1:qWioISoOtEGapIqSRs4MTbygZUEd8fsWzdmYu5bY8Bs= 12 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 13 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 14 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 16 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 17 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 18 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 19 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 20 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 21 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 22 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 23 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 24 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 25 | go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= 26 | go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= 27 | go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= 28 | go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= 29 | golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= 30 | golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= 31 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 32 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 33 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 34 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 35 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 36 | -------------------------------------------------------------------------------- /internal/dumper/dumper.go: -------------------------------------------------------------------------------- 1 | package dumper 2 | 3 | import ( 4 | "bytes" 5 | "encoding/hex" 6 | "strings" 7 | "sync" 8 | 9 | "go.uber.org/zap" 10 | ) 11 | 12 | // Dumper dumper struct 13 | type Dumper struct { 14 | direction uint 15 | logger *zap.Logger 16 | buf *bytes.Buffer 17 | mu *sync.RWMutex 18 | } 19 | 20 | // New new handler 21 | func New(direction uint, logger *zap.Logger) *Dumper { 22 | d := &Dumper{ 23 | direction: direction, 24 | logger: logger, 25 | buf: new(bytes.Buffer), 26 | mu: new(sync.RWMutex), 27 | } 28 | return d 29 | } 30 | 31 | // Write to dump 32 | func (d *Dumper) Write(p []byte) (n int, err error) { 33 | d.mu.Lock() 34 | defer d.mu.Unlock() 35 | d.buf.Write(p) 36 | return len(p), nil 37 | } 38 | 39 | // Flush flush buffer 40 | func (d *Dumper) Flush() { 41 | d.mu.Lock() 42 | defer d.mu.Unlock() 43 | if d.buf.Len() == 0 { 44 | return 45 | } 46 | hexdump := strings.Split(hex.Dump(d.buf.Bytes()), "\n") 47 | d.buf.Truncate(0) 48 | byteString := []string{} 49 | ascii := []string{} 50 | for _, hd := range hexdump { 51 | if hd == "" { 52 | continue 53 | } 54 | byteString = append(byteString, strings.TrimRight(strings.Replace(hd[10:58], " ", " ", 1), " ")) 55 | ascii = append(ascii, hd[61:len(hd)-1]) 56 | } 57 | d.logger.Info("dump", 58 | zap.Uint("direction", d.direction), 59 | zap.String("hex", strings.Join(byteString, " ")), 60 | zap.String("ascii", strings.Join(ascii, "")), 61 | ) 62 | } 63 | -------------------------------------------------------------------------------- /internal/handler/handler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net" 7 | "net/http" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | 12 | "github.com/gorilla/mux" 13 | "github.com/gorilla/websocket" 14 | "github.com/kazeburo/wsgate-server/internal/dumper" 15 | "github.com/kazeburo/wsgate-server/internal/mapping" 16 | "github.com/kazeburo/wsgate-server/internal/publickey" 17 | "go.uber.org/zap" 18 | ) 19 | 20 | // BufferSize for coybuffer and websocket 21 | const BufferSize = 256 * 1024 22 | 23 | var ( 24 | websocketUpstream uint = 1 25 | upstreamWebsocket uint = 2 26 | flushDumperInterval time.Duration = 300 27 | ) 28 | 29 | // Handler handlers 30 | type Handler struct { 31 | logger *zap.Logger 32 | upgrader websocket.Upgrader 33 | dialTimeout time.Duration 34 | writeTimeout time.Duration 35 | mp *mapping.Mapping 36 | pk *publickey.Publickey 37 | dumpTCP uint 38 | sq *uint64 39 | } 40 | 41 | // New new handler 42 | func New( 43 | handshakeTimeout time.Duration, 44 | dialTimeout time.Duration, 45 | writeTimeout time.Duration, 46 | enableCompression bool, 47 | mp *mapping.Mapping, 48 | pk *publickey.Publickey, 49 | dumpTCP uint, 50 | logger *zap.Logger) (*Handler, error) { 51 | 52 | upgrader := websocket.Upgrader{ 53 | EnableCompression: enableCompression, 54 | ReadBufferSize: BufferSize, 55 | WriteBufferSize: BufferSize, 56 | HandshakeTimeout: handshakeTimeout, 57 | CheckOrigin: func(r *http.Request) bool { 58 | return true 59 | }, 60 | } 61 | 62 | seq := uint64(0) 63 | return &Handler{ 64 | logger: logger, 65 | upgrader: upgrader, 66 | dialTimeout: dialTimeout, 67 | writeTimeout: writeTimeout, 68 | mp: mp, 69 | pk: pk, 70 | dumpTCP: dumpTCP, 71 | sq: &seq, 72 | }, nil 73 | } 74 | 75 | func (h *Handler) GetSq() uint64 { 76 | return atomic.LoadUint64(h.sq) 77 | } 78 | 79 | // Hello hello handler 80 | func (h *Handler) Hello() func(w http.ResponseWriter, r *http.Request) { 81 | return func(w http.ResponseWriter, r *http.Request) { 82 | w.Write([]byte("OK\n")) 83 | } 84 | } 85 | 86 | // Proxy proxy handler 87 | func (h *Handler) Proxy(wg *sync.WaitGroup) func(w http.ResponseWriter, r *http.Request) { 88 | return func(w http.ResponseWriter, r *http.Request) { 89 | wg.Add(1) 90 | defer wg.Done() 91 | 92 | vars := mux.Vars(r) 93 | proxyDest := vars["dest"] 94 | upstream := "" 95 | readLen := int64(0) 96 | writeLen := int64(0) 97 | hasError := false 98 | disconnectAt := "" 99 | 100 | logger := h.logger.With( 101 | zap.Uint64("seq", atomic.AddUint64(h.sq, 1)), 102 | zap.String("x-forwarded-for", r.Header.Get("X-Forwarded-For")), 103 | zap.String("remote-addr", r.RemoteAddr), 104 | zap.String("destination", proxyDest), 105 | ) 106 | 107 | if h.pk.Enabled() { 108 | sub, err := h.pk.Verify(r.Header.Get("Authorization")) 109 | if err != nil { 110 | logger.Warn("Failed to authorize", zap.Error(err)) 111 | http.Error(w, err.Error(), http.StatusUnauthorized) 112 | return 113 | } 114 | logger = logger.With(zap.String("user-email", sub)) 115 | 116 | } else { 117 | logger = logger.With(zap.String("user-email", r.Header.Get("X-Goog-Authenticated-User-Email"))) 118 | } 119 | 120 | upstream, ok := h.mp.Get(proxyDest) 121 | if !ok { 122 | hasError = true 123 | logger.Warn("No map found") 124 | http.Error(w, fmt.Sprintf("Not found: %s", proxyDest), 404) 125 | return 126 | } 127 | 128 | logger = logger.With(zap.String("upstream", upstream)) 129 | 130 | s, err := net.DialTimeout("tcp", upstream, h.dialTimeout) 131 | 132 | if err != nil { 133 | hasError = true 134 | logger.Warn("DialTimeout", zap.Error(err)) 135 | http.Error(w, fmt.Sprintf("Could not connect upstream: %v", err), 500) 136 | return 137 | } 138 | 139 | conn, err := h.upgrader.Upgrade(w, r, nil) 140 | if err != nil { 141 | hasError = true 142 | s.Close() 143 | logger.Warn("Failed to Upgrade", zap.Error(err)) 144 | return 145 | } 146 | 147 | logger.Info("log", zap.String("status", "Connected")) 148 | dr := dumper.New(websocketUpstream, logger) 149 | ds := dumper.New(upstreamWebsocket, logger) 150 | 151 | defer func() { 152 | dr.Flush() 153 | ds.Flush() 154 | status := "Suceeded" 155 | if hasError { 156 | status = "Failed" 157 | } 158 | logger.Info("log", 159 | zap.String("status", status), 160 | zap.Int64("read", readLen), 161 | zap.Int64("write", writeLen), 162 | zap.String("disconnect_at", disconnectAt), 163 | ) 164 | }() 165 | 166 | ticker := time.NewTicker(flushDumperInterval * time.Millisecond) 167 | defer ticker.Stop() 168 | go func() { 169 | for { 170 | select { 171 | case <-r.Context().Done(): 172 | dr.Flush() 173 | ds.Flush() 174 | return 175 | case <-ticker.C: 176 | dr.Flush() 177 | ds.Flush() 178 | } 179 | } 180 | }() 181 | 182 | doneCh := make(chan bool) 183 | goClose := false 184 | 185 | // websocket -> server 186 | go func() { 187 | defer func() { doneCh <- true }() 188 | b := make([]byte, BufferSize) 189 | for { 190 | mt, r, err := conn.NextReader() 191 | if websocket.IsCloseError(err, 192 | websocket.CloseNormalClosure, // Normal. 193 | websocket.CloseAbnormalClosure, // OpenSSH killed proxy client. 194 | ) { 195 | return 196 | } 197 | if err != nil { 198 | if !goClose { 199 | logger.Warn("NextReader", zap.Error(err)) 200 | hasError = true 201 | } 202 | if disconnectAt == "" { 203 | disconnectAt = "client_nextreader" 204 | } 205 | return 206 | } 207 | if mt != websocket.BinaryMessage { 208 | logger.Warn("BinaryMessage required", zap.Int("messageType", mt)) 209 | hasError = true 210 | return 211 | } 212 | if h.dumpTCP > 0 { 213 | r = io.TeeReader(r, dr) 214 | } 215 | n, err := io.CopyBuffer(s, r, b) 216 | if err != nil { 217 | if !goClose { 218 | logger.Warn("Reading from websocket", zap.Error(err)) 219 | hasError = true 220 | } 221 | if disconnectAt == "" { 222 | disconnectAt = "client_upstream_copy" 223 | } 224 | return 225 | } 226 | readLen += n 227 | } 228 | }() 229 | 230 | // server -> websocket 231 | go func() { 232 | defer func() { doneCh <- true }() 233 | b := make([]byte, BufferSize) 234 | for { 235 | n, err := s.Read(b) 236 | if err != nil { 237 | if !goClose && err != io.EOF { 238 | logger.Warn("Reading from dest", zap.Error(err)) 239 | hasError = true 240 | } 241 | if disconnectAt == "" { 242 | disconnectAt = "upstream_read" 243 | } 244 | return 245 | } 246 | 247 | if h.dumpTCP > 1 { 248 | ds.Write(b) 249 | } 250 | if err := conn.WriteMessage(websocket.BinaryMessage, b[:n]); err != nil { 251 | if !goClose { 252 | logger.Warn("WriteMessage", zap.Error(err)) 253 | hasError = true 254 | } 255 | if disconnectAt == "" { 256 | disconnectAt = "client_write" 257 | } 258 | return 259 | } 260 | writeLen += int64(n) 261 | } 262 | }() 263 | 264 | <-doneCh 265 | goClose = true 266 | s.Close() 267 | conn.Close() 268 | <-doneCh 269 | 270 | } 271 | 272 | } 273 | -------------------------------------------------------------------------------- /internal/handler/handler_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net" 8 | "net/http" 9 | "net/http/httptest" 10 | "sync" 11 | "testing" 12 | "time" 13 | 14 | "golang.org/x/net/websocket" 15 | 16 | "github.com/gorilla/mux" 17 | "github.com/kazeburo/wsgate-server/internal/mapping" 18 | "github.com/kazeburo/wsgate-server/internal/publickey" 19 | "github.com/stretchr/testify/assert" 20 | "go.uber.org/zap" 21 | ) 22 | 23 | func TestHello(t *testing.T) { 24 | logger := zap.NewNop() 25 | h, err := New( 26 | 10*time.Second, 27 | 10*time.Second, 28 | 10*time.Second, 29 | true, 30 | nil, 31 | nil, 32 | 0, 33 | logger, 34 | ) 35 | assert.NoError(t, err) 36 | 37 | req := httptest.NewRequest(http.MethodGet, "/hello", nil) 38 | rec := httptest.NewRecorder() 39 | 40 | handler := h.Hello() 41 | handler(rec, req) 42 | 43 | assert.Equal(t, http.StatusOK, rec.Code) 44 | assert.Equal(t, "OK\n", rec.Body.String()) 45 | } 46 | 47 | func createClient(wsAddr string, disableKeepalive bool) *http.Client { 48 | client := &http.Client{ 49 | Transport: &http.Transport{ 50 | DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 51 | wsConf, err := websocket.NewConfig( 52 | fmt.Sprintf("ws://%s/proxy/dummy", wsAddr), 53 | fmt.Sprintf("http://%s/proxy/dummy", wsAddr), 54 | ) 55 | if err != nil { 56 | return nil, err 57 | } 58 | conn, err := websocket.DialConfig(wsConf) 59 | if err != nil { 60 | return nil, err 61 | } 62 | conn.PayloadType = websocket.BinaryFrame 63 | return conn, nil 64 | }, 65 | DisableKeepAlives: disableKeepalive, 66 | }, 67 | } 68 | return client 69 | } 70 | 71 | func TestWebSocket(t *testing.T) { 72 | logger := zap.NewExample() 73 | 74 | // 空いているポートでテストサーバーを起動 75 | dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 76 | w.WriteHeader(http.StatusOK) 77 | w.Write([]byte("Hello from dummy server")) 78 | }) 79 | ts := httptest.NewServer(dummyHandler) 80 | defer ts.Close() 81 | addr := ts.Listener.Addr().String() 82 | t.Logf("dummy server address: %s", addr) 83 | mp, _ := mapping.New("", logger) 84 | mp.Set("dummy", addr) 85 | 86 | pk, _ := publickey.New("", time.Minute, logger) 87 | 88 | proxyHandler, err := New( 89 | 10*time.Second, 90 | 10*time.Second, 91 | 10*time.Second, 92 | true, 93 | mp, 94 | pk, 95 | 0, 96 | logger, 97 | ) 98 | assert.NoError(t, err) 99 | 100 | wg := &sync.WaitGroup{} 101 | m := mux.NewRouter() 102 | m.HandleFunc("/proxy/{dest}", proxyHandler.Proxy(wg)) 103 | 104 | ws := httptest.NewServer(m) 105 | defer ws.Close() 106 | 107 | // 割り当てられたポート番号を取得 108 | wsAddr := ws.Listener.Addr().String() 109 | t.Logf("wsAddr: %s", wsAddr) 110 | 111 | // http clientでの接続 112 | { 113 | client := createClient(wsAddr, false) 114 | for i := 0; i < 10; i++ { 115 | req, _ := http.NewRequest(http.MethodGet, "http://example/test", nil) 116 | resp, err := client.Do(req) 117 | assert.NoError(t, err) 118 | defer resp.Body.Close() 119 | assert.Equal(t, http.StatusOK, resp.StatusCode) 120 | body, _ := io.ReadAll(resp.Body) 121 | assert.Equal(t, "Hello from dummy server", string(body)) 122 | } 123 | client.CloseIdleConnections() 124 | } 125 | { 126 | client := createClient(wsAddr, true) 127 | for i := 0; i < 3; i++ { 128 | req, _ := http.NewRequest(http.MethodGet, "http://example/test", nil) 129 | resp, err := client.Do(req) 130 | assert.NoError(t, err) 131 | defer resp.Body.Close() 132 | assert.Equal(t, http.StatusOK, resp.StatusCode) 133 | body, _ := io.ReadAll(resp.Body) 134 | assert.Equal(t, "Hello from dummy server", string(body)) 135 | } 136 | client.CloseIdleConnections() 137 | } 138 | assert.Equal(t, uint64(4), proxyHandler.GetSq()) 139 | } 140 | -------------------------------------------------------------------------------- /internal/mapping/mapping.go: -------------------------------------------------------------------------------- 1 | package mapping 2 | 3 | import ( 4 | "bufio" 5 | "os" 6 | "regexp" 7 | "strings" 8 | 9 | "github.com/pkg/errors" 10 | "go.uber.org/zap" 11 | ) 12 | 13 | // Mapping struct 14 | type Mapping struct { 15 | m map[string]string 16 | } 17 | 18 | // New new mapping 19 | func New(mapFile string, logger *zap.Logger) (*Mapping, error) { 20 | r := regexp.MustCompile(`^ *#`) 21 | m := make(map[string]string) 22 | if mapFile != "" { 23 | f, err := os.Open(mapFile) 24 | if err != nil { 25 | return nil, errors.Wrap(err, "Failed to open mapFile") 26 | } 27 | s := bufio.NewScanner(f) 28 | for s.Scan() { 29 | if r.MatchString(s.Text()) { 30 | continue 31 | } 32 | l := strings.SplitN(s.Text(), ",", 2) 33 | if len(l) != 2 { 34 | return nil, errors.Wrapf(err, "Invalid line: %s", s.Text()) 35 | } 36 | logger.Info("Created map", 37 | zap.String("from", l[0]), 38 | zap.String("to", l[1])) 39 | m[l[0]] = l[1] 40 | } 41 | } 42 | return &Mapping{ 43 | m: m, 44 | }, nil 45 | } 46 | 47 | // Get get mapping 48 | func (mp *Mapping) Get(proxyDest string) (string, bool) { 49 | upstream, ok := mp.m[proxyDest] 50 | return upstream, ok 51 | } 52 | 53 | // Set mapping 54 | func (mp *Mapping) Set(proxyDest string, upstream string) { 55 | mp.m[proxyDest] = upstream 56 | } 57 | -------------------------------------------------------------------------------- /internal/publickey/publickey.go: -------------------------------------------------------------------------------- 1 | package publickey 2 | 3 | import ( 4 | "crypto/rsa" 5 | "fmt" 6 | "os" 7 | "strings" 8 | "time" 9 | 10 | "github.com/golang-jwt/jwt/v5" 11 | "github.com/pkg/errors" 12 | "go.uber.org/zap" 13 | ) 14 | 15 | // Publickey struct 16 | type Publickey struct { 17 | publicKeyFile string 18 | verifyKey *rsa.PublicKey 19 | freshnessTime time.Duration 20 | } 21 | 22 | // New publickey reader/checker 23 | func New(publicKeyFile string, freshnessTime time.Duration, logger *zap.Logger) (*Publickey, error) { 24 | var verifyKey *rsa.PublicKey 25 | if publicKeyFile != "" { 26 | verifyBytes, err := os.ReadFile(publicKeyFile) 27 | if err != nil { 28 | return nil, errors.Wrap(err, "failed read pubkey") 29 | } 30 | verifyKey, err = jwt.ParseRSAPublicKeyFromPEM(verifyBytes) 31 | if err != nil { 32 | return nil, errors.Wrap(err, "failed parse pubkey") 33 | } 34 | } 35 | return &Publickey{ 36 | publicKeyFile: publicKeyFile, 37 | verifyKey: verifyKey, 38 | freshnessTime: freshnessTime, 39 | }, nil 40 | } 41 | 42 | // Enabled publickey is enabled 43 | func (pk Publickey) Enabled() bool { 44 | return pk.publicKeyFile != "" 45 | } 46 | 47 | // Verify verify auth header 48 | func (pk Publickey) Verify(t string) (string, error) { 49 | if t == "" { 50 | return "", fmt.Errorf("no tokenString") 51 | } 52 | t = strings.TrimPrefix(t, "Bearer ") 53 | 54 | claims := &jwt.RegisteredClaims{} 55 | _, err := jwt.ParseWithClaims(t, claims, func(token *jwt.Token) (interface{}, error) { 56 | return pk.verifyKey, nil 57 | }, jwt.WithValidMethods([]string{"RS256", "RS384", "RS512"})) 58 | 59 | if err != nil { 60 | return "", fmt.Errorf("token is invalid: %v", err) 61 | } 62 | 63 | now := time.Now() 64 | iat := now.Add(-pk.freshnessTime) 65 | 66 | if claims.ExpiresAt == nil || claims.ExpiresAt.Time.Before(now) { 67 | return "", fmt.Errorf("token is expired") 68 | } 69 | if claims.IssuedAt == nil || claims.IssuedAt.Time.Before(iat) { 70 | return "", fmt.Errorf("token is too old") 71 | } 72 | 73 | return claims.Subject, nil 74 | } 75 | -------------------------------------------------------------------------------- /internal/publickey/publickey_test.go: -------------------------------------------------------------------------------- 1 | package publickey 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/x509" 7 | "encoding/pem" 8 | "os" 9 | "testing" 10 | "time" 11 | 12 | "github.com/golang-jwt/jwt/v5" 13 | "github.com/stretchr/testify/assert" 14 | "go.uber.org/zap" 15 | ) 16 | 17 | func generateTestKeys() (privateKey *rsa.PrivateKey, publicKeyPEM []byte, err error) { 18 | privateKey, err = rsa.GenerateKey(rand.Reader, 2048) 19 | if err != nil { 20 | return nil, nil, err 21 | } 22 | 23 | publicKeyBytes, _ := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) 24 | publicKeyPEM = pem.EncodeToMemory(&pem.Block{ 25 | Type: "RSA PUBLIC KEY", 26 | Bytes: publicKeyBytes, 27 | }) 28 | 29 | return privateKey, publicKeyPEM, nil 30 | } 31 | 32 | func TestNew(t *testing.T) { 33 | _, publicKeyPEM, err := generateTestKeys() 34 | assert.NoError(t, err) 35 | 36 | tempFile, err := os.CreateTemp("", "publickey_test_*.pem") 37 | assert.NoError(t, err) 38 | defer os.Remove(tempFile.Name()) 39 | 40 | _, err = tempFile.Write(publicKeyPEM) 41 | assert.NoError(t, err) 42 | tempFile.Close() 43 | 44 | logger := zap.NewNop() 45 | pk, err := New(tempFile.Name(), time.Minute, logger) 46 | assert.NoError(t, err) 47 | assert.NotNil(t, pk) 48 | assert.True(t, pk.Enabled()) 49 | } 50 | 51 | func TestVerify(t *testing.T) { 52 | privateKey, publicKeyPEM, err := generateTestKeys() 53 | assert.NoError(t, err) 54 | 55 | tempFile, err := os.CreateTemp("", "publickey_test_*.pem") 56 | assert.NoError(t, err) 57 | defer os.Remove(tempFile.Name()) 58 | 59 | _, err = tempFile.Write(publicKeyPEM) 60 | assert.NoError(t, err) 61 | tempFile.Close() 62 | 63 | logger := zap.NewNop() 64 | pk, err := New(tempFile.Name(), time.Minute, logger) 65 | assert.NoError(t, err) 66 | 67 | // Generate a valid token 68 | now := time.Now() 69 | token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.RegisteredClaims{ 70 | Subject: "test-subject", 71 | IssuedAt: jwt.NewNumericDate(now), 72 | ExpiresAt: jwt.NewNumericDate(now.Add(time.Minute)), 73 | }) 74 | tokenString, err := token.SignedString(privateKey) 75 | assert.NoError(t, err) 76 | 77 | subject, err := pk.Verify("Bearer " + tokenString) 78 | assert.NoError(t, err) 79 | assert.Equal(t, "test-subject", subject) 80 | 81 | // Test expired token 82 | expiredToken := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.RegisteredClaims{ 83 | Subject: "test-subject", 84 | IssuedAt: jwt.NewNumericDate(now.Add(-2 * time.Minute)), 85 | ExpiresAt: jwt.NewNumericDate(now.Add(-time.Minute)), 86 | }) 87 | expiredTokenString, err := expiredToken.SignedString(privateKey) 88 | assert.NoError(t, err) 89 | 90 | _, err = pk.Verify("Bearer " + expiredTokenString) 91 | assert.Error(t, err) 92 | assert.Contains(t, err.Error(), "token is expired") 93 | 94 | // Test token too old 95 | oldToken := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.RegisteredClaims{ 96 | Subject: "test-subject", 97 | IssuedAt: jwt.NewNumericDate(now.Add(-2 * time.Minute)), 98 | ExpiresAt: jwt.NewNumericDate(now.Add(time.Minute)), 99 | }) 100 | oldTokenString, err := oldToken.SignedString(privateKey) 101 | assert.NoError(t, err) 102 | 103 | _, err = pk.Verify("Bearer " + oldTokenString) 104 | assert.Error(t, err) 105 | assert.Contains(t, err.Error(), "token is too old") 106 | 107 | // Test invalid token 108 | _, err = pk.Verify("Bearer invalid-token") 109 | assert.Error(t, err) 110 | assert.Contains(t, err.Error(), "token is invalid") 111 | } 112 | 113 | func TestEnabled(t *testing.T) { 114 | logger := zap.NewNop() 115 | 116 | // Test with public key file 117 | _, publicKeyPEM, err := generateTestKeys() 118 | assert.NoError(t, err) 119 | 120 | tempFile, err := os.CreateTemp("", "publickey_test_*.pem") 121 | assert.NoError(t, err) 122 | defer os.Remove(tempFile.Name()) 123 | 124 | _, err = tempFile.Write(publicKeyPEM) 125 | assert.NoError(t, err) 126 | tempFile.Close() 127 | 128 | pk, err := New(tempFile.Name(), time.Minute, logger) 129 | assert.NoError(t, err) 130 | assert.True(t, pk.Enabled()) 131 | 132 | // Test without public key file 133 | pk, err = New("", time.Minute, logger) 134 | assert.NoError(t, err) 135 | assert.False(t, pk.Enabled()) 136 | } 137 | -------------------------------------------------------------------------------- /sample-map.txt: -------------------------------------------------------------------------------- 1 | mysql,127.0.0.1:3306 2 | ssh,127.0.0.1:22 3 | plack,127.0.0.1:8080 4 | --------------------------------------------------------------------------------