├── .github └── workflows │ ├── codeql-analysis.yml │ ├── docker-image.yml │ └── go.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── UPGRADING.md ├── cmd ├── auth │ ├── auth.go │ ├── config │ │ └── configuration.go │ ├── database │ │ ├── config.go │ │ ├── config_test.go │ │ └── database.go │ └── ntlm │ │ ├── ntlm.go │ │ └── ntlm_test.go └── rdpgw │ ├── config │ └── configuration.go │ ├── identity │ ├── identity.go │ ├── identity_test.go │ └── user.go │ ├── kdcproxy │ └── proxy.go │ ├── main.go │ ├── protocol │ ├── client.go │ ├── common.go │ ├── common_test.go │ ├── errors.go │ ├── gateway.go │ ├── metrics.go │ ├── packet_reader.go │ ├── process.go │ ├── protocol_test.go │ ├── track.go │ ├── tunnel.go │ ├── types.go │ └── utf16.go │ ├── rdp │ ├── koanf │ │ └── parsers │ │ │ └── rdp │ │ │ ├── rdp.go │ │ │ └── rdp_test.go │ ├── rdp.go │ ├── rdp_test.go │ └── rdp_test_file.rdp │ ├── security │ ├── basic.go │ ├── basic_test.go │ ├── jwt.go │ ├── jwt_test.go │ └── string.go │ ├── transport │ ├── legacy.go │ ├── transport.go │ └── websocket.go │ └── web │ ├── basic.go │ ├── context.go │ ├── mux.go │ ├── ntlm.go │ ├── oidc.go │ ├── oidc_test.go │ ├── session.go │ ├── token.go │ ├── web.go │ └── web_test.go ├── dev ├── docker-distroless │ └── Dockerfile └── docker │ ├── Dockerfile │ ├── Dockerfile.xrdp │ ├── docker-compose-arm64.yml │ ├── docker-compose-local.yml │ ├── docker-compose.yml │ ├── docker-readme.md │ ├── rdpgw-pam │ ├── rdpgw.yaml │ ├── realm-export.json │ ├── run.sh │ ├── tmp.tar │ ├── xrdp.ini │ └── xrdp_users.txt ├── docs └── images │ ├── flow-auth.svg │ ├── flow-kerberos.svg │ ├── flow-openid.svg │ ├── flow-pam.svg │ └── flow.svg ├── go.mod ├── proto └── auth.proto └── shared └── auth ├── auth.pb.go └── auth_grpc.pb.go /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "master" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "master" ] 20 | schedule: 21 | - cron: '22 22 * * 3' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'go' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Install pam-devel 41 | run: sudo apt-get -y install libpam-dev 42 | 43 | - name: Checkout repository 44 | uses: actions/checkout@v4 45 | 46 | - name: Install Go 47 | uses: actions/setup-go@v4 48 | with: 49 | go-version-file: go.mod 50 | 51 | # Initializes the CodeQL tools for scanning. 52 | - name: Initialize CodeQL 53 | uses: github/codeql-action/init@v3 54 | with: 55 | languages: ${{ matrix.language }} 56 | # If you wish to specify custom queries, you can do so here or in a config file. 57 | # By default, queries listed here will override any specified in a config file. 58 | # Prefix the list here with "+" to use these queries and those in the config file. 59 | 60 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 61 | # queries: security-extended,security-and-quality 62 | 63 | 64 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 65 | # If this step fails, then you should remove it and run the build manually (see below) 66 | - name: Autobuild 67 | uses: github/codeql-action/autobuild@v3 68 | 69 | # ℹ️ Command-line programs to run using the OS shell. 70 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 71 | 72 | # If the Autobuild fails above, remove it and uncomment the following three lines. 73 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 74 | 75 | # - run: | 76 | # echo "Run, Build Application using script" 77 | # ./location_of_script_within_repo/buildscript.sh 78 | 79 | - name: Perform CodeQL Analysis 80 | uses: github/codeql-action/analyze@v3 81 | -------------------------------------------------------------------------------- /.github/workflows/docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | tags: [ "v*" ] 7 | 8 | jobs: 9 | 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@v4 17 | - name: Set up QEMU 18 | uses: docker/setup-qemu-action@v3 19 | - name: Set up Docker Buildx 20 | uses: docker/setup-buildx-action@v3 21 | - name: Login to Docker Hub 22 | uses: docker/login-action@v3 23 | with: 24 | username: ${{ secrets.DOCKER_USER }} 25 | password: ${{ secrets.DOCKER_PASSWORD }} 26 | - name: Build and push - latest 27 | uses: docker/build-push-action@v3 28 | with: 29 | context: ./dev/docker 30 | file: ./dev/docker/Dockerfile 31 | platforms: linux/amd64,linux/arm64 32 | push: true 33 | tags: ${{ github.repository_owner }}/rdpgw:latest 34 | - name: Build and push - latest 35 | uses: docker/build-push-action@v3 36 | with: 37 | context: ./dev/docker 38 | file: ./dev/docker/Dockerfile 39 | platforms: linux/amd64,linux/arm64 40 | push: true 41 | tags: ${{ github.repository_owner }}/rdpgw:${{ github.ref_name }} 42 | - name: Update Docker Hub Description 43 | uses: peter-evans/dockerhub-description@v3 44 | with: 45 | username: ${{ secrets.DOCKER_USER }} 46 | password: ${{ secrets.DOCKER_PASSWORD }} 47 | repository: ${{ github.repository_owner }}/rdpgw 48 | readme-filepath: ./dev/docker/docker-readme.md 49 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | steps: 15 | 16 | - name: Set up Go 1.x 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: ^1.22 20 | id: go 21 | 22 | - name: Install pam-devel 23 | run: sudo apt-get -y install libpam-dev 24 | 25 | - name: Check out code into the Go module directory 26 | uses: actions/checkout@v2 27 | 28 | - name: Install golint 29 | run: go get -u golang.org/x/lint/golint 30 | 31 | - name: Update go.sum 32 | run: make mod 33 | 34 | - name: Build 35 | run: make build 36 | 37 | - name: Test 38 | run: make test 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | go.sum 2 | bin 3 | *.swp 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | BINDIR := $(CURDIR)/bin 2 | INSTALL_PATH ?= /usr/local/bin 3 | BINNAME ?= rdpgw 4 | BINNAME2 ?= rdpgw-auth 5 | 6 | # Rebuild the binary if any of these files change 7 | SRC := $(shell find . -type f -name '*.go' -print) go.mod go.sum 8 | 9 | # Required for globs to work correctly 10 | SHELL = /usr/bin/env bash 11 | 12 | GIT_COMMIT = $(shell git rev-parse HEAD) 13 | GIT_SHA = $(shell git rev-parse --short HEAD) 14 | GIT_TAG = $(shell git describe --tags --abbrev=0 --exact-match 2>/dev/null) 15 | GIT_DIRTY = $(shell test -n "`git status --porcelain`" && echo "dirty" || echo "clean") 16 | 17 | ifdef VERSION 18 | BINARY_VERSION = $(VERSION) 19 | endif 20 | BINARY_VERSION ?= ${GIT_TAG} 21 | 22 | VERSION_METADATA = unreleased 23 | # Clear the "unreleased" string in BuildMetadata 24 | ifneq ($(GIT_TAG),) 25 | VERSION_METADATA = 26 | endif 27 | 28 | .PHONY: all 29 | all: mod build 30 | 31 | # ------------------------------------------------------------------------------ 32 | # build 33 | 34 | .PHONY: build 35 | build: $(BINDIR)/$(BINNAME) 36 | 37 | $(BINDIR)/$(BINNAME): $(SRC) 38 | go build $(GOFLAGS) -trimpath -tags '$(TAGS)' -ldflags '$(LDFLAGS)' -o '$(BINDIR)'/$(BINNAME) ./cmd/rdpgw 39 | go build $(GOFLAGS) -trimpath -tags '$(TAGS)' -ldflags '$(LDFLAGS)' -o '$(BINDIR)'/$(BINNAME2) ./cmd/auth 40 | 41 | # ------------------------------------------------------------------------------ 42 | # install 43 | 44 | .PHONY: install 45 | install: build 46 | @install "$(BINDIR)/$(BINNAME)" "$(INSTALL_PATH)/$(BINNAME)" 47 | 48 | # ------------------------------------------------------------------------------ 49 | # mod 50 | 51 | .PHONY: mod 52 | mod: 53 | go mod tidy -compat=1.22 54 | 55 | # ------------------------------------------------------------------------------ 56 | # test 57 | 58 | .PHONY: test 59 | test: 60 | go test -cover -v ./... 61 | # ------------------------------------------------------------------------------ 62 | # clean 63 | 64 | .PHONY: clean 65 | clean: 66 | @rm -rf '$(BINDIR)' ./_dist 67 | 68 | .PHONY: info 69 | info: 70 | @echo "Version: ${VERSION}" 71 | @echo "Git Tag: ${GIT_TAG}" 72 | @echo "Git Commit: ${GIT_COMMIT}" 73 | @echo "Git Tree State: ${GIT_DIRTY}" 74 | -------------------------------------------------------------------------------- /UPGRADING.md: -------------------------------------------------------------------------------- 1 | # Upgrading from 1.X to 2.0 2 | 3 | In 2.0 the options for configuring client side RDP settings have been removed in favor of template file. 4 | The template file is a RDP file that is used as a template for the connection. The template file is parsed 5 | and a few settings are replaced to ensure the client can connect to the server and the correct domain is used. 6 | 7 | The format of the template file is as follows: 8 | 9 | ``` 10 | # :: 11 | domain:s:testdomain 12 | connection type:i:2 13 | ``` 14 | 15 | The filename is set under `client > defaults`. 16 | -------------------------------------------------------------------------------- /cmd/auth/auth.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "github.com/bolkedebruin/rdpgw/cmd/auth/config" 8 | "github.com/bolkedebruin/rdpgw/cmd/auth/database" 9 | "github.com/bolkedebruin/rdpgw/cmd/auth/ntlm" 10 | "github.com/bolkedebruin/rdpgw/shared/auth" 11 | "github.com/msteinert/pam/v2" 12 | "github.com/thought-machine/go-flags" 13 | "google.golang.org/grpc" 14 | "log" 15 | "net" 16 | "os" 17 | "syscall" 18 | ) 19 | 20 | const ( 21 | protocol = "unix" 22 | ) 23 | 24 | var opts struct { 25 | ServiceName string `short:"n" long:"name" default:"rdpgw" description:"the PAM service name to use"` 26 | SocketAddr string `short:"s" long:"socket" default:"/tmp/rdpgw-auth.sock" description:"the location of the socket"` 27 | ConfigFile string `short:"c" long:"conf" default:"rdpgw-auth.yaml" description:"users config file for NTLM (yaml)"` 28 | } 29 | 30 | type AuthServiceImpl struct { 31 | auth.UnimplementedAuthenticateServer 32 | 33 | serviceName string 34 | ntlm *ntlm.NTLMAuth 35 | } 36 | 37 | var conf config.Configuration 38 | var _ auth.AuthenticateServer = (*AuthServiceImpl)(nil) 39 | 40 | func NewAuthService(serviceName string, database database.Database) auth.AuthenticateServer { 41 | s := &AuthServiceImpl{ 42 | serviceName: serviceName, 43 | ntlm: ntlm.NewNTLMAuth(database), 44 | } 45 | return s 46 | } 47 | 48 | func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPass) (*auth.AuthResponse, error) { 49 | t, err := pam.StartFunc(s.serviceName, message.Username, func(s pam.Style, msg string) (string, error) { 50 | switch s { 51 | case pam.PromptEchoOff: 52 | return message.Password, nil 53 | case pam.PromptEchoOn, pam.ErrorMsg, pam.TextInfo: 54 | return "", nil 55 | } 56 | return "", errors.New("unrecognized PAM message style") 57 | }) 58 | 59 | r := &auth.AuthResponse{} 60 | r.Authenticated = false 61 | 62 | if err != nil { 63 | log.Printf("Error authenticating user: %s due to: %s", message.Username, err) 64 | r.Error = err.Error() 65 | return r, err 66 | } 67 | defer func() { 68 | err := t.End() 69 | if err != nil { 70 | fmt.Fprintf(os.Stderr, "end: %v\n", err) 71 | os.Exit(1) 72 | } 73 | }() 74 | if err = t.Authenticate(0); err != nil { 75 | log.Printf("Authentication for user: %s failed due to: %s", message.Username, err) 76 | r.Error = err.Error() 77 | return r, nil 78 | } 79 | 80 | if err = t.AcctMgmt(0); err != nil { 81 | log.Printf("Account authorization for user: %s failed due to %s", message.Username, err) 82 | r.Error = err.Error() 83 | return r, nil 84 | } 85 | 86 | log.Printf("User: %s authenticated", message.Username) 87 | r.Authenticated = true 88 | return r, nil 89 | } 90 | 91 | func (s *AuthServiceImpl) NTLM(ctx context.Context, message *auth.NtlmRequest) (*auth.NtlmResponse, error) { 92 | r, err := s.ntlm.Authenticate(message) 93 | 94 | if err != nil { 95 | log.Printf("[%s] NTLM failed: %s", message.Session, err) 96 | } else if r.Authenticated { 97 | log.Printf("[%s] User: %s authenticated using NTLM", message.Session, r.Username) 98 | } else if r.NtlmMessage != "" { 99 | log.Printf("[%s] Sending NTLM challenge", message.Session) 100 | } 101 | 102 | return r, err 103 | } 104 | 105 | func main() { 106 | _, err := flags.Parse(&opts) 107 | if err != nil { 108 | var fErr *flags.Error 109 | if errors.As(err, &fErr) { 110 | if fErr.Type == flags.ErrHelp { 111 | fmt.Printf("Acknowledgements:\n") 112 | fmt.Printf(" - This product includes software developed by the Thomson Reuters Global Resources. (go-ntlm - https://github.com/m7913d/go-ntlm - BSD-4 License)\n") 113 | } 114 | } 115 | return 116 | } 117 | 118 | conf = config.Load(opts.ConfigFile) 119 | 120 | log.Printf("Starting auth server on %s", opts.SocketAddr) 121 | cleanup := func() { 122 | if _, err := os.Stat(opts.SocketAddr); err == nil { 123 | if err := os.RemoveAll(opts.SocketAddr); err != nil { 124 | log.Fatal(err) 125 | } 126 | } 127 | } 128 | cleanup() 129 | 130 | oldUmask := syscall.Umask(0) 131 | listener, err := net.Listen(protocol, opts.SocketAddr) 132 | syscall.Umask(oldUmask) 133 | if err != nil { 134 | log.Fatal(err) 135 | } 136 | server := grpc.NewServer() 137 | db := database.NewConfig(conf.Users) 138 | service := NewAuthService(opts.ServiceName, db) 139 | auth.RegisterAuthenticateServer(server, service) 140 | server.Serve(listener) 141 | } 142 | -------------------------------------------------------------------------------- /cmd/auth/config/configuration.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/knadh/koanf/parsers/yaml" 5 | "github.com/knadh/koanf/providers/confmap" 6 | "github.com/knadh/koanf/providers/file" 7 | "github.com/knadh/koanf/v2" 8 | "log" 9 | "os" 10 | ) 11 | 12 | type Configuration struct { 13 | Users []UserConfig `koanf:"users"` 14 | } 15 | 16 | type UserConfig struct { 17 | Username string `koanf:"username"` 18 | Password string `koanf:"password"` 19 | } 20 | 21 | var Conf Configuration 22 | 23 | func Load(configFile string) Configuration { 24 | 25 | var k = koanf.New(".") 26 | 27 | k.Load(confmap.Provider(map[string]interface{}{}, "."), nil) 28 | 29 | if _, err := os.Stat(configFile); os.IsNotExist(err) { 30 | log.Printf("Config file %s not found, skipping config file", configFile) 31 | } else { 32 | if err := k.Load(file.Provider(configFile), yaml.Parser()); err != nil { 33 | log.Fatalf("Error loading config from file: %v", err) 34 | } 35 | } 36 | 37 | koanfTag := koanf.UnmarshalConf{Tag: "koanf"} 38 | k.UnmarshalWithConf("Users", &Conf.Users, koanfTag) 39 | 40 | return Conf 41 | 42 | } 43 | -------------------------------------------------------------------------------- /cmd/auth/database/config.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "github.com/bolkedebruin/rdpgw/cmd/auth/config" 5 | ) 6 | 7 | type Config struct { 8 | users map[string]config.UserConfig 9 | } 10 | 11 | func NewConfig(users []config.UserConfig) *Config { 12 | usersMap := map[string]config.UserConfig{} 13 | 14 | for _, user := range users { 15 | usersMap[user.Username] = user 16 | } 17 | 18 | return &Config{ 19 | users: usersMap, 20 | } 21 | } 22 | 23 | func (c *Config) GetPassword (username string) string { 24 | return c.users[username].Password 25 | } 26 | -------------------------------------------------------------------------------- /cmd/auth/database/config_test.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "github.com/bolkedebruin/rdpgw/cmd/auth/config" 5 | "testing" 6 | ) 7 | 8 | func createTestDatabase () (Database) { 9 | var users = []config.UserConfig{} 10 | 11 | user1 := config.UserConfig{} 12 | user1.Username = "my_username" 13 | user1.Password = "my_password" 14 | users = append(users, user1) 15 | 16 | user2 := config.UserConfig{} 17 | user2.Username = "my_username2" 18 | user2.Password = "my_password2" 19 | users = append(users, user2) 20 | 21 | config := NewConfig(users) 22 | 23 | return config 24 | } 25 | 26 | func TestDatabaseConfigValidUsername(t *testing.T) { 27 | database := createTestDatabase() 28 | 29 | if database.GetPassword("my_username") != "my_password" { 30 | t.Fatalf("Wrong password returned") 31 | } 32 | if database.GetPassword("my_username2") != "my_password2" { 33 | t.Fatalf("Wrong password returned") 34 | } 35 | } 36 | 37 | func TestDatabaseInvalidUsername(t *testing.T) { 38 | database := createTestDatabase() 39 | 40 | if database.GetPassword("my_invalid_username") != "" { 41 | t.Fatalf("Non empty password returned for invalid username") 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /cmd/auth/database/database.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | type Database interface { 4 | GetPassword (username string) string 5 | } -------------------------------------------------------------------------------- /cmd/auth/ntlm/ntlm.go: -------------------------------------------------------------------------------- 1 | package ntlm 2 | 3 | import ( 4 | "encoding/base64" 5 | "errors" 6 | "github.com/bolkedebruin/rdpgw/cmd/auth/database" 7 | "github.com/bolkedebruin/rdpgw/shared/auth" 8 | "github.com/patrickmn/go-cache" 9 | "github.com/m7913d/go-ntlm/ntlm" 10 | "fmt" 11 | "log" 12 | "time" 13 | ) 14 | 15 | const ( 16 | cacheExpiration = time.Minute 17 | cleanupInterval = time.Minute * 5 18 | ) 19 | 20 | type NTLMAuth struct { 21 | contextCache *cache.Cache 22 | 23 | // Information about the server, returned to the client during authentication 24 | ServerName string // e.g. EXAMPLE1 25 | DomainName string // e.g. EXAMPLE 26 | DnsServerName string // e.g. example1.example.com 27 | DnsDomainName string // e.g. example.com 28 | DnsTreeName string // e.g. example.com 29 | 30 | Database database.Database 31 | } 32 | 33 | func NewNTLMAuth (database database.Database) (*NTLMAuth) { 34 | return &NTLMAuth{ 35 | contextCache: cache.New(cacheExpiration, cleanupInterval), 36 | Database: database, 37 | } 38 | } 39 | 40 | func (h *NTLMAuth) Authenticate(message *auth.NtlmRequest) (*auth.NtlmResponse, error) { 41 | r := &auth.NtlmResponse{} 42 | r.Authenticated = false 43 | 44 | if message.Session == "" { 45 | return r, errors.New("Invalid (empty) session specified") 46 | } 47 | 48 | if message.NtlmMessage == "" { 49 | return r, errors.New("Empty NTLM message specified") 50 | } 51 | 52 | c := h.getContext(message.Session) 53 | err := c.Authenticate(message.NtlmMessage, r) 54 | 55 | if err != nil || r.Authenticated { 56 | h.removeContext(message.Session) 57 | } 58 | 59 | return r, err 60 | } 61 | 62 | func (h *NTLMAuth) getContext (session string) (*ntlmContext) { 63 | if c_, found := h.contextCache.Get(session); found { 64 | if c, ok := c_.(*ntlmContext); ok { 65 | return c 66 | } 67 | } 68 | c := new(ntlmContext) 69 | c.h = h 70 | h.contextCache.Set(session, c, cache.DefaultExpiration) 71 | return c 72 | } 73 | 74 | func (h *NTLMAuth) removeContext (session string) { 75 | h.contextCache.Delete(session) 76 | } 77 | 78 | type ntlmContext struct { 79 | session ntlm.ServerSession 80 | h *NTLMAuth 81 | } 82 | 83 | func (c *ntlmContext) Authenticate(authorisationEncoded string, r *auth.NtlmResponse) (error) { 84 | authorisation, err := base64.StdEncoding.DecodeString(authorisationEncoded) 85 | if err != nil { 86 | return errors.New(fmt.Sprintf("Failed to decode NTLM Authorisation header: %s", err)) 87 | } 88 | 89 | nm, err := ntlm.ParseNegotiateMessage(authorisation) 90 | if err == nil { 91 | return c.negotiate(nm, r) 92 | } 93 | if (nm != nil && nm.MessageType == 1) { 94 | return errors.New(fmt.Sprintf("Failed to parse NTLM Authorisation header: %s", err)) 95 | } else if c.session == nil { 96 | return errors.New(fmt.Sprintf("New NTLM auth sequence should start with negotioate request")) 97 | } 98 | 99 | am, err := ntlm.ParseAuthenticateMessage(authorisation, 2) 100 | if err == nil { 101 | return c.authenticate(am, r) 102 | } 103 | 104 | return errors.New(fmt.Sprintf("Failed to parse NTLM Authorisation header: %s", err)) 105 | } 106 | 107 | func (c *ntlmContext) negotiate(nm *ntlm.NegotiateMessage, r *auth.NtlmResponse) (error) { 108 | session, err := ntlm.CreateServerSession(ntlm.Version2, ntlm.ConnectionOrientedMode) 109 | 110 | if err != nil { 111 | c.session = nil; 112 | return errors.New(fmt.Sprintf("Failed to create NTLM server session: %s", err)) 113 | } 114 | 115 | c.session = session 116 | c.session.SetRequireNtHash(true) 117 | c.session.SetDomainName(c.h.DomainName) 118 | c.session.SetComputerName(c.h.ServerName) 119 | c.session.SetDnsDomainName(c.h.DnsDomainName) 120 | c.session.SetDnsComputerName(c.h.DnsServerName) 121 | c.session.SetDnsTreeName(c.h.DnsTreeName) 122 | 123 | err = c.session.ProcessNegotiateMessage(nm) 124 | if err != nil { 125 | return errors.New(fmt.Sprintf("Failed to process NTLM negotiate message: %s", err)) 126 | } 127 | 128 | cm, err := c.session.GenerateChallengeMessage() 129 | if err != nil { 130 | return errors.New(fmt.Sprintf("Failed to generate NTLM challenge message: %s", err)) 131 | } 132 | 133 | r.NtlmMessage = base64.StdEncoding.EncodeToString(cm.Bytes()) 134 | return nil 135 | } 136 | 137 | func (c *ntlmContext) authenticate(am *ntlm.AuthenticateMessage, r *auth.NtlmResponse) (error) { 138 | if c.session == nil { 139 | return errors.New(fmt.Sprintf("NTLM Authenticate requires active session: first call negotioate")) 140 | } 141 | 142 | username := am.UserName.String() 143 | password := c.h.Database.GetPassword (username) 144 | if password == "" { 145 | log.Printf("NTLM: unknown username specified: %s", username) 146 | return nil 147 | } 148 | 149 | c.session.SetUserInfo(username,password,"") 150 | 151 | err := c.session.ProcessAuthenticateMessage(am) 152 | if err != nil { 153 | log.Printf("Failed to process NTLM authenticate message: %s", err) 154 | return nil 155 | } 156 | 157 | r.Authenticated = true 158 | r.Username = username 159 | return nil 160 | } 161 | -------------------------------------------------------------------------------- /cmd/auth/ntlm/ntlm_test.go: -------------------------------------------------------------------------------- 1 | package ntlm 2 | 3 | import ( 4 | "encoding/base64" 5 | "github.com/bolkedebruin/rdpgw/cmd/auth/config" 6 | "github.com/bolkedebruin/rdpgw/cmd/auth/database" 7 | "github.com/bolkedebruin/rdpgw/shared/auth" 8 | "github.com/m7913d/go-ntlm/ntlm" 9 | "testing" 10 | "log" 11 | ) 12 | 13 | func createTestDatabase () (database.Database) { 14 | user := config.UserConfig{} 15 | user.Username = "my_username" 16 | user.Password = "my_password" 17 | 18 | var users = []config.UserConfig{} 19 | users = append(users, user) 20 | 21 | config := database.NewConfig(users) 22 | 23 | return config 24 | } 25 | 26 | func TestNtlmValidCredentials(t *testing.T) { 27 | client := ntlm.V2ClientSession{} 28 | client.SetUserInfo("my_username", "my_password", "") 29 | 30 | authenticateResponse := authenticate(t, &client) 31 | if !authenticateResponse.Authenticated { 32 | t.Errorf("Failed to authenticate") 33 | return 34 | } 35 | if authenticateResponse.Username != "my_username" { 36 | t.Errorf("Wrong username returned") 37 | return 38 | } 39 | } 40 | 41 | func TestNtlmInvalidPassword(t *testing.T) { 42 | client := ntlm.V2ClientSession{} 43 | client.SetUserInfo("my_username", "my_invalid_password", "") 44 | 45 | authenticateResponse := authenticate(t, &client) 46 | if authenticateResponse.Authenticated { 47 | t.Errorf("Authenticated with wrong password") 48 | return 49 | } 50 | if authenticateResponse.Username != "" { 51 | t.Errorf("If authentication failed, no username should be returned") 52 | return 53 | } 54 | } 55 | 56 | func TestNtlmInvalidUsername(t *testing.T) { 57 | client := ntlm.V2ClientSession{} 58 | client.SetUserInfo("my_invalid_username", "my_password", "") 59 | 60 | authenticateResponse := authenticate(t, &client) 61 | if authenticateResponse.Authenticated { 62 | t.Errorf("Authenticated with wrong password") 63 | return 64 | } 65 | if authenticateResponse.Username != "" { 66 | t.Errorf("If authentication failed, no username should be returned") 67 | return 68 | } 69 | } 70 | 71 | func authenticate(t *testing.T, client *ntlm.V2ClientSession) (*auth.NtlmResponse) { 72 | session := "X" 73 | database := createTestDatabase() 74 | 75 | server := NewNTLMAuth(database) 76 | 77 | negotiate, err := client.GenerateNegotiateMessage() 78 | if err != nil { 79 | t.Errorf("Could not generate negotiate message: %s", err) 80 | return nil 81 | } 82 | 83 | negotiateRequest := &auth.NtlmRequest{} 84 | negotiateRequest.Session = session 85 | negotiateRequest.NtlmMessage = base64.StdEncoding.EncodeToString(negotiate.Bytes()) 86 | negotiateResponse, err := server.Authenticate(negotiateRequest) 87 | if err != nil { 88 | t.Errorf("Could not generate challenge message: %s", err) 89 | return nil 90 | } 91 | if negotiateResponse.Authenticated { 92 | t.Errorf("User should not be authenticated by after negotiate message") 93 | return nil 94 | } 95 | if negotiateResponse.NtlmMessage == "" { 96 | t.Errorf("Could not generate challenge message") 97 | return nil 98 | } 99 | 100 | decodedChallenge, err := base64.StdEncoding.DecodeString(negotiateResponse.NtlmMessage) 101 | if err != nil { 102 | t.Errorf("Challenge should be base64 encoded: %s", err) 103 | return nil 104 | } 105 | 106 | challenge, err := ntlm.ParseChallengeMessage(decodedChallenge) 107 | if err != nil { 108 | t.Errorf("Invalid challenge message generated: %s", err) 109 | return nil 110 | } 111 | 112 | client.ProcessChallengeMessage(challenge) 113 | authenticate, err := client.GenerateAuthenticateMessage() 114 | if err != nil { 115 | t.Errorf("Could not generate authenticate message: %s", err) 116 | return nil 117 | } 118 | 119 | authenticateRequest := &auth.NtlmRequest{} 120 | authenticateRequest.Session = session 121 | authenticateRequest.NtlmMessage = base64.StdEncoding.EncodeToString(authenticate.Bytes()) 122 | authenticateResponse, err := server.Authenticate(authenticateRequest) 123 | if err != nil { 124 | t.Errorf("Could not parse authenticate message: %s", err) 125 | return authenticateResponse 126 | } 127 | if authenticateResponse.NtlmMessage != "" { 128 | t.Errorf("Authenticate request should not generate a new NTLM message") 129 | return authenticateResponse 130 | } 131 | return authenticateResponse 132 | } 133 | 134 | func TestInvalidBase64 (t *testing.T) { 135 | testInvalidDataBase(t, "X", "X") // not valid base64 136 | } 137 | 138 | func TestInvalidData (t *testing.T) { 139 | testInvalidDataBase(t, "X", "XXXX") // valid base64 140 | } 141 | 142 | func TestInvalidDataEmptyMessage (t *testing.T) { 143 | testInvalidDataBase(t, "X", "") 144 | } 145 | 146 | func TestEmptySession (t *testing.T) { 147 | testInvalidDataBase(t, "", "XXXX") 148 | } 149 | 150 | func testInvalidDataBase (t *testing.T, session string, data string) { 151 | database := createTestDatabase() 152 | server := NewNTLMAuth(database) 153 | 154 | request := &auth.NtlmRequest{} 155 | request.Session = session 156 | request.NtlmMessage = data 157 | response, err := server.Authenticate(request) 158 | log.Printf("%s",err) 159 | if err == nil { 160 | t.Errorf("Invalid request should return an error") 161 | } 162 | if response.Authenticated { 163 | t.Errorf("User should not be authenticated using invalid data") 164 | } 165 | if response.NtlmMessage != "" { 166 | t.Errorf("No NTLM message should be generated for invalid data") 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /cmd/rdpgw/config/configuration.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" 5 | "github.com/knadh/koanf/parsers/yaml" 6 | "github.com/knadh/koanf/providers/confmap" 7 | "github.com/knadh/koanf/providers/env" 8 | "github.com/knadh/koanf/providers/file" 9 | "github.com/knadh/koanf/v2" 10 | "log" 11 | "os" 12 | "strings" 13 | ) 14 | 15 | const ( 16 | TlsDisable = "disable" 17 | TlsAuto = "auto" 18 | 19 | HostSelectionSigned = "signed" 20 | HostSelectionRoundRobin = "roundrobin" 21 | 22 | SessionStoreCookie = "cookie" 23 | SessionStoreFile = "file" 24 | 25 | AuthenticationOpenId = "openid" 26 | AuthenticationBasic = "local" 27 | AuthenticationKerberos = "kerberos" 28 | ) 29 | 30 | type Configuration struct { 31 | Server ServerConfig `koanf:"server"` 32 | OpenId OpenIDConfig `koanf:"openid"` 33 | Kerberos KerberosConfig `koanf:"kerberos"` 34 | Caps RDGCapsConfig `koanf:"caps"` 35 | Security SecurityConfig `koanf:"security"` 36 | Client ClientConfig `koanf:"client"` 37 | } 38 | 39 | type ServerConfig struct { 40 | GatewayAddress string `koanf:"gatewayaddress"` 41 | Port int `koanf:"port"` 42 | CertFile string `koanf:"certfile"` 43 | KeyFile string `koanf:"keyfile"` 44 | Hosts []string `koanf:"hosts"` 45 | HostSelection string `koanf:"hostselection"` 46 | SessionKey string `koanf:"sessionkey"` 47 | SessionEncryptionKey string `koanf:"sessionencryptionkey"` 48 | SessionStore string `koanf:"sessionstore"` 49 | MaxSessionLength int `koanf:"maxsessionlength"` 50 | SendBuf int `koanf:"sendbuf"` 51 | ReceiveBuf int `koanf:"receivebuf"` 52 | Tls string `koanf:"tls"` 53 | Authentication []string `koanf:"authentication"` 54 | AuthSocket string `koanf:"authsocket"` 55 | BasicAuthTimeout int `koanf:"basicauthtimeout"` 56 | } 57 | 58 | type KerberosConfig struct { 59 | Keytab string `koanf:"keytab"` 60 | Krb5Conf string `koanf:"krb5conf"` 61 | } 62 | 63 | type OpenIDConfig struct { 64 | ProviderUrl string `koanf:"providerurl"` 65 | ClientId string `koanf:"clientid"` 66 | ClientSecret string `koanf:"clientsecret"` 67 | } 68 | 69 | type RDGCapsConfig struct { 70 | SmartCardAuth bool `koanf:"smartcardauth"` 71 | TokenAuth bool `koanf:"tokenauth"` 72 | IdleTimeout int `koanf:"idletimeout"` 73 | RedirectAll bool `koanf:"redirectall"` 74 | DisableRedirect bool `koanf:"disableredirect"` 75 | EnableClipboard bool `koanf:"enableclipboard"` 76 | EnablePrinter bool `koanf:"enableprinter"` 77 | EnablePort bool `koanf:"enableport"` 78 | EnablePnp bool `koanf:"enablepnp"` 79 | EnableDrive bool `koanf:"enabledrive"` 80 | } 81 | 82 | type SecurityConfig struct { 83 | PAATokenEncryptionKey string `koanf:"paatokenencryptionkey"` 84 | PAATokenSigningKey string `koanf:"paatokensigningkey"` 85 | UserTokenEncryptionKey string `koanf:"usertokenencryptionkey"` 86 | UserTokenSigningKey string `koanf:"usertokensigningkey"` 87 | QueryTokenSigningKey string `koanf:"querytokensigningkey"` 88 | QueryTokenIssuer string `koanf:"querytokenissuer"` 89 | VerifyClientIp bool `koanf:"verifyclientip"` 90 | EnableUserToken bool `koanf:"enableusertoken"` 91 | } 92 | 93 | type ClientConfig struct { 94 | Defaults string `koanf:"defaults"` 95 | // kept for backwards compatibility 96 | UsernameTemplate string `koanf:"usernametemplate"` 97 | SplitUserDomain bool `koanf:"splituserdomain"` 98 | NoUsername bool `koanf:"nousername"` 99 | } 100 | 101 | func ToCamel(s string) string { 102 | s = strings.TrimSpace(s) 103 | n := strings.Builder{} 104 | n.Grow(len(s)) 105 | var capNext bool = true 106 | for i, v := range []byte(s) { 107 | vIsCap := v >= 'A' && v <= 'Z' 108 | vIsLow := v >= 'a' && v <= 'z' 109 | if capNext { 110 | if vIsLow { 111 | v += 'A' 112 | v -= 'a' 113 | } 114 | } else if i == 0 { 115 | if vIsCap { 116 | v += 'a' 117 | v -= 'A' 118 | } 119 | } 120 | if vIsCap || vIsLow { 121 | n.WriteByte(v) 122 | capNext = false 123 | } else if vIsNum := v >= '0' && v <= '9'; vIsNum { 124 | n.WriteByte(v) 125 | capNext = true 126 | } else { 127 | capNext = v == '_' || v == ' ' || v == '-' || v == '.' 128 | if v == '.' { 129 | n.WriteByte(v) 130 | } 131 | } 132 | } 133 | return n.String() 134 | } 135 | 136 | var Conf Configuration 137 | 138 | func Load(configFile string) Configuration { 139 | 140 | var k = koanf.New(".") 141 | 142 | k.Load(confmap.Provider(map[string]interface{}{ 143 | "Server.Tls": "auto", 144 | "Server.Port": 443, 145 | "Server.SessionStore": "cookie", 146 | "Server.HostSelection": "roundrobin", 147 | "Server.Authentication": "openid", 148 | "Server.AuthSocket": "/tmp/rdpgw-auth.sock", 149 | "Server.BasicAuthTimeout": 5, 150 | "Client.NetworkAutoDetect": 1, 151 | "Client.BandwidthAutoDetect": 1, 152 | "Security.VerifyClientIp": true, 153 | "Caps.TokenAuth": true, 154 | }, "."), nil) 155 | 156 | if _, err := os.Stat(configFile); os.IsNotExist(err) { 157 | log.Printf("Config file %s not found, using defaults and environment", configFile) 158 | } else { 159 | if err := k.Load(file.Provider(configFile), yaml.Parser()); err != nil { 160 | log.Fatalf("Error loading config from file: %v", err) 161 | } 162 | } 163 | 164 | if err := k.Load(env.ProviderWithValue("RDPGW_", ".", func(s string, v string) (string, interface{}) { 165 | key := strings.Replace(strings.ToLower(strings.TrimPrefix(s, "RDPGW_")), "__", ".", -1) 166 | key = ToCamel(key) 167 | 168 | v = strings.Trim(v, " ") 169 | 170 | // handle lists 171 | if strings.Contains(v, " ") { 172 | return key, strings.Split(v, " ") 173 | } 174 | return key, v 175 | 176 | }), nil); err != nil { 177 | log.Fatalf("Error loading config from environment: %v", err) 178 | } 179 | 180 | koanfTag := koanf.UnmarshalConf{Tag: "koanf"} 181 | k.UnmarshalWithConf("Server", &Conf.Server, koanfTag) 182 | k.UnmarshalWithConf("OpenId", &Conf.OpenId, koanfTag) 183 | k.UnmarshalWithConf("Caps", &Conf.Caps, koanfTag) 184 | k.UnmarshalWithConf("Security", &Conf.Security, koanfTag) 185 | k.UnmarshalWithConf("Client", &Conf.Client, koanfTag) 186 | k.UnmarshalWithConf("Kerberos", &Conf.Kerberos, koanfTag) 187 | 188 | if len(Conf.Security.PAATokenEncryptionKey) != 32 { 189 | Conf.Security.PAATokenEncryptionKey, _ = security.GenerateRandomString(32) 190 | log.Printf("No valid `security.paatokenencryptionkey` specified (empty or not 32 characters). Setting to random") 191 | } 192 | 193 | if len(Conf.Security.PAATokenSigningKey) != 32 { 194 | Conf.Security.PAATokenSigningKey, _ = security.GenerateRandomString(32) 195 | log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random") 196 | } 197 | 198 | if Conf.Security.EnableUserToken { 199 | if len(Conf.Security.UserTokenEncryptionKey) != 32 { 200 | Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32) 201 | log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random") 202 | } 203 | } 204 | 205 | if len(Conf.Server.SessionKey) != 32 { 206 | Conf.Server.SessionKey, _ = security.GenerateRandomString(32) 207 | log.Printf("No valid `server.sessionkey` specified (empty or not 32 characters). Setting to random") 208 | } 209 | 210 | if len(Conf.Server.SessionEncryptionKey) != 32 { 211 | Conf.Server.SessionEncryptionKey, _ = security.GenerateRandomString(32) 212 | log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random") 213 | } 214 | 215 | if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 { 216 | log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set") 217 | } 218 | 219 | if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" { 220 | log.Fatalf("basicauth=local and tls=disable are mutually exclusive") 221 | } 222 | 223 | if Conf.Server.NtlmEnabled() && Conf.Server.KerberosEnabled() { 224 | log.Fatalf("ntlm and kerberos authentication are not stackable") 225 | } 226 | 227 | if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() { 228 | log.Fatalf("openid is configured but tokenauth disabled") 229 | } 230 | 231 | if Conf.Server.KerberosEnabled() && Conf.Kerberos.Keytab == "" { 232 | log.Fatalf("kerberos is configured but no keytab was specified") 233 | } 234 | 235 | // prepend '//' if required for URL parsing 236 | if !strings.Contains(Conf.Server.GatewayAddress, "//") { 237 | Conf.Server.GatewayAddress = "//" + Conf.Server.GatewayAddress 238 | } 239 | 240 | return Conf 241 | 242 | } 243 | 244 | func (s *ServerConfig) OpenIDEnabled() bool { 245 | return s.matchAuth("openid") 246 | } 247 | 248 | func (s *ServerConfig) KerberosEnabled() bool { 249 | return s.matchAuth("kerberos") 250 | } 251 | 252 | func (s *ServerConfig) BasicAuthEnabled() bool { 253 | return s.matchAuth("local") || s.matchAuth("basic") 254 | } 255 | 256 | func (s *ServerConfig) NtlmEnabled() bool { 257 | return s.matchAuth("ntlm") 258 | } 259 | 260 | func (s *ServerConfig) matchAuth(needle string) bool { 261 | for _, q := range s.Authentication { 262 | if q == needle { 263 | return true 264 | } 265 | } 266 | return false 267 | } 268 | -------------------------------------------------------------------------------- /cmd/rdpgw/identity/identity.go: -------------------------------------------------------------------------------- 1 | package identity 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "time" 7 | ) 8 | 9 | const ( 10 | CTXKey = "github.com/bolkedebruin/rdpgw/common/identity" 11 | 12 | AttrRemoteAddr = "remoteAddr" 13 | AttrClientIp = "clientIp" 14 | AttrProxies = "proxyAddresses" 15 | AttrAccessToken = "accessToken" // todo remove for security reasons 16 | ) 17 | 18 | type Identity interface { 19 | UserName() string 20 | SetUserName(string) 21 | DisplayName() string 22 | SetDisplayName(string) 23 | Domain() string 24 | SetDomain(string) 25 | Authenticated() bool 26 | SetAuthenticated(bool) 27 | AuthTime() time.Time 28 | SetAuthTime(time2 time.Time) 29 | SessionId() string 30 | SetAttribute(string, interface{}) 31 | GetAttribute(string) interface{} 32 | Attributes() map[string]interface{} 33 | DelAttribute(string) 34 | Email() string 35 | SetEmail(string) 36 | Expiry() time.Time 37 | SetExpiry(time.Time) 38 | Marshal() ([]byte, error) 39 | Unmarshal([]byte) error 40 | } 41 | 42 | func AddToRequestCtx(id Identity, r *http.Request) *http.Request { 43 | ctx := r.Context() 44 | ctx = context.WithValue(ctx, CTXKey, id) 45 | return r.WithContext(ctx) 46 | } 47 | 48 | func FromRequestCtx(r *http.Request) Identity { 49 | return FromCtx(r.Context()) 50 | } 51 | 52 | func FromCtx(ctx context.Context) Identity { 53 | if id, ok := ctx.Value(CTXKey).(Identity); ok { 54 | return id 55 | } 56 | return nil 57 | } 58 | -------------------------------------------------------------------------------- /cmd/rdpgw/identity/identity_test.go: -------------------------------------------------------------------------------- 1 | package identity 2 | 3 | import ( 4 | "log" 5 | "testing" 6 | ) 7 | 8 | func TestMarshalling(t *testing.T) { 9 | u := NewUser() 10 | u.SetUserName("ANAME") 11 | u.SetAuthenticated(true) 12 | u.SetDomain("DOMAIN") 13 | 14 | c := NewUser() 15 | data, err := u.Marshal() 16 | if err != nil { 17 | log.Fatalf("Cannot marshal %s", err) 18 | } 19 | 20 | err = c.Unmarshal(data) 21 | if err != nil { 22 | t.Fatalf("Error while unmarshalling: %s", err) 23 | } 24 | 25 | if u.UserName() != c.UserName() || u.Authenticated() != c.Authenticated() || u.Domain() != c.Domain() { 26 | t.Fatalf("identities not equal: %+v != %+v", u, c) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /cmd/rdpgw/identity/user.go: -------------------------------------------------------------------------------- 1 | package identity 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | "github.com/google/uuid" 7 | "time" 8 | ) 9 | 10 | type User struct { 11 | authenticated bool 12 | domain string 13 | userName string 14 | displayName string 15 | email string 16 | authTime time.Time 17 | sessionId string 18 | expiry time.Time 19 | attributes map[string]interface{} 20 | groupMembership map[string]bool 21 | } 22 | 23 | type user struct { 24 | Authenticated bool 25 | UserName string 26 | Domain string 27 | DisplayName string 28 | Email string 29 | AuthTime time.Time 30 | SessionId string 31 | Expiry time.Time 32 | Attributes map[string]interface{} 33 | GroupMembership map[string]bool 34 | } 35 | 36 | func NewUser() *User { 37 | uuid := uuid.New().String() 38 | return &User{ 39 | attributes: make(map[string]interface{}), 40 | groupMembership: make(map[string]bool), 41 | sessionId: uuid, 42 | } 43 | } 44 | 45 | func (u *User) UserName() string { 46 | return u.userName 47 | } 48 | 49 | func (u *User) SetUserName(s string) { 50 | u.userName = s 51 | } 52 | 53 | func (u *User) DisplayName() string { 54 | if u.displayName == "" { 55 | return u.userName 56 | } 57 | return u.displayName 58 | } 59 | 60 | func (u *User) SetDisplayName(s string) { 61 | u.displayName = s 62 | } 63 | 64 | func (u *User) Domain() string { 65 | return u.domain 66 | } 67 | 68 | func (u *User) SetDomain(s string) { 69 | u.domain = s 70 | } 71 | 72 | func (u *User) Authenticated() bool { 73 | return u.authenticated 74 | } 75 | 76 | func (u *User) SetAuthenticated(b bool) { 77 | u.authenticated = b 78 | } 79 | 80 | func (u *User) AuthTime() time.Time { 81 | return u.authTime 82 | } 83 | 84 | func (u *User) SetAuthTime(t time.Time) { 85 | u.authTime = t 86 | } 87 | 88 | func (u *User) SessionId() string { 89 | return u.sessionId 90 | } 91 | 92 | func (u *User) SetAttribute(s string, i interface{}) { 93 | u.attributes[s] = i 94 | } 95 | 96 | func (u *User) GetAttribute(s string) interface{} { 97 | if found, ok := u.attributes[s]; ok { 98 | return found 99 | } 100 | return nil 101 | } 102 | 103 | func (u *User) Attributes() map[string]interface{} { 104 | return u.attributes 105 | } 106 | 107 | func (u *User) DelAttribute(s string) { 108 | delete(u.attributes, s) 109 | } 110 | 111 | func (u *User) Email() string { 112 | return u.email 113 | } 114 | 115 | func (u *User) SetEmail(s string) { 116 | u.email = s 117 | } 118 | 119 | func (u *User) Expiry() time.Time { 120 | return u.expiry 121 | } 122 | 123 | func (u *User) SetExpiry(t time.Time) { 124 | u.expiry = t 125 | } 126 | 127 | func (u *User) Marshal() ([]byte, error) { 128 | buf := new(bytes.Buffer) 129 | enc := gob.NewEncoder(buf) 130 | uu := user{ 131 | Authenticated: u.authenticated, 132 | UserName: u.userName, 133 | Domain: u.domain, 134 | DisplayName: u.displayName, 135 | Email: u.email, 136 | AuthTime: u.authTime, 137 | SessionId: u.sessionId, 138 | Expiry: u.expiry, 139 | Attributes: u.attributes, 140 | GroupMembership: u.groupMembership, 141 | } 142 | err := enc.Encode(uu) 143 | 144 | if err != nil { 145 | return []byte{}, err 146 | } 147 | return buf.Bytes(), nil 148 | } 149 | 150 | func (u *User) Unmarshal(b []byte) error { 151 | buf := bytes.NewBuffer(b) 152 | dec := gob.NewDecoder(buf) 153 | var uu user 154 | err := dec.Decode(&uu) 155 | if err != nil { 156 | return err 157 | } 158 | u.sessionId = uu.SessionId 159 | u.userName = uu.UserName 160 | u.domain = uu.Domain 161 | u.displayName = uu.DisplayName 162 | u.email = uu.Email 163 | u.authenticated = uu.Authenticated 164 | u.authTime = uu.AuthTime 165 | u.expiry = uu.Expiry 166 | u.attributes = uu.Attributes 167 | u.groupMembership = uu.GroupMembership 168 | 169 | return nil 170 | } 171 | -------------------------------------------------------------------------------- /cmd/rdpgw/kdcproxy/proxy.go: -------------------------------------------------------------------------------- 1 | package kdcproxy 2 | 3 | import ( 4 | "fmt" 5 | krbconfig "github.com/bolkedebruin/gokrb5/v8/config" 6 | "github.com/jcmturner/gofork/encoding/asn1" 7 | "io" 8 | "log" 9 | "net" 10 | "net/http" 11 | "time" 12 | ) 13 | 14 | const ( 15 | maxLength = 128 * 1024 16 | systemConfigPath = "/etc/krb5.conf" 17 | timeout = 5 * time.Second 18 | ) 19 | 20 | type KdcProxyMsg struct { 21 | Message []byte `asn1:"tag:0,explicit"` 22 | Realm string `asn1:"tag:1,optional"` 23 | Flags int `asn1:"tag:2,optional"` 24 | } 25 | 26 | type Kdc struct { 27 | Realm string 28 | Host string 29 | Proto string 30 | Conn net.Conn 31 | } 32 | 33 | type KerberosProxy struct { 34 | krb5Config *krbconfig.Config 35 | } 36 | 37 | func InitKdcProxy(krb5Conf string) KerberosProxy { 38 | path := systemConfigPath 39 | if krb5Conf != "" { 40 | path = krb5Conf 41 | } 42 | cfg, err := krbconfig.Load(path) 43 | if err != nil { 44 | log.Fatalf("Cannot load krb5 config %s due to %s", path, err) 45 | } 46 | 47 | return KerberosProxy{ 48 | krb5Config: cfg, 49 | } 50 | } 51 | 52 | func (k KerberosProxy) Handler(w http.ResponseWriter, r *http.Request) { 53 | if r.Method != "POST" { 54 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 55 | return 56 | } 57 | 58 | length := r.ContentLength 59 | if length == -1 { 60 | http.Error(w, "Content length required", http.StatusLengthRequired) 61 | return 62 | } 63 | 64 | if length > maxLength { 65 | http.Error(w, "Request entity too large", http.StatusRequestEntityTooLarge) 66 | return 67 | } 68 | 69 | data := make([]byte, length) 70 | _, err := io.ReadFull(r.Body, data) 71 | if err != nil { 72 | log.Printf("Error reading from stream: %s", err) 73 | http.Error(w, "Error reading from stream", http.StatusInternalServerError) 74 | return 75 | } 76 | 77 | msg, err := decode(data) 78 | if err != nil { 79 | log.Printf("Cannot unmarshal: %s", err) 80 | http.Error(w, "Invalid request", http.StatusBadRequest) 81 | return 82 | } 83 | 84 | krb5resp, err := k.forward(msg.Realm, msg.Message) 85 | if err != nil { 86 | log.Printf("cannot forward to kdc due to %s", err) 87 | http.Error(w, "Service unavailable", http.StatusServiceUnavailable) 88 | return 89 | } 90 | 91 | reply, err := encode(krb5resp) 92 | if err != nil { 93 | log.Printf("unable to encode krb5 message due to %s", err) 94 | http.Error(w, "encoding error", http.StatusInternalServerError) 95 | } 96 | 97 | w.Header().Set("Content-Type", "application/kerberos") 98 | w.Write(reply) 99 | } 100 | 101 | func (k *KerberosProxy) forward(realm string, data []byte) (resp []byte, err error) { 102 | if realm == "" { 103 | realm = k.krb5Config.LibDefaults.DefaultRealm 104 | } 105 | 106 | // load udp first as is the default for kerberos 107 | udpCnt, udpKdcs, err := k.krb5Config.GetKDCs(realm, false) 108 | if err != nil { 109 | return nil, fmt.Errorf("cannot get udp kdc for realm %s due to %s", realm, err) 110 | } 111 | 112 | // load tcp 113 | tcpCnt, tcpKdcs, err := k.krb5Config.GetKDCs(realm, true) 114 | if err != nil { 115 | return nil, fmt.Errorf("cannot get tcp kdc for realm %s due to %s", realm, err) 116 | } 117 | 118 | if tcpCnt+udpCnt == 0 { 119 | return nil, fmt.Errorf("cannot get any kdcs (tcp or udp) for realm %s", realm) 120 | } 121 | 122 | // merge the kdcs 123 | kdcs := make([]Kdc, tcpCnt+udpCnt) 124 | for i := range udpKdcs { 125 | kdcs[i] = Kdc{Realm: realm, Host: udpKdcs[i], Proto: "udp"} 126 | } 127 | for i := range tcpKdcs { 128 | kdcs[i+udpCnt] = Kdc{Realm: realm, Host: tcpKdcs[i], Proto: "tcp"} 129 | } 130 | 131 | replies := make(chan []byte, len(kdcs)) 132 | for i := range kdcs { 133 | conn, err := net.Dial(kdcs[i].Proto, kdcs[i].Host) 134 | 135 | if err != nil { 136 | log.Printf("error connecting to %s due to %s, trying next if available", kdcs[i], err) 137 | continue 138 | } 139 | conn.SetDeadline(time.Now().Add(timeout)) 140 | 141 | // if we proxy over UDP remove the length prefix 142 | if kdcs[i].Proto == "tcp" { 143 | _, err = conn.Write(data) 144 | } else { 145 | _, err = conn.Write(data[4:]) 146 | } 147 | if err != nil { 148 | log.Printf("cannot write packet data to %s due to %s, trying next if available", kdcs[i], err) 149 | conn.Close() 150 | continue 151 | } 152 | 153 | kdcs[i].Conn = conn 154 | go awaitReply(conn, kdcs[i].Proto == "udp", replies) 155 | } 156 | 157 | reply := <-replies 158 | 159 | // close all the connections and return the first reply 160 | for kdc := range kdcs { 161 | if kdcs[kdc].Conn != nil { 162 | kdcs[kdc].Conn.Close() 163 | } 164 | <-replies 165 | } 166 | 167 | if reply != nil { 168 | return reply, nil 169 | } 170 | 171 | return nil, fmt.Errorf("no replies received from kdcs for realm %s", realm) 172 | } 173 | 174 | func decode(data []byte) (msg *KdcProxyMsg, err error) { 175 | var m KdcProxyMsg 176 | rest, err := asn1.Unmarshal(data, &m) 177 | if err != nil { 178 | return nil, err 179 | } 180 | 181 | if len(rest) > 0 { 182 | return nil, fmt.Errorf("trailing data in request") 183 | } 184 | 185 | return &m, nil 186 | } 187 | 188 | func encode(krb5data []byte) (r []byte, err error) { 189 | m := KdcProxyMsg{Message: krb5data} 190 | enc, err := asn1.Marshal(m) 191 | if err != nil { 192 | log.Printf("cannot marshal due to %s", err) 193 | return nil, err 194 | } 195 | return enc, nil 196 | } 197 | 198 | func awaitReply(conn net.Conn, isUdp bool, reply chan<- []byte) { 199 | resp, err := io.ReadAll(conn) 200 | if err != nil { 201 | log.Printf("error reading from kdc due to %s", err) 202 | reply <- nil 203 | return 204 | } 205 | if isUdp { 206 | // udp will be missing the length prefix so add it 207 | resp = append([]byte{byte(len(resp))}, resp...) 208 | } 209 | reply <- resp 210 | } 211 | -------------------------------------------------------------------------------- /cmd/rdpgw/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "fmt" 7 | "github.com/bolkedebruin/gokrb5/v8/keytab" 8 | "github.com/bolkedebruin/gokrb5/v8/service" 9 | "github.com/bolkedebruin/gokrb5/v8/spnego" 10 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/config" 11 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/kdcproxy" 12 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" 13 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" 14 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/web" 15 | "github.com/coreos/go-oidc/v3/oidc" 16 | "github.com/gorilla/mux" 17 | "github.com/prometheus/client_golang/prometheus/promhttp" 18 | "github.com/thought-machine/go-flags" 19 | "golang.org/x/crypto/acme/autocert" 20 | "golang.org/x/oauth2" 21 | "log" 22 | "net/http" 23 | "net/url" 24 | "os" 25 | "strconv" 26 | ) 27 | 28 | const ( 29 | gatewayEndPoint = "/remoteDesktopGateway/" 30 | kdcProxyEndPoint = "/KdcProxy" 31 | ) 32 | 33 | var opts struct { 34 | ConfigFile string `short:"c" long:"conf" default:"rdpgw.yaml" description:"config file (yaml)"` 35 | } 36 | 37 | var conf config.Configuration 38 | 39 | func initOIDC(callbackUrl *url.URL) *web.OIDC { 40 | // set oidc config 41 | provider, err := oidc.NewProvider(context.Background(), conf.OpenId.ProviderUrl) 42 | if err != nil { 43 | log.Fatalf("Cannot get oidc provider: %s", err) 44 | } 45 | oidcConfig := &oidc.Config{ 46 | ClientID: conf.OpenId.ClientId, 47 | } 48 | verifier := provider.Verifier(oidcConfig) 49 | 50 | oauthConfig := oauth2.Config{ 51 | ClientID: conf.OpenId.ClientId, 52 | ClientSecret: conf.OpenId.ClientSecret, 53 | RedirectURL: callbackUrl.String(), 54 | Endpoint: provider.Endpoint(), 55 | Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, 56 | } 57 | security.OIDCProvider = provider 58 | security.Oauth2Config = oauthConfig 59 | 60 | o := web.OIDCConfig{ 61 | OAuth2Config: &oauthConfig, 62 | OIDCTokenVerifier: verifier, 63 | } 64 | 65 | return o.New() 66 | } 67 | 68 | func main() { 69 | // load config 70 | _, err := flags.Parse(&opts) 71 | if err != nil { 72 | panic(err) 73 | } 74 | conf = config.Load(opts.ConfigFile) 75 | 76 | // set callback url and external advertised gateway address 77 | url, err := url.Parse(conf.Server.GatewayAddress) 78 | if err != nil { 79 | log.Printf("Cannot parse server gateway address %s due to %s", url, err) 80 | } 81 | if url.Scheme == "" { 82 | url.Scheme = "https" 83 | } 84 | url.Path = "callback" 85 | 86 | // set security options 87 | security.VerifyClientIP = conf.Security.VerifyClientIp 88 | security.SigningKey = []byte(conf.Security.PAATokenSigningKey) 89 | security.EncryptionKey = []byte(conf.Security.PAATokenEncryptionKey) 90 | security.UserEncryptionKey = []byte(conf.Security.UserTokenEncryptionKey) 91 | security.UserSigningKey = []byte(conf.Security.UserTokenSigningKey) 92 | security.QuerySigningKey = []byte(conf.Security.QueryTokenSigningKey) 93 | security.HostSelection = conf.Server.HostSelection 94 | security.Hosts = conf.Server.Hosts 95 | 96 | // init session store 97 | web.InitStore([]byte(conf.Server.SessionKey), 98 | []byte(conf.Server.SessionEncryptionKey), 99 | conf.Server.SessionStore, 100 | conf.Server.MaxSessionLength, 101 | ) 102 | 103 | // configure web backend 104 | w := &web.Config{ 105 | QueryInfo: security.QueryInfo, 106 | QueryTokenIssuer: conf.Security.QueryTokenIssuer, 107 | EnableUserToken: conf.Security.EnableUserToken, 108 | Hosts: conf.Server.Hosts, 109 | HostSelection: conf.Server.HostSelection, 110 | RdpOpts: web.RdpOpts{ 111 | UsernameTemplate: conf.Client.UsernameTemplate, 112 | SplitUserDomain: conf.Client.SplitUserDomain, 113 | NoUsername: conf.Client.NoUsername, 114 | }, 115 | GatewayAddress: url, 116 | TemplateFile: conf.Client.Defaults, 117 | } 118 | 119 | if conf.Caps.TokenAuth { 120 | w.PAATokenGenerator = security.GeneratePAAToken 121 | } 122 | if conf.Security.EnableUserToken { 123 | w.UserTokenGenerator = security.GenerateUserToken 124 | } 125 | h := w.NewHandler() 126 | 127 | log.Printf("Starting remote desktop gateway server") 128 | cfg := &tls.Config{} 129 | 130 | // configure tls security 131 | if conf.Server.Tls == config.TlsDisable { 132 | log.Printf("TLS disabled - rdp gw connections require tls, make sure to have a terminator") 133 | } else { 134 | // auto config 135 | tlsConfigured := false 136 | 137 | tlsDebug := os.Getenv("SSLKEYLOGFILE") 138 | if tlsDebug != "" { 139 | w, err := os.OpenFile(tlsDebug, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) 140 | if err != nil { 141 | log.Fatalf("Cannot open key log file %s for writing %s", tlsDebug, err) 142 | } 143 | log.Printf("Key log file set to: %s", tlsDebug) 144 | cfg.KeyLogWriter = w 145 | } 146 | 147 | if conf.Server.KeyFile != "" && conf.Server.CertFile != "" { 148 | cert, err := tls.LoadX509KeyPair(conf.Server.CertFile, conf.Server.KeyFile) 149 | if err != nil { 150 | log.Printf("Cannot load certfile or keyfile (%s) falling back to acme", err) 151 | } 152 | cfg.Certificates = append(cfg.Certificates, cert) 153 | tlsConfigured = true 154 | } 155 | 156 | if !tlsConfigured { 157 | log.Printf("Using acme / letsencrypt for tls configuration. Enabling http (port 80) for verification") 158 | // setup a simple handler which sends a HTHS header for six months (!) 159 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 160 | w.Header().Set("Strict-Transport-Security", "max-age=15768000 ; includeSubDomains") 161 | fmt.Fprintf(w, "Hello from RDPGW") 162 | }) 163 | 164 | certMgr := autocert.Manager{ 165 | Prompt: autocert.AcceptTOS, 166 | HostPolicy: autocert.HostWhitelist(url.Host), 167 | Cache: autocert.DirCache("/tmp/rdpgw"), 168 | } 169 | cfg.GetCertificate = certMgr.GetCertificate 170 | 171 | go func() { 172 | http.ListenAndServe(":80", certMgr.HTTPHandler(nil)) 173 | }() 174 | } 175 | } 176 | 177 | // gateway confg 178 | gw := protocol.Gateway{ 179 | RedirectFlags: protocol.RedirectFlags{ 180 | Clipboard: conf.Caps.EnableClipboard, 181 | Drive: conf.Caps.EnableDrive, 182 | Printer: conf.Caps.EnablePrinter, 183 | Port: conf.Caps.EnablePort, 184 | Pnp: conf.Caps.EnablePnp, 185 | DisableAll: conf.Caps.DisableRedirect, 186 | EnableAll: conf.Caps.RedirectAll, 187 | }, 188 | IdleTimeout: conf.Caps.IdleTimeout, 189 | SmartCardAuth: conf.Caps.SmartCardAuth, 190 | TokenAuth: conf.Caps.TokenAuth, 191 | ReceiveBuf: conf.Server.ReceiveBuf, 192 | SendBuf: conf.Server.SendBuf, 193 | } 194 | 195 | if conf.Caps.TokenAuth { 196 | gw.CheckPAACookie = security.CheckPAACookie 197 | gw.CheckHost = security.CheckSession(security.CheckHost) 198 | } else { 199 | gw.CheckHost = security.CheckHost 200 | } 201 | 202 | r := mux.NewRouter() 203 | 204 | // ensure identity is set in context and get some extra info 205 | r.Use(web.EnrichContext) 206 | 207 | // prometheus metrics 208 | r.Handle("/metrics", promhttp.Handler()) 209 | 210 | // for sso callbacks 211 | r.HandleFunc("/tokeninfo", web.TokenInfo) 212 | 213 | // gateway endpoint 214 | rdp := r.PathPrefix(gatewayEndPoint).Subrouter() 215 | 216 | // openid 217 | if conf.Server.OpenIDEnabled() { 218 | log.Printf("enabling openid extended authentication") 219 | o := initOIDC(url) 220 | r.Handle("/connect", o.Authenticated(http.HandlerFunc(h.HandleDownload))) 221 | r.HandleFunc("/callback", o.HandleCallback) 222 | 223 | // only enable un-auth endpoint for openid only config 224 | if !conf.Server.KerberosEnabled() && !conf.Server.BasicAuthEnabled() && !conf.Server.NtlmEnabled() { 225 | rdp.Name("gw").HandlerFunc(gw.HandleGatewayProtocol) 226 | } 227 | } 228 | 229 | // for stacking of authentication 230 | auth := web.NewAuthMux() 231 | rdp.MatcherFunc(web.NoAuthz).HandlerFunc(auth.SetAuthenticate) 232 | 233 | // ntlm 234 | if conf.Server.NtlmEnabled() { 235 | log.Printf("enabling NTLM authentication") 236 | ntlm := web.NTLMAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout} 237 | rdp.NewRoute().HeadersRegexp("Authorization", "NTLM").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol)) 238 | rdp.NewRoute().HeadersRegexp("Authorization", "Negotiate").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol)) 239 | auth.Register(`NTLM`) 240 | auth.Register(`Negotiate`) 241 | } 242 | 243 | // basic auth 244 | if conf.Server.BasicAuthEnabled() { 245 | log.Printf("enabling basic authentication") 246 | q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout} 247 | rdp.NewRoute().HeadersRegexp("Authorization", "Basic").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol)) 248 | auth.Register(`Basic realm="restricted", charset="UTF-8"`) 249 | } 250 | 251 | // spnego / kerberos 252 | if conf.Server.KerberosEnabled() { 253 | log.Printf("enabling kerberos authentication") 254 | keytab, err := keytab.Load(conf.Kerberos.Keytab) 255 | if err != nil { 256 | log.Fatalf("Cannot load keytab: %s", err) 257 | } 258 | rdp.NewRoute().HeadersRegexp("Authorization", "Negotiate").Handler( 259 | spnego.SPNEGOKRB5Authenticate(web.TransposeSPNEGOContext(http.HandlerFunc(gw.HandleGatewayProtocol)), 260 | keytab, 261 | service.Logger(log.Default()))) 262 | 263 | // kdcproxy 264 | k := kdcproxy.InitKdcProxy(conf.Kerberos.Krb5Conf) 265 | r.HandleFunc(kdcProxyEndPoint, k.Handler).Methods("POST") 266 | auth.Register("Negotiate") 267 | } 268 | 269 | // setup server 270 | server := http.Server{ 271 | Addr: ":" + strconv.Itoa(conf.Server.Port), 272 | Handler: r, 273 | TLSConfig: cfg, 274 | TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 275 | } 276 | 277 | if conf.Server.Tls == config.TlsDisable { 278 | err = server.ListenAndServe() 279 | } else { 280 | err = server.ListenAndServeTLS("", "") 281 | } 282 | if err != nil { 283 | log.Fatal("ListenAndServe: ", err) 284 | } 285 | } 286 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/client.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net" 10 | ) 11 | 12 | const ( 13 | MajorVersion = 0x0 14 | MinorVersion = 0x0 15 | Version = 0x00 16 | ) 17 | 18 | type ClientConfig struct { 19 | SmartCardAuth bool 20 | PAAToken string 21 | NTLMAuth bool 22 | Session *Tunnel 23 | LocalConn net.Conn 24 | Server string 25 | Port int 26 | Name string 27 | } 28 | 29 | func (c *ClientConfig) ConnectAndForward() error { 30 | c.Session.transportOut.WritePacket(c.handshakeRequest()) 31 | 32 | for { 33 | messages, err := readMessage(c.Session.transportIn) 34 | if err != nil { 35 | log.Printf("Cannot read message from stream %s", err) 36 | return err 37 | } 38 | 39 | for _, message := range messages { 40 | if message.err != nil { 41 | log.Printf("Cannot read message from stream %p", err) 42 | continue 43 | } 44 | switch message.packetType { 45 | case PKT_TYPE_HANDSHAKE_RESPONSE: 46 | caps, err := c.handshakeResponse(message.msg) 47 | if err != nil { 48 | log.Printf("Cannot connect to %s due to %s", c.Server, err) 49 | return err 50 | } 51 | log.Printf("Handshake response received. Caps: %d", caps) 52 | c.Session.transportOut.WritePacket(c.tunnelRequest()) 53 | case PKT_TYPE_TUNNEL_RESPONSE: 54 | tid, caps, err := c.tunnelResponse(message.msg) 55 | if err != nil { 56 | log.Printf("Cannot setup tunnel due to %s", err) 57 | return err 58 | } 59 | log.Printf("Tunnel creation succesful. Tunnel id: %d and caps %d", tid, caps) 60 | c.Session.transportOut.WritePacket(c.tunnelAuthRequest()) 61 | case PKT_TYPE_TUNNEL_AUTH_RESPONSE: 62 | flags, timeout, err := c.tunnelAuthResponse(message.msg) 63 | if err != nil { 64 | log.Printf("Cannot do tunnel auth due to %s", err) 65 | return err 66 | } 67 | log.Printf("Tunnel auth succesful. Flags: %d and timeout %d", flags, timeout) 68 | c.Session.transportOut.WritePacket(c.channelRequest()) 69 | case PKT_TYPE_CHANNEL_RESPONSE: 70 | cid, err := c.channelResponse(message.msg) 71 | if err != nil { 72 | log.Printf("Cannot do tunnel auth due to %s", err) 73 | return err 74 | } 75 | if cid < 1 { 76 | log.Printf("Channel id (%d) is smaller than 1. This doesnt work for Windows clients", cid) 77 | } 78 | log.Printf("Channel creation succesful. Channel id: %d", cid) 79 | //go forward(c.LocalConn, c.Session.transportOut) 80 | case PKT_TYPE_DATA: 81 | receive(message.msg, c.LocalConn) 82 | default: 83 | log.Printf("Unknown packet type received: %d size %d", message.packetType, message.length) 84 | } 85 | } 86 | } 87 | } 88 | 89 | func (c *ClientConfig) handshakeRequest() []byte { 90 | var caps uint16 91 | 92 | if c.SmartCardAuth { 93 | caps = caps | HTTP_EXTENDED_AUTH_SC 94 | } 95 | 96 | if len(c.PAAToken) > 0 { 97 | caps = caps | HTTP_EXTENDED_AUTH_PAA 98 | } 99 | 100 | if c.NTLMAuth { 101 | caps = caps | HTTP_EXTENDED_AUTH_SSPI_NTLM 102 | } 103 | 104 | buf := new(bytes.Buffer) 105 | 106 | binary.Write(buf, binary.LittleEndian, byte(MajorVersion)) 107 | binary.Write(buf, binary.LittleEndian, byte(MinorVersion)) 108 | binary.Write(buf, binary.LittleEndian, uint16(Version)) 109 | 110 | binary.Write(buf, binary.LittleEndian, uint16(caps)) 111 | 112 | return createPacket(PKT_TYPE_HANDSHAKE_REQUEST, buf.Bytes()) 113 | } 114 | 115 | func (c *ClientConfig) handshakeResponse(data []byte) (caps uint16, err error) { 116 | var errorCode int32 117 | var major byte 118 | var minor byte 119 | var version uint16 120 | 121 | r := bytes.NewReader(data) 122 | binary.Read(r, binary.LittleEndian, &errorCode) 123 | binary.Read(r, binary.LittleEndian, &major) 124 | binary.Read(r, binary.LittleEndian, &minor) 125 | binary.Read(r, binary.LittleEndian, &version) 126 | binary.Read(r, binary.LittleEndian, &caps) 127 | 128 | if errorCode > 0 { 129 | return 0, fmt.Errorf("error code: %d", errorCode) 130 | } 131 | 132 | return caps, nil 133 | } 134 | 135 | func (c *ClientConfig) tunnelRequest() []byte { 136 | buf := new(bytes.Buffer) 137 | var caps uint32 138 | var size uint16 139 | var fields uint16 140 | 141 | if len(c.PAAToken) > 0 { 142 | fields = fields | HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE 143 | } 144 | 145 | caps = caps | HTTP_CAPABILITY_IDLE_TIMEOUT 146 | 147 | binary.Write(buf, binary.LittleEndian, caps) 148 | binary.Write(buf, binary.LittleEndian, fields) 149 | binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved 150 | 151 | if len(c.PAAToken) > 0 { 152 | utf16Token := EncodeUTF16(c.PAAToken) 153 | size = uint16(len(utf16Token)) 154 | binary.Write(buf, binary.LittleEndian, size) 155 | buf.Write(utf16Token) 156 | } 157 | 158 | return createPacket(PKT_TYPE_TUNNEL_CREATE, buf.Bytes()) 159 | } 160 | 161 | func (c *ClientConfig) tunnelResponse(data []byte) (tunnelId uint32, caps uint32, err error) { 162 | var version uint16 163 | var errorCode uint32 164 | var fields uint16 165 | 166 | r := bytes.NewReader(data) 167 | binary.Read(r, binary.LittleEndian, &version) 168 | binary.Read(r, binary.LittleEndian, &errorCode) 169 | binary.Read(r, binary.LittleEndian, &fields) 170 | r.Seek(2, io.SeekCurrent) 171 | if (fields & HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID) == HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID { 172 | binary.Read(r, binary.LittleEndian, &tunnelId) 173 | } 174 | if (fields & HTTP_TUNNEL_RESPONSE_FIELD_CAPS) == HTTP_TUNNEL_RESPONSE_FIELD_CAPS { 175 | binary.Read(r, binary.LittleEndian, &caps) 176 | } 177 | 178 | if errorCode != 0 { 179 | err = fmt.Errorf("tunnel error %d", errorCode) 180 | } 181 | 182 | return 183 | } 184 | 185 | func (c *ClientConfig) tunnelAuthRequest() []byte { 186 | utf16name := EncodeUTF16(c.Name) 187 | size := uint16(len(utf16name)) 188 | 189 | buf := new(bytes.Buffer) 190 | binary.Write(buf, binary.LittleEndian, size) 191 | buf.Write(utf16name) 192 | 193 | return createPacket(PKT_TYPE_TUNNEL_AUTH, buf.Bytes()) 194 | } 195 | 196 | func (c *ClientConfig) tunnelAuthResponse(data []byte) (flags uint32, timeout uint32, err error) { 197 | var errorCode uint32 198 | var fields uint16 199 | 200 | r := bytes.NewReader(data) 201 | binary.Read(r, binary.LittleEndian, &errorCode) 202 | binary.Read(r, binary.LittleEndian, &fields) 203 | r.Seek(2, io.SeekCurrent) 204 | 205 | if (fields & HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS) == HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS { 206 | binary.Read(r, binary.LittleEndian, &flags) 207 | } 208 | if (fields & HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT) == HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT { 209 | binary.Read(r, binary.LittleEndian, &timeout) 210 | } 211 | 212 | if errorCode > 0 { 213 | return 0, 0, fmt.Errorf("tunnel auth error %d", errorCode) 214 | } 215 | 216 | return 217 | } 218 | 219 | func (c *ClientConfig) channelRequest() []byte { 220 | utf16server := EncodeUTF16(c.Server) 221 | 222 | buf := new(bytes.Buffer) 223 | binary.Write(buf, binary.LittleEndian, []byte{0x01}) // amount of server names 224 | binary.Write(buf, binary.LittleEndian, []byte{0x00}) // amount of alternate server names (range 0-3) 225 | binary.Write(buf, binary.LittleEndian, uint16(c.Port)) 226 | binary.Write(buf, binary.LittleEndian, uint16(3)) // protocol, must be 3 227 | 228 | binary.Write(buf, binary.LittleEndian, uint16(len(utf16server))) 229 | buf.Write(utf16server) 230 | 231 | return createPacket(PKT_TYPE_CHANNEL_CREATE, buf.Bytes()) 232 | } 233 | 234 | func (c *ClientConfig) channelResponse(data []byte) (channelId uint32, err error) { 235 | var errorCode uint32 236 | var fields uint16 237 | 238 | r := bytes.NewReader(data) 239 | binary.Read(r, binary.LittleEndian, &errorCode) 240 | binary.Read(r, binary.LittleEndian, &fields) 241 | r.Seek(2, io.SeekCurrent) 242 | 243 | if (fields & HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID) == HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID { 244 | binary.Read(r, binary.LittleEndian, &channelId) 245 | } 246 | 247 | if errorCode > 0 { 248 | return 0, fmt.Errorf("channel response error %d", errorCode) 249 | } 250 | 251 | return channelId, nil 252 | } 253 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/common.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log" 10 | "net" 11 | "os" 12 | "syscall" 13 | 14 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" 15 | ) 16 | 17 | const ( 18 | maxFragmentSize = 65536 19 | ) 20 | 21 | type RedirectFlags struct { 22 | Clipboard bool 23 | Port bool 24 | Drive bool 25 | Printer bool 26 | Pnp bool 27 | DisableAll bool 28 | EnableAll bool 29 | } 30 | 31 | func handleMsgFrame(packet *packetReader) *message { 32 | pt, sz, msg, err := readHeader(packet.getPtr()) 33 | if err == nil { 34 | packet.incrementPtr(int(sz)) 35 | return &message{packetType: int(pt), length: int(sz), msg: msg, err: nil} 36 | } 37 | 38 | buf := make([]byte, maxFragmentSize) 39 | index := 0 40 | for { 41 | // keep parsing thfragment 42 | if len(packet.getPtr()) > len(buf[index:]) { 43 | return &message{packetType: int(pt), length: int(sz), msg: msg, err: fmt.Errorf("fragment exceeded max fragment size")} 44 | } 45 | index += copy(buf[index:], packet.getPtr()) 46 | // Get a new frame 47 | err := packet.read() 48 | if err != nil { 49 | // Failed to make a msg 50 | return &message{packetType: int(pt), length: int(sz), msg: msg, err: err} 51 | } 52 | pt, sz, msg, err = readHeader(append(buf[:index], packet.getPtr()...)) 53 | if err == nil { 54 | // the increment is based upon how much of the data we have used 55 | // in this packet. The index tells us how much is in the previous frame(s), 56 | // So we remove that from the size of the message. 57 | packet.incrementPtr(int(sz) - index) 58 | return &message{packetType: int(pt), length: int(sz), msg: msg, err: nil} 59 | } 60 | } 61 | } 62 | 63 | // readMessage parses and defragments a packet from a Transport. It returns 64 | // at most the bytes that have been reported by the packet. 65 | func readMessage(in transport.Transport) ([]*message, error) { 66 | messages := make([]*message, 0) 67 | 68 | packet := newTransportPacket(in) 69 | err := packet.read() 70 | if err != nil { 71 | return messages, err 72 | } 73 | 74 | var message *message 75 | for packet.hasMoreData() { 76 | message = handleMsgFrame(packet) 77 | messages = append(messages, message) 78 | } 79 | return messages, nil 80 | } 81 | 82 | // createPacket wraps the data into the protocol packet 83 | func createPacket(pktType uint16, data []byte) (packet []byte) { 84 | size := len(data) + 8 85 | buf := new(bytes.Buffer) 86 | 87 | binary.Write(buf, binary.LittleEndian, uint16(pktType)) 88 | binary.Write(buf, binary.LittleEndian, uint16(0)) // reserved 89 | binary.Write(buf, binary.LittleEndian, uint32(size)) 90 | buf.Write(data) 91 | 92 | return buf.Bytes() 93 | } 94 | 95 | // readHeader parses a packet and verifies its reported size 96 | func readHeader(data []byte) (packetType uint16, size uint32, packet []byte, err error) { 97 | // header needs to be 8 min 98 | if len(data) < 8 { 99 | return 0, 0, nil, errors.New("header too short, fragment likely") 100 | } 101 | r := bytes.NewReader(data) 102 | binary.Read(r, binary.LittleEndian, &packetType) 103 | r.Seek(4, io.SeekStart) 104 | binary.Read(r, binary.LittleEndian, &size) 105 | if len(data) < int(size) { 106 | return packetType, size, data[8:], errors.New("data incomplete, fragment received") 107 | } 108 | return packetType, size, data[8:size], nil 109 | } 110 | 111 | // forwards data from a Connection to Transport and wraps it in the rdpgw protocol 112 | func forward(in net.Conn, tunnel *Tunnel) { 113 | defer in.Close() 114 | 115 | b1 := new(bytes.Buffer) 116 | buf := make([]byte, 4086) 117 | 118 | for { 119 | n, err := in.Read(buf) 120 | if err != nil { 121 | log.Printf("Error reading from local conn %s", err) 122 | break 123 | } 124 | binary.Write(b1, binary.LittleEndian, uint16(n)) 125 | b1.Write(buf[:n]) 126 | tunnel.Write(createPacket(PKT_TYPE_DATA, b1.Bytes())) 127 | b1.Reset() 128 | } 129 | } 130 | 131 | // receive data received from the gateway client, unwrap and forward the remote desktop server 132 | func receive(data []byte, out net.Conn) { 133 | buf := bytes.NewReader(data) 134 | 135 | var cblen uint16 136 | binary.Read(buf, binary.LittleEndian, &cblen) 137 | pkt := make([]byte, cblen) 138 | binary.Read(buf, binary.LittleEndian, &pkt) 139 | 140 | out.Write(pkt) 141 | } 142 | 143 | // wrapSyscallError takes an error and a syscall name. If the error is 144 | // a syscall.Errno, it wraps it in a os.SyscallError using the syscall name. 145 | func wrapSyscallError(name string, err error) error { 146 | if _, ok := err.(syscall.Errno); ok { 147 | err = os.NewSyscallError(name, err) 148 | } 149 | return err 150 | } 151 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/common_test.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "sync" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | type messageMock struct { 13 | buffer []byte 14 | msgBuffer []byte 15 | } 16 | 17 | const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 18 | 19 | func randBytes(message []byte) { 20 | for index := range message { 21 | message[index] = letterBytes[rand.Intn(len(letterBytes))] 22 | } 23 | } 24 | 25 | func newMessageMock(packetType uint16, message []byte) *messageMock { 26 | randBytes(message) 27 | buf := createPacket(packetType, message) 28 | return &messageMock{msgBuffer: buf[8:], buffer: buf} 29 | } 30 | 31 | type packetMock struct { 32 | bytes []byte 33 | err error 34 | } 35 | 36 | func newPacketMock() *packetMock { 37 | return &packetMock{bytes: make([]byte, 0)} 38 | } 39 | 40 | func (p *packetMock) addBytes(b []byte) { 41 | p.bytes = append(p.bytes, b...) 42 | } 43 | 44 | func (p *packetMock) GetPacket() (int, []byte, error) { 45 | return len(p.bytes), p.bytes, p.err 46 | } 47 | 48 | type transportMock struct { 49 | lock sync.Mutex 50 | packets []*packetMock 51 | packetPtr int 52 | } 53 | 54 | func newTransportMock() *transportMock { 55 | return &transportMock{packets: make([]*packetMock, 0)} 56 | } 57 | 58 | func (t *transportMock) addPacket(p *packetMock) { 59 | t.lock.Lock() 60 | defer t.lock.Unlock() 61 | 62 | t.packets = append(t.packets, p) 63 | } 64 | 65 | func (t *transportMock) ReadPacket() (n int, p []byte, err error) { 66 | t.lock.Lock() 67 | defer t.lock.Unlock() 68 | 69 | if t.packetPtr >= len(t.packets) { 70 | return 0, nil, fmt.Errorf("no packets available") 71 | } 72 | packet := t.packets[t.packetPtr] 73 | t.packetPtr++ 74 | return packet.GetPacket() 75 | } 76 | 77 | func (t *transportMock) WritePacket(b []byte) (n int, err error) { 78 | return 0, fmt.Errorf("not tested") 79 | } 80 | 81 | func (t *transportMock) Close() error { 82 | return nil 83 | } 84 | 85 | func TestSimplePacket(t *testing.T) { 86 | transport := newTransportMock() 87 | m := newMessageMock(6, make([]byte, 10)) 88 | p := newPacketMock() 89 | p.addBytes(m.buffer) 90 | transport.addPacket(p) 91 | 92 | messages, err := readMessage(transport) 93 | assert.Nil(t, err) 94 | assert.NotNil(t, messages) 95 | assert.Len(t, messages, 1) 96 | assert.Equal(t, 6, messages[0].packetType) 97 | assert.Equal(t, 18, messages[0].length) 98 | assert.Equal(t, m.msgBuffer, messages[0].msg) 99 | } 100 | 101 | func TestMultiMessageInPacket(t *testing.T) { 102 | transport := newTransportMock() 103 | p := newPacketMock() 104 | 105 | m := newMessageMock(6, make([]byte, 10)) 106 | p.addBytes(m.buffer) 107 | 108 | m2 := newMessageMock(8, make([]byte, 12)) 109 | p.addBytes(m2.buffer) 110 | 111 | m3 := newMessageMock(8, make([]byte, 12)) 112 | p.addBytes(m3.buffer) 113 | 114 | transport.addPacket(p) 115 | 116 | messages, err := readMessage(transport) 117 | assert.Nil(t, err) 118 | assert.NotNil(t, messages) 119 | assert.Len(t, messages, 3) 120 | assert.Nil(t, messages[0].err) 121 | assert.Equal(t, 6, messages[0].packetType) 122 | assert.Equal(t, 18, messages[0].length) 123 | assert.Equal(t, m.msgBuffer, messages[0].msg) 124 | 125 | assert.Nil(t, messages[1].err) 126 | assert.Equal(t, 8, messages[1].packetType) 127 | assert.Equal(t, 20, messages[1].length) 128 | assert.Equal(t, m2.msgBuffer, messages[1].msg) 129 | 130 | assert.Nil(t, messages[2].err) 131 | assert.Equal(t, 8, messages[2].packetType) 132 | assert.Equal(t, 20, messages[2].length) 133 | assert.Equal(t, m3.msgBuffer, messages[2].msg) 134 | } 135 | 136 | func TestFragment(t *testing.T) { 137 | transport := newTransportMock() 138 | p1 := newPacketMock() 139 | p2 := newPacketMock() 140 | 141 | m := newMessageMock(6, make([]byte, 100)) 142 | // split the message across 2 packets 143 | p1.addBytes(m.buffer[0:50]) 144 | p2.addBytes(m.buffer[50:]) 145 | transport.addPacket(p1) 146 | transport.addPacket(p2) 147 | 148 | messages, err := readMessage(transport) 149 | assert.Nil(t, err) 150 | assert.NotNil(t, messages) 151 | assert.Len(t, messages, 1) 152 | assert.Equal(t, 6, messages[0].packetType) 153 | assert.Equal(t, 108, messages[0].length) 154 | assert.Equal(t, m.msgBuffer, messages[0].msg) 155 | 156 | _, err = readMessage(transport) 157 | // no more packets 158 | assert.NotNil(t, err) 159 | } 160 | 161 | func TestDroppedBytes(t *testing.T) { 162 | transport := newTransportMock() 163 | p1 := newPacketMock() 164 | 165 | m := newMessageMock(6, make([]byte, 100)) 166 | // add only partial bytes 167 | p1.addBytes(m.buffer[0:50]) 168 | transport.addPacket(p1) 169 | 170 | messages, err := readMessage(transport) 171 | assert.Nil(t, err) 172 | assert.Len(t, messages, 1) 173 | assert.NotNil(t, messages[0].err) 174 | 175 | _, err = readMessage(transport) 176 | // no more packets 177 | assert.NotNil(t, err) 178 | } 179 | 180 | func TestTooMuchData(t *testing.T) { 181 | transport := newTransportMock() 182 | p1 := newPacketMock() 183 | 184 | m := newMessageMock(6, make([]byte, 100)) 185 | // add only partial bytes 186 | p1.addBytes(m.buffer) 187 | p1.addBytes([]byte{0, 0, 0}) 188 | // add some junk bytes 189 | transport.addPacket(p1) 190 | 191 | messages, err := readMessage(transport) 192 | assert.Nil(t, err) 193 | assert.NotNil(t, messages) 194 | assert.Len(t, messages, 2) 195 | assert.Nil(t, messages[0].err) 196 | assert.NotNil(t, messages[1].err) 197 | 198 | _, err = readMessage(transport) 199 | // no more packets 200 | assert.NotNil(t, err) 201 | } 202 | 203 | func TestJumbo(t *testing.T) { 204 | transport := newTransportMock() 205 | p1 := newPacketMock() 206 | p2 := newPacketMock() 207 | 208 | m := newMessageMock(6, make([]byte, maxFragmentSize)) 209 | // add only partial bytes 210 | p1.addBytes(m.buffer[0 : maxFragmentSize/2]) 211 | p2.addBytes(m.buffer[maxFragmentSize/2:]) 212 | // add some junk bytes 213 | transport.addPacket(p1) 214 | transport.addPacket(p2) 215 | 216 | messages, err := readMessage(transport) 217 | assert.Nil(t, err) 218 | assert.NotNil(t, messages) 219 | assert.Len(t, messages, 1) 220 | assert.Equal(t, m.msgBuffer, messages[0].msg) 221 | } 222 | 223 | func TestManyFragments(t *testing.T) { 224 | transport := newTransportMock() 225 | 226 | m := newMessageMock(6, make([]byte, 256)) 227 | fragmentSize := len(m.buffer) / 5 228 | bufferSize := len(m.buffer) 229 | for fragPtr := 0; fragPtr < len(m.buffer); fragPtr += fragmentSize { 230 | p := newPacketMock() 231 | p.addBytes(m.buffer[fragPtr:min(bufferSize, fragPtr+fragmentSize)]) 232 | transport.addPacket(p) 233 | } 234 | 235 | messages, err := readMessage(transport) 236 | assert.Nil(t, err) 237 | assert.NotNil(t, messages) 238 | assert.Len(t, messages, 1) 239 | assert.Nil(t, messages[0].err) 240 | assert.Equal(t, m.msgBuffer, messages[0].msg) 241 | 242 | messages, err = readMessage(transport) 243 | // no more packets 244 | fmt.Println(messages) 245 | assert.NotNil(t, err) 246 | } 247 | 248 | func TestFragmentTooLarge(t *testing.T) { 249 | transport := newTransportMock() 250 | 251 | m := newMessageMock(6, make([]byte, maxFragmentSize*2)) 252 | fragmentSize := len(m.buffer) / 5 253 | bufferSize := len(m.buffer) 254 | for fragPtr := 0; fragPtr < len(m.buffer); fragPtr += fragmentSize { 255 | p := newPacketMock() 256 | p.addBytes(m.buffer[fragPtr:min(bufferSize, fragPtr+fragmentSize)]) 257 | transport.addPacket(p) 258 | } 259 | 260 | messages, err := readMessage(transport) 261 | assert.Nil(t, err) 262 | assert.NotNil(t, messages[0].err) 263 | assert.Contains(t, "fragment exceeded max fragment size", messages[0].err.Error()) 264 | } 265 | 266 | // TestFragmentWithMultiMessage the first message is fragmented, 267 | // while the second message is found whole in the final packet 268 | func TestFragmentWithMultiMessage(t *testing.T) { 269 | transport := newTransportMock() 270 | p1 := newPacketMock() 271 | p2 := newPacketMock() 272 | 273 | m1 := newMessageMock(6, make([]byte, 100)) 274 | m2 := newMessageMock(6, make([]byte, 10)) 275 | // split the message across 2 packets 276 | p1.addBytes(m1.buffer[0:50]) 277 | p2.addBytes(m1.buffer[50:]) 278 | p2.addBytes(m2.buffer) 279 | transport.addPacket(p1) 280 | transport.addPacket(p2) 281 | 282 | messages, err := readMessage(transport) 283 | assert.Nil(t, err) 284 | assert.NotNil(t, messages) 285 | assert.Len(t, messages, 2) 286 | assert.Equal(t, 6, messages[0].packetType) 287 | assert.Equal(t, 108, messages[0].length) 288 | assert.Equal(t, m1.msgBuffer, messages[0].msg) 289 | 290 | assert.Equal(t, 6, messages[1].packetType) 291 | assert.Equal(t, 18, messages[1].length) 292 | assert.Equal(t, m2.msgBuffer, messages[1].msg) 293 | 294 | _, err = readMessage(transport) 295 | // no more packets 296 | assert.NotNil(t, err) 297 | } 298 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/errors.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | /* 4 | const ( 5 | ERROR_NO = 0x0000000 6 | ERROR_CLIENT_DISCONNECT = 0x0000001 7 | ERROR_CLIENT_LOGOFF = 0x0000002 8 | ERROR_NETWORK_DISCONNECT = 0x0000003 9 | ERROR_NOT_FOUND = 0x0000104 10 | ERROR_NO_MEM = 0x0000106 11 | ERROR_CONNECT_TIMEOUT = 0x0000108 12 | ERROR_SMARTCARD_SERVICE = 0x000010A 13 | ERROR_UNAVAILABLE = 0x0000204 14 | ERROR_SMARTCARD_READER = 0x000020A 15 | ERROR_NETWORK = 0x0000304 16 | ERROR_SMARTCART_NOCARD = 0x000030A 17 | ERROR_SECURITY = 0x0000406 18 | ERROR_INVALID_NAME = 0x0000408 19 | ERROR_SMARTCARD_SUBSYSTEM = 0x000040A 20 | ERROR_GENERIC = 0x0000704 21 | ERROR_CONSOLE_EXIST = 0x0000708 22 | ERROR_LICENSING_PROTOCOL = 0x0000808 23 | ERROR_NETWORK_GENERIC = 0x0000904 24 | ERROR_SECURITY_UNEXPECTED_CERTIFICATE = 0x0000907 25 | ERROR_LICENSING_TIMEOUT = 0x0000908 26 | ERROR_SECURITY_USER = 0x0000A07 27 | ERROR_GENERIC_UNAVAIL = 0x0000B04 28 | ERROR_ENCRYPTION = 0x0000B06 29 | ERROR_SECURITY_USER_DISABLED = 0x0000B07 30 | ERROR_SECURITY_NLA_REQUIRED = 0x0000B09 31 | ERROR_SECURITY_USER_RESTRICTION = 0x0000C07 32 | ERROR_DECOMPRESSION = 0x0000C08 33 | ERROR_SECURITY_USER_LOCKED_OUT = 0x0000D07 34 | ERROR_SECURITY_USER_DIALOG_REQUIRED = 0x0000D09 35 | ERROR_SECURITY_FIPS_REQUIRED = 0x0000E06 36 | ERROR_SECURITY_USER_EXPIRED = 0x0000E07 37 | ERROR_GENERIC_FAILED = 0x0000E08 38 | ERROR_SERVER_RA_UNAVAILABLE = 0x0000E09 39 | ERROR_SECURITY_USER_PASSWORD_EXPIRED = 0x0000F07 40 | ERROR_SECURITY_USER_CREDENTIALS_NOT_SENT = 0x0000F08 41 | ERROR_SECURITY_USER_TIME_RESTRICTION = 0x0001007 42 | ERROR_LOW_VIDEO = 0x0001008 43 | ERROR_SECURITY_USER_COMPUTER_RANGE = 0x0001107 44 | ERROR_SECURITY_USER_CHANGE_PASSWORD = 0x0001207 45 | ERROR_SECURITY_USER_LOGON_TYPE = 0x0001307 46 | ERROR_KRB_SUB_REQUIRED = 0x0001407 47 | ERROR_SECURITY_SERVER_INVALID_CERTIFICATE = 0x0001B07 48 | ERROR_SECURITY_SERVER_TIMESKEW = 0x0001D07 49 | ERROR_SECURITY_SMARTCARD_LOCKEDOUT = 0x0002207 50 | ERROR_RELAUNCH_APP = 0x0002507 51 | ERROR_UPGRADE_CLIENT = 0x0002604 52 | ERROR_RELAUNCH_REMOTE = 0x2000001 53 | ERROR_REMOTEAPP_UNSUPPORTED = 0x2000002 54 | ERROR_SECURITY_USER_PASSWORD_INVALID = 0x3000001 55 | ERROR_SECURITY_CERTIFICATE_REVOKE_LIST_UNAVAIL = 0x3000002 56 | ERROR_SECURITY_CERTIFICATE_INVALID = 0x3000003 57 | ERROR_SECURITY_CERTIFICATE_REVOKED = 0x3000004 58 | ERROR_SECURITY_GATEWAY_IDENTITY = 0x3000005 59 | ERROR_SECURITY_GATEWAY_SUBJECT = 0x3000006 60 | ERROR_SECURITY_GATEWAY_EXPIRED = 0x3000007 61 | ERROR_SECURITY_REMOTE_ERROR = 0x3000008 62 | ERROR_GATEWAY_NETWORK_SEND = 0x3000009 63 | ERROR_GATEWAY_NETWORK_RECEIVE = 0x300000A 64 | ERROR_SECURITY_ALTERNATE = 0x300000B 65 | ERROR_GATEWAY_INVALID_ADDRESS = 0x300000C 66 | ERROR_GATEWAY_TEMP_UNAVAIL = 0x300000D 67 | ERROR_REMOTE_CLIENT_MISSING = 0x300000E 68 | ERROR_GATEWAY_LOW_RESOURCES = 0x300000F 69 | ERROR_GATEWAY_CLIENT_DLL = 0x3000010 70 | ERROR_SMARTCART_NOSERVICE = 0x3000011 71 | ERROR_SECURITY_SMARTCARD_REMOVED = 0x3000012 72 | ERROR_SECURITY_SMARTCARD_REQUIRED = 0x3000013 73 | ERROR_SECURITY_SMARTCARD_REMOVED2 = 0x3000014 74 | ERROR_SECURITY_USER_PASSWORD_INVALID2 = 0x3000015 75 | ERROR_SECURITY_TRANSPORT = 0x3000017 76 | ERROR_GATEWAY_TERMINATE = 0x3000018 77 | ERROR_GATEWAY_ADMIN_TERMINATE = 0x3000019 78 | ERROR_SECURITY_USER_CREDENTIALS = 0x300001A 79 | ERROR_SECURITY_GATEWAY_NOT_PERMITTED = 0x300001B 80 | ERROR_SECURITY_GATEWAY_UNAUTHORIZED = 0x300001C 81 | ERROR_SECURITY_GATEWAY_RESTRICTED = 0x300001F 82 | ERROR_SECURITY_PROXY_AUTH = 0x3000020 83 | ERROR_SECURITY_USER_PASSWORD_MUST_CHANGE = 0x3000021 84 | ERROR_GATEWAY_MAX_REACHED = 0x3000022 85 | ERROR_GATEWAY_UNSUPPORTED_REQUEST = 0x3000023 86 | ERROR_GATEWAY_UNSUPPORTED_CAP = 0x3000024 87 | ERROR_GATEWAY_INCOMPAT = 0x3000025 88 | ERROR_SECURITY_SMARTCARD_INVALID_CREDENTIALS = 0x3000026 89 | ERROR_SECURITY_NLA_INVALID = 0x3000027 90 | ERROR_GATEWAY_NO_CERTIFICATE = 0x3000028 91 | ERROR_GATEWAY_NOT_ALLOWED = 0x3000029 92 | ERROR_GATEWAY_INVALID_CERTIFICATE = 0x300002A 93 | ERROR_SECURITY_GATEWAY_USER_PASSWORD_REQUIRED = 0x300002B 94 | ERROR_SECURITY_GATEWAY_SMARTCARD_REQUIRED = 0x300002C 95 | ERROR_SECURITY_SMARTCARD_UNAVAIL = 0x300002D 96 | ERROR_SECURITY_FIREWALL_NOAUTH = 0x300002F 97 | ERROR_SECURITY_FIREWALL_AUTH = 0x3000030 98 | ERROR_NO_INPUT = 0x3000032 99 | ERROR_TIMEOUT = 0x3000033 100 | ERROR_SECURITY_GATEWAY_COOKIE_INVALID = 0x3000034 101 | ERROR_SECURITY_GATEWAY_COOKIE_REJECTED = 0x3000035 102 | ERROR_SECURITY_GATEWAY_AUTH_METHOD = 0x3000037 103 | ERROR_SECURITY_USER_PERIOD_AUTH = 0x3000038 104 | ERROR_SECURITY_USER_PERIOD_AUTHZ = 0x3000039 105 | ERROR_SECURITY_GATEWAY_POLICY = 0x300003B 106 | ERROR_SECURITY_SMARTCARD_CERTIFICATE = 0x300003C 107 | ERROR_LOGON_FIRST = 0x300003D 108 | ERROR_AUTH_LOGON_FIRST = 0x300003E 109 | ERROR_SESSION_ENDED = 0x300003F 110 | ERROR_SESSION_ENDED_AUTH = 0x3000040 111 | ERROR_SECURITY_GATEWAY_NAP = 0x3000041 112 | ERROR_COOKIE_SIZE = 0x3000042 113 | ERROR_PROXY_CONFIG = 0x3000044 114 | ERROR_NO_PERMISSION = 0x3000045 115 | ERROR_NO_RESOURCES = 0x3000046 116 | ERROR_RESOURCE_ACCESS = 0x3000047 117 | ERROR_UPGRADE_CLIENT2 = 0x3000049 118 | ERROR_SECURITY_NETWORK_HTTPS = 0x300004A 119 | ERROR_TEMP_FAIL = 0x300004B 120 | ERROR_SECURITY_USER_MISMATCH = 0x300004C 121 | ERROR_AZURE_TOO_MANY = 0x300004D 122 | ERROR_MAX_USER = 0x300004E 123 | ERROR_AZURE_TRIAL = 0x300004F 124 | ERROR_AZURE_EXPIRED = 0x3000050 125 | ) 126 | */ 127 | 128 | /* Common Error Code */ 129 | const ( 130 | ERROR_SUCCESS = 0x00000000 131 | ERROR_ACCESS_DENIED = 0x00000005 132 | E_PROXY_INTERNALERROR = 0x800759D8 133 | E_PROXY_RAP_ACCESSDENIED = 0x800759DA 134 | E_PROXY_NAP_ACCESSDENIED = 0x800759DB 135 | E_PROXY_ALREADYDISCONNECTED = 0x800759DF 136 | E_PROXY_QUARANTINE_ACCESSDENIED = 0x800759ED 137 | E_PROXY_NOCERTAVAILABLE = 0x800759EE 138 | E_PROXY_COOKIE_BADPACKET = 0x800759F7 139 | E_PROXY_COOKIE_AUTHENTICATION_ACCESS_DENIED = 0x800759F8 140 | E_PROXY_UNSUPPORTED_AUTHENTICATION_METHOD = 0x800759F9 141 | E_PROXY_CAPABILITYMISMATCH = 0x800759E9 142 | E_PROXY_TS_CONNECTFAILED = 0x000059DD 143 | E_PROXY_MAXCONNECTIONSREACHED = 0x000059E6 144 | // E_PROXY_INTERNALERROR = 0x000059D8 145 | ERROR_GRACEFUL_DISCONNECT = 0x000004CA 146 | E_PROXY_NOTSUPPORTED = 0x000059E8 147 | SEC_E_LOGON_DENIED = 0x8009030C 148 | E_PROXY_SESSIONTIMEOUT = 0x000059F6 149 | E_PROXY_REAUTH_AUTHN_FAILED = 0x000059FA 150 | E_PROXY_REAUTH_CAP_FAILED = 0x000059FB 151 | E_PROXY_REAUTH_RAP_FAILED = 0x000059FC 152 | E_PROXY_SDR_NOT_SUPPORTED_BY_TS = 0x000059FD 153 | E_PROXY_REAUTH_NAP_FAILED = 0x00005A00 154 | E_PROXY_CONNECTIONABORTED = 0x000004D4 155 | ) 156 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/gateway.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 7 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" 8 | "github.com/google/uuid" 9 | "github.com/gorilla/websocket" 10 | "github.com/patrickmn/go-cache" 11 | "log" 12 | "net" 13 | "net/http" 14 | "reflect" 15 | "syscall" 16 | "time" 17 | ) 18 | 19 | const ( 20 | rdgConnectionIdKey = "Rdg-Connection-Id" 21 | MethodRDGIN = "RDG_IN_DATA" 22 | MethodRDGOUT = "RDG_OUT_DATA" 23 | ) 24 | 25 | type CheckPAACookieFunc func(context.Context, string) (bool, error) 26 | type CheckClientNameFunc func(context.Context, string) (bool, error) 27 | type CheckHostFunc func(context.Context, string) (bool, error) 28 | 29 | type Gateway struct { 30 | // CheckPAACookie verifies if the PAA cookie sent by the client is valid 31 | CheckPAACookie CheckPAACookieFunc 32 | 33 | // CheckClientName verifies if the client name is allowed to connect 34 | CheckClientName CheckClientNameFunc 35 | 36 | // CheckHost verifies if the client is allowed to connect to the remote host 37 | CheckHost CheckHostFunc 38 | 39 | // RedirectFlags sets what devices the client is allowed to redirect to the remote host 40 | RedirectFlags RedirectFlags 41 | 42 | // IdleTimeOut is used to determine when to disconnect clients that have been idle 43 | IdleTimeout int 44 | 45 | // SmartCardAuth sets whether to use smart card based authentication 46 | SmartCardAuth bool 47 | 48 | // TokenAuth sets whether to use token/cookie based authentication 49 | TokenAuth bool 50 | 51 | ReceiveBuf int 52 | SendBuf int 53 | } 54 | 55 | var upgrader = websocket.Upgrader{} 56 | var c = cache.New(5*time.Minute, 10*time.Minute) 57 | 58 | func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) { 59 | connectionCache.Set(float64(c.ItemCount())) 60 | 61 | var t *Tunnel 62 | 63 | ctx := r.Context() 64 | id := identity.FromRequestCtx(r) 65 | 66 | connId := r.Header.Get(rdgConnectionIdKey) 67 | x, found := c.Get(connId) 68 | if !found { 69 | t = &Tunnel{ 70 | RDGId: connId, 71 | RemoteAddr: id.GetAttribute(identity.AttrRemoteAddr).(string), 72 | User: id, 73 | } 74 | } else { 75 | t = x.(*Tunnel) 76 | } 77 | ctx = context.WithValue(ctx, CtxTunnel, t) 78 | 79 | if r.Method == MethodRDGOUT { 80 | if r.Header.Get("Connection") != "upgrade" && r.Header.Get("Upgrade") != "websocket" { 81 | g.handleLegacyProtocol(w, r.WithContext(ctx), t) 82 | return 83 | } 84 | r.Method = "GET" // force 85 | conn, err := upgrader.Upgrade(w, r, nil) 86 | if err != nil { 87 | log.Printf("Cannot upgrade falling back to old protocol: %t", err) 88 | return 89 | } 90 | defer conn.Close() 91 | 92 | err = g.setSendReceiveBuffers(conn.UnderlyingConn()) 93 | if err != nil { 94 | log.Printf("Cannot set send/receive buffers: %t", err) 95 | } 96 | 97 | g.handleWebsocketProtocol(ctx, conn, t) 98 | } else if r.Method == MethodRDGIN { 99 | g.handleLegacyProtocol(w, r.WithContext(ctx), t) 100 | } 101 | } 102 | 103 | func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { 104 | if g.SendBuf < 1 && g.ReceiveBuf < 1 { 105 | return nil 106 | } 107 | 108 | // conn == tls.Tunnel 109 | ptr := reflect.ValueOf(conn) 110 | val := reflect.Indirect(ptr) 111 | 112 | if val.Kind() != reflect.Struct { 113 | return errors.New("didn't get a struct from conn") 114 | } 115 | 116 | // this gets net.Tunnel -> *net.TCPConn -> net.TCPConn 117 | ptrConn := val.FieldByName("conn") 118 | valConn := reflect.Indirect(ptrConn) 119 | if !valConn.IsValid() { 120 | return errors.New("cannot find conn field") 121 | } 122 | valConn = valConn.Elem().Elem() 123 | 124 | // net.FD 125 | ptrNetFd := valConn.FieldByName("fd") 126 | valNetFd := reflect.Indirect(ptrNetFd) 127 | if !valNetFd.IsValid() { 128 | return errors.New("cannot find fd field") 129 | } 130 | 131 | // pfd member 132 | ptrPfd := valNetFd.FieldByName("pfd") 133 | valPfd := reflect.Indirect(ptrPfd) 134 | if !valPfd.IsValid() { 135 | return errors.New("cannot find pfd field") 136 | } 137 | 138 | // finally the exported Sysfd 139 | ptrSysFd := valPfd.FieldByName("Sysfd") 140 | if !ptrSysFd.IsValid() { 141 | return errors.New("cannot find Sysfd field") 142 | } 143 | fd := int(ptrSysFd.Int()) 144 | 145 | if g.ReceiveBuf > 0 { 146 | err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf) 147 | if err != nil { 148 | return wrapSyscallError("setsockopt", err) 149 | } 150 | } 151 | 152 | if g.SendBuf > 0 { 153 | err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.SendBuf) 154 | if err != nil { 155 | return wrapSyscallError("setsockopt", err) 156 | } 157 | } 158 | 159 | return nil 160 | } 161 | 162 | func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, t *Tunnel) { 163 | websocketConnections.Inc() 164 | defer websocketConnections.Dec() 165 | 166 | inout, _ := transport.NewWS(c) 167 | defer inout.Close() 168 | 169 | t.Id = uuid.New().String() 170 | t.transportOut = inout 171 | t.transportIn = inout 172 | t.ConnectedOn = time.Now() 173 | 174 | handler := NewProcessor(g, t) 175 | RegisterTunnel(t, handler) 176 | defer RemoveTunnel(t) 177 | handler.Process(ctx) 178 | } 179 | 180 | // The legacy protocol (no websockets) uses an RDG_IN_DATA for client -> server 181 | // and RDG_OUT_DATA for server -> client data. The handshakeRequest procedure is a bit different 182 | // to ensure the connections do not get cached or terminated by a proxy prematurely. 183 | func (g *Gateway) handleLegacyProtocol(w http.ResponseWriter, r *http.Request, t *Tunnel) { 184 | log.Printf("Session %s, %t, %t", t.RDGId, t.transportOut != nil, t.transportIn != nil) 185 | 186 | id := identity.FromRequestCtx(r) 187 | if r.Method == MethodRDGOUT { 188 | out, err := transport.NewLegacy(w) 189 | if err != nil { 190 | log.Printf("cannot hijack connection to support RDG OUT data channel: %s", err) 191 | return 192 | } 193 | log.Printf("Opening RDGOUT for client %s", id.GetAttribute(identity.AttrClientIp)) 194 | 195 | t.transportOut = out 196 | out.SendAccept(true) 197 | 198 | c.Set(t.RDGId, t, cache.DefaultExpiration) 199 | } else if r.Method == MethodRDGIN { 200 | legacyConnections.Inc() 201 | defer legacyConnections.Dec() 202 | 203 | in, err := transport.NewLegacy(w) 204 | if err != nil { 205 | log.Printf("cannot hijack connection to support RDG IN data channel: %s", err) 206 | return 207 | } 208 | defer in.Close() 209 | 210 | if t.transportIn == nil { 211 | t.Id = uuid.New().String() 212 | t.transportIn = in 213 | c.Set(t.RDGId, t, cache.DefaultExpiration) 214 | 215 | log.Printf("Opening RDGIN for client %s", id.GetAttribute(identity.AttrClientIp)) 216 | in.SendAccept(false) 217 | 218 | // read some initial data 219 | in.Drain() 220 | 221 | log.Printf("Legacy handshakeRequest done for client %s", id.GetAttribute(identity.AttrClientIp)) 222 | handler := NewProcessor(g, t) 223 | RegisterTunnel(t, handler) 224 | defer RemoveTunnel(t) 225 | handler.Process(r.Context()) 226 | } 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/metrics.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import "github.com/prometheus/client_golang/prometheus" 4 | 5 | var ( 6 | connectionCache = prometheus.NewGauge( 7 | prometheus.GaugeOpts{ 8 | Namespace: "rdpgw", 9 | Name: "connection_cache", 10 | Help: "The amount of connections in the cache", 11 | }) 12 | 13 | websocketConnections = prometheus.NewGauge( 14 | prometheus.GaugeOpts{ 15 | Namespace: "rdpgw", 16 | Name: "websocket_connections", 17 | Help: "The count of websocket connections", 18 | }) 19 | 20 | legacyConnections = prometheus.NewGauge( 21 | prometheus.GaugeOpts{ 22 | Namespace: "rdpgw", 23 | Name: "legacy_connections", 24 | Help: "The count of legacy https connections", 25 | }) 26 | ) 27 | 28 | func init() { 29 | prometheus.MustRegister(connectionCache) 30 | prometheus.MustRegister(legacyConnections) 31 | prometheus.MustRegister(websocketConnections) 32 | } 33 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/packet_reader.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" 4 | 5 | type packetReader struct { 6 | in transport.Transport 7 | size int 8 | pkt []byte 9 | err error 10 | readPtr int 11 | } 12 | 13 | func newTransportPacket(in transport.Transport) *packetReader { 14 | return &packetReader{in: in} 15 | } 16 | 17 | func (t *packetReader) hasMoreData() bool { 18 | return t.readPtr < t.size 19 | } 20 | 21 | func (t *packetReader) getPtr() []byte { 22 | return t.pkt[t.readPtr:] 23 | } 24 | 25 | func (t *packetReader) incrementPtr(size int) { 26 | t.readPtr += size 27 | } 28 | 29 | func (t *packetReader) read() error { 30 | size, pkt, err := t.in.ReadPacket() 31 | if err != nil { 32 | t.size = 0 33 | } else { 34 | t.size = size 35 | } 36 | t.pkt = pkt 37 | t.err = err 38 | t.readPtr = 0 39 | return err 40 | } 41 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/protocol_test.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "testing" 7 | ) 8 | 9 | const ( 10 | HeaderLen = 8 11 | HandshakeRequestLen = HeaderLen + 6 12 | HandshakeResponseLen = HeaderLen + 10 13 | TunnelCreateRequestLen = HeaderLen + 8 // + dynamic 14 | TunnelCreateResponseLen = HeaderLen + 18 15 | TunnelAuthLen = HeaderLen + 2 // + dynamic 16 | TunnelAuthResponseLen = HeaderLen + 16 17 | ChannelCreateLen = HeaderLen + 8 // + dynamic 18 | ChannelResponseLen = HeaderLen + 12 19 | ) 20 | 21 | func verifyPacketHeader(data []byte, expPt uint16, expSize uint32) (uint16, uint32, []byte, error) { 22 | pt, size, pkt, err := readHeader(data) 23 | 24 | if pt != expPt { 25 | return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected packet type %d got %d", expPt, pt) 26 | } 27 | 28 | if size != expSize { 29 | return 0, 0, []byte{}, fmt.Errorf("readHeader failed, expected size %d, got %d", expSize, size) 30 | } 31 | 32 | if err != nil { 33 | return 0, 0, []byte{}, err 34 | } 35 | 36 | return pt, size, pkt, nil 37 | } 38 | 39 | func TestHandshake(t *testing.T) { 40 | client := ClientConfig{ 41 | PAAToken: "abab", 42 | } 43 | gw := &Gateway{} 44 | tunnel := &Tunnel{} 45 | 46 | h := NewProcessor(gw, tunnel) 47 | 48 | data := client.handshakeRequest() 49 | 50 | _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_REQUEST, HandshakeRequestLen) 51 | 52 | if err != nil { 53 | t.Fatalf("verifyHeader failed: %s", err) 54 | } 55 | 56 | log.Printf("pkt: %x", pkt) 57 | 58 | major, minor, version, extAuth := h.handshakeRequest(pkt) 59 | if major != MajorVersion || minor != MinorVersion || version != Version { 60 | t.Fatalf("handshakeRequest failed got version %d.%d protocol %d, expected %d.%d protocol %d", 61 | major, minor, version, MajorVersion, MinorVersion, Version) 62 | } 63 | 64 | if !((extAuth & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) { 65 | t.Fatalf("handshakeRequest failed got ext auth %d, expected %d", extAuth, extAuth|HTTP_EXTENDED_AUTH_PAA) 66 | } 67 | 68 | data = h.handshakeResponse(0x0, 0x0, HTTP_EXTENDED_AUTH_PAA, ERROR_SUCCESS) 69 | _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_HANDSHAKE_RESPONSE, HandshakeResponseLen) 70 | if err != nil { 71 | t.Fatalf("verifyHeader failed: %s", err) 72 | } 73 | log.Printf("pkt: %x", pkt) 74 | 75 | caps, err := client.handshakeResponse(pkt) 76 | if !((caps & HTTP_EXTENDED_AUTH_PAA) == HTTP_EXTENDED_AUTH_PAA) { 77 | t.Fatalf("handshakeResponse failed got caps %d, expected %d", caps, caps|HTTP_EXTENDED_AUTH_PAA) 78 | } 79 | } 80 | 81 | func capsHelper(gw Gateway) uint16 { 82 | var caps uint16 83 | if gw.TokenAuth { 84 | caps = caps | HTTP_EXTENDED_AUTH_PAA 85 | } 86 | if gw.SmartCardAuth { 87 | caps = caps | HTTP_EXTENDED_AUTH_SC 88 | } 89 | return caps 90 | } 91 | 92 | func TestMatchAuth(t *testing.T) { 93 | gw := &Gateway{} 94 | tunnel := &Tunnel{} 95 | 96 | h := NewProcessor(gw, tunnel) 97 | 98 | in := uint16(0) 99 | caps, err := h.matchAuth(in) 100 | if err != nil { 101 | t.Fatalf("in caps: %x <= server caps %x, but %s", in, capsHelper(*gw), err) 102 | } 103 | if caps > in { 104 | t.Fatalf("returned server caps %x > client cpas %x", capsHelper(*gw), in) 105 | } 106 | 107 | in = HTTP_EXTENDED_AUTH_PAA 108 | caps, err = h.matchAuth(in) 109 | if err == nil { 110 | t.Fatalf("server cannot satisfy client caps %x but error is nil (server caps %x)", in, caps) 111 | } else { 112 | t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err) 113 | } 114 | 115 | gw.SmartCardAuth = true 116 | caps, err = h.matchAuth(in) 117 | if err == nil { 118 | t.Fatalf("server cannot satisfy client caps %x but error is nil (server caps %x)", in, caps) 119 | } else { 120 | t.Logf("(SUCCESS) server cannot satisfy client caps : %s", err) 121 | } 122 | 123 | gw.TokenAuth = true 124 | caps, err = h.matchAuth(in) 125 | if err != nil { 126 | t.Fatalf("server caps %x (orig: %x) should match client request %x, %s", caps, capsHelper(*gw), in, err) 127 | } 128 | } 129 | 130 | func TestTunnelCreation(t *testing.T) { 131 | client := ClientConfig{ 132 | PAAToken: "abab", 133 | } 134 | gw := &Gateway{TokenAuth: true} 135 | tunnel := &Tunnel{} 136 | 137 | h := NewProcessor(gw, tunnel) 138 | 139 | data := client.tunnelRequest() 140 | _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_CREATE, 141 | uint32(TunnelCreateRequestLen+2+len(client.PAAToken)*2)) 142 | if err != nil { 143 | t.Fatalf("verifyHeader failed: %s", err) 144 | } 145 | 146 | caps, token := h.tunnelRequest(pkt) 147 | if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) { 148 | t.Fatalf("tunnelRequest failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT) 149 | } 150 | if token != client.PAAToken { 151 | t.Fatalf("tunnelRequest failed got token %s, expected %s", token, client.PAAToken) 152 | } 153 | 154 | data = h.tunnelResponse(ERROR_SUCCESS) 155 | _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_RESPONSE, TunnelCreateResponseLen) 156 | if err != nil { 157 | t.Fatalf("verifyHeader failed: %s", err) 158 | } 159 | 160 | tid, caps, err := client.tunnelResponse(pkt) 161 | if err != nil { 162 | t.Fatalf("Error %s", err) 163 | } 164 | if tid != tunnelId { 165 | t.Fatalf("tunnelResponse failed tunnel id %d, expected %d", tid, tunnelId) 166 | } 167 | if !((caps & HTTP_CAPABILITY_IDLE_TIMEOUT) == HTTP_CAPABILITY_IDLE_TIMEOUT) { 168 | t.Fatalf("tunnelResponse failed got caps %d, expected %d", caps, caps|HTTP_CAPABILITY_IDLE_TIMEOUT) 169 | } 170 | } 171 | 172 | func TestTunnelAuth(t *testing.T) { 173 | name := "test_name" 174 | client := ClientConfig{ 175 | Name: name, 176 | } 177 | gw := &Gateway{ 178 | TokenAuth: true, 179 | IdleTimeout: 10, 180 | RedirectFlags: RedirectFlags{Clipboard: true}, 181 | } 182 | tunnel := &Tunnel{} 183 | h := NewProcessor(gw, tunnel) 184 | 185 | data := client.tunnelAuthRequest() 186 | _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH, uint32(TunnelAuthLen+len(name)*2)) 187 | if err != nil { 188 | t.Fatalf("verifyHeader failed: %s", err) 189 | } 190 | 191 | n := h.tunnelAuthRequest(pkt) 192 | if n != name { 193 | t.Fatalf("tunnelAuthRequest failed got name %s, expected %s", n, name) 194 | } 195 | 196 | data = h.tunnelAuthResponse(ERROR_SUCCESS) 197 | _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_TUNNEL_AUTH_RESPONSE, TunnelAuthResponseLen) 198 | if err != nil { 199 | t.Fatalf("verifyHeader failed: %s", err) 200 | } 201 | flags, timeout, err := client.tunnelAuthResponse(pkt) 202 | if err != nil { 203 | t.Fatalf("tunnel auth error %s", err) 204 | } 205 | if (flags & HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) == HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD { 206 | t.Fatalf("tunnelAuthResponse failed got flags %d, expected %d", 207 | flags, flags|HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD) 208 | } 209 | if int(timeout) != gw.IdleTimeout { 210 | t.Fatalf("tunnelAuthResponse failed got timeout %d, expected %d", 211 | timeout, gw.IdleTimeout) 212 | } 213 | } 214 | 215 | func TestChannelCreation(t *testing.T) { 216 | server := "test_server" 217 | client := ClientConfig{ 218 | Server: server, 219 | Port: 3389, 220 | } 221 | gw := &Gateway{ 222 | TokenAuth: true, 223 | IdleTimeout: 10, 224 | RedirectFlags: RedirectFlags{ 225 | Clipboard: true, 226 | }, 227 | } 228 | tunnel := &Tunnel{} 229 | h := NewProcessor(gw, tunnel) 230 | 231 | data := client.channelRequest() 232 | _, _, pkt, err := verifyPacketHeader(data, PKT_TYPE_CHANNEL_CREATE, uint32(ChannelCreateLen+len(server)*2)) 233 | if err != nil { 234 | t.Fatalf("verifyHeader failed: %s", err) 235 | } 236 | hServer, hPort := h.channelRequest(pkt) 237 | if hServer != server { 238 | t.Fatalf("channelRequest failed got server %s, expected %s", hServer, server) 239 | } 240 | if int(hPort) != client.Port { 241 | t.Fatalf("channelRequest failed got port %d, expected %d", hPort, client.Port) 242 | } 243 | 244 | data = h.channelResponse(ERROR_SUCCESS) 245 | _, _, pkt, err = verifyPacketHeader(data, PKT_TYPE_CHANNEL_RESPONSE, uint32(ChannelResponseLen)) 246 | if err != nil { 247 | t.Fatalf("verifyHeader failed: %s", err) 248 | } 249 | channelId, err := client.channelResponse(pkt) 250 | if err != nil { 251 | t.Fatalf("channelResponse failed: %s", err) 252 | } 253 | if channelId < 1 { 254 | t.Fatalf("channelResponse failed got channeld id %d, expected > 0", channelId) 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/track.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import "fmt" 4 | 5 | var Connections map[string]*Monitor 6 | 7 | type Monitor struct { 8 | Processor *Processor 9 | Tunnel *Tunnel 10 | } 11 | 12 | const ( 13 | ctlDisconnect = -1 14 | ) 15 | 16 | func RegisterTunnel(t *Tunnel, p *Processor) { 17 | if Connections == nil { 18 | Connections = make(map[string]*Monitor) 19 | } 20 | 21 | Connections[t.Id] = &Monitor{ 22 | Processor: p, 23 | Tunnel: t, 24 | } 25 | } 26 | 27 | func RemoveTunnel(t *Tunnel) { 28 | delete(Connections, t.Id) 29 | } 30 | 31 | func Disconnect(id string) error { 32 | if Connections == nil { 33 | return fmt.Errorf("%s connection does not exist", id) 34 | } 35 | 36 | if m, ok := Connections[id]; !ok { 37 | m.Processor.ctl <- ctlDisconnect 38 | return nil 39 | } 40 | 41 | return fmt.Errorf("%s connection does not exist", id) 42 | } 43 | 44 | // CalculateSpeedPerSecond calculate moving average. 45 | /* 46 | func CalculateSpeedPerSecond(connId string) (in int, out int) { 47 | now := time.Now().UnixMilli() 48 | 49 | c := Connections[connId] 50 | total := int64(0) 51 | for _, v := range c.Tunnel.BytesReceived { 52 | total += v 53 | } 54 | in = int(total / (now - c.TimeStamp) * 1000) 55 | 56 | total = int64(0) 57 | for _, v := range c.BytesSent { 58 | total += v 59 | } 60 | out = int(total / (now - c.TimeStamp)) 61 | 62 | return in, out 63 | } 64 | */ 65 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/tunnel.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "net" 5 | "time" 6 | 7 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 8 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" 9 | ) 10 | 11 | const ( 12 | CtxTunnel = "github.com/bolkedebruin/rdpgw/tunnel" 13 | ) 14 | 15 | type Tunnel struct { 16 | // Id identifies the connection in the server 17 | Id string 18 | // The connection-id (RDG-ConnID) as reported by the client 19 | RDGId string 20 | // The underlying incoming transport being either websocket or legacy http 21 | // in case of websocket transportOut will equal transportIn 22 | transportIn transport.Transport 23 | // The underlying outgoing transport being either websocket or legacy http 24 | // in case of websocket transportOut will equal transportOut 25 | transportOut transport.Transport 26 | // The remote desktop server (rdp, vnc etc) the clients intends to connect to 27 | TargetServer string 28 | // The obtained client ip address 29 | RemoteAddr string 30 | // User 31 | User identity.Identity 32 | 33 | // rwc is the underlying connection to the remote desktop server. 34 | // It is of the type *net.TCPConn 35 | rwc net.Conn 36 | 37 | // BytesSent is the total amount of bytes sent by the server to the client minus tunnel overhead 38 | BytesSent int64 39 | 40 | // BytesReceived is the total amount of bytes received by the server from the client minus tunnel overhad 41 | BytesReceived int64 42 | 43 | // ConnectedOn is when the client connected to the server 44 | ConnectedOn time.Time 45 | 46 | // LastSeen is when the server received the last packet from the client 47 | LastSeen time.Time 48 | } 49 | 50 | type message struct { 51 | packetType int 52 | length int 53 | msg []byte 54 | err error 55 | } 56 | 57 | // Write puts the packet on the transport and updates the statistics for bytes sent 58 | func (t *Tunnel) Write(pkt []byte) { 59 | n, _ := t.transportOut.WritePacket(pkt) 60 | t.BytesSent += int64(n) 61 | } 62 | 63 | // Read picks up a packet from the transport and returns the packet type 64 | // packet, with the header removed, and the packet size. It updates the 65 | // statistics for bytes received 66 | func (t *Tunnel) Read() ([]*message, error) { 67 | messages, err := readMessage(t.transportIn) 68 | if err != nil { 69 | return nil, err 70 | } 71 | for _, message := range messages { 72 | t.BytesReceived += int64(message.length) 73 | t.LastSeen = time.Now() 74 | } 75 | return messages, err 76 | } 77 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/types.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | const ( 4 | PKT_TYPE_HANDSHAKE_REQUEST = 0x1 5 | PKT_TYPE_HANDSHAKE_RESPONSE = 0x2 6 | PKT_TYPE_EXTENDED_AUTH_MSG = 0x3 7 | PKT_TYPE_TUNNEL_CREATE = 0x4 8 | PKT_TYPE_TUNNEL_RESPONSE = 0x5 9 | PKT_TYPE_TUNNEL_AUTH = 0x6 10 | PKT_TYPE_TUNNEL_AUTH_RESPONSE = 0x7 11 | PKT_TYPE_CHANNEL_CREATE = 0x8 12 | PKT_TYPE_CHANNEL_RESPONSE = 0x9 13 | PKT_TYPE_DATA = 0xA 14 | PKT_TYPE_SERVICE_MESSAGE = 0xB 15 | PKT_TYPE_REAUTH_MESSAGE = 0xC 16 | PKT_TYPE_KEEPALIVE = 0xD 17 | PKT_TYPE_CLOSE_CHANNEL = 0x10 18 | PKT_TYPE_CLOSE_CHANNEL_RESPONSE = 0x11 19 | ) 20 | 21 | const ( 22 | HTTP_TUNNEL_RESPONSE_FIELD_TUNNEL_ID = 0x01 23 | HTTP_TUNNEL_RESPONSE_FIELD_CAPS = 0x02 24 | HTTP_TUNNEL_RESPONSE_FIELD_SOH_REQ = 0x04 25 | HTTP_TUNNEL_RESPONSE_FIELD_CONSENT_MSG = 0x10 26 | ) 27 | 28 | const ( 29 | HTTP_EXTENDED_AUTH_NONE = 0x0 30 | HTTP_EXTENDED_AUTH_SC = 0x1 /* Smart card authentication. */ 31 | HTTP_EXTENDED_AUTH_PAA = 0x02 /* Pluggable authentication. */ 32 | HTTP_EXTENDED_AUTH_SSPI_NTLM = 0x04 /* NTLM extended authentication. */ 33 | ) 34 | 35 | const ( 36 | HTTP_TUNNEL_AUTH_RESPONSE_FIELD_REDIR_FLAGS = 0x01 37 | HTTP_TUNNEL_AUTH_RESPONSE_FIELD_IDLE_TIMEOUT = 0x02 38 | HTTP_TUNNEL_AUTH_RESPONSE_FIELD_SOH_RESPONSE = 0x04 39 | ) 40 | 41 | const ( 42 | HTTP_TUNNEL_REDIR_ENABLE_ALL = 0x80000000 43 | HTTP_TUNNEL_REDIR_DISABLE_ALL = 0x40000000 44 | HTTP_TUNNEL_REDIR_DISABLE_DRIVE = 0x01 45 | HTTP_TUNNEL_REDIR_DISABLE_PRINTER = 0x02 46 | HTTP_TUNNEL_REDIR_DISABLE_PORT = 0x04 47 | HTTP_TUNNEL_REDIR_DISABLE_CLIPBOARD = 0x08 48 | HTTP_TUNNEL_REDIR_DISABLE_PNP = 0x10 49 | ) 50 | 51 | const ( 52 | HTTP_CHANNEL_RESPONSE_FIELD_CHANNELID = 0x01 53 | HTTP_CHANNEL_RESPONSE_FIELD_AUTHNCOOKIE = 0x02 54 | HTTP_CHANNEL_RESPONSE_FIELD_UDPPORT = 0x04 55 | ) 56 | 57 | const ( 58 | HTTP_TUNNEL_PACKET_FIELD_PAA_COOKIE = 0x1 59 | ) 60 | 61 | const ( 62 | SERVER_STATE_INITIALIZED = 0x0 63 | SERVER_STATE_HANDSHAKE = 0x1 64 | SERVER_STATE_TUNNEL_CREATE = 0x2 65 | SERVER_STATE_TUNNEL_AUTHORIZE = 0x3 66 | SERVER_STATE_CHANNEL_CREATE = 0x4 67 | SERVER_STATE_OPENED = 0x5 68 | SERVER_STATE_CLOSED = 0x6 69 | ) 70 | 71 | const ( 72 | HTTP_CAPABILITY_TYPE_QUAR_SOH = 0x1 73 | HTTP_CAPABILITY_IDLE_TIMEOUT = 0x2 74 | HTTP_CAPABILITY_MESSAGING_CONSENT_SIGN = 0x4 75 | HTTP_CAPABILITY_MESSAGING_SERVICE_MSG = 0x8 76 | HTTP_CAPABILITY_REAUTH = 0x10 77 | HTTP_CAPABILITY_UDP_TRANSPORT = 0x20 78 | ) 79 | -------------------------------------------------------------------------------- /cmd/rdpgw/protocol/utf16.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "fmt" 7 | "unicode/utf16" 8 | "unicode/utf8" 9 | ) 10 | 11 | func DecodeUTF16(b []byte) (string, error) { 12 | if len(b)%2 != 0 { 13 | return "", fmt.Errorf("must have even length byte slice") 14 | } 15 | 16 | u16s := make([]uint16, 1) 17 | ret := &bytes.Buffer{} 18 | b8buf := make([]byte, 4) 19 | 20 | lb := len(b) 21 | for i := 0; i < lb; i += 2 { 22 | u16s[0] = uint16(b[i]) + (uint16(b[i+1]) << 8) 23 | r := utf16.Decode(u16s) 24 | n := utf8.EncodeRune(b8buf, r[0]) 25 | ret.Write(b8buf[:n]) 26 | } 27 | 28 | bret := ret.Bytes() 29 | if len(bret) > 0 && bret[len(bret)-1] == '\x00' { 30 | bret = bret[:len(bret)-1] 31 | } 32 | return string(bret), nil 33 | } 34 | 35 | func EncodeUTF16(s string) []byte { 36 | ret := new(bytes.Buffer) 37 | enc := utf16.Encode([]rune(s)) 38 | for c := range enc { 39 | binary.Write(ret, binary.LittleEndian, enc[c]) 40 | } 41 | return ret.Bytes() 42 | } -------------------------------------------------------------------------------- /cmd/rdpgw/rdp/koanf/parsers/rdp/rdp.go: -------------------------------------------------------------------------------- 1 | package rdp 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "sort" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | type RDP struct{} 13 | 14 | func Parser() *RDP { 15 | return &RDP{} 16 | } 17 | 18 | func (p *RDP) Unmarshal(b []byte) (map[string]interface{}, error) { 19 | r := bytes.NewReader(b) 20 | scanner := bufio.NewScanner(r) 21 | mp := make(map[string]interface{}) 22 | 23 | c := 0 24 | for scanner.Scan() { 25 | c++ 26 | line := strings.TrimSpace(scanner.Text()) 27 | if line == "" || strings.HasPrefix(line, "#") { 28 | continue 29 | } 30 | fields := strings.SplitN(line, ":", 3) 31 | if len(fields) != 3 { 32 | return nil, fmt.Errorf("malformed line %d: %q", c, line) 33 | } 34 | 35 | key := strings.TrimSpace(fields[0]) 36 | t := strings.TrimSpace(fields[1]) 37 | val := strings.TrimSpace(fields[2]) 38 | 39 | switch t { 40 | case "i": 41 | intValue, err := strconv.Atoi(val) 42 | if err != nil { 43 | return nil, fmt.Errorf("cannot parse integer at line %d: %s", c, line) 44 | } 45 | mp[key] = intValue 46 | case "s": 47 | mp[key] = val 48 | case "b": 49 | mp[key] = val 50 | default: 51 | return nil, fmt.Errorf("malformed line %d: %s", c, line) 52 | } 53 | } 54 | return mp, nil 55 | } 56 | 57 | func (p *RDP) Marshal(o map[string]interface{}) ([]byte, error) { 58 | var b bytes.Buffer 59 | 60 | keys := make([]string, 0, len(o)) 61 | for k := range o { 62 | keys = append(keys, k) 63 | } 64 | sort.Strings(keys) 65 | 66 | for _, key := range keys { 67 | v := o[key] 68 | switch v.(type) { 69 | case bool: 70 | if v == true { 71 | fmt.Fprintf(&b, "%s:i:1", key) 72 | } else { 73 | fmt.Fprintf(&b, "%s:i:0", key) 74 | } 75 | case int: 76 | fmt.Fprintf(&b, "%s:i:%d", key, v) 77 | case string: 78 | fmt.Fprintf(&b, "%s:s:%s", key, v) 79 | default: 80 | return nil, fmt.Errorf("error marshalling") 81 | } 82 | fmt.Fprint(&b, "\r\n") 83 | } 84 | return b.Bytes(), nil 85 | } 86 | -------------------------------------------------------------------------------- /cmd/rdpgw/rdp/koanf/parsers/rdp/rdp_test.go: -------------------------------------------------------------------------------- 1 | package rdp 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestUnmarshalRDPFile(t *testing.T) { 9 | rdp := Parser() 10 | 11 | testCases := []struct { 12 | name string 13 | cfg []byte 14 | expOutput map[string]interface{} 15 | err error 16 | }{ 17 | { 18 | name: "empty", 19 | expOutput: map[string]interface{}{}, 20 | }, 21 | { 22 | name: "string", 23 | cfg: []byte(`username:s:user1`), 24 | expOutput: map[string]interface{}{ 25 | "username": "user1", 26 | }, 27 | }, 28 | { 29 | name: "integer", 30 | cfg: []byte(`session bpp:i:32`), 31 | expOutput: map[string]interface{}{ 32 | "session bpp": 32, 33 | }, 34 | }, 35 | { 36 | name: "multi", 37 | cfg: []byte("compression:i:1\r\nusername:s:user2\r\n"), 38 | expOutput: map[string]interface{}{ 39 | "compression": 1, 40 | "username": "user2", 41 | }, 42 | }, 43 | } 44 | 45 | for _, tc := range testCases { 46 | t.Run(tc.name, func(t *testing.T) { 47 | outMap, err := rdp.Unmarshal(tc.cfg) 48 | assert.Equal(t, tc.err, err) 49 | assert.Equal(t, tc.expOutput, outMap) 50 | }) 51 | } 52 | } 53 | 54 | func TestRDP_Marshal(t *testing.T) { 55 | testCases := []struct { 56 | name string 57 | input map[string]interface{} 58 | output []byte 59 | err error 60 | }{ 61 | { 62 | name: "Empty RDP", 63 | input: map[string]interface{}{}, 64 | output: []byte(nil), 65 | }, 66 | { 67 | name: "Valid RDP all types", 68 | input: map[string]interface{}{ 69 | "compression": 1, 70 | "session bpp": 32, 71 | "username": "user1", 72 | }, 73 | output: []byte("compression:i:1\r\nsession bpp:i:32\r\nusername:s:user1\r\n"), 74 | }, 75 | } 76 | 77 | rdp := Parser() 78 | for _, tc := range testCases { 79 | t.Run(tc.name, func(t *testing.T) { 80 | out, err := rdp.Marshal(tc.input) 81 | assert.Equal(t, tc.output, out) 82 | assert.Equal(t, tc.err, err) 83 | }) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /cmd/rdpgw/rdp/rdp.go: -------------------------------------------------------------------------------- 1 | package rdp 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp/koanf/parsers/rdp" 7 | "github.com/fatih/structs" 8 | "github.com/go-viper/mapstructure/v2" 9 | "github.com/knadh/koanf/providers/file" 10 | "github.com/knadh/koanf/v2" 11 | "log" 12 | "reflect" 13 | "strconv" 14 | "strings" 15 | ) 16 | 17 | const ( 18 | CRLF = "\r\n" 19 | ) 20 | 21 | const ( 22 | SourceNTLM int = iota 23 | SourceSmartCard 24 | SourceCurrent 25 | SourceBasic 26 | SourceUserSelect 27 | SourceCookie 28 | ) 29 | 30 | type RdpSettings struct { 31 | AllowFontSmoothing bool `rdp:"allow font smoothing" default:"0"` 32 | AllowDesktopComposition bool `rdp:"allow desktop composition" default:"0"` 33 | DisableFullWindowDrag bool `rdp:"disable full window drag" default:"0"` 34 | DisableMenuAnims bool `rdp:"disable menu anims" default:"0"` 35 | DisableThemes bool `rdp:"disable themes" default:"0"` 36 | DisableCursorSetting bool `rdp:"disable cursor setting" default:"0"` 37 | GatewayHostname string `rdp:"gatewayhostname"` 38 | FullAddress string `rdp:"full address"` 39 | AlternateFullAddress string `rdp:"alternate full address"` 40 | Username string `rdp:"username"` 41 | Domain string `rdp:"domain"` 42 | GatewayCredentialsSource int `rdp:"gatewaycredentialssource" default:"0"` 43 | GatewayCredentialMethod int `rdp:"gatewayprofileusagemethod" default:"0"` 44 | GatewayUsageMethod int `rdp:"gatewayusagemethod" default:"0"` 45 | GatewayAccessToken string `rdp:"gatewayaccesstoken"` 46 | PromptCredentialsOnce bool `rdp:"promptcredentialonce" default:"true"` 47 | AuthenticationLevel int `rdp:"authentication level" default:"3"` 48 | EnableCredSSPSupport bool `rdp:"enablecredsspsupport" default:"true"` 49 | EnableRdsAasAuth bool `rdp:"enablerdsaadauth" default:"false"` 50 | DisableConnectionSharing bool `rdp:"disableconnectionsharing" default:"false"` 51 | AlternateShell string `rdp:"alternate shell"` 52 | AutoReconnectionEnabled bool `rdp:"autoreconnectionenabled" default:"true"` 53 | BandwidthAutodetect bool `rdp:"bandwidthautodetect" default:"true"` 54 | NetworkAutodetect bool `rdp:"networkautodetect" default:"true"` 55 | Compression bool `rdp:"compression" default:"true"` 56 | VideoPlaybackMode bool `rdp:"videoplaybackmode" default:"true"` 57 | ConnectionType int `rdp:"connection type" default:"2"` 58 | AudioCaptureMode bool `rdp:"audiocapturemode" default:"false"` 59 | EncodeRedirectedVideoCapture bool `rdp:"encode redirected video capture" default:"true"` 60 | RedirectedVideoCaptureEncodingQuality int `rdp:"redirected video capture encoding quality" default:"0"` 61 | AudioMode int `rdp:"audiomode" default:"0"` 62 | CameraStoreRedirect string `rdp:"camerastoredirect" default:"false"` 63 | DeviceStoreRedirect string `rdp:"devicestoredirect" default:"false"` 64 | DriveStoreRedirect string `rdp:"drivestoredirect" default:"false"` 65 | KeyboardHook int `rdp:"keyboardhook" default:"2"` 66 | RedirectClipboard bool `rdp:"redirectclipboard" default:"true"` 67 | RedirectComPorts bool `rdp:"redirectcomports" default:"false"` 68 | RedirectLocation bool `rdp:"redirectlocation" default:"false"` 69 | RedirectPrinters bool `rdp:"redirectprinters" default:"true"` 70 | RedirectSmartcards bool `rdp:"redirectsmartcards" default:"true"` 71 | RedirectWebAuthn bool `rdp:"redirectwebauthn" default:"true"` 72 | UsbDeviceStoRedirect string `rdp:"usbdevicestoredirect"` 73 | UseMultimon bool `rdp:"use multimon" default:"false"` 74 | SelectedMonitors string `rdp:"selectedmonitors"` 75 | MaximizeToCurrentDisplays bool `rdp:"maximizetocurrentdisplays" default:"false"` 76 | SingleMonInWindowedMode bool `rdp:"singlemoninwindowedmode" default:"0"` 77 | ScreenModeId int `rdp:"screen mode id" default:"2"` 78 | SmartSizing bool `rdp:"smart sizing" default:"false"` 79 | DynamicResolution bool `rdp:"dynamic resolution" default:"true"` 80 | DesktopSizeId int `rdp:"desktop size id"` 81 | DesktopHeight int `rdp:"desktopheight"` 82 | DesktopWidth int `rdp:"desktopwidth"` 83 | DesktopScaleFactor int `rdp:"desktopscalefactor"` 84 | BitmapCacheSize int `rdp:"bitmapcachesize" default:"1500"` 85 | BitmapCachePersistEnable bool `rdp:"bitmapcachepersistenable" default:"true"` 86 | RemoteApplicationCmdLine string `rdp:"remoteapplicationcmdline"` 87 | RemoteAppExpandWorkingDir bool `rdp:"remoteapplicationexpandworkingdir" default:"true"` 88 | RemoteApplicationFile string `rdp:"remoteapplicationfile" default:"true"` 89 | RemoteApplicationIcon string `rdp:"remoteapplicationicon"` 90 | RemoteApplicationMode bool `rdp:"remoteapplicationmode" default:"false"` 91 | RemoteApplicationName string `rdp:"remoteapplicationname"` 92 | RemoteApplicationProgram string `rdp:"remoteapplicationprogram"` 93 | } 94 | 95 | type Builder struct { 96 | Settings RdpSettings 97 | Metadata mapstructure.Metadata 98 | } 99 | 100 | func NewBuilder() *Builder { 101 | c := RdpSettings{} 102 | 103 | initStruct(&c) 104 | 105 | return &Builder{ 106 | Settings: c, 107 | Metadata: mapstructure.Metadata{}, 108 | } 109 | } 110 | 111 | func NewBuilderFromFile(filename string) (*Builder, error) { 112 | c := RdpSettings{} 113 | initStruct(&c) 114 | metadata := mapstructure.Metadata{} 115 | 116 | decoderConfig := &mapstructure.DecoderConfig{ 117 | Result: &c, 118 | Metadata: &metadata, 119 | WeaklyTypedInput: true, 120 | } 121 | 122 | var k = koanf.New(".") 123 | if err := k.Load(file.Provider(filename), rdp.Parser()); err != nil { 124 | return nil, err 125 | } 126 | t := koanf.UnmarshalConf{Tag: "rdp", DecoderConfig: decoderConfig} 127 | 128 | if err := k.UnmarshalWithConf("", &c, t); err != nil { 129 | return nil, err 130 | } 131 | return &Builder{ 132 | Settings: c, 133 | Metadata: metadata, 134 | }, nil 135 | } 136 | 137 | func (rb *Builder) String() string { 138 | var sb strings.Builder 139 | 140 | addStructToString(rb.Settings, rb.Metadata, &sb) 141 | 142 | return sb.String() 143 | } 144 | 145 | func addStructToString(st interface{}, metadata mapstructure.Metadata, sb *strings.Builder) { 146 | s := structs.New(st) 147 | for _, f := range s.Fields() { 148 | if isZero(f) && !isSet(f, metadata) { 149 | continue 150 | } 151 | sb.WriteString(f.Tag("rdp")) 152 | sb.WriteString(":") 153 | 154 | switch f.Kind() { 155 | case reflect.String: 156 | sb.WriteString("s:") 157 | sb.WriteString(f.Value().(string)) 158 | case reflect.Int: 159 | sb.WriteString("i:") 160 | fmt.Fprintf(sb, "%d", f.Value()) 161 | case reflect.Bool: 162 | sb.WriteString("i:") 163 | if f.Value().(bool) { 164 | sb.WriteString("1") 165 | } else { 166 | sb.WriteString("0") 167 | 168 | } 169 | } 170 | sb.WriteString(CRLF) 171 | } 172 | } 173 | 174 | func isZero(f *structs.Field) bool { 175 | t := f.Tag("default") 176 | if t == "" { 177 | return f.IsZero() 178 | } 179 | 180 | switch f.Kind() { 181 | case reflect.String: 182 | if f.Value().(string) != t { 183 | return false 184 | } 185 | return true 186 | case reflect.Int: 187 | i, err := strconv.Atoi(t) 188 | if err != nil { 189 | log.Fatalf("runtime error: default %s is not an integer", t) 190 | } 191 | if f.Value().(int) != i { 192 | return false 193 | } 194 | return true 195 | case reflect.Bool: 196 | b := false 197 | if t == "true" || t == "1" { 198 | b = true 199 | } 200 | if f.Value().(bool) != b { 201 | return false 202 | } 203 | return true 204 | } 205 | 206 | return f.IsZero() 207 | } 208 | 209 | func isSet(f *structs.Field, metadata mapstructure.Metadata) bool { 210 | for _, v := range metadata.Unset { 211 | if v == f.Name() { 212 | log.Printf("field %s is unset", f.Name()) 213 | return true 214 | } 215 | } 216 | return false 217 | } 218 | 219 | func initStruct(st interface{}) { 220 | s := structs.New(st) 221 | for _, f := range s.Fields() { 222 | t := f.Tag("default") 223 | if t == "" { 224 | continue 225 | } 226 | 227 | err := setVariable(f, t) 228 | if err != nil { 229 | log.Fatalf("cannot init rdp struct: %s", err) 230 | } 231 | } 232 | } 233 | 234 | func setVariable(f *structs.Field, v string) error { 235 | switch f.Kind() { 236 | case reflect.String: 237 | return f.Set(v) 238 | case reflect.Int: 239 | i, err := strconv.Atoi(v) 240 | if err != nil { 241 | return err 242 | } 243 | return f.Set(i) 244 | case reflect.Bool: 245 | b := false 246 | if v == "true" || v == "1" { 247 | b = true 248 | } 249 | return f.Set(b) 250 | default: 251 | return errors.New("invalid field type") 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /cmd/rdpgw/rdp/rdp_test.go: -------------------------------------------------------------------------------- 1 | package rdp 2 | 3 | import ( 4 | "log" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | const ( 10 | GatewayHostName = "my.yahoo.com" 11 | ) 12 | 13 | func TestRdpBuilder(t *testing.T) { 14 | builder := NewBuilder() 15 | builder.Settings.GatewayHostname = "my.yahoo.com" 16 | builder.Settings.AutoReconnectionEnabled = true 17 | builder.Settings.SmartSizing = true 18 | 19 | s := builder.String() 20 | if !strings.Contains(s, "gatewayhostname:s:"+GatewayHostName+CRLF) { 21 | t.Fatalf("%s does not contain `gatewayhostname:s:%s", s, GatewayHostName) 22 | } 23 | if strings.Contains(s, "autoreconnectionenabled") { 24 | t.Fatalf("autoreconnectionenabled is in %s, but it's default value", s) 25 | } 26 | if !strings.Contains(s, "smart sizing:i:1"+CRLF) { 27 | t.Fatalf("%s does not contain smart sizing:i:1", s) 28 | 29 | } 30 | log.Printf(builder.String()) 31 | } 32 | 33 | func TestInitStruct(t *testing.T) { 34 | conn := RdpSettings{} 35 | initStruct(&conn) 36 | 37 | if conn.PromptCredentialsOnce != true { 38 | t.Fatalf("conn.PromptCredentialsOnce != true") 39 | } 40 | } 41 | 42 | func TestLoadFile(t *testing.T) { 43 | _, err := NewBuilderFromFile("rdp_test_file.rdp") 44 | if err != nil { 45 | t.Fatalf("LoadFile failed: %v", err) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /cmd/rdpgw/rdp/rdp_test_file.rdp: -------------------------------------------------------------------------------- 1 | Password:b:0200000000000000000000000000000000000000000000000800000072006400700000000E660000100000001000000031A2D4A21767565E3A268420A9397C4400000000048000001000000010000000A56C359BBBA13EC284391427E6A107BD20000000333E6F6DA024E1B6B4CC7DDF57BFC1783ED02F212B8FBD39997C888F9D4B438914000000A80D19234BA4CC5CE2695A34EF0B9B92D5D777A6 2 | ColorDepthID:i:1 3 | ScreenStyle:i:0 4 | DesktopWidth:i:640 5 | DesktopHeight:i:480 6 | UserName:s:rdesktop 7 | SavePassword:i:1 8 | Keyboard Layout:s:00000409 9 | BitmapPersistCacheSize:i:1 10 | BitmapCacheSize:i:21 11 | KeyboardFunctionKey:i:12 12 | KeyboardSubType:i:0 13 | KeyboardType:i:4 14 | KeyboardLayoutString:s:0xE0010409 15 | Disable Themes:i:0 16 | Disable Menu Anims:i:1 17 | Disable Full Window Drag:i:1 18 | Disable Wallpaper:i:1 19 | MaxReconnectAttempts:i:20 20 | KeyboardHookMode:i:0 21 | Compress:i:1 22 | BBarShowPinBtn:i:0 23 | BitmapPersistenceEnabled:i:0 24 | AudioRedirectionMode:i:2 25 | EnablePortRedirection:i:0 26 | EnableDriveRedirection:i:0 27 | AutoReconnectEnabled:i:1 28 | EnableSCardRedirection:i:1 29 | EnablePrinterRedirection:i:0 30 | BBarEnabled:i:0 31 | DisableFileAccess:i:0 32 | MinutesToIdleTimeout:i:5 33 | GrabFocusOnConnect:i:0 34 | StartFullScreen:i:1 35 | Domain:s:GE3SDT8KLRL4J 36 | enablecredsspsupport:i:0 37 | use multimon:i:1 38 | -------------------------------------------------------------------------------- /cmd/rdpgw/security/basic.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "strings" 9 | ) 10 | 11 | var ( 12 | Hosts []string 13 | HostSelection string 14 | ) 15 | 16 | func CheckHost(ctx context.Context, host string) (bool, error) { 17 | switch HostSelection { 18 | case "any": 19 | return true, nil 20 | case "signed": 21 | // todo get from context? 22 | return false, errors.New("cannot verify host in 'signed' mode as token data is missing") 23 | case "roundrobin", "unsigned": 24 | s := getTunnel(ctx) 25 | if s.User.UserName() == "" { 26 | return false, errors.New("no valid session info or username found in context") 27 | } 28 | 29 | log.Printf("Checking host for user %s", s.User.UserName()) 30 | for _, h := range Hosts { 31 | h = strings.Replace(h, "{{ preferred_username }}", s.User.UserName(), 1) 32 | if h == host { 33 | return true, nil 34 | } 35 | } 36 | return false, fmt.Errorf("invalid host %s", host) 37 | } 38 | 39 | return false, errors.New("unrecognized host selection criteria") 40 | } 41 | -------------------------------------------------------------------------------- /cmd/rdpgw/security/basic_test.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "context" 5 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 6 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" 7 | "testing" 8 | ) 9 | 10 | var ( 11 | info = protocol.Tunnel{ 12 | RDGId: "myid", 13 | TargetServer: "my.remote.server", 14 | RemoteAddr: "10.0.0.1", 15 | } 16 | 17 | hosts = []string{"localhost:3389", "my-{{ preferred_username }}-host:3389"} 18 | ) 19 | 20 | func TestCheckHost(t *testing.T) { 21 | info.User = identity.NewUser() 22 | info.User.SetUserName("MYNAME") 23 | 24 | ctx := context.WithValue(context.Background(), protocol.CtxTunnel, &info) 25 | 26 | Hosts = hosts 27 | 28 | // check any 29 | HostSelection = "any" 30 | host := "try.my.server:3389" 31 | if ok, err := CheckHost(ctx, host); !ok || err != nil { 32 | t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err) 33 | } 34 | 35 | HostSelection = "signed" 36 | if ok, err := CheckHost(ctx, host); ok || err == nil { 37 | t.Fatalf("signed host selection isnt supported at the moment") 38 | } 39 | 40 | HostSelection = "roundrobin" 41 | if ok, err := CheckHost(ctx, host); ok { 42 | t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err) 43 | } 44 | 45 | host = "my-MYNAME-host:3389" 46 | if ok, err := CheckHost(ctx, host); !ok { 47 | t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err) 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /cmd/rdpgw/security/jwt.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 8 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol" 9 | "github.com/coreos/go-oidc/v3/oidc" 10 | "github.com/go-jose/go-jose/v4" 11 | "github.com/go-jose/go-jose/v4/jwt" 12 | "golang.org/x/oauth2" 13 | "log" 14 | "time" 15 | ) 16 | 17 | var ( 18 | SigningKey []byte 19 | EncryptionKey []byte 20 | UserSigningKey []byte 21 | UserEncryptionKey []byte 22 | QuerySigningKey []byte 23 | OIDCProvider *oidc.Provider 24 | Oauth2Config oauth2.Config 25 | ) 26 | 27 | var ExpiryTime time.Duration = 5 28 | var VerifyClientIP bool = true 29 | 30 | type customClaims struct { 31 | RemoteServer string `json:"remoteServer"` 32 | ClientIP string `json:"clientIp"` 33 | AccessToken string `json:"accessToken"` 34 | } 35 | 36 | func CheckSession(next protocol.CheckHostFunc) protocol.CheckHostFunc { 37 | return func(ctx context.Context, host string) (bool, error) { 38 | tunnel := getTunnel(ctx) 39 | if tunnel == nil { 40 | return false, errors.New("no valid session info found in context") 41 | } 42 | 43 | if tunnel.TargetServer != host { 44 | log.Printf("Client specified host %s does not match token host %s", host, tunnel.TargetServer) 45 | return false, nil 46 | } 47 | 48 | // use identity from context rather then set by tunnel 49 | id := identity.FromCtx(ctx) 50 | if VerifyClientIP && tunnel.RemoteAddr != id.GetAttribute(identity.AttrClientIp) { 51 | log.Printf("Current client ip address %s does not match token client ip %s", 52 | id.GetAttribute(identity.AttrClientIp), tunnel.RemoteAddr) 53 | return false, nil 54 | } 55 | return next(ctx, host) 56 | } 57 | } 58 | 59 | func CheckPAACookie(ctx context.Context, tokenString string) (bool, error) { 60 | if tokenString == "" { 61 | log.Printf("no token to parse") 62 | return false, errors.New("no token to parse") 63 | } 64 | 65 | token, err := jwt.ParseSigned(tokenString, []jose.SignatureAlgorithm{jose.HS256}) 66 | if err != nil { 67 | log.Printf("cannot parse token due to: %t", err) 68 | return false, err 69 | } 70 | 71 | // check if the signing algo matches what we expect 72 | for _, header := range token.Headers { 73 | if header.Algorithm != string(jose.HS256) { 74 | return false, fmt.Errorf("unexpected signing method: %v", header.Algorithm) 75 | } 76 | } 77 | 78 | standard := jwt.Claims{} 79 | custom := customClaims{} 80 | 81 | // Claims automagically checks the signature... 82 | err = token.Claims(SigningKey, &standard, &custom) 83 | if err != nil { 84 | log.Printf("token signature validation failed due to %tunnel", err) 85 | return false, err 86 | } 87 | 88 | // ...but doesn't check the expiry claim :/ 89 | err = standard.Validate(jwt.Expected{ 90 | Issuer: "rdpgw", 91 | Time: time.Now(), 92 | }) 93 | 94 | if err != nil { 95 | log.Printf("token validation failed due to %tunnel", err) 96 | return false, err 97 | } 98 | 99 | // validate the access token 100 | tokenSource := Oauth2Config.TokenSource(ctx, &oauth2.Token{AccessToken: custom.AccessToken}) 101 | user, err := OIDCProvider.UserInfo(ctx, tokenSource) 102 | if err != nil { 103 | log.Printf("Cannot get user info for access token: %tunnel", err) 104 | return false, err 105 | } 106 | 107 | tunnel := getTunnel(ctx) 108 | 109 | tunnel.TargetServer = custom.RemoteServer 110 | tunnel.RemoteAddr = custom.ClientIP 111 | tunnel.User.SetUserName(user.Subject) 112 | 113 | return true, nil 114 | } 115 | 116 | func GeneratePAAToken(ctx context.Context, username string, server string) (string, error) { 117 | if len(SigningKey) < 32 { 118 | return "", errors.New("token signing key not long enough or not specified") 119 | } 120 | sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: SigningKey}, nil) 121 | if err != nil { 122 | log.Printf("Cannot obtain signer %s", err) 123 | return "", err 124 | } 125 | 126 | standard := jwt.Claims{ 127 | Issuer: "rdpgw", 128 | Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), 129 | Subject: username, 130 | } 131 | 132 | id := identity.FromCtx(ctx) 133 | private := customClaims{ 134 | RemoteServer: server, 135 | ClientIP: id.GetAttribute(identity.AttrClientIp).(string), 136 | AccessToken: id.GetAttribute(identity.AttrAccessToken).(string), 137 | } 138 | 139 | if token, err := jwt.Signed(sig).Claims(standard).Claims(private).Serialize(); err != nil { 140 | log.Printf("Cannot sign PAA token %s", err) 141 | return "", err 142 | } else { 143 | return token, nil 144 | } 145 | } 146 | 147 | func GenerateUserToken(ctx context.Context, userName string) (string, error) { 148 | if len(UserEncryptionKey) < 32 { 149 | return "", errors.New("user token encryption key not long enough or not specified") 150 | } 151 | 152 | claims := jwt.Claims{ 153 | Subject: userName, 154 | Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), 155 | Issuer: "rdpgw", 156 | } 157 | 158 | enc, err := jose.NewEncrypter( 159 | jose.A128CBC_HS256, 160 | jose.Recipient{ 161 | Algorithm: jose.DIRECT, 162 | Key: UserEncryptionKey, 163 | }, 164 | (&jose.EncrypterOptions{Compression: jose.DEFLATE}).WithContentType("JWT"), 165 | ) 166 | 167 | if err != nil { 168 | log.Printf("Cannot encrypt user token due to %s", err) 169 | return "", err 170 | } 171 | 172 | // this makes the token bigger and we deal with a limited space of 511 characters 173 | if len(UserSigningKey) > 0 { 174 | sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: UserSigningKey}, nil) 175 | token, err := jwt.SignedAndEncrypted(sig, enc).Claims(claims).Serialize() 176 | if len(token) > 511 { 177 | log.Printf("WARNING: token too long: len %d > 511", len(token)) 178 | } 179 | return token, err 180 | } 181 | 182 | // no signature 183 | token, err := jwt.Encrypted(enc).Claims(claims).Serialize() 184 | return token, err 185 | } 186 | 187 | func UserInfo(ctx context.Context, token string) (jwt.Claims, error) { 188 | standard := jwt.Claims{} 189 | if len(UserEncryptionKey) > 0 && len(UserSigningKey) > 0 { 190 | enc, err := jwt.ParseSignedAndEncrypted( 191 | token, 192 | []jose.KeyAlgorithm{jose.DIRECT}, 193 | []jose.ContentEncryption{jose.A128CBC_HS256}, 194 | []jose.SignatureAlgorithm{jose.HS256}, 195 | ) 196 | if err != nil { 197 | log.Printf("Cannot get token %s", err) 198 | return standard, errors.New("cannot get token") 199 | } 200 | token, err := enc.Decrypt(UserEncryptionKey) 201 | if err != nil { 202 | log.Printf("Cannot decrypt token %s", err) 203 | return standard, errors.New("cannot decrypt token") 204 | } 205 | if err = token.Claims(UserSigningKey, &standard); err != nil { 206 | log.Printf("cannot verify signature %s", err) 207 | return standard, errors.New("cannot verify signature") 208 | } 209 | } else if len(UserSigningKey) == 0 { 210 | token, err := jwt.ParseEncrypted(token, []jose.KeyAlgorithm{jose.DIRECT}, []jose.ContentEncryption{jose.A128CBC_HS256}) 211 | if err != nil { 212 | log.Printf("Cannot get token %s", err) 213 | return standard, errors.New("cannot get token") 214 | } 215 | err = token.Claims(UserEncryptionKey, &standard) 216 | if err != nil { 217 | log.Printf("Cannot decrypt token %s", err) 218 | return standard, errors.New("cannot decrypt token") 219 | } 220 | } 221 | 222 | // go-jose doesnt verify the expiry 223 | err := standard.Validate(jwt.Expected{ 224 | Issuer: "rdpgw", 225 | Time: time.Now(), 226 | }) 227 | 228 | if err != nil { 229 | log.Printf("token validation failed due to %s", err) 230 | return standard, fmt.Errorf("token validation failed due to %s", err) 231 | } 232 | 233 | return standard, nil 234 | } 235 | 236 | func QueryInfo(ctx context.Context, tokenString string, issuer string) (string, error) { 237 | standard := jwt.Claims{} 238 | token, err := jwt.ParseSigned(tokenString, []jose.SignatureAlgorithm{jose.HS256}) 239 | if err != nil { 240 | log.Printf("Cannot get token %s", err) 241 | return "", errors.New("cannot get token") 242 | } 243 | err = token.Claims(QuerySigningKey, &standard) 244 | if err = token.Claims(QuerySigningKey, &standard); err != nil { 245 | log.Printf("cannot verify signature %s", err) 246 | return "", errors.New("cannot verify signature") 247 | } 248 | 249 | // go-jose doesnt verify the expiry 250 | err = standard.Validate(jwt.Expected{ 251 | Issuer: issuer, 252 | Time: time.Now(), 253 | }) 254 | 255 | if err != nil { 256 | log.Printf("token validation failed due to %s", err) 257 | return "", fmt.Errorf("token validation failed due to %s", err) 258 | } 259 | 260 | return standard.Subject, nil 261 | } 262 | 263 | // GenerateQueryToken this is a helper function for testing 264 | func GenerateQueryToken(ctx context.Context, query string, issuer string) (string, error) { 265 | if len(QuerySigningKey) < 32 { 266 | return "", errors.New("query token encryption key not long enough or not specified") 267 | } 268 | 269 | claims := jwt.Claims{ 270 | Subject: query, 271 | Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)), 272 | Issuer: issuer, 273 | } 274 | 275 | sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: QuerySigningKey}, 276 | (&jose.SignerOptions{}).WithBase64(true)) 277 | 278 | if err != nil { 279 | log.Printf("Cannot encrypt user token due to %s", err) 280 | return "", err 281 | } 282 | 283 | token, err := jwt.Signed(sig).Claims(claims).Serialize() 284 | return token, err 285 | } 286 | 287 | func getTunnel(ctx context.Context) *protocol.Tunnel { 288 | s, ok := ctx.Value(protocol.CtxTunnel).(*protocol.Tunnel) 289 | if !ok { 290 | log.Printf("cannot get session info from context") 291 | return nil 292 | } 293 | return s 294 | } 295 | -------------------------------------------------------------------------------- /cmd/rdpgw/security/jwt_test.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "context" 5 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 6 | "testing" 7 | ) 8 | 9 | func TestGenerateUserToken(t *testing.T) { 10 | cases := []struct { 11 | SigningKey []byte 12 | EncryptionKey []byte 13 | name string 14 | username string 15 | }{ 16 | { 17 | SigningKey: []byte("5aa3a1568fe8421cd7e127d5ace28d2d"), 18 | EncryptionKey: []byte("d3ecd7e565e56e37e2f2e95b584d8c0c"), 19 | name: "sign_and_encrypt", 20 | username: "test_sign_and_encrypt", 21 | }, 22 | { 23 | SigningKey: nil, 24 | EncryptionKey: []byte("d3ecd7e565e56e37e2f2e95b584d8c0c"), 25 | name: "encrypt_only", 26 | username: "test_encrypt_only", 27 | }, 28 | } 29 | for _, tc := range cases { 30 | t.Run(tc.name, func(t *testing.T) { 31 | SigningKey = tc.SigningKey 32 | UserEncryptionKey = tc.EncryptionKey 33 | token, err := GenerateUserToken(context.Background(), tc.username) 34 | if err != nil { 35 | t.Fatalf("GenerateUserToken failed: %s", err) 36 | } 37 | claims, err := UserInfo(context.Background(), token) 38 | if err != nil { 39 | t.Fatalf("UserInfo failed: %s", err) 40 | } 41 | if claims.Subject != tc.username { 42 | t.Fatalf("Expected %s, got %s", tc.username, claims.Subject) 43 | } 44 | }) 45 | } 46 | 47 | } 48 | 49 | func TestPAACookie(t *testing.T) { 50 | SigningKey = []byte("5aa3a1568fe8421cd7e127d5ace28d2d") 51 | EncryptionKey = []byte("d3ecd7e565e56e37e2f2e95b584d8c0c") 52 | 53 | username := "test_paa_cookie" 54 | attr_client_ip := "127.0.0.1" 55 | attr_access_token := "aabbcc" 56 | 57 | id := identity.NewUser() 58 | id.SetUserName(username) 59 | id.SetAttribute(identity.AttrClientIp, attr_client_ip) 60 | id.SetAttribute(identity.AttrAccessToken, attr_access_token) 61 | 62 | ctx := context.Background() 63 | ctx = context.WithValue(ctx, identity.CTXKey, id) 64 | 65 | _, err := GeneratePAAToken(ctx, "test_paa_cookie", "host.does.not.exist") 66 | if err != nil { 67 | t.Fatalf("GeneratePAAToken failed: %s", err) 68 | } 69 | /*ok, err := CheckPAACookie(ctx, token) 70 | if err != nil { 71 | t.Fatalf("CheckPAACookie failed: %s", err) 72 | } 73 | if !ok { 74 | t.Fatalf("CheckPAACookie failed") 75 | }*/ 76 | } 77 | -------------------------------------------------------------------------------- /cmd/rdpgw/security/string.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "crypto/rand" 5 | "math/big" 6 | ) 7 | 8 | // GenerateRandomBytes returns securely generated random bytes. 9 | // It will return an error if the system's secure random 10 | // number generator fails to function correctly, in which 11 | // case the caller should not continue. 12 | func GenerateRandomBytes(n int) ([]byte, error) { 13 | b := make([]byte, n) 14 | _, err := rand.Read(b) 15 | // Note that err == nil only if we read len(b) bytes. 16 | if err != nil { 17 | return nil, err 18 | } 19 | 20 | return b, nil 21 | } 22 | 23 | // GenerateRandomString returns a securely generated random string. 24 | // It will return an error if the system's secure random 25 | // number generator fails to function correctly, in which 26 | // case the caller should not continue. 27 | func GenerateRandomString(n int) (string, error) { 28 | const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-" 29 | ret := make([]byte, n) 30 | for i := 0; i < n; i++ { 31 | num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) 32 | if err != nil { 33 | return "", err 34 | } 35 | ret[i] = letters[num.Int64()] 36 | } 37 | 38 | return string(ret), nil 39 | } 40 | -------------------------------------------------------------------------------- /cmd/rdpgw/transport/legacy.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "bufio" 5 | "crypto/rand" 6 | "errors" 7 | "io" 8 | "net" 9 | "net/http" 10 | "net/http/httputil" 11 | "time" 12 | ) 13 | 14 | const ( 15 | crlf = "\r\n" 16 | HttpOK = "HTTP/1.1 200 OK\r\n" 17 | ) 18 | 19 | type LegacyPKT struct { 20 | Conn net.Conn 21 | ChunkedReader io.Reader 22 | Writer *bufio.Writer 23 | } 24 | 25 | func NewLegacy(w http.ResponseWriter) (*LegacyPKT, error) { 26 | hj, ok := w.(http.Hijacker) 27 | if ok { 28 | conn, rw, err := hj.Hijack() 29 | l := &LegacyPKT{ 30 | Conn: conn, 31 | ChunkedReader: httputil.NewChunkedReader(rw.Reader), 32 | Writer: rw.Writer, 33 | } 34 | return l, err 35 | } 36 | 37 | return nil, errors.New("cannot hijack connection") 38 | } 39 | 40 | func (t *LegacyPKT) ReadPacket() (n int, p []byte, err error) { 41 | buf := make([]byte, 4096) // bufio.defaultBufSize 42 | n, err = t.ChunkedReader.Read(buf) 43 | p = make([]byte, n) 44 | copy(p, buf) 45 | 46 | return n, p, err 47 | } 48 | 49 | func (t *LegacyPKT) WritePacket(b []byte) (n int, err error) { 50 | return t.Conn.Write(b) 51 | } 52 | 53 | func (t *LegacyPKT) Close() error { 54 | return t.Conn.Close() 55 | } 56 | 57 | // [MS-TSGU]: Terminal Services Gateway Server Protocol version 39.0 58 | // The server sends back the final status code 200 OK, and also a random entity body of limited size (100 bytes). 59 | // This enables a reverse proxy to start allowing data from the RDG server to the RDG client. The RDG server does 60 | // not specify an entity length in its response. It uses HTTP 1.0 semantics to send the entity body and closes the 61 | // connection after the last byte is sent. 62 | func (t *LegacyPKT) SendAccept(doSeed bool) { 63 | t.Writer.WriteString(HttpOK) 64 | t.Writer.WriteString("Date: " + time.Now().Format(time.RFC1123) + crlf) 65 | if !doSeed { 66 | t.Writer.WriteString("Content-Length: 0" + crlf) 67 | } 68 | t.Writer.WriteString(crlf) 69 | 70 | if doSeed { 71 | seed := make([]byte, 10) 72 | rand.Read(seed) 73 | // docs say it's a seed but 2019 responds with ab cd * 5 74 | t.Writer.Write(seed) 75 | } 76 | t.Writer.Flush() 77 | } 78 | 79 | func (t *LegacyPKT) Drain() { 80 | p := make([]byte, 32767) 81 | t.Conn.Read(p) 82 | } 83 | -------------------------------------------------------------------------------- /cmd/rdpgw/transport/transport.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | type Transport interface { 4 | ReadPacket() (n int, p []byte, err error) 5 | WritePacket(b []byte) (n int, err error) 6 | Close() error 7 | } 8 | 9 | -------------------------------------------------------------------------------- /cmd/rdpgw/transport/websocket.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "errors" 5 | "github.com/gorilla/websocket" 6 | ) 7 | 8 | type WSPKT struct { 9 | Conn *websocket.Conn 10 | } 11 | 12 | func NewWS(c *websocket.Conn) (*WSPKT, error) { 13 | w := &WSPKT{Conn: c} 14 | return w, nil 15 | } 16 | 17 | func (t *WSPKT) ReadPacket() (n int, b []byte, err error) { 18 | mt, msg, err := t.Conn.ReadMessage() 19 | if err != nil { 20 | return 0, []byte{0, 0}, err 21 | } 22 | 23 | if mt == websocket.BinaryMessage { 24 | return len(msg), msg, nil 25 | } 26 | 27 | return len(msg), msg, errors.New("not a binary packet") 28 | } 29 | 30 | func (t *WSPKT) WritePacket(b []byte) (n int, err error) { 31 | err = t.Conn.WriteMessage(websocket.BinaryMessage, b) 32 | 33 | if err != nil { 34 | return 0, err 35 | } 36 | 37 | return len(b), nil 38 | } 39 | 40 | func (t *WSPKT) Close() error { 41 | return t.Conn.Close() 42 | } -------------------------------------------------------------------------------- /cmd/rdpgw/web/basic.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "context" 5 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 6 | "github.com/bolkedebruin/rdpgw/shared/auth" 7 | "google.golang.org/grpc" 8 | "google.golang.org/grpc/credentials/insecure" 9 | "log" 10 | "net" 11 | "net/http" 12 | "time" 13 | ) 14 | 15 | const ( 16 | protocolGrpc = "unix" 17 | ) 18 | 19 | type BasicAuthHandler struct { 20 | SocketAddress string 21 | Timeout int 22 | } 23 | 24 | func (h *BasicAuthHandler) BasicAuth(next http.HandlerFunc) http.HandlerFunc { 25 | return func(w http.ResponseWriter, r *http.Request) { 26 | username, password, ok := r.BasicAuth() 27 | if ok { 28 | authenticated := h.authenticate(w, r, username, password) 29 | 30 | if !authenticated { 31 | log.Printf("User %s is not authenticated for this service", username) 32 | } else { 33 | log.Printf("User %s authenticated", username) 34 | id := identity.FromRequestCtx(r) 35 | id.SetUserName(username) 36 | id.SetAuthenticated(true) 37 | id.SetAuthTime(time.Now()) 38 | next.ServeHTTP(w, identity.AddToRequestCtx(id, r)) 39 | return 40 | } 41 | } 42 | // If the Authentication header is not present, is invalid, or the 43 | // username or password is wrong, then set a WWW-Authenticate 44 | // header to inform the client that we expect them to use basic 45 | // authentication and send a 401 Unauthorized response. 46 | w.Header().Add("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) 47 | http.Error(w, "Unauthorized", http.StatusUnauthorized) 48 | } 49 | } 50 | 51 | func (h *BasicAuthHandler) authenticate(w http.ResponseWriter, r *http.Request, username string, password string) (authenticated bool) { 52 | if h.SocketAddress == "" { 53 | return false 54 | } 55 | 56 | ctx := r.Context() 57 | 58 | conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), 59 | grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { 60 | return net.Dial(protocolGrpc, addr) 61 | })) 62 | if err != nil { 63 | log.Printf("Cannot reach authentication provider: %s", err) 64 | http.Error(w, "Server error", http.StatusInternalServerError) 65 | return false 66 | } 67 | defer conn.Close() 68 | 69 | c := auth.NewAuthenticateClient(conn) 70 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(h.Timeout)) 71 | defer cancel() 72 | 73 | req := &auth.UserPass{Username: username, Password: password} 74 | res, err := c.Authenticate(ctx, req) 75 | if err != nil { 76 | log.Printf("Error talking to authentication provider: %s", err) 77 | http.Error(w, "Server error", http.StatusInternalServerError) 78 | return false 79 | } 80 | 81 | return res.Authenticated 82 | } 83 | -------------------------------------------------------------------------------- /cmd/rdpgw/web/context.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 5 | "github.com/jcmturner/goidentity/v6" 6 | "log" 7 | "net" 8 | "net/http" 9 | "strings" 10 | ) 11 | 12 | func EnrichContext(next http.Handler) http.Handler { 13 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 | id, err := GetSessionIdentity(r) 15 | if err != nil { 16 | http.Error(w, err.Error(), http.StatusInternalServerError) 17 | return 18 | } 19 | 20 | if id == nil { 21 | id = identity.NewUser() 22 | if err := SaveSessionIdentity(r, w, id); err != nil { 23 | http.Error(w, err.Error(), http.StatusInternalServerError) 24 | return 25 | } 26 | } 27 | 28 | log.Printf("Identity SessionId: %s, UserName: %s: Authenticated: %t", 29 | id.SessionId(), id.UserName(), id.Authenticated()) 30 | 31 | h := r.Header.Get("X-Forwarded-For") 32 | if h != "" { 33 | var proxies []string 34 | ips := strings.Split(h, ",") 35 | for i := range ips { 36 | ips[i] = strings.TrimSpace(ips[i]) 37 | } 38 | clientIp := ips[0] 39 | if len(ips) > 1 { 40 | proxies = ips[1:] 41 | } 42 | id.SetAttribute(identity.AttrClientIp, clientIp) 43 | id.SetAttribute(identity.AttrProxies, proxies) 44 | } 45 | 46 | id.SetAttribute(identity.AttrRemoteAddr, r.RemoteAddr) 47 | if h == "" { 48 | clientIp, _, _ := net.SplitHostPort(r.RemoteAddr) 49 | id.SetAttribute(identity.AttrClientIp, clientIp) 50 | } 51 | next.ServeHTTP(w, identity.AddToRequestCtx(id, r)) 52 | }) 53 | } 54 | 55 | func TransposeSPNEGOContext(next http.Handler) http.Handler { 56 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 57 | gid := goidentity.FromHTTPRequestContext(r) 58 | if gid != nil { 59 | id := identity.FromRequestCtx(r) 60 | id.SetUserName(gid.UserName()) 61 | id.SetAuthenticated(gid.Authenticated()) 62 | id.SetDomain(gid.Domain()) 63 | id.SetAuthTime(gid.AuthTime()) 64 | r = identity.AddToRequestCtx(id, r) 65 | } 66 | next.ServeHTTP(w, r) 67 | }) 68 | } 69 | -------------------------------------------------------------------------------- /cmd/rdpgw/web/mux.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "github.com/gorilla/mux" 5 | "net/http" 6 | ) 7 | 8 | type AuthMux struct { 9 | headers []string 10 | } 11 | 12 | func NewAuthMux() *AuthMux { 13 | return &AuthMux{} 14 | } 15 | 16 | func (a *AuthMux) Register(s string) { 17 | a.headers = append(a.headers, s) 18 | } 19 | 20 | func (a *AuthMux) SetAuthenticate(w http.ResponseWriter, r *http.Request) { 21 | for _, s := range a.headers { 22 | w.Header().Add("WWW-Authenticate", s) 23 | } 24 | http.Error(w, "Unauthorized", http.StatusUnauthorized) 25 | } 26 | 27 | func NoAuthz(r *http.Request, rm *mux.RouteMatch) bool { 28 | return r.Header.Get("Authorization") == "" 29 | } 30 | -------------------------------------------------------------------------------- /cmd/rdpgw/web/ntlm.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 7 | "github.com/bolkedebruin/rdpgw/shared/auth" 8 | "google.golang.org/grpc" 9 | "google.golang.org/grpc/credentials/insecure" 10 | "log" 11 | "net" 12 | "net/http" 13 | "time" 14 | ) 15 | 16 | type ntlmAuthMode uint32 17 | const ( 18 | authNone ntlmAuthMode = iota 19 | authNTLM 20 | authNegotiate 21 | ) 22 | 23 | type NTLMAuthHandler struct { 24 | SocketAddress string 25 | Timeout int 26 | } 27 | 28 | func (h *NTLMAuthHandler) NTLMAuth(next http.HandlerFunc) http.HandlerFunc { 29 | return func(w http.ResponseWriter, r *http.Request) { 30 | authPayload, authMode, err := h.getAuthPayload(r) 31 | if err != nil { 32 | log.Printf("Failed parsing auth header: %s", err) 33 | h.requestAuthenticate(w) 34 | return 35 | } 36 | 37 | authenticated, username := h.authenticate(w, r, authPayload, authMode) 38 | 39 | if authenticated { 40 | log.Printf("NTLM: User %s authenticated", username) 41 | id := identity.FromRequestCtx(r) 42 | id.SetUserName(username) 43 | id.SetAuthenticated(true) 44 | id.SetAuthTime(time.Now()) 45 | next.ServeHTTP(w, identity.AddToRequestCtx(id, r)) 46 | } 47 | } 48 | } 49 | 50 | func (h *NTLMAuthHandler) getAuthPayload (r *http.Request) (payload string, authMode ntlmAuthMode, err error) { 51 | authorisationEncoded := r.Header.Get("Authorization") 52 | if authorisationEncoded[0:5] == "NTLM " { 53 | return authorisationEncoded[5:], authNTLM, nil 54 | } 55 | if authorisationEncoded[0:10] == "Negotiate " { 56 | return authorisationEncoded[10:], authNegotiate, nil 57 | } 58 | return "", authNone, errors.New("Invalid NTLM Authorisation header") 59 | } 60 | 61 | func (h *NTLMAuthHandler) requestAuthenticate (w http.ResponseWriter) { 62 | w.Header().Add("WWW-Authenticate", `NTLM`) 63 | w.Header().Add("WWW-Authenticate", `Negotiate`) 64 | http.Error(w, "Unauthorized", http.StatusUnauthorized) 65 | } 66 | 67 | func (h *NTLMAuthHandler) getAuthPrefix (authMode ntlmAuthMode) (prefix string) { 68 | if authMode == authNTLM { 69 | return "NTLM " 70 | } 71 | if authMode == authNegotiate { 72 | return "Negotiate " 73 | } 74 | return "" 75 | } 76 | 77 | func (h *NTLMAuthHandler) authenticate(w http.ResponseWriter, r *http.Request, authorisationEncoded string, authMode ntlmAuthMode) (authenticated bool, username string) { 78 | if h.SocketAddress == "" { 79 | return false, "" 80 | } 81 | 82 | ctx := r.Context() 83 | 84 | conn, err := grpc.Dial(h.SocketAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), 85 | grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { 86 | return net.Dial(protocolGrpc, addr) 87 | })) 88 | if err != nil { 89 | log.Printf("Cannot reach authentication provider: %s", err) 90 | http.Error(w, "Server error", http.StatusInternalServerError) 91 | return false, "" 92 | } 93 | defer conn.Close() 94 | 95 | c := auth.NewAuthenticateClient(conn) 96 | ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(h.Timeout)) 97 | defer cancel() 98 | 99 | req := &auth.NtlmRequest{Session: r.RemoteAddr, NtlmMessage: authorisationEncoded} 100 | res, err := c.NTLM(ctx, req) 101 | if err != nil { 102 | log.Printf("Error talking to authentication provider: %s", err) 103 | http.Error(w, "Server error", http.StatusInternalServerError) 104 | return false, "" 105 | } 106 | 107 | if res.NtlmMessage != "" { 108 | log.Printf("Sending NTLM challenge") 109 | w.Header().Add("WWW-Authenticate", h.getAuthPrefix(authMode)+res.NtlmMessage) 110 | http.Error(w, "Unauthorized", http.StatusUnauthorized) 111 | return false, "" 112 | } 113 | 114 | if !res.Authenticated { 115 | h.requestAuthenticate(w) 116 | return false, "" 117 | } 118 | 119 | return res.Authenticated, res.Username 120 | } 121 | -------------------------------------------------------------------------------- /cmd/rdpgw/web/oidc.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "encoding/json" 7 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 8 | "github.com/coreos/go-oidc/v3/oidc" 9 | "github.com/patrickmn/go-cache" 10 | "golang.org/x/oauth2" 11 | "net/http" 12 | "time" 13 | ) 14 | 15 | const ( 16 | CacheExpiration = time.Minute * 2 17 | CleanupInterval = time.Minute * 5 18 | ) 19 | 20 | type OIDC struct { 21 | oAuth2Config *oauth2.Config 22 | oidcTokenVerifier *oidc.IDTokenVerifier 23 | stateStore *cache.Cache 24 | } 25 | 26 | type OIDCConfig struct { 27 | OAuth2Config *oauth2.Config 28 | OIDCTokenVerifier *oidc.IDTokenVerifier 29 | } 30 | 31 | func (c *OIDCConfig) New() *OIDC { 32 | return &OIDC{ 33 | oAuth2Config: c.OAuth2Config, 34 | oidcTokenVerifier: c.OIDCTokenVerifier, 35 | stateStore: cache.New(CacheExpiration, CleanupInterval), 36 | } 37 | } 38 | 39 | func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) { 40 | state := r.URL.Query().Get("state") 41 | s, found := h.stateStore.Get(state) 42 | if !found { 43 | http.Error(w, "unknown state", http.StatusBadRequest) 44 | return 45 | } 46 | url := s.(string) 47 | 48 | ctx := r.Context() 49 | oauth2Token, err := h.oAuth2Config.Exchange(ctx, r.URL.Query().Get("code")) 50 | if err != nil { 51 | http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) 52 | return 53 | } 54 | 55 | rawIDToken, ok := oauth2Token.Extra("id_token").(string) 56 | if !ok { 57 | http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) 58 | return 59 | } 60 | idToken, err := h.oidcTokenVerifier.Verify(ctx, rawIDToken) 61 | if err != nil { 62 | http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) 63 | return 64 | } 65 | 66 | resp := struct { 67 | OAuth2Token *oauth2.Token 68 | IDTokenClaims *json.RawMessage // ID Token payload is just JSON. 69 | }{oauth2Token, new(json.RawMessage)} 70 | 71 | if err := idToken.Claims(&resp.IDTokenClaims); err != nil { 72 | http.Error(w, err.Error(), http.StatusInternalServerError) 73 | return 74 | } 75 | 76 | var data map[string]interface{} 77 | if err := json.Unmarshal(*resp.IDTokenClaims, &data); err != nil { 78 | http.Error(w, err.Error(), http.StatusInternalServerError) 79 | return 80 | } 81 | 82 | id := identity.FromRequestCtx(r) 83 | 84 | userName := findUsernameInClaims(data) 85 | if userName == "" { 86 | http.Error(w, "no oidc claim for username found", http.StatusInternalServerError) 87 | } 88 | 89 | id.SetUserName(userName) 90 | id.SetAuthenticated(true) 91 | id.SetAuthTime(time.Now()) 92 | id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken) 93 | 94 | if err = SaveSessionIdentity(r, w, id); err != nil { 95 | http.Error(w, err.Error(), http.StatusInternalServerError) 96 | } 97 | 98 | http.Redirect(w, r, url, http.StatusFound) 99 | } 100 | 101 | func findUsernameInClaims(data map[string]interface{}) string { 102 | candidates := []string{"preferred_username", "unique_name", "upn", "username"} 103 | for _, claim := range candidates { 104 | userName, found := data[claim].(string) 105 | if found { 106 | return userName 107 | } 108 | } 109 | 110 | return "" 111 | } 112 | 113 | func (h *OIDC) Authenticated(next http.Handler) http.Handler { 114 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 115 | id := identity.FromRequestCtx(r) 116 | 117 | if !id.Authenticated() { 118 | seed := make([]byte, 16) 119 | _, err := rand.Read(seed) 120 | if err != nil { 121 | http.Error(w, err.Error(), http.StatusInternalServerError) 122 | return 123 | } 124 | state := hex.EncodeToString(seed) 125 | h.stateStore.Set(state, r.RequestURI, cache.DefaultExpiration) 126 | http.Redirect(w, r, h.oAuth2Config.AuthCodeURL(state), http.StatusFound) 127 | return 128 | } 129 | 130 | // replace the identity with the one from the sessions 131 | next.ServeHTTP(w, r) 132 | }) 133 | } 134 | -------------------------------------------------------------------------------- /cmd/rdpgw/web/oidc_test.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import "testing" 4 | 5 | func TestFindUserNameInClaims(t *testing.T) { 6 | cases := []struct { 7 | data map[string]interface{} 8 | ret string 9 | name string 10 | }{ 11 | { 12 | data: map[string]interface{}{ 13 | "preferred_username": "exists", 14 | }, 15 | ret: "exists", 16 | name: "preferred_username", 17 | }, 18 | { 19 | data: map[string]interface{}{ 20 | "upn": "exists", 21 | }, 22 | ret: "exists", 23 | name: "upn", 24 | }, 25 | { 26 | data: map[string]interface{}{ 27 | "unique_name": "exists", 28 | }, 29 | ret: "exists", 30 | name: "unique_name", 31 | }, 32 | { 33 | data: map[string]interface{}{ 34 | "fail": "exists", 35 | }, 36 | ret: "", 37 | name: "fail", 38 | }, 39 | } 40 | 41 | for _, tc := range cases { 42 | t.Run(tc.name, func(t *testing.T) { 43 | s := findUsernameInClaims(tc.data) 44 | if s != tc.ret { 45 | t.Fatalf("expected return: %v, got: %v", tc.ret, s) 46 | } 47 | }) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /cmd/rdpgw/web/session.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 5 | "github.com/gorilla/sessions" 6 | "log" 7 | "net/http" 8 | "os" 9 | ) 10 | 11 | const ( 12 | rdpGwSession = "RDPGWSESSION" 13 | MaxAge = 120 14 | identityKey = "RDPGWID" 15 | maxSessionLength = 8192 16 | ) 17 | 18 | var sessionStore sessions.Store 19 | 20 | func InitStore(sessionKey []byte, encryptionKey []byte, storeType string, maxLength int) { 21 | if len(sessionKey) < 32 { 22 | log.Fatal("Session key too small") 23 | } 24 | if len(encryptionKey) < 32 { 25 | log.Fatal("Session key too small") 26 | } 27 | 28 | if storeType == "file" { 29 | log.Println("Filesystem is used as session storage") 30 | fs := sessions.NewFilesystemStore(os.TempDir(), sessionKey, encryptionKey) 31 | 32 | // set max length 33 | if maxLength == 0 { 34 | maxLength = maxSessionLength 35 | } 36 | log.Printf("Setting maximum session storage to %d bytes", maxLength) 37 | fs.MaxLength(maxLength) 38 | 39 | sessionStore = fs 40 | } else { 41 | log.Println("Cookies are used as session storage") 42 | sessionStore = sessions.NewCookieStore(sessionKey, encryptionKey) 43 | } 44 | } 45 | 46 | func GetSession(r *http.Request) (*sessions.Session, error) { 47 | session, err := sessionStore.Get(r, rdpGwSession) 48 | if err != nil { 49 | return nil, err 50 | } 51 | return session, nil 52 | } 53 | 54 | func GetSessionIdentity(r *http.Request) (identity.Identity, error) { 55 | s, err := GetSession(r) 56 | if err != nil { 57 | return nil, err 58 | } 59 | 60 | idData := s.Values[identityKey] 61 | if idData == nil { 62 | return nil, nil 63 | 64 | } 65 | id := identity.NewUser() 66 | id.Unmarshal(idData.([]byte)) 67 | return id, nil 68 | } 69 | 70 | func SaveSessionIdentity(r *http.Request, w http.ResponseWriter, id identity.Identity) error { 71 | session, err := GetSession(r) 72 | if err != nil { 73 | return err 74 | } 75 | session.Options.MaxAge = MaxAge 76 | 77 | idData, err := id.Marshal() 78 | if err != nil { 79 | return err 80 | } 81 | session.Values[identityKey] = idData 82 | 83 | return sessionStore.Save(r, w, session) 84 | 85 | } 86 | -------------------------------------------------------------------------------- /cmd/rdpgw/web/token.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" 8 | "log" 9 | "net/http" 10 | ) 11 | 12 | func TokenInfo(w http.ResponseWriter, r *http.Request) { 13 | if r.Method != http.MethodGet { 14 | http.Error(w, "Invalid request", http.StatusMethodNotAllowed) 15 | return 16 | } 17 | 18 | tokens, ok := r.URL.Query()["access_token"] 19 | if !ok || len(tokens[0]) < 1 { 20 | log.Printf("Missing access_token in request") 21 | http.Error(w, "access_token missing in request", http.StatusBadRequest) 22 | return 23 | } 24 | 25 | token := tokens[0] 26 | 27 | info, err := security.UserInfo(context.Background(), token) 28 | if err != nil { 29 | log.Printf("Token validation failed due to %s", err) 30 | http.Error(w, fmt.Sprintf("token validation failed due to %s", err), http.StatusForbidden) 31 | return 32 | } 33 | 34 | w.Header().Set("Content-Type", "application/json; charset=UTF-8") 35 | if err = json.NewEncoder(w).Encode(info); err != nil { 36 | log.Printf("Cannot encode json due to %s", err) 37 | http.Error(w, "cannot encode json", http.StatusInternalServerError) 38 | return 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /cmd/rdpgw/web/web.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | "encoding/hex" 7 | "errors" 8 | "fmt" 9 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 10 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp" 11 | "hash/maphash" 12 | "log" 13 | rnd "math/rand" 14 | "net/http" 15 | "net/url" 16 | "strings" 17 | "time" 18 | ) 19 | 20 | type TokenGeneratorFunc func(context.Context, string, string) (string, error) 21 | type UserTokenGeneratorFunc func(context.Context, string) (string, error) 22 | type QueryInfoFunc func(context.Context, string, string) (string, error) 23 | 24 | type Config struct { 25 | PAATokenGenerator TokenGeneratorFunc 26 | UserTokenGenerator UserTokenGeneratorFunc 27 | QueryInfo QueryInfoFunc 28 | QueryTokenIssuer string 29 | EnableUserToken bool 30 | Hosts []string 31 | HostSelection string 32 | GatewayAddress *url.URL 33 | RdpOpts RdpOpts 34 | TemplateFile string 35 | } 36 | 37 | type RdpOpts struct { 38 | UsernameTemplate string 39 | SplitUserDomain bool 40 | NoUsername bool 41 | } 42 | 43 | type Handler struct { 44 | paaTokenGenerator TokenGeneratorFunc 45 | enableUserToken bool 46 | userTokenGenerator UserTokenGeneratorFunc 47 | queryInfo QueryInfoFunc 48 | queryTokenIssuer string 49 | gatewayAddress *url.URL 50 | hosts []string 51 | hostSelection string 52 | rdpOpts RdpOpts 53 | rdpDefaults string 54 | } 55 | 56 | func (c *Config) NewHandler() *Handler { 57 | if len(c.Hosts) < 1 { 58 | log.Fatal("Not enough hosts to connect to specified") 59 | } 60 | 61 | return &Handler{ 62 | paaTokenGenerator: c.PAATokenGenerator, 63 | enableUserToken: c.EnableUserToken, 64 | userTokenGenerator: c.UserTokenGenerator, 65 | queryInfo: c.QueryInfo, 66 | queryTokenIssuer: c.QueryTokenIssuer, 67 | gatewayAddress: c.GatewayAddress, 68 | hosts: c.Hosts, 69 | hostSelection: c.HostSelection, 70 | rdpOpts: c.RdpOpts, 71 | rdpDefaults: c.TemplateFile, 72 | } 73 | } 74 | 75 | func (h *Handler) selectRandomHost() string { 76 | r := rnd.New(rnd.NewSource(int64(new(maphash.Hash).Sum64()))) 77 | host := h.hosts[r.Intn(len(h.hosts))] 78 | return host 79 | } 80 | 81 | func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) { 82 | switch h.hostSelection { 83 | case "roundrobin": 84 | return h.selectRandomHost(), nil 85 | case "signed": 86 | hosts, ok := u.Query()["host"] 87 | if !ok { 88 | return "", errors.New("invalid query parameter") 89 | } 90 | host, err := h.queryInfo(ctx, hosts[0], h.queryTokenIssuer) 91 | if err != nil { 92 | return "", err 93 | } 94 | found := false 95 | for _, check := range h.hosts { 96 | if check == host { 97 | found = true 98 | break 99 | } 100 | } 101 | if !found { 102 | log.Printf("Invalid host %s specified in token", hosts[0]) 103 | return "", errors.New("invalid host specified in query token") 104 | } 105 | return host, nil 106 | case "unsigned": 107 | hosts, ok := u.Query()["host"] 108 | if !ok { 109 | return "", errors.New("invalid query parameter") 110 | } 111 | for _, check := range h.hosts { 112 | if check == hosts[0] { 113 | return hosts[0], nil 114 | } 115 | } 116 | // not found 117 | log.Printf("Invalid host %s specified in client request", hosts[0]) 118 | return "", errors.New("invalid host specified in query parameter") 119 | case "any": 120 | hosts, ok := u.Query()["host"] 121 | if !ok { 122 | return "", errors.New("invalid query parameter") 123 | } 124 | return hosts[0], nil 125 | default: 126 | return h.selectRandomHost(), nil 127 | } 128 | } 129 | 130 | func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) { 131 | id := identity.FromRequestCtx(r) 132 | ctx := r.Context() 133 | 134 | opts := h.rdpOpts 135 | 136 | if !id.Authenticated() { 137 | log.Printf("unauthenticated user %s", id.UserName()) 138 | http.Error(w, errors.New("cannot find session or user").Error(), http.StatusInternalServerError) 139 | return 140 | } 141 | 142 | // determine host to connect to 143 | host, err := h.getHost(ctx, r.URL) 144 | if err != nil { 145 | http.Error(w, err.Error(), http.StatusBadRequest) 146 | return 147 | } 148 | host = strings.Replace(host, "{{ preferred_username }}", id.UserName(), 1) 149 | 150 | // split the username into user and domain 151 | var user = id.UserName() 152 | var domain = "" 153 | if opts.SplitUserDomain { 154 | creds := strings.SplitN(id.UserName(), "@", 2) 155 | user = creds[0] 156 | if len(creds) > 1 { 157 | domain = creds[1] 158 | } 159 | } 160 | 161 | render := user 162 | if opts.UsernameTemplate != "" { 163 | render = fmt.Sprintf(h.rdpOpts.UsernameTemplate) 164 | render = strings.Replace(render, "{{ username }}", user, 1) 165 | if h.rdpOpts.UsernameTemplate == render { 166 | log.Printf("Invalid username template. %s == %s", h.rdpOpts.UsernameTemplate, user) 167 | http.Error(w, errors.New("invalid server configuration").Error(), http.StatusInternalServerError) 168 | return 169 | } 170 | } 171 | 172 | token, err := h.paaTokenGenerator(ctx, user, host) 173 | if err != nil { 174 | log.Printf("Cannot generate PAA token for user %s due to %s", user, err) 175 | http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) 176 | return 177 | } 178 | 179 | if h.enableUserToken { 180 | userToken, err := h.userTokenGenerator(ctx, user) 181 | if err != nil { 182 | log.Printf("Cannot generate token for user %s due to %s", user, err) 183 | http.Error(w, errors.New("unable to generate gateway credentials").Error(), http.StatusInternalServerError) 184 | return 185 | } 186 | render = strings.Replace(render, "{{ token }}", userToken, 1) 187 | } 188 | 189 | // authenticated 190 | seed := make([]byte, 16) 191 | _, err = rand.Read(seed) 192 | if err != nil { 193 | log.Printf("Cannot generate random seed due to %s", err) 194 | http.Error(w, errors.New("unable to generate random sequence").Error(), http.StatusInternalServerError) 195 | return 196 | } 197 | fn := hex.EncodeToString(seed) + ".rdp" 198 | 199 | w.Header().Set("Content-Disposition", "attachment; filename="+fn) 200 | w.Header().Set("Content-Type", "application/x-rdp") 201 | 202 | var d *rdp.Builder 203 | if h.rdpDefaults == "" { 204 | d = rdp.NewBuilder() 205 | } else { 206 | d, err = rdp.NewBuilderFromFile(h.rdpDefaults) 207 | if err != nil { 208 | log.Printf("Cannot load RDP template file %s due to %s", h.rdpDefaults, err) 209 | http.Error(w, errors.New("unable to load RDP template").Error(), http.StatusInternalServerError) 210 | return 211 | } 212 | } 213 | 214 | if !h.rdpOpts.NoUsername { 215 | d.Settings.Username = render 216 | if domain != "" { 217 | d.Settings.Domain = domain 218 | } 219 | } 220 | d.Settings.FullAddress = host 221 | d.Settings.GatewayHostname = h.gatewayAddress.Host 222 | d.Settings.GatewayCredentialsSource = rdp.SourceCookie 223 | d.Settings.GatewayAccessToken = token 224 | d.Settings.GatewayCredentialMethod = 1 225 | d.Settings.GatewayUsageMethod = 1 226 | 227 | http.ServeContent(w, r, fn, time.Now(), strings.NewReader(d.String())) 228 | } 229 | -------------------------------------------------------------------------------- /cmd/rdpgw/web/web_test.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "context" 5 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" 6 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp" 7 | "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "os" 12 | "strings" 13 | "testing" 14 | ) 15 | 16 | const ( 17 | testuser = "test_user" 18 | gateway = "https://my.gateway.com:993" 19 | ) 20 | 21 | var ( 22 | hosts = []string{"10.0.0.1:3389", "10.1.1.1:3000", "32.32.11.1", "remote.host.com"} 23 | key = []byte("thisisasessionkeyreplacethisjetzt") 24 | ) 25 | 26 | func contains(needle string, haystack []string) bool { 27 | for _, val := range haystack { 28 | if val == needle { 29 | return true 30 | } 31 | } 32 | return false 33 | } 34 | 35 | func TestGetHost(t *testing.T) { 36 | ctx := context.Background() 37 | c := Config{ 38 | HostSelection: "roundrobin", 39 | Hosts: hosts, 40 | } 41 | h := c.NewHandler() 42 | 43 | u := &url.URL{ 44 | Host: "example.com", 45 | } 46 | vals := u.Query() 47 | 48 | host, err := h.getHost(ctx, u) 49 | if err != nil { 50 | t.Fatalf("#{err}") 51 | } 52 | if !contains(host, hosts) { 53 | t.Fatalf("host %s is not in hosts list", host) 54 | } 55 | 56 | // check unsigned 57 | c.HostSelection = "unsigned" 58 | vals.Set("host", "in.valid.host") 59 | u.RawQuery = vals.Encode() 60 | h = c.NewHandler() 61 | host, err = h.getHost(ctx, u) 62 | if err == nil { 63 | t.Fatalf("Accepted host %s is not in hosts list", host) 64 | } 65 | 66 | vals.Set("host", hosts[0]) 67 | u.RawQuery = vals.Encode() 68 | h = c.NewHandler() 69 | host, err = h.getHost(ctx, u) 70 | if err != nil { 71 | t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err) 72 | } 73 | if host != hosts[0] { 74 | t.Fatalf("host %s is not equal to input %s", host, hosts[0]) 75 | } 76 | 77 | // check any 78 | c.HostSelection = "any" 79 | test := "bla.bla.com" 80 | vals.Set("host", test) 81 | u.RawQuery = vals.Encode() 82 | h = c.NewHandler() 83 | host, err = h.getHost(ctx, u) 84 | if err != nil { 85 | t.Fatalf("%s is not accepted", host) 86 | } 87 | if test != host { 88 | t.Fatalf("Returned host %s is not equal to input host %s", host, test) 89 | } 90 | 91 | // check signed 92 | c.HostSelection = "signed" 93 | c.QueryInfo = security.QueryInfo 94 | issuer := "rdpgwtest" 95 | security.QuerySigningKey = key 96 | queryToken, err := security.GenerateQueryToken(ctx, hosts[0], issuer) 97 | if err != nil { 98 | t.Fatalf("cannot generate token") 99 | } 100 | vals.Set("host", queryToken) 101 | u.RawQuery = vals.Encode() 102 | h = c.NewHandler() 103 | host, err = h.getHost(ctx, u) 104 | if err != nil { 105 | t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err) 106 | } 107 | if host != hosts[0] { 108 | t.Fatalf("%s does not equal %s", host, hosts[0]) 109 | } 110 | } 111 | 112 | func TestHandler_HandleDownload(t *testing.T) { 113 | req, err := http.NewRequest("GET", "/connect", nil) 114 | if err != nil { 115 | t.Fatal(err) 116 | } 117 | 118 | rr := httptest.NewRecorder() 119 | id := identity.NewUser() 120 | 121 | id.SetUserName(testuser) 122 | id.SetAuthenticated(true) 123 | 124 | req = identity.AddToRequestCtx(id, req) 125 | ctx := req.Context() 126 | 127 | u, _ := url.Parse(gateway) 128 | c := Config{ 129 | HostSelection: "roundrobin", 130 | Hosts: hosts, 131 | PAATokenGenerator: paaTokenMock, 132 | GatewayAddress: u, 133 | RdpOpts: RdpOpts{SplitUserDomain: true}, 134 | } 135 | h := c.NewHandler() 136 | 137 | hh := http.HandlerFunc(h.HandleDownload) 138 | hh.ServeHTTP(rr, req) 139 | 140 | if status := rr.Code; status != http.StatusOK { 141 | t.Errorf("handler returned wrong status code: got %v want %v", 142 | status, http.StatusOK) 143 | } 144 | 145 | if ctype := rr.Header().Get("Content-Type"); ctype != "application/x-rdp" { 146 | t.Errorf("content type header does not match: got %v want %v", 147 | ctype, "application/json") 148 | } 149 | 150 | if cdisp := rr.Header().Get("Content-Disposition"); cdisp == "" { 151 | t.Errorf("content disposition is nil") 152 | } 153 | 154 | data := rdpToMap(strings.Split(rr.Body.String(), rdp.CRLF)) 155 | if data["username"] != testuser { 156 | t.Errorf("username key in rdp does not match: got %v want %v", data["username"], testuser) 157 | } 158 | 159 | if data["gatewayhostname"] != u.Host { 160 | t.Errorf("gatewayhostname key in rdp does not match: got %v want %v", data["gatewayhostname"], u.Host) 161 | } 162 | 163 | if token, _ := paaTokenMock(ctx, testuser, data["full address"]); token != data["gatewayaccesstoken"] { 164 | t.Errorf("gatewayaccesstoken key in rdp does not match username_full address: got %v want %v", 165 | data["gatewayaccesstoken"], token) 166 | } 167 | 168 | if !contains(data["full address"], hosts) { 169 | t.Errorf("full address key in rdp is not in allowed hosts list: go %v want in %v", 170 | data["full address"], hosts) 171 | } 172 | 173 | } 174 | 175 | func TestHandler_HandleDownloadWithRdpTemplate(t *testing.T) { 176 | f, err := os.CreateTemp("", "rdp") 177 | if err != nil { 178 | t.Fatal(err) 179 | } 180 | defer os.Remove(f.Name()) 181 | 182 | err = os.WriteFile(f.Name(), []byte("domain:s:testdomain\r\n"), 0644) 183 | if err != nil { 184 | t.Fatal(err) 185 | } 186 | 187 | req, err := http.NewRequest("GET", "/connect", nil) 188 | if err != nil { 189 | t.Fatal(err) 190 | } 191 | 192 | rr := httptest.NewRecorder() 193 | id := identity.NewUser() 194 | 195 | id.SetUserName(testuser) 196 | id.SetAuthenticated(true) 197 | 198 | req = identity.AddToRequestCtx(id, req) 199 | 200 | u, _ := url.Parse(gateway) 201 | c := Config{ 202 | HostSelection: "roundrobin", 203 | Hosts: hosts, 204 | PAATokenGenerator: paaTokenMock, 205 | GatewayAddress: u, 206 | RdpOpts: RdpOpts{SplitUserDomain: true}, 207 | TemplateFile: f.Name(), 208 | } 209 | h := c.NewHandler() 210 | 211 | hh := http.HandlerFunc(h.HandleDownload) 212 | hh.ServeHTTP(rr, req) 213 | 214 | data := rdpToMap(strings.Split(rr.Body.String(), rdp.CRLF)) 215 | if data["domain"] != "testdomain" { 216 | t.Errorf("domain key in rdp does not match: got %v want %v", data["domain"], "testdomain") 217 | } 218 | } 219 | 220 | func paaTokenMock(ctx context.Context, username string, host string) (string, error) { 221 | return username + "_" + host, nil 222 | } 223 | 224 | func rdpToMap(rdp []string) map[string]string { 225 | ret := make(map[string]string) 226 | 227 | for s := range rdp { 228 | d := strings.SplitN(rdp[s], ":", 3) 229 | if len(d) >= 2 { 230 | ret[d[0]] = d[2] 231 | } 232 | } 233 | 234 | return ret 235 | } 236 | -------------------------------------------------------------------------------- /dev/docker-distroless/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1 2 | WORKDIR /src 3 | ENV CGO_ENABLED 0 4 | COPY go.mod go.sum ./ 5 | RUN go mod download 6 | COPY . . 7 | RUN go build github.com/bolkedebruin/rdpgw/cmd/rdpgw 8 | 9 | FROM gcr.io/distroless/static-debian11:nonroot 10 | WORKDIR /config 11 | COPY --from=0 /src/rdpgw /rdpgw 12 | CMD ["/rdpgw"] 13 | -------------------------------------------------------------------------------- /dev/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # builder stage 2 | FROM golang:1.22-alpine as builder 3 | 4 | #RUN apt-get update && apt-get install -y libpam-dev 5 | RUN apk --no-cache add git gcc musl-dev linux-pam-dev openssl 6 | 7 | # add user 8 | RUN adduser --disabled-password --gecos "" --home /opt/rdpgw --uid 1001 rdpgw 9 | 10 | # certificate 11 | RUN mkdir -p /opt/rdpgw && cd /opt/rdpgw && \ 12 | random=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 32 | head -n 1) && \ 13 | openssl genrsa -des3 -passout pass:$random -out server.pass.key 2048 && \ 14 | openssl rsa -passin pass:$random -in server.pass.key -out key.pem && \ 15 | rm server.pass.key && \ 16 | openssl req -new -sha256 -key key.pem -out server.csr \ 17 | -subj "/C=US/ST=VA/L=SomeCity/O=MyCompany/OU=MyDivision/CN=rdpgw" && \ 18 | openssl x509 -req -days 365 -in server.csr -signkey key.pem -out server.pem 19 | 20 | # build rdpgw and set rights 21 | ARG CACHEBUST 22 | RUN git clone https://github.com/bolkedebruin/rdpgw.git /app && \ 23 | cd /app && \ 24 | go mod tidy -compat=1.19 && \ 25 | CGO_ENABLED=0 GOOS=linux go build -trimpath -tags '' -ldflags '' -o '/opt/rdpgw/rdpgw' ./cmd/rdpgw && \ 26 | CGO_ENABLED=1 GOOS=linux go build -trimpath -tags '' -ldflags '' -o '/opt/rdpgw/rdpgw-auth' ./cmd/auth && \ 27 | chmod +x /opt/rdpgw/rdpgw && \ 28 | chmod +x /opt/rdpgw/rdpgw-auth && \ 29 | chmod u+s /opt/rdpgw/rdpgw-auth 30 | 31 | FROM alpine:latest 32 | 33 | RUN apk --no-cache add linux-pam musl 34 | 35 | # make tempdir in case filestore is used 36 | ADD tmp.tar / 37 | 38 | COPY --chown=0 rdpgw-pam /etc/pam.d/rdpgw 39 | 40 | USER 1001 41 | COPY --chown=1001 run.sh run.sh 42 | COPY --chown=1001 --from=builder /opt/rdpgw /opt/rdpgw 43 | COPY --chown=1001 --from=builder /etc/passwd /etc/passwd 44 | COPY --chown=1001 --from=builder /etc/ssl/certs /etc/ssl/certs 45 | 46 | USER 0 47 | 48 | WORKDIR /opt/rdpgw 49 | ENTRYPOINT ["/bin/sh", "/run.sh"] 50 | -------------------------------------------------------------------------------- /dev/docker/Dockerfile.xrdp: -------------------------------------------------------------------------------- 1 | FROM rattydave/docker-ubuntu-xrdp-mate-custom:latest 2 | 3 | RUN cd /etc/xrdp/ && \ 4 | openssl req -x509 -newkey rsa:2048 -nodes -keyout key.pem -out cert.pem -days 3650 \ 5 | -subj "/C=US/ST=VA/L=SomeCity/O=MyCompany/OU=MyDivision/CN=xrdp" 6 | 7 | COPY xrdp.ini /etc/xrdp/xrdp.ini 8 | -------------------------------------------------------------------------------- /dev/docker/docker-compose-arm64.yml: -------------------------------------------------------------------------------- 1 | version: '3.4' 2 | 3 | volumes: 4 | mysql_data: 5 | driver: local 6 | realm-export.json: 7 | 8 | services: 9 | keycloak: 10 | container_name: keycloak 11 | image: richardjkendall/keycloak-arm:latest 12 | hostname: keycloak 13 | volumes: 14 | - ${PWD}/realm-export.json:/export/realm-export.json 15 | environment: 16 | KEYCLOAK_USER: admin 17 | KEYCLOAK_PASSWORD: admin 18 | KEYCLOAK_IMPORT: /export/realm-export.json 19 | ports: 20 | - 8080:8080 21 | restart: on-failure 22 | healthcheck: 23 | test: ["CMD", "curl", "-f", "http://localhost:8080/auth"] 24 | interval: 10s 25 | timeout: 3s 26 | retries: 10 27 | start_period: 5s 28 | xrdp: 29 | container_name: xrdp 30 | hostname: xrdp 31 | image: bolkedebruin/docker-ubuntu-xrdp-mate-rdpgw:latest 32 | ports: 33 | - 3389:3389 34 | restart: on-failure 35 | volumes: 36 | - ${PWD}/xrdp_users.txt:/root/createusers.txt 37 | environment: 38 | TZ: "Europe/Amsterdam" 39 | rdpgw: 40 | container_name: rdpgw 41 | hostname: rdpgw 42 | image: bolkedebruin/rdpgw:latest 43 | build: . 44 | ports: 45 | - 9443:9443 46 | restart: on-failure 47 | depends_on: 48 | keycloak: 49 | condition: service_healthy 50 | environment: 51 | RDPGW_SERVER__SESSION_STORE: file 52 | RDPGW_SERVER__CERT_FILE: /opt/rdpgw/server.pem 53 | RDPGW_SERVER__KEY_FILE: /opt/rdpgw/key.pem 54 | RDPGW_SERVER__GATEWAY_ADDRESS: localhost:9443 55 | RDPGW_SERVER__PORT: 9443 56 | RDPGW_SERVER__HOSTS: xrdp:3389 57 | RDPGW_SERVER__ROUND_ROBIN: "false" 58 | RDPGW_OPEN_ID__PROVIDER_URL: "http://keycloak:8080/auth/realms/rdpgw" 59 | RDPGW_OPEN_ID__CLIENT_ID: rdpgw 60 | RDPGW_OPEN_ID__CLIENT_SECRET: 01cd304c-6f43-4480-9479-618eb6fd578f 61 | RDPGW_CLIENT__USERNAME_TEMPLATE: "{{ username }}" 62 | RDPGW_CAPS__TOKEN_AUTH: "true" 63 | healthcheck: 64 | test: ["CMD", "curl", "-f", "http://keycloak:8080"] 65 | interval: 10s 66 | timeout: 10s 67 | retries: 10 68 | -------------------------------------------------------------------------------- /dev/docker/docker-compose-local.yml: -------------------------------------------------------------------------------- 1 | version: '3.4' 2 | 3 | services: 4 | xrdp: 5 | container_name: xrdp 6 | hostname: xrdp 7 | image: bolkedebruin/docker-ubuntu-xrdp-mate-rdpgw:latest 8 | ports: 9 | - 3389:3389 10 | restart: on-failure 11 | volumes: 12 | - ${PWD}/xrdp_users.txt:/root/createusers.txt 13 | environment: 14 | TZ: "Europe/Amsterdam" 15 | rdpgw: 16 | container_name: rdpgw 17 | hostname: rdpgw 18 | image: bolkedebruin/rdpgw:latest 19 | build: . 20 | ports: 21 | - 9443:9443 22 | restart: on-failure 23 | volumes: 24 | - ${PWD}/xrdp_users.txt:/root/createusers.txt 25 | environment: 26 | RDPGW_SERVER__SESSION_STORE: file 27 | RDPGW_SERVER__CERT_FILE: /opt/rdpgw/server.pem 28 | RDPGW_SERVER__KEY_FILE: /opt/rdpgw/key.pem 29 | RDPGW_SERVER__GATEWAY_ADDRESS: localhost:9443 30 | RDPGW_SERVER__PORT: 9443 31 | RDPGW_SERVER__HOSTS: xrdp:3389 32 | RDPGW_SERVER__ROUND_ROBIN: "false" 33 | RDPGW_SERVER__AUTHENTICATION: local 34 | RDPGW_CAPS__TOKEN_AUTH: "false" 35 | healthcheck: 36 | test: ["CMD", "curl", "-f", "http://localhost:9443/"] 37 | interval: 10s 38 | timeout: 10s 39 | retries: 10 40 | -------------------------------------------------------------------------------- /dev/docker/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.4' 2 | 3 | volumes: 4 | mysql_data: 5 | driver: local 6 | realm-export.json: 7 | 8 | services: 9 | keycloak: 10 | container_name: keycloak 11 | image: quay.io/keycloak/keycloak:latest 12 | hostname: keycloak 13 | volumes: 14 | - ${PWD}/realm-export.json:/opt/keycloak/data/import/realm-export.json 15 | environment: 16 | KEYCLOAK_USER: admin 17 | KEYCLOAK_PASSWORD: admin 18 | KEYCLOAK_ADMIN: admin 19 | KEYCLOAK_ADMIN_PASSWORD: admin 20 | ports: 21 | - 8080:8080 22 | restart: on-failure 23 | command: 24 | - start-dev 25 | - --import-realm 26 | - --http-relative-path=/auth 27 | healthcheck: 28 | test: ["CMD", "curl", "-f", "http://localhost:8080/auth"] 29 | interval: 10s 30 | timeout: 3s 31 | retries: 10 32 | start_period: 5s 33 | xrdp: 34 | container_name: xrdp 35 | hostname: xrdp 36 | image: bolkedebruin/docker-ubuntu-xrdp-mate-rdpgw:latest 37 | ports: 38 | - 3389:3389 39 | restart: on-failure 40 | volumes: 41 | - ${PWD}/xrdp_users.txt:/root/createusers.txt 42 | environment: 43 | TZ: "Europe/Amsterdam" 44 | rdpgw: 45 | build: . 46 | ports: 47 | - 9443:9443 48 | restart: on-failure 49 | depends_on: 50 | keycloak: 51 | condition: service_healthy 52 | environment: 53 | RDPGW_SERVER__SESSION_STORE: file 54 | RDPGW_SERVER__CERT_FILE: /opt/rdpgw/server.pem 55 | RDPGW_SERVER__KEY_FILE: /opt/rdpgw/key.pem 56 | RDPGW_SERVER__GATEWAY_ADDRESS: localhost:9443 57 | RDPGW_SERVER__PORT: 9443 58 | RDPGW_SERVER__HOSTS: xrdp:3389 59 | RDPGW_SERVER__ROUND_ROBIN: "false" 60 | RDPGW_OPEN_ID__PROVIDER_URL: "http://keycloak:8080/auth/realms/rdpgw" 61 | RDPGW_OPEN_ID__CLIENT_ID: rdpgw 62 | RDPGW_OPEN_ID__CLIENT_SECRET: 01cd304c-6f43-4480-9479-618eb6fd578f 63 | RDPGW_CLIENT__USERNAME_TEMPLATE: "{{ username }}" 64 | RDPGW_CAPS__TOKEN_AUTH: "true" 65 | healthcheck: 66 | test: ["CMD", "curl", "-f", "http://keycloak:8080"] 67 | interval: 10s 68 | timeout: 10s 69 | retries: 10 70 | -------------------------------------------------------------------------------- /dev/docker/docker-readme.md: -------------------------------------------------------------------------------- 1 | # RDPGW 2 | ## What is RDPGW? 3 | Remote Desktop Gateway (RDPGW, RDG or RD Gateway) provides a secure encrypted connection 4 | to user desktops via RDP. It enhances control by removing all remote user direct access to 5 | your system and replaces it with a point-to-point remote desktop connection. 6 | 7 | ## How to use this image 8 | The remote desktop gateway relies on an OpenID Connect authentication service, such as Keycloak, 9 | Azure AD or Google, and a backend remote desktop service such as XRDP, gnome-remote-desktop, or 10 | Windows VMs. Make sure that these services have been properly setup and can be reached from 11 | where you will run this image. 12 | 13 | This image works stateless, which means it does not store any state by default. In case you configure 14 | the session store to be a `filestore` a little bit of session information is stored temporarily. This means 15 | that a load balancer would need to maintain state for a while, which typically is the case. 16 | 17 | Session and token encryption keys will be randomized on startup. As a consequence sessions will be 18 | invalidated on restarts and if you are load balancing the different instances will not be able to share 19 | user sessions. Make sure to set these encryption keys to something static, so they can be shared 20 | across the different instances if this is not what you want. 21 | 22 | ## Configuration through environment variables 23 | ```bash 24 | docker --run name rdpgw bolkedebruin/rdpgw:latest \ 25 | -e RDPGW_SERVER__CERT_FILE=/etc/rdpgw/cert.pem 26 | -e RDPGW_SERVER__KEY_FILE=/etc/rdpgw.cert.pem 27 | -e RDPGW_SERVER__GATEWAY_ADDRESS=https://localhost:443 28 | -e RDPGW_SERVER__SESSION_KEY=thisisasessionkeyreplacethisjetz # 32 characters 29 | -e RDPGW_SERVER__SESSION_ENCRYPTION_KEY=thisisasessionkeyreplacethisnunu # 32 characters 30 | -e RDPGW_OPEN_ID__PROVIDER_URL=http://keycloak:8080/auth/realms/rdpgw 31 | -e RDPGW_OPEN_ID__CLIENT_ID=rdpgw 32 | -e RDPGW_OPEN_ID__CLIENT_SECRET=01cd304c-6f43-4480-9479-618eb6fd578f 33 | -e RDPGW_SECURITY__SECURITY_PAA_TOKEN_SIGNING_KEY=prettypleasereplacemeinproductio # 32 characters 34 | -v conf:/etc/rdpgw 35 | ``` -------------------------------------------------------------------------------- /dev/docker/rdpgw-pam: -------------------------------------------------------------------------------- 1 | # basic PAM configuration for rdpgw on Alpine 2 | auth include base-auth 3 | auth include base-account 4 | -------------------------------------------------------------------------------- /dev/docker/rdpgw.yaml: -------------------------------------------------------------------------------- 1 | Server: 2 | CertFile: /opt/rdpgw/server.pem 3 | KeyFile: /opt/rdpgw/key.pem 4 | GatewayAddress: localhost:9443 5 | Port: 9443 6 | Hosts: 7 | - xrdp:3389 8 | RoundRobin: false 9 | SessionKey: thisisasessionkeyreplacethisjetz 10 | SessionEncryptionKey: thisisasessionkeyreplacethisnunu 11 | OpenId: 12 | ProviderUrl: http://keycloak:8080/auth/realms/rdpgw 13 | ClientId: rdpgw 14 | ClientSecret: 01cd304c-6f43-4480-9479-618eb6fd578f 15 | Client: 16 | UsernameTemplate: "{{ username }}" 17 | Security: 18 | PAATokenSigningKey: prettypleasereplacemeinproductio 19 | Caps: 20 | TokenAuth: true 21 | -------------------------------------------------------------------------------- /dev/docker/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | USER=rdpgw 4 | 5 | file="/root/createusers.txt" 6 | if [ -f $file ] 7 | then 8 | while IFS=: read -r username password is_sudo 9 | do 10 | echo "Username: $username, Password: **** , Sudo: $is_sudo" 11 | 12 | if getent passwd "$username" > /dev/null 2>&1 13 | then 14 | echo "User Exists" 15 | else 16 | adduser -s /sbin/nologin "$username" 17 | echo "$username:$password" | chpasswd 18 | fi 19 | done <"$file" 20 | fi 21 | 22 | cd /opt/rdpgw || exit 1 23 | 24 | if [ -n "${RDPGW_SERVER__AUTHENTICATION}" ]; then 25 | if [ "${RDPGW_SERVER__AUTHENTICATION}" = "local" ]; then 26 | echo "Starting rdpgw-auth" 27 | /opt/rdpgw/rdpgw-auth & 28 | fi 29 | fi 30 | 31 | # drop privileges and run the application 32 | su -c /opt/rdpgw/rdpgw "${USER}" -- "$@" & 33 | wait 34 | exit $? 35 | -------------------------------------------------------------------------------- /dev/docker/tmp.tar: -------------------------------------------------------------------------------- 1 | tmp/0001777000000000000000000000000013140662635010367 5ustar rootroot -------------------------------------------------------------------------------- /dev/docker/xrdp.ini: -------------------------------------------------------------------------------- 1 | [Globals] 2 | ; xrdp.ini file version number 3 | ini_version=1 4 | 5 | ; fork a new process for each incoming connection 6 | fork=true 7 | ; tcp port to listen 8 | port=3389 9 | ; regulate if the listening socket use socket option tcp_nodelay 10 | ; no buffering will be performed in the TCP stack 11 | tcp_nodelay=true 12 | ; regulate if the listening socket use socket option keepalive 13 | ; if the network connection disappear without close messages the connection will be closed 14 | tcp_keepalive=true 15 | #tcp_send_buffer_bytes=32768 16 | #tcp_recv_buffer_bytes=32768 17 | 18 | ; security layer can be 'tls', 'rdp' or 'negotiate' 19 | ; for client compatible layer 20 | security_layer=negotiate 21 | ; minimum security level allowed for client 22 | ; can be 'none', 'low', 'medium', 'high', 'fips' 23 | crypt_level=high 24 | ; X.509 certificate and private key 25 | ; openssl req -x509 -newkey rsa:2048 -nodes -keyout key.pem -out cert.pem -days 365 26 | certificate=cert.pem 27 | key_file=key.pem 28 | ; set SSL protocols 29 | ; can be comma separated list of 'SSLv3', 'TLSv1', 'TLSv1.1', 'TLSv1.2' 30 | ssl_protocols=TLSv1, TLSv1.1, TLSv1.2 31 | ; set TLS cipher suites 32 | #tls_ciphers=HIGH 33 | 34 | ; Section name to use for automatic login if the client sends username 35 | ; and password. If empty, the domain name sent by the client is used. 36 | ; If empty and no domain name is given, the first suitable section in 37 | ; this file will be used. 38 | autorun= 39 | 40 | allow_channels=true 41 | allow_multimon=true 42 | bitmap_cache=true 43 | bitmap_compression=true 44 | bulk_compression=true 45 | #hidelogwindow=true 46 | max_bpp=16 47 | new_cursors=false 48 | ; fastpath - can be 'input', 'output', 'both', 'none' 49 | use_fastpath=both 50 | ; when true, userid/password *must* be passed on cmd line 51 | #require_credentials=true 52 | ; You can set the PAM error text in a gateway setup (MAX 256 chars) 53 | #pamerrortxt=change your password according to policy at http://url 54 | 55 | ; 56 | ; colors used by windows in RGB format 57 | ; 58 | blue=009cb5 59 | grey=dedede 60 | #black=000000 61 | #dark_grey=808080 62 | #blue=08246b 63 | #dark_blue=08246b 64 | #white=ffffff 65 | #red=ff0000 66 | #green=00ff00 67 | #background=626c72 68 | 69 | ; 70 | ; configure login screen 71 | ; 72 | 73 | ; Login Screen Window Title 74 | #ls_title=My Login Title 75 | 76 | ; top level window background color in RGB format 77 | ls_top_window_bg_color=009cb5 78 | 79 | ; width and height of login screen 80 | ls_width=350 81 | ls_height=430 82 | 83 | ; login screen background color in RGB format 84 | ls_bg_color=dedede 85 | 86 | ; optional background image filename (bmp format). 87 | #ls_background_image= 88 | 89 | ; logo 90 | ; full path to bmp-file or file in shared folder 91 | ls_logo_filename= 92 | ls_logo_x_pos=55 93 | ls_logo_y_pos=50 94 | 95 | ; for positioning labels such as username, password etc 96 | ls_label_x_pos=30 97 | ls_label_width=60 98 | 99 | ; for positioning text and combo boxes next to above labels 100 | ls_input_x_pos=110 101 | ls_input_width=210 102 | 103 | ; y pos for first label and combo box 104 | ls_input_y_pos=220 105 | 106 | ; OK button 107 | ls_btn_ok_x_pos=142 108 | ls_btn_ok_y_pos=370 109 | ls_btn_ok_width=85 110 | ls_btn_ok_height=30 111 | 112 | ; Cancel button 113 | ls_btn_cancel_x_pos=237 114 | ls_btn_cancel_y_pos=370 115 | ls_btn_cancel_width=85 116 | ls_btn_cancel_height=30 117 | 118 | [Logging] 119 | LogFile=xrdp.log 120 | LogLevel=debug 121 | EnableSyslog=true 122 | SyslogLevel=error 123 | ; LogLevel and SysLogLevel could by any of: core, error, warning, info or debug 124 | 125 | [Channels] 126 | ; Channel names not listed here will be blocked by XRDP. 127 | ; You can block any channel by setting its value to false. 128 | ; IMPORTANT! All channels are not supported in all use 129 | ; cases even if you set all values to true. 130 | ; You can override these settings on each session type 131 | ; These settings are only used if allow_channels=true 132 | rdpdr=true 133 | rdpsnd=true 134 | drdynvc=true 135 | cliprdr=true 136 | rail=true 137 | xrdpvr=true 138 | tcutils=true 139 | 140 | ; for debugging xrdp, in section xrdp1, change port=-1 to this: 141 | #port=/tmp/.xrdp/xrdp_display_10 142 | 143 | ; for debugging xrdp, add following line to section xrdp1 144 | #chansrvport=/tmp/.xrdp/xrdp_chansrv_socket_7210 145 | 146 | 147 | ; 148 | ; Session types 149 | ; 150 | 151 | [Xorg] 152 | name=Xorg - Resizing. 153 | lib=libxup.so 154 | username=ask 155 | password=ask 156 | ip=127.0.0.1 157 | port=-1 158 | code=20 159 | 160 | #[X11rdp] 161 | #name=X11rdp 162 | #lib=libxup.so 163 | #username=ask 164 | #password=ask 165 | #ip=127.0.0.1 166 | #port=-1 167 | #xserverbpp=24 168 | #code=10 169 | 170 | [Xvnc] 171 | name=Xvnc - Screen Sharing. 172 | lib=libvnc.so 173 | username=ask 174 | password=ask 175 | ip=127.0.0.1 176 | port=-1 177 | xserverbpp=16 178 | #delay_ms=2000 179 | 180 | [Reconnect] 181 | name=Reconnect 182 | lib=libvnc.so 183 | ip=127.0.0.1 184 | port=ask5910 185 | username=ask 186 | password=ask 187 | #delay_ms=2000 188 | 189 | #[vnc-any] 190 | #name=vnc-any 191 | #lib=libvnc.so 192 | #ip=ask 193 | #port=ask5900 194 | #username=na 195 | #password=ask 196 | #pamusername=asksame 197 | #pampassword=asksame 198 | #pamsessionmng=127.0.0.1 199 | #delay_ms=2000 200 | 201 | #[sesman-any] 202 | #name=sesman-any 203 | #lib=libvnc.so 204 | #ip=ask 205 | #port=-1 206 | #username=ask 207 | #password=ask 208 | #delay_ms=20 209 | -------------------------------------------------------------------------------- /dev/docker/xrdp_users.txt: -------------------------------------------------------------------------------- 1 | admin:admin:Y 2 | -------------------------------------------------------------------------------- /docs/images/flow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Client 11 | RDP Gateway 12 | RDP GW Auth 13 | PAM 14 | Passwd or LDAP 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/bolkedebruin/rdpgw 2 | 3 | go 1.22 4 | toolchain go1.24.1 5 | 6 | require ( 7 | github.com/bolkedebruin/gokrb5/v8 v8.5.0 8 | github.com/coreos/go-oidc/v3 v3.9.0 9 | github.com/fatih/structs v1.1.0 10 | github.com/go-jose/go-jose/v4 v4.0.5 11 | github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 12 | github.com/google/uuid v1.6.0 13 | github.com/gorilla/mux v1.8.1 14 | github.com/gorilla/sessions v1.2.2 15 | github.com/gorilla/websocket v1.5.1 16 | github.com/jcmturner/gofork v1.7.6 17 | github.com/jcmturner/goidentity/v6 v6.0.1 18 | github.com/knadh/koanf/parsers/yaml v0.1.0 19 | github.com/knadh/koanf/providers/confmap v0.1.0 20 | github.com/knadh/koanf/providers/env v0.1.0 21 | github.com/knadh/koanf/providers/file v0.1.0 22 | github.com/knadh/koanf/v2 v2.1.0 23 | github.com/m7913d/go-ntlm v0.0.1 24 | github.com/msteinert/pam/v2 v2.0.0 25 | github.com/patrickmn/go-cache v2.1.0+incompatible 26 | github.com/prometheus/client_golang v1.19.0 27 | github.com/stretchr/testify v1.10.0 28 | github.com/thought-machine/go-flags v1.6.3 29 | golang.org/x/crypto v0.36.0 30 | golang.org/x/oauth2 v0.18.0 31 | google.golang.org/grpc v1.62.1 32 | google.golang.org/protobuf v1.33.0 33 | ) 34 | 35 | require ( 36 | github.com/beorn7/perks v1.0.1 // indirect 37 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 38 | github.com/davecgh/go-spew v1.1.1 // indirect 39 | github.com/fsnotify/fsnotify v1.7.0 // indirect 40 | github.com/go-jose/go-jose/v3 v3.0.4 // indirect 41 | github.com/golang/protobuf v1.5.4 // indirect 42 | github.com/gorilla/securecookie v1.1.2 // indirect 43 | github.com/hashicorp/go-uuid v1.0.3 // indirect 44 | github.com/jcmturner/aescts/v2 v2.0.0 // indirect 45 | github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect 46 | github.com/jcmturner/rpc/v2 v2.0.3 // indirect 47 | github.com/knadh/koanf/maps v0.1.1 // indirect 48 | github.com/kr/text v0.2.0 // indirect 49 | github.com/mitchellh/copystructure v1.2.0 // indirect 50 | github.com/mitchellh/reflectwalk v1.0.2 // indirect 51 | github.com/pmezard/go-difflib v1.0.0 // indirect 52 | github.com/prometheus/client_model v0.6.0 // indirect 53 | github.com/prometheus/common v0.50.0 // indirect 54 | github.com/prometheus/procfs v0.13.0 // indirect 55 | golang.org/x/net v0.38.0 // indirect 56 | golang.org/x/sys v0.31.0 // indirect 57 | golang.org/x/text v0.23.0 // indirect 58 | google.golang.org/appengine v1.6.8 // indirect 59 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240314234333-6e1732d8331c // indirect 60 | gopkg.in/yaml.v3 v3.0.1 // indirect 61 | ) 62 | -------------------------------------------------------------------------------- /proto/auth.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package auth; 4 | 5 | option go_package = "./auth"; 6 | 7 | message UserPass { 8 | string username = 1; 9 | string password = 2; 10 | } 11 | 12 | message AuthResponse { 13 | bool authenticated = 1; 14 | string error = 2; 15 | } 16 | 17 | message NtlmRequest { 18 | string session = 1; 19 | string ntlmMessage = 2; 20 | } 21 | 22 | message NtlmResponse { 23 | bool authenticated = 1; 24 | string username = 2; 25 | string ntlmMessage = 3; 26 | } 27 | 28 | service Authenticate { 29 | rpc Authenticate (UserPass) returns (AuthResponse) {} 30 | rpc NTLM (NtlmRequest) returns (NtlmResponse) {} 31 | } 32 | -------------------------------------------------------------------------------- /shared/auth/auth_grpc.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go-grpc. DO NOT EDIT. 2 | 3 | package auth 4 | 5 | import ( 6 | context "context" 7 | grpc "google.golang.org/grpc" 8 | codes "google.golang.org/grpc/codes" 9 | status "google.golang.org/grpc/status" 10 | ) 11 | 12 | // This is a compile-time assertion to ensure that this generated file 13 | // is compatible with the grpc package it is being compiled against. 14 | // Requires gRPC-Go v1.32.0 or later. 15 | const _ = grpc.SupportPackageIsVersion7 16 | 17 | // AuthenticateClient is the client API for Authenticate service. 18 | // 19 | // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. 20 | type AuthenticateClient interface { 21 | Authenticate(ctx context.Context, in *UserPass, opts ...grpc.CallOption) (*AuthResponse, error) 22 | NTLM(ctx context.Context, in *NtlmRequest, opts ...grpc.CallOption) (*NtlmResponse, error) 23 | } 24 | 25 | type authenticateClient struct { 26 | cc grpc.ClientConnInterface 27 | } 28 | 29 | func NewAuthenticateClient(cc grpc.ClientConnInterface) AuthenticateClient { 30 | return &authenticateClient{cc} 31 | } 32 | 33 | func (c *authenticateClient) Authenticate(ctx context.Context, in *UserPass, opts ...grpc.CallOption) (*AuthResponse, error) { 34 | out := new(AuthResponse) 35 | err := c.cc.Invoke(ctx, "/auth.Authenticate/Authenticate", in, out, opts...) 36 | if err != nil { 37 | return nil, err 38 | } 39 | return out, nil 40 | } 41 | 42 | func (c *authenticateClient) NTLM(ctx context.Context, in *NtlmRequest, opts ...grpc.CallOption) (*NtlmResponse, error) { 43 | out := new(NtlmResponse) 44 | err := c.cc.Invoke(ctx, "/auth.Authenticate/NTLM", in, out, opts...) 45 | if err != nil { 46 | return nil, err 47 | } 48 | return out, nil 49 | } 50 | 51 | // AuthenticateServer is the server API for Authenticate service. 52 | // All implementations must embed UnimplementedAuthenticateServer 53 | // for forward compatibility 54 | type AuthenticateServer interface { 55 | Authenticate(context.Context, *UserPass) (*AuthResponse, error) 56 | NTLM(context.Context, *NtlmRequest) (*NtlmResponse, error) 57 | mustEmbedUnimplementedAuthenticateServer() 58 | } 59 | 60 | // UnimplementedAuthenticateServer must be embedded to have forward compatible implementations. 61 | type UnimplementedAuthenticateServer struct { 62 | } 63 | 64 | func (UnimplementedAuthenticateServer) Authenticate(context.Context, *UserPass) (*AuthResponse, error) { 65 | return nil, status.Errorf(codes.Unimplemented, "method Authenticate not implemented") 66 | } 67 | func (UnimplementedAuthenticateServer) NTLM(context.Context, *NtlmRequest) (*NtlmResponse, error) { 68 | return nil, status.Errorf(codes.Unimplemented, "method NTLM not implemented") 69 | } 70 | func (UnimplementedAuthenticateServer) mustEmbedUnimplementedAuthenticateServer() {} 71 | 72 | // UnsafeAuthenticateServer may be embedded to opt out of forward compatibility for this service. 73 | // Use of this interface is not recommended, as added methods to AuthenticateServer will 74 | // result in compilation errors. 75 | type UnsafeAuthenticateServer interface { 76 | mustEmbedUnimplementedAuthenticateServer() 77 | } 78 | 79 | func RegisterAuthenticateServer(s grpc.ServiceRegistrar, srv AuthenticateServer) { 80 | s.RegisterService(&Authenticate_ServiceDesc, srv) 81 | } 82 | 83 | func _Authenticate_Authenticate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 84 | in := new(UserPass) 85 | if err := dec(in); err != nil { 86 | return nil, err 87 | } 88 | if interceptor == nil { 89 | return srv.(AuthenticateServer).Authenticate(ctx, in) 90 | } 91 | info := &grpc.UnaryServerInfo{ 92 | Server: srv, 93 | FullMethod: "/auth.Authenticate/Authenticate", 94 | } 95 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 96 | return srv.(AuthenticateServer).Authenticate(ctx, req.(*UserPass)) 97 | } 98 | return interceptor(ctx, in, info, handler) 99 | } 100 | 101 | func _Authenticate_NTLM_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { 102 | in := new(NtlmRequest) 103 | if err := dec(in); err != nil { 104 | return nil, err 105 | } 106 | if interceptor == nil { 107 | return srv.(AuthenticateServer).NTLM(ctx, in) 108 | } 109 | info := &grpc.UnaryServerInfo{ 110 | Server: srv, 111 | FullMethod: "/auth.Authenticate/NTLM", 112 | } 113 | handler := func(ctx context.Context, req interface{}) (interface{}, error) { 114 | return srv.(AuthenticateServer).NTLM(ctx, req.(*NtlmRequest)) 115 | } 116 | return interceptor(ctx, in, info, handler) 117 | } 118 | 119 | // Authenticate_ServiceDesc is the grpc.ServiceDesc for Authenticate service. 120 | // It's only intended for direct use with grpc.RegisterService, 121 | // and not to be introspected or modified (even as a copy) 122 | var Authenticate_ServiceDesc = grpc.ServiceDesc{ 123 | ServiceName: "auth.Authenticate", 124 | HandlerType: (*AuthenticateServer)(nil), 125 | Methods: []grpc.MethodDesc{ 126 | { 127 | MethodName: "Authenticate", 128 | Handler: _Authenticate_Authenticate_Handler, 129 | }, 130 | { 131 | MethodName: "NTLM", 132 | Handler: _Authenticate_NTLM_Handler, 133 | }, 134 | }, 135 | Streams: []grpc.StreamDesc{}, 136 | Metadata: "auth.proto", 137 | } 138 | --------------------------------------------------------------------------------