├── docs ├── .nojekyll ├── images │ ├── proxy.gif │ ├── psql.gif │ ├── rds-proxy-auth-flow.png │ ├── rds-proxy-client-startup-flow.png │ ├── shield-icon-white.svg │ └── shield-icon-gradient.svg ├── guides │ ├── README.md │ └── unique_port_per_db.md ├── _sidebar.md ├── _coverpage.md ├── architecture.md ├── index.html ├── security.md ├── quickstart.md └── reference.md ├── .gitignore ├── pkg ├── file │ ├── filesystem.go │ ├── helpers.go │ ├── filewriter.go │ └── filewriter_test.go ├── discovery │ ├── errors.go │ ├── client.go │ ├── factory │ │ └── from_config.go │ ├── static │ │ ├── client.go │ │ └── client_test.go │ ├── combined │ │ ├── client.go │ │ └── client_test.go │ └── rds │ │ ├── rds.go │ │ └── rds_test.go ├── config │ ├── acl_test.go │ ├── tags.go │ ├── kubernetes.go │ ├── ssl.go │ ├── acl.go │ ├── targets_test.go │ ├── targets.go │ ├── config.go │ └── config_test.go ├── log │ ├── log.go │ └── logger.go ├── proxy │ ├── config_test.go │ ├── manager.go │ ├── config.go │ └── proxy.go ├── cert │ └── cert.go ├── aws │ └── rds.go ├── kubernetes │ └── port_forward.go └── pg │ ├── backend.go │ ├── frontend.go │ └── ssl.go ├── .dockerignore ├── main.go ├── SECURITY.md ├── cmd ├── root.go ├── version.go ├── gen_certs.go ├── completion.go ├── proxy_server.go └── proxy_client.go ├── .github ├── workflows │ ├── .commitlint.json │ ├── release.yml │ ├── ci.yaml │ └── build.yml └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── CHANGELOG.md ├── go.mod ├── configs ├── server_config.yaml ├── client_config.yaml └── client_config_local.yaml ├── LICENSE ├── docker-compose.yml ├── examples └── interceptor.go ├── Makefile └── README.md /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | rds-proxy 2 | rds-auth-proxy 3 | coverage.out 4 | dist 5 | -------------------------------------------------------------------------------- /docs/images/proxy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mothership/rds-auth-proxy/HEAD/docs/images/proxy.gif -------------------------------------------------------------------------------- /docs/images/psql.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mothership/rds-auth-proxy/HEAD/docs/images/psql.gif -------------------------------------------------------------------------------- /docs/images/rds-proxy-auth-flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mothership/rds-auth-proxy/HEAD/docs/images/rds-proxy-auth-flow.png -------------------------------------------------------------------------------- /docs/images/rds-proxy-client-startup-flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mothership/rds-auth-proxy/HEAD/docs/images/rds-proxy-client-startup-flow.png -------------------------------------------------------------------------------- /pkg/file/filesystem.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import "github.com/spf13/afero" 4 | 5 | var appFs = afero.NewOsFs() 6 | 7 | func GetFileSystem() afero.Fs { 8 | return appFs 9 | } 10 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | rds-proxy 2 | rds-auth-proxy 3 | Makefile 4 | Dockerfile 5 | docker-compose.yml 6 | README.md 7 | CHANGELOG.txt 8 | 9 | build/ 10 | !build/bin 11 | configs/ 12 | docs/ 13 | dist/ 14 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/mothership/rds-auth-proxy/cmd" 7 | ) 8 | 9 | func main() { 10 | err := cmd.Execute() 11 | if err != nil { 12 | os.Exit(1) 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /docs/guides/README.md: -------------------------------------------------------------------------------- 1 | # User Guides 2 | 3 | This section contains guides for using RDS Auth Proxy in unique situations or with other specific technologies. 4 | 5 | 6 | - [Unique Port Per Database](./unique_port_per_db.md) 7 | -------------------------------------------------------------------------------- /docs/_sidebar.md: -------------------------------------------------------------------------------- 1 | - [Getting Started](./quickstart.md) 2 | - [Architecture](./architecture.md) 3 | - [Security](./security.md) 4 | - [Guides](./guides/) 5 | - [Unique Port Per Database](./guides/unique_port_per_db.md) 6 | - [Reference](./reference.md) 7 | -------------------------------------------------------------------------------- /pkg/discovery/errors.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | import "errors" 4 | 5 | // Possible errors returned by disovery clients 6 | var ( 7 | // ErrTargetNotFound shoould be returned when a target lookup by host 8 | // or name fails 9 | ErrTargetNotFound = errors.New("target not found") 10 | ) 11 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | All for the moment :) 6 | 7 | ## Reporting a Vulnerability 8 | 9 | Send us an email at `security@mothership.com`. 10 | 11 | As this software is in early development, there are no bounties, but you'll have our eternal gratitude! 12 | -------------------------------------------------------------------------------- /cmd/root.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | ) 6 | 7 | var rootCmd = &cobra.Command{ 8 | Use: "rds-auth-proxy", 9 | Short: "rds-auth-proxy launches an SSL-capable postgres proxy", 10 | } 11 | 12 | // Execute kicks off the CLI 13 | func Execute() error { 14 | return rootCmd.Execute() 15 | } 16 | -------------------------------------------------------------------------------- /.github/workflows/.commitlint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["@commitlint/config-conventional"], 3 | "rules": { 4 | "subject-case": [0, "never", "sentence-case"], 5 | "type-enum": [ 6 | 2, 7 | "always", 8 | ["build", "ci", "chore", "docs", "feat", "fix", "perf", "refactor", "revert", "style", "test"] 9 | ] 10 | } 11 | } -------------------------------------------------------------------------------- /pkg/config/acl_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/mothership/rds-auth-proxy/pkg/config" 7 | ) 8 | 9 | func TestACLInit(t *testing.T) { 10 | var acl ACL 11 | acl.Init() 12 | if acl.AllowedRDSTags == nil { 13 | t.Errorf("Expected allowed tags not to be nil") 14 | } 15 | 16 | if acl.BlockedRDSTags == nil { 17 | t.Errorf("Expected blocked tags not to be nil") 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /pkg/discovery/client.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/mothership/rds-auth-proxy/pkg/config" 7 | ) 8 | 9 | // Client is for discovering new database servers 10 | type Client interface { 11 | LookupTargetByHost(host string) (config.Target, error) 12 | LookupTargetByName(name string) (config.Target, error) 13 | GetTargets() []config.Target 14 | Refresh(ctx context.Context) error 15 | } 16 | -------------------------------------------------------------------------------- /docs/_coverpage.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 |

RDS Auth Proxy

5 |
6 | 7 | `rds-auth-proxy` allows you to keep your databases firewalled off, 8 | control access through IAM policies, and never share a database password again! 9 | 10 | [Getting Started](./quickstart.md) [View on Github](https://github.com/mothership/rds-auth-proxy) 11 | -------------------------------------------------------------------------------- /pkg/config/tags.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | // Tag is an RDS tag 4 | type Tag struct { 5 | Name string `mapstructure:"name"` 6 | Value string `mapstructure:"value"` 7 | } 8 | 9 | // TagList is a list of tags 10 | type TagList []*Tag 11 | 12 | // Find returns a tag by name 13 | func (t TagList) Find(key string) *Tag { 14 | for _, tag := range t { 15 | if tag.Name == key { 16 | return tag 17 | } 18 | } 19 | return nil 20 | } 21 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## v0.1.1 (2021-09-24) 4 | 5 | ### Bug Fixes 6 | 7 | - aurora instances are valid postgres instances (46b771d) 8 | - badge link for reportcard was bad (#10) (35ec1f8) 9 | - misspell in tests (2b17c8a) 10 | 11 | --- 12 | 13 | ## v0.1.0 (2021-09-23) 14 | 15 | ### Features 16 | 17 | - little more context to release script (#7) (43df2c4) 18 | - Initial import (#6) (2f1def0) 19 | - Changelog & Conventional Commit linter (f31ca29) 20 | 21 | ### Bug Fixes 22 | 23 | - fix release script (04f419d) 24 | 25 | --- 26 | 27 | -------------------------------------------------------------------------------- /cmd/version.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | ) 8 | 9 | var version string 10 | var commit string 11 | var date string 12 | 13 | var versionCmd = &cobra.Command{ 14 | Use: "version", 15 | Short: "Displays the current version", 16 | Long: `Displays the current version of the rds-auth-proxy binary`, 17 | Run: func(cmd *cobra.Command, args []string) { 18 | fmt.Printf("rds-auth-proxy: %s\n", version) 19 | fmt.Printf("commit: %s\n", commit) 20 | fmt.Printf("date: %s\n", date) 21 | }, 22 | } 23 | 24 | func init() { 25 | rootCmd.AddCommand(versionCmd) 26 | } 27 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mothership/rds-auth-proxy 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/AlecAivazis/survey/v2 v2.3.2 7 | github.com/aws/aws-sdk-go-v2 v1.9.1 8 | github.com/aws/aws-sdk-go-v2/config v1.8.2 9 | github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.1.7 10 | github.com/aws/aws-sdk-go-v2/service/rds v1.9.0 11 | github.com/imdario/mergo v0.3.8 // indirect 12 | github.com/jackc/pgproto3/v2 v2.1.1 13 | github.com/spf13/afero v1.6.0 14 | github.com/spf13/cobra v1.1.3 15 | github.com/spf13/viper v1.8.1 16 | go.uber.org/zap v1.19.1 17 | k8s.io/apimachinery v0.22.2 18 | k8s.io/client-go v0.22.2 19 | ) 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /pkg/config/kubernetes.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | // PortForward represents kubernetes port-forward config for tunneling a connection to the server-side proxy 4 | type PortForward struct { 5 | Namespace string `mapstructure:"namespace"` 6 | DeploymentName string `mapstructure:"deployment"` 7 | RemotePort string `mapstructure:"remote_port"` 8 | // Optional, if not set "0" is used 9 | LocalPort *string `mapstructure:"local_port"` 10 | Context string `mapstructure:"context"` 11 | KubeConfigFilePath string `mapstructure:"kube_config"` 12 | } 13 | 14 | // GetLocalPort returns the local port to be used for the port-forward 15 | func (p *PortForward) GetLocalPort() string { 16 | if p.LocalPort != nil { 17 | return *p.LocalPort 18 | } 19 | return "0" 20 | } 21 | -------------------------------------------------------------------------------- /configs/server_config.yaml: -------------------------------------------------------------------------------- 1 | proxy: 2 | # The listen addr of this proxy 3 | listen_addr: 0.0.0.0:8000 4 | ## 5 | # SSL Config 6 | # 7 | # The SSL config for the proxy itself. SSL for individual 8 | # hosts/targets is defined below 9 | ssl: 10 | enabled: true 11 | 12 | ## 13 | # Target ACL 14 | # 15 | # Configure allowed or blocked hosts / RDS instances. 16 | target_acl: 17 | allowed_rds_tags: [] 18 | blocked_rds_tags: [] 19 | 20 | ## 21 | # Target configuration 22 | # 23 | # This is where you can specify SSL settings for the upstream 24 | # databases or proxies. The keys MUST match an allowed target. 25 | # 26 | # RDS databases are handled automatically. 27 | targets: 28 | postgres: 29 | host: postgres:5432 30 | ssl: 31 | mode: "disable" # options are "disable", "verify-full", "verify-ca", or "require" 32 | -------------------------------------------------------------------------------- /pkg/file/helpers.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | ) 7 | 8 | func ExpandPath(filePath string) (string, error) { 9 | home, err := os.UserHomeDir() 10 | if err != nil { 11 | return "", err 12 | } 13 | return strings.ReplaceAll(filePath, "$HOME", home), nil 14 | } 15 | 16 | func Exists(filePath string) bool { 17 | path, err := ExpandPath(filePath) 18 | if err != nil { 19 | // Actually panic here, because no homedir is ??? 20 | panic(err) 21 | } 22 | 23 | info, err := os.Stat(path) 24 | if os.IsNotExist(err) { 25 | return false 26 | } 27 | return !info.IsDir() 28 | } 29 | 30 | func DirExists(filePath string) bool { 31 | path, err := ExpandPath(filePath) 32 | if err != nil { 33 | // Actually panic here, because no homedir is ??? 34 | panic(err) 35 | } 36 | 37 | info, err := os.Stat(path) 38 | if os.IsNotExist(err) { 39 | return false 40 | } 41 | return info.IsDir() 42 | } 43 | -------------------------------------------------------------------------------- /pkg/discovery/factory/from_config.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | import ( 4 | "github.com/mothership/rds-auth-proxy/pkg/aws" 5 | "github.com/mothership/rds-auth-proxy/pkg/config" 6 | "github.com/mothership/rds-auth-proxy/pkg/discovery" 7 | "github.com/mothership/rds-auth-proxy/pkg/discovery/combined" 8 | "github.com/mothership/rds-auth-proxy/pkg/discovery/rds" 9 | "github.com/mothership/rds-auth-proxy/pkg/discovery/static" 10 | ) 11 | 12 | // FromConfig returns a new DiscoveryClient from the settings in your configfile. 13 | func FromConfig(rdsClient aws.RDSClient, c *config.ConfigFile) discovery.Client { 14 | var staticTargets = make(map[string]config.Target, len(c.Targets)) 15 | for _, target := range c.Targets { 16 | staticTargets[target.Host] = *target 17 | } 18 | return combined.NewCombinedDiscoveryClient([]discovery.Client{ 19 | static.NewStaticDiscoveryClient(staticTargets), 20 | rds.NewRdsDiscoveryClient(rdsClient, c), 21 | }) 22 | } 23 | -------------------------------------------------------------------------------- /configs/client_config.yaml: -------------------------------------------------------------------------------- 1 | proxy: 2 | # The listen addr of this proxy 3 | listen_addr: 0.0.0.0:8001 4 | ## 5 | # SSL Config 6 | # 7 | # The SSL config for the proxy itself. SSL for individual 8 | # hosts/targets is defined below 9 | ssl: 10 | enabled: false 11 | 12 | ## 13 | # Target ACL 14 | # 15 | # Configure allowed or blocked hosts / RDS instances. 16 | target_acl: 17 | allowed_rds_tags: [] 18 | blocked_rds_tags: [] 19 | 20 | ## 21 | # Upstream Proxies Configuration 22 | # 23 | # This is where you can specify upstream proxy settings 24 | upstream_proxies: 25 | default: 26 | host: rds-proxy-server:8000 27 | ssl: 28 | mode: "require" 29 | 30 | ## 31 | # Target configuration 32 | # 33 | # This is where you can specify SSL settings for the upstream 34 | # (non-RDS) databases 35 | # 36 | # RDS databases are added automatically at runtime. 37 | targets: 38 | postgres: 39 | host: postgres:5432 40 | database: postgres 41 | -------------------------------------------------------------------------------- /docs/guides/unique_port_per_db.md: -------------------------------------------------------------------------------- 1 | # Unique Local Ports Per Database 2 | 3 | In some cases, you may want a unique local port per database instead of 4 | the default listening port for the proxy. 5 | 6 | Maybe you want your staging environment use the local port range 7 | `54000-54999`, and your production environment to use local ports 8 | `55000-55999`. Maybe you want to save and share connection details across 9 | various database GUIs or other tooling. 10 | 11 | Whatever the case, we can do that with the tag `rds-auth-proxy:local-port`. 12 | 13 | ## Adding the Tag 14 | 15 | ```bash 16 | aws rds add-tags-to-resource \ 17 | --resource-name {your-db-arn} \ 18 | --tags "[{\"Key\": \"rds-auth-proxy:local-port\",\"Value\": \"54000\"}]" 19 | ``` 20 | 21 | ## Try it out 22 | 23 | Now, when any of your developers run the client proxy, they should see the local proxy 24 | boot on port `54000`. 25 | 26 | ```bash 27 | rds-auth-proxy client --target {my-db-identifier} 28 | ``` 29 | -------------------------------------------------------------------------------- /pkg/log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "os" 5 | 6 | "go.uber.org/zap" 7 | ) 8 | 9 | // NewLogger returns a configured Zap Logger 10 | func NewLogger() *zap.Logger { 11 | var logConfig zap.Config 12 | if os.Getenv("DEBUG") == "true" { 13 | logConfig = zap.NewDevelopmentConfig() 14 | } else { 15 | logConfig = zap.NewProductionConfig() 16 | } 17 | 18 | level := os.Getenv("LOG_LEVEL") 19 | if level != "" { 20 | logConfig.Level = unmarshalLevel(level) 21 | } 22 | 23 | logger, err := logConfig.Build() 24 | if err != nil { 25 | panic(err) 26 | } 27 | return logger 28 | } 29 | 30 | func unmarshalLevel(l string) zap.AtomicLevel { 31 | switch l { 32 | case "warn": 33 | return zap.NewAtomicLevelAt(zap.WarnLevel) 34 | case "error": 35 | return zap.NewAtomicLevelAt(zap.ErrorLevel) 36 | case "debug": 37 | return zap.NewAtomicLevelAt(zap.DebugLevel) 38 | case "info": 39 | fallthrough 40 | default: 41 | return zap.NewAtomicLevelAt(zap.InfoLevel) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /configs/client_config_local.yaml: -------------------------------------------------------------------------------- 1 | # This config file should be used outside the docker container 2 | # to test the client binary for interactivity. 3 | proxy: 4 | # The listen addr of this proxy 5 | listen_addr: 0.0.0.0:8002 6 | ## 7 | # SSL Config 8 | # 9 | # The SSL config for the proxy itself. SSL for individual 10 | # hosts/targets is defined below 11 | ssl: 12 | enabled: false 13 | 14 | ## 15 | # Target ACL 16 | # 17 | # Configure allowed or blocked hosts / RDS instances. 18 | target_acl: 19 | allowed_rds_tags: [] 20 | blocked_rds_tags: [] 21 | 22 | ## 23 | # Upstream Proxies Configuration 24 | # 25 | # This is where you can specify upstream proxy settings 26 | upstream_proxies: 27 | default: 28 | host: 0.0.0.0:8000 29 | ssl: 30 | mode: "require" 31 | 32 | ## 33 | # Target configuration 34 | # 35 | # This is where you can specify SSL settings for the upstream 36 | # (non-RDS) databases 37 | # 38 | # RDS databases are added automatically at runtime. 39 | targets: 40 | postgres: 41 | host: postgres:5432 42 | database: postgres 43 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Client (please complete the following information):** 27 | - OS: [e.g. OSX] 28 | - Arch: [e.g. arm64] 29 | - Binary Version [e.g. 0.1.0] 30 | 31 | **Server (please complete the following information):** 32 | - OS: [e.g. Linux] 33 | - Arch: [e.g. arm64] 34 | - Binary Version [e.g. 0.1.0] 35 | 36 | **Config Files** 37 | 38 | If appropriate, please share the config files you're using (in a redacted form) 39 | 40 | **Additional context** 41 | Add any other context about the problem here. 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Mothership 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pkg/config/ssl.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/mothership/rds-auth-proxy/pkg/pg" 5 | ) 6 | 7 | // SSL represents settings for upstream (RDS instances, pg instances) 8 | type SSL struct { 9 | // Optional client certificate to use 10 | ClientCertificatePath *string `mapstructure:"client_certificate,omitempty"` 11 | // Optional client private key to use 12 | ClientPrivateKeyPath *string `mapstructure:"client_private_key,omitempty"` 13 | // SSL mode to verify upstream connection, defaults to "verify-full" 14 | Mode pg.SSLMode `mapstructure:"mode,omitempty"` 15 | // Path to a root certificate if the certificate is 16 | // not already in the system roots 17 | RootCertificatePath *string `mapstructure:"root_certificate"` 18 | } 19 | 20 | // ServerSSL is SSL settings for the proxy server 21 | type ServerSSL struct { 22 | Enabled bool `mapstructure:"enabled"` 23 | CertificatePath *string `mapstructure:"certificate,omitempty"` 24 | PrivateKeyPath *string `mapstructure:"private_key,omitempty"` 25 | ClientCertificatePath *string `mapstructure:"client_certificate,omitempty"` 26 | ClientPrivateKeyPath *string `mapstructure:"client_private_key,omitempty"` 27 | } 28 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.3' 2 | services: 3 | postgres: 4 | image: postgres:12.8-alpine 5 | environment: 6 | POSTGRES_PASSWORD: password 7 | POSTGRES_USER: postgres 8 | POSTGRES_DB: postgres 9 | PGDATA: /var/lib/postgresql/data/pgdata 10 | restart: always 11 | ports: 12 | - 5432:5432 13 | volumes: 14 | - /var/lib/postgresql/data/pgdata 15 | 16 | rds-proxy-server: 17 | build: 18 | context: . 19 | dockerfile: build/Dockerfile 20 | command: "server --configfile /configs/server_config.yaml" 21 | depends_on: 22 | - postgres 23 | ports: 24 | - "8000:8000" 25 | volumes: 26 | - ./configs:/configs 27 | - $HOME/.aws/:/.aws/ 28 | environment: 29 | LOG_LEVEL: "debug" 30 | 31 | rds-proxy-client: 32 | build: 33 | context: . 34 | dockerfile: build/Dockerfile 35 | command: "client --target postgres --password password --configfile /configs/client_config.yaml" 36 | depends_on: 37 | - rds-proxy-server 38 | ports: 39 | - "8001:8001" 40 | volumes: 41 | - ./configs:/configs 42 | - $HOME/.aws/:/.aws/ 43 | environment: 44 | LOG_LEVEL: "debug" 45 | -------------------------------------------------------------------------------- /examples/interceptor.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/jackc/pgproto3/v2" 8 | "github.com/mothership/rds-auth-proxy/pkg/pg" 9 | "github.com/mothership/rds-auth-proxy/pkg/proxy" 10 | ) 11 | 12 | // BasicInterceptor just echoes back the query to the backend. 13 | // Since it returns nil, the proxy will handle sending the message to the frontend. 14 | func BasicInterceptor(frontend pg.SendOnlyFrontend, backend pg.SendOnlyBackend, msg *pgproto3.Query) error { 15 | message := fmt.Sprintf("Got query from client: %+v", msg.String) 16 | _ = backend.Send(&pgproto3.NoticeResponse{Message: message}) 17 | return nil 18 | } 19 | 20 | // BasicDelayedInterceptor calls a goroutine and tells the proxy it 21 | // will take care of sending the message to the frontend. 22 | func BasicDelayedInterceptor(frontend pg.SendOnlyFrontend, backend pg.SendOnlyBackend, msg *pgproto3.Query) error { 23 | go func(frontend pg.SendOnlyFrontend, backend pg.SendOnlyBackend, msg *pgproto3.Query) { 24 | message := "Starting long running task. Please wait." 25 | _ = backend.Send(&pgproto3.NoticeResponse{Message: message}) 26 | time.Sleep(time.Second * 5) 27 | _ = frontend.Send(msg) 28 | }(frontend, backend, msg) 29 | return proxy.WillSendManually 30 | } 31 | -------------------------------------------------------------------------------- /pkg/log/logger.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import "go.uber.org/zap" 4 | 5 | var rootLogger *zap.Logger = NewLogger() 6 | 7 | // SetLogger sets a new logger as the root logger 8 | func SetLogger(l *zap.Logger) { 9 | rootLogger = l 10 | } 11 | 12 | // Debug forwards debug logs to the root logger 13 | func Debug(msg string, args ...zap.Field) { 14 | rootLogger.Debug(msg, args...) 15 | } 16 | 17 | // Info forwards info logs to the root logger 18 | func Info(msg string, args ...zap.Field) { 19 | rootLogger.Info(msg, args...) 20 | } 21 | 22 | // Warn forwards warn logs to the root logger 23 | func Warn(msg string, args ...zap.Field) { 24 | rootLogger.Warn(msg, args...) 25 | } 26 | 27 | // Error forwards error logs to the root logger 28 | func Error(msg string, args ...zap.Field) { 29 | rootLogger.Error(msg, args...) 30 | } 31 | 32 | // Fatal forwards fatal logs to the root logger 33 | func Fatal(msg string, args ...zap.Field) { 34 | rootLogger.Fatal(msg, args...) 35 | } 36 | 37 | // With returns a new logger with fields persisted 38 | func With(args ...zap.Field) *zap.Logger { 39 | return rootLogger.With(args...) 40 | } 41 | 42 | // WithOptions returns a new logger with options persisted 43 | func WithOptions(args ...zap.Field) *zap.Logger { 44 | return rootLogger.With(args...) 45 | } 46 | -------------------------------------------------------------------------------- /docs/architecture.md: -------------------------------------------------------------------------------- 1 | # Architecture 2 | 3 | `rds-auth-proxy` is a binary containing two different proxies. 4 | One proxy is run in a VPC subnet that can reach your RDS instances, 5 | the other on your client machine (dev laptop, etc.) with access to 6 | aws credentials. 7 | 8 | ## Client Proxy Startup Flow 9 | 10 | The client proxy is responsible for picking a host (RDS instance), and 11 | generating a temporary password using the local IAM identity. The 12 | client proxy injects the desired host and password into the postgres 13 | startup message as additional parameters. 14 | 15 | ![Client startup flow](./images/rds-proxy-client-startup-flow.png) 16 | 17 | ## Server Proxy Startup Flow 18 | 19 | The server proxy accepts a connection from the client proxy, and 20 | unpacks the host and password parameters. The server proxy checks 21 | that it's allowed to connect to the postgres database, based on 22 | the set of allowed/blocked tags specified in the config file. 23 | 24 | The server proxy then opens a connection to the RDS database and intercepts 25 | the authentication request. It passes along the password it received from 26 | the client, and forwards the result to the client. 27 | 28 | ![Auth overview](./images/rds-proxy-auth-flow.png) 29 | 30 | After successful auth, all messages are proxied transparently between the 31 | client and database. 32 | -------------------------------------------------------------------------------- /docs/images/shield-icon-white.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /pkg/discovery/static/client.go: -------------------------------------------------------------------------------- 1 | package static 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/mothership/rds-auth-proxy/pkg/config" 7 | "github.com/mothership/rds-auth-proxy/pkg/discovery" 8 | ) 9 | 10 | type StaticDiscoveryClient struct { 11 | targets map[string]config.Target 12 | } 13 | 14 | var _ discovery.Client = (*StaticDiscoveryClient)(nil) 15 | 16 | func NewStaticDiscoveryClient(targets map[string]config.Target) *StaticDiscoveryClient { 17 | return &StaticDiscoveryClient{ 18 | targets: targets, 19 | } 20 | } 21 | 22 | func (s *StaticDiscoveryClient) LookupTargetByHost(host string) (config.Target, error) { 23 | if target, ok := s.targets[host]; ok { 24 | return target, nil 25 | } 26 | return config.Target{}, discovery.ErrTargetNotFound 27 | } 28 | 29 | func (s *StaticDiscoveryClient) LookupTargetByName(name string) (config.Target, error) { 30 | for _, target := range s.targets { 31 | if target.Name == name { 32 | return target, nil 33 | } 34 | } 35 | return config.Target{}, discovery.ErrTargetNotFound 36 | } 37 | 38 | func (s *StaticDiscoveryClient) GetTargets() []config.Target { 39 | targetList := make([]config.Target, 0, len(s.targets)) 40 | for _, target := range s.targets { 41 | targetList = append(targetList, target) 42 | } 43 | return targetList 44 | } 45 | 46 | func (s *StaticDiscoveryClient) Refresh(ctx context.Context) error { 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /pkg/config/acl.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/aws/aws-sdk-go-v2/service/rds/types" 7 | ) 8 | 9 | // ACL represents rds instance tags allowed, or blocked by the proxy 10 | type ACL struct { 11 | AllowedRDSTags TagList `mapstructure:"allowed_rds_tags"` 12 | BlockedRDSTags TagList `mapstructure:"blocked_rds_tags"` 13 | } 14 | 15 | // Init finishes initializing the ACL struct 16 | func (a *ACL) Init() { 17 | if a.AllowedRDSTags == nil { 18 | a.AllowedRDSTags = []*Tag{} 19 | } 20 | 21 | if a.BlockedRDSTags == nil { 22 | a.BlockedRDSTags = []*Tag{} 23 | } 24 | } 25 | 26 | // IsAllowed returns an error if the instance tags are either not allowed, 27 | // or explicitly blocked. 28 | func (a *ACL) IsAllowed(tagList []types.Tag) error { 29 | tags := map[string]string{} 30 | for _, t := range tagList { 31 | tags[*t.Key] = *t.Value 32 | } 33 | 34 | for _, matcher := range a.AllowedRDSTags { 35 | value, ok := tags[matcher.Name] 36 | if !ok { 37 | return fmt.Errorf("tag %q not found on instance", matcher.Name) 38 | } 39 | if value != matcher.Value { 40 | return fmt.Errorf("tag %q has wrong value %q (wanted: %q)", matcher.Name, value, matcher.Value) 41 | } 42 | } 43 | for _, matcher := range a.BlockedRDSTags { 44 | value, ok := tags[matcher.Name] 45 | if !ok { 46 | continue 47 | } 48 | if value == matcher.Value { 49 | return fmt.Errorf("blocked by tag %q (value: %q)", matcher.Name, value) 50 | } 51 | } 52 | return nil 53 | } 54 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Create Release 2 | on: 3 | push: 4 | tags: 5 | - '*' 6 | 7 | jobs: 8 | goreleaser: 9 | runs-on: macos-latest 10 | steps: 11 | - name: Checkout 12 | uses: actions/checkout@v2 13 | with: 14 | fetch-depth: 0 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: 1.17 20 | 21 | - name: Install gon 22 | run: | 23 | brew tap mitchellh/gon 24 | brew install mitchellh/gon/gon 25 | 26 | - name: Import Code-Signing Certificates 27 | uses: Apple-Actions/import-codesign-certs@v1 28 | with: 29 | p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} 30 | p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} 31 | 32 | - name: Run GoReleaser 33 | uses: goreleaser/goreleaser-action@v2 34 | with: 35 | distribution: goreleaser 36 | version: latest 37 | args: -f ./build/goreleaser.yml release --rm-dist 38 | env: 39 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 40 | AC_PASSWORD: ${{ secrets.AC_PASSWORD }} 41 | AC_USERNAME: ${{ secrets.AC_USERNAME }} 42 | 43 | - name: Notarize Apple Binaries 44 | run: | 45 | gon ./build/notarization-config.json 46 | env: 47 | AC_PASSWORD: ${{ secrets.AC_PASSWORD }} 48 | AC_USERNAME: ${{ secrets.AC_USERNAME }} 49 | -------------------------------------------------------------------------------- /pkg/config/targets_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "testing" 5 | 6 | . "github.com/mothership/rds-auth-proxy/pkg/config" 7 | ) 8 | 9 | func TestTargetGetHost(t *testing.T) { 10 | cases := []struct { 11 | Target ProxyTarget 12 | ExpectedHost string 13 | }{ 14 | { 15 | Target: ProxyTarget{ 16 | Host: "0.0.0.0:8000", 17 | }, 18 | ExpectedHost: "0.0.0.0:8000", 19 | }, 20 | { 21 | Target: ProxyTarget{ 22 | Host: "0.0.0.0:8000", 23 | PortForward: &PortForward{ 24 | LocalPort: strPtr("8001"), 25 | }, 26 | }, 27 | ExpectedHost: "0.0.0.0:8001", 28 | }, 29 | } 30 | 31 | for idx, test := range cases { 32 | result := test.Target.GetHost() 33 | if result != test.ExpectedHost { 34 | t.Errorf("[Case %d] Expected %q, got %q", idx, test.ExpectedHost, result) 35 | } 36 | } 37 | } 38 | 39 | func TestTargetIsPortForward(t *testing.T) { 40 | cases := []struct { 41 | Target ProxyTarget 42 | Expected bool 43 | }{ 44 | { 45 | Target: ProxyTarget{ 46 | Host: "0.0.0.0:8000", 47 | }, 48 | Expected: false, 49 | }, 50 | { 51 | Target: ProxyTarget{ 52 | Host: "0.0.0.0:8000", 53 | PortForward: &PortForward{ 54 | LocalPort: strPtr("8001"), 55 | }, 56 | }, 57 | Expected: true, 58 | }, 59 | } 60 | 61 | for idx, test := range cases { 62 | result := test.Target.IsPortForward() 63 | if result != test.Expected { 64 | t.Errorf("[Case %d] Expected %t, got %t", idx, test.Expected, result) 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /pkg/file/filewriter.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | 7 | "github.com/spf13/afero" 8 | ) 9 | 10 | type FileWriter struct { 11 | buffer *bytes.Buffer 12 | Fs afero.Fs 13 | err error 14 | } 15 | 16 | func NewFileWriter() *FileWriter { 17 | return &FileWriter{ 18 | buffer: bytes.NewBuffer(make([]byte, 0, 120)), 19 | Fs: appFs, 20 | } 21 | } 22 | 23 | func (f *FileWriter) P(fmtStr string, args ...interface{}) { 24 | if f.err != nil { 25 | return 26 | } 27 | _, err := f.buffer.WriteString(fmt.Sprintf(fmtStr, args...)) 28 | if err != nil { 29 | f.err = err 30 | return 31 | } 32 | _, err = f.buffer.WriteString("\n") 33 | f.err = err 34 | } 35 | 36 | func (f *FileWriter) Write(bytes []byte) (int, error) { 37 | if f.err != nil { 38 | return 0, f.err 39 | } 40 | count, err := f.buffer.Write(bytes) 41 | f.err = err 42 | return count, err 43 | } 44 | 45 | func (f *FileWriter) Save(path string) error { 46 | if f.err != nil { 47 | return f.err 48 | } 49 | 50 | tmpfile, err := afero.TempFile(f.Fs, "", "file*") 51 | if err != nil { 52 | return err 53 | } 54 | // Note: We don't check errors here because a successful 55 | // write means the tmpfile won't exist anymore 56 | //nolint:errcheck 57 | defer f.Fs.Remove(tmpfile.Name()) 58 | if _, err := f.buffer.WriteTo(tmpfile); err != nil { 59 | return err 60 | } 61 | if err = tmpfile.Sync(); err != nil { 62 | return err 63 | } 64 | if err = tmpfile.Close(); err != nil { 65 | return err 66 | } 67 | 68 | return f.Fs.Rename(tmpfile.Name(), path) 69 | } 70 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: ["main"] 5 | pull_request: 6 | jobs: 7 | commit-lint: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | with: 12 | fetch-depth: 0 13 | - uses: wagoid/commitlint-github-action@v4 14 | with: 15 | configFile: .github/workflows/.commitlint.json 16 | test: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout Branch 20 | uses: actions/checkout@v2 21 | - name: Setup Go 22 | uses: actions/setup-go@v2.1.3 23 | with: 24 | go-version: 1.17.x 25 | - name: Setup Environment 26 | run: | 27 | echo "GOPATH=$(go env GOPATH)" >> $GITHUB_ENV 28 | echo "$(go env GOPATH)/bin" >> $GITHUB_PATH 29 | - name: Restore Cache 30 | uses: actions/cache@v2 31 | with: 32 | path: | 33 | ~/go/pkg/mod 34 | ~/.cache/go-build 35 | key: ubuntu-latest-go-${{ hashFiles('**/go.sum') }} 36 | restore-keys: | 37 | ubuntu-latest-go- 38 | - name: Install Dependencies 39 | run: go mod download 40 | - name: Test 41 | run: go test ./... 42 | lint: 43 | runs-on: ubuntu-latest 44 | steps: 45 | - uses: actions/checkout@v2 46 | - name: Setup Go 47 | uses: actions/setup-go@v2.1.3 48 | with: 49 | go-version: 1.17.x 50 | - name: Lint 51 | uses: golangci/golangci-lint-action@v3 52 | with: 53 | version: v1.51 54 | -------------------------------------------------------------------------------- /pkg/discovery/combined/client.go: -------------------------------------------------------------------------------- 1 | package combined 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/mothership/rds-auth-proxy/pkg/config" 7 | "github.com/mothership/rds-auth-proxy/pkg/discovery" 8 | ) 9 | 10 | type CombinedDiscoveryClient struct { 11 | clients []discovery.Client 12 | } 13 | 14 | var _ discovery.Client = (*CombinedDiscoveryClient)(nil) 15 | 16 | func NewCombinedDiscoveryClient(clients []discovery.Client) *CombinedDiscoveryClient { 17 | return &CombinedDiscoveryClient{ 18 | clients: clients, 19 | } 20 | } 21 | 22 | func (c *CombinedDiscoveryClient) LookupTargetByHost(host string) (config.Target, error) { 23 | for _, client := range c.clients { 24 | if t, err := client.LookupTargetByHost(host); err == nil { 25 | return t, nil 26 | } 27 | } 28 | return config.Target{}, discovery.ErrTargetNotFound 29 | } 30 | 31 | func (c *CombinedDiscoveryClient) LookupTargetByName(name string) (config.Target, error) { 32 | for _, client := range c.clients { 33 | if t, err := client.LookupTargetByName(name); err == nil { 34 | return t, nil 35 | } 36 | } 37 | return config.Target{}, discovery.ErrTargetNotFound 38 | } 39 | 40 | func (c *CombinedDiscoveryClient) GetTargets() []config.Target { 41 | targetList := make([]config.Target, 0, 16) 42 | for _, client := range c.clients { 43 | targetList = append(targetList, client.GetTargets()...) 44 | } 45 | return targetList 46 | } 47 | 48 | func (c *CombinedDiscoveryClient) Refresh(ctx context.Context) error { 49 | for _, client := range c.clients { 50 | if err := client.Refresh(ctx); err != nil { 51 | return err 52 | } 53 | } 54 | return nil 55 | } 56 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | HOME?=$($HOME) 2 | CERT_DIR?=$(HOME)/.config/rds-auth-proxy 3 | CERTIFICATE_PATH?=$(CERT_DIR)/selfsigned_cert.pem 4 | PRIVATE_KEY_PATH?=$(CERT_DIR)/selfsigned_key.pem 5 | 6 | AC_USERNAME?= 7 | DEBUG_TARGET?=rds-auth-proxy-macos-amd 8 | 9 | DOCKER_REPO?=ghcr.io/mothership/rds-auth-proxy 10 | DOCKER_TAG?=dev 11 | 12 | .PHONY: debug 13 | debug: 14 | AC_USERNAME=$(AC_USERNAME) goreleaser build -f ./build/goreleaser.yml --snapshot --rm-dist --id $(DEBUG_TARGET) 15 | mv dist/rds-auth-proxy-macos-amd_darwin_amd64/rds-auth-proxy /usr/local/bin 16 | 17 | .PHONY: debug-release 18 | debug-release: 19 | AC_USERNAME=$(AC_USERNAME) goreleaser -f ./build/goreleaser.yml --snapshot --rm-dist 20 | AC_USERNAME=$(AC_USERNAME) gon ./build/notorizing-config.json 21 | 22 | .PHONY: release 23 | release: 24 | AC_USERNAME=$(AC_USERNAME) goreleaser -f ./build/goreleaser.yml --rm-dist 25 | AC_USERNAME=$(AC_USERNAME) gon ./build/notorizing-config.json 26 | 27 | gen-certs: debug 28 | mkdir -p $(CERT_DIR) 29 | rm -rf $(CERTIFICATE_PATH) $(PRIVATE_KEY_PATH) 30 | LOG_LEVEL=debug DEBUG=true rds-auth-proxy gen-cert \ 31 | --certificate $(CERTIFICATE_PATH) \ 32 | --key $(PRIVATE_KEY_PATH) 33 | 34 | .PHONY: docker 35 | docker: 36 | DOCKER_BUILDKIT=1 docker build \ 37 | -t $(DOCKER_REPO):$(DOCKER_TAG) \ 38 | -f ./build/Dockerfile \ 39 | . 40 | 41 | .PHONY: test 42 | test: 43 | go test -coverprofile=coverage.out ./... 44 | 45 | .PHONY: test-cover 46 | test-cover: test 47 | go tool cover -html=coverage.out 48 | 49 | .PHONY: lint 50 | lint: 51 | golangci-lint run 52 | 53 | .PHONY: it-happen 54 | it-happen: 55 | docker-compose up --build 56 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | RDS Auth Proxy 6 | 7 | 8 | 9 | 10 | 23 | 24 | 25 | 26 |
27 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /pkg/config/targets.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import "fmt" 4 | 5 | // ProxyTarget is a config block specifying an upstream proxy 6 | type ProxyTarget struct { 7 | Name string 8 | Host string `mapstructure:"host"` 9 | SSL SSL `mapstructure:"ssl"` 10 | // For tunneling the connection through a kubernetes port-forward, only useful 11 | // for client-side proxy targets 12 | PortForward *PortForward `mapstructure:"port_forward,omitempty"` 13 | } 14 | 15 | // Target is the actual DB server we're connecting to 16 | type Target struct { 17 | Host string `mapstructure:"host"` 18 | SSL SSL `mapstructure:"ssl"` 19 | // Hint for showing the default database in the connection string 20 | DefaultDatabase *string `mapstructure:"database,omitempty"` 21 | // LocalPort to use instead of the proxy's default ListenAddr port 22 | LocalPort *string `mapstructure:"local_port,omitempty"` 23 | // Name in target list, or RDS db instance identifier 24 | Name string 25 | // Only set for RDS instances 26 | Region string 27 | // Only set for RDS instances 28 | IsRDS bool 29 | } 30 | 31 | // GetHost returns the correct host + port combo for the proxy target 32 | // if the target is port-forwarded, this is a localhost address 33 | // otherwise, it's exposed over a VPN or by some other means. 34 | func (p *ProxyTarget) GetHost() string { 35 | if p.PortForward == nil { 36 | return p.Host 37 | } 38 | if p.PortForward.LocalPort == nil { 39 | return "0.0.0.0:0" 40 | } 41 | return fmt.Sprintf("0.0.0.0:%s", *p.PortForward.LocalPort) 42 | } 43 | 44 | // IsPortForward returns true if this proxy target requires a port-forward connection 45 | func (p *ProxyTarget) IsPortForward() bool { 46 | return p.PortForward != nil 47 | } 48 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build Image 2 | on: 3 | push: 4 | branches: 5 | - 'main' 6 | tags: 7 | - 'v*' 8 | # Doesn't push, only confirms we can build :) 9 | pull_request: 10 | branches: 11 | - 'main' 12 | 13 | env: 14 | REGISTRY: ghcr.io 15 | IMAGE_NAME: ${{ github.repository }} 16 | 17 | jobs: 18 | build-and-push-image: 19 | runs-on: ubuntu-latest 20 | permissions: 21 | contents: read 22 | packages: write 23 | steps: 24 | - name: Checkout code 25 | uses: actions/checkout@v2 26 | 27 | - name: Set up Docker Buildx 28 | uses: docker/setup-buildx-action@v1 29 | 30 | - name: Login to registry 31 | if: github.event_name != 'pull_request' 32 | uses: docker/login-action@v1 33 | with: 34 | registry: ${{ env.REGISTRY }} 35 | username: ${{ github.actor }} 36 | password: ${{ secrets.GITHUB_TOKEN }} 37 | 38 | - name: Generate tags 39 | id: meta 40 | uses: docker/metadata-action@v3 41 | with: 42 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 43 | tags: | 44 | type=semver,pattern={{version}} 45 | type=semver,pattern={{major}}.{{minor}} 46 | type=semver,pattern={{major}} 47 | type=ref,event=branch 48 | type=ref,event=pr 49 | type=sha 50 | 51 | - name: Build and push Docker image 52 | uses: docker/build-push-action@v2 53 | with: 54 | context: . 55 | push: ${{ github.event_name != 'pull_request' }} 56 | tags: ${{ steps.meta.outputs.tags }} 57 | labels: ${{ steps.meta.outputs.labels }} 58 | file: ./build/Dockerfile 59 | -------------------------------------------------------------------------------- /cmd/gen_certs.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/mothership/rds-auth-proxy/pkg/cert" 7 | "github.com/mothership/rds-auth-proxy/pkg/file" 8 | "github.com/spf13/cobra" 9 | ) 10 | 11 | var genCertsCommand = &cobra.Command{ 12 | Use: "gen-cert", 13 | Short: "Generates a self-signed certificate", 14 | Long: `Generates a self-signed certificate`, 15 | RunE: func(cmd *cobra.Command, args []string) error { 16 | keyPath, err := cmd.Flags().GetString("key") 17 | if err != nil { 18 | return err 19 | } 20 | if keyPath == "" { 21 | return fmt.Errorf("Key path must not be empty") 22 | } 23 | 24 | certPath, err := cmd.Flags().GetString("certificate") 25 | if err != nil { 26 | return err 27 | } 28 | if certPath == "" { 29 | return fmt.Errorf("Certificate path must not be empty") 30 | } 31 | 32 | if file.Exists(certPath) || file.Exists(keyPath) { 33 | return fmt.Errorf("certificate/key already exists at this location") 34 | } 35 | 36 | hosts, err := cmd.Flags().GetString("hosts") 37 | if err != nil { 38 | return err 39 | } 40 | 41 | certBytes, keyBytes, err := cert.GenerateSelfSignedCert(hosts, false) 42 | if err != nil { 43 | return err 44 | } 45 | err = cert.Save(certPath, certBytes) 46 | if err != nil { 47 | return err 48 | } 49 | return cert.Save(keyPath, keyBytes) 50 | }, 51 | } 52 | 53 | func init() { 54 | rootCmd.AddCommand(genCertsCommand) 55 | genCertsCommand.PersistentFlags().String("certificate", "", "Path to generate the certificate") 56 | _ = genCertsCommand.MarkPersistentFlagRequired("certificate") 57 | genCertsCommand.PersistentFlags().String("key", "", "Path to generate the private key") 58 | _ = genCertsCommand.MarkPersistentFlagRequired("key") 59 | 60 | genCertsCommand.PersistentFlags().String("hosts", "rds-auth-proxy", "Comma separated list of hosts to add to the certificate") 61 | } 62 | -------------------------------------------------------------------------------- /pkg/file/filewriter_test.go: -------------------------------------------------------------------------------- 1 | package file 2 | 3 | import ( 4 | "bytes" 5 | "os" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/spf13/afero" 10 | ) 11 | 12 | func TestFileWriter_Valid(t *testing.T) { 13 | cases := []struct { 14 | FileName string 15 | Content []byte 16 | }{ 17 | //Basic file in working directory 18 | { 19 | FileName: "/test.txt", 20 | Content: []byte("hello world"), 21 | }, 22 | } 23 | 24 | for idx, test := range cases { 25 | writer := NewFileWriter() 26 | writer.Fs = afero.NewMemMapFs() 27 | _, _ = writer.Write(test.Content) 28 | if err := writer.Save(test.FileName); err != nil { 29 | t.Errorf("[Case %d]: Error occurred while writing file: %t", idx, err) 30 | } 31 | info, err := writer.Fs.Stat(test.FileName) 32 | if os.IsNotExist(err) || info.IsDir() { 33 | t.Errorf("[Case %d]: Could not locate the file FileWriter is expected to write to '%s'", idx, test.FileName) 34 | } 35 | fileBytes, err := afero.ReadFile(writer.Fs, test.FileName) 36 | if os.IsNotExist(err) || !bytes.Equal(fileBytes, test.Content) { 37 | t.Errorf("[Case %d]: Could not read the file FileWriter is expected to write to '%s'", idx, test.FileName) 38 | } 39 | } 40 | } 41 | 42 | func TestFileWriter_Invalid(t *testing.T) { 43 | cases := []struct { 44 | FileName string 45 | Content []byte 46 | ExpectedError string 47 | }{ 48 | //attempt to write to home directory in a folder that does not exist 49 | { 50 | FileName: "/badDir/test.txt", 51 | Content: []byte("hello world"), 52 | ExpectedError: "no such file or directory", 53 | }, 54 | } 55 | for idx, test := range cases { 56 | writer := NewFileWriter() 57 | testPath, _ := afero.TempDir(writer.Fs, "testDir", "testPrefix") 58 | _, _ = writer.Write(test.Content) 59 | err := writer.Save(testPath + test.FileName) 60 | if !strings.Contains(err.Error(), test.ExpectedError) { 61 | t.Errorf("[Case %d]: Unexpected error occurred while writing file: %t", idx, err) 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /pkg/proxy/config_test.go: -------------------------------------------------------------------------------- 1 | package proxy_test 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | . "github.com/mothership/rds-auth-proxy/pkg/proxy" 9 | ) 10 | 11 | func TestConfigOptionErrors(t *testing.T) { 12 | config := Config{} 13 | 14 | cases := []struct { 15 | Option Option 16 | Error error 17 | }{ 18 | // Valid listen address 19 | { 20 | Option: WithListenAddress("0.0.0.0:8000"), 21 | Error: nil, 22 | }, 23 | // Missing port 24 | { 25 | Option: WithListenAddress("bah"), 26 | Error: fmt.Errorf("missing port in address"), 27 | }, 28 | // Bad Host 29 | { 30 | Option: WithListenAddress("bah:80"), 31 | // XXX: On OSX (local dev env) we get a different error message than testing on linux (CI) 32 | // for now, just assert that we got an error :/ 33 | Error: fmt.Errorf(""), 34 | }, 35 | // Bad Port 36 | { 37 | Option: WithListenAddress("0.0.0.0:bah"), 38 | // XXX: On OSX (local dev env) we get a different error message than testing on linux (CI) 39 | // for now, just assert that we got an error :/ 40 | Error: fmt.Errorf(""), 41 | }, 42 | // valid credential getter 43 | { 44 | Option: WithCredentialInterceptor(func(creds *Credentials) error { 45 | return nil 46 | }), 47 | Error: nil, 48 | }, 49 | // valid mode 50 | { 51 | Option: WithMode(ServerSide), 52 | Error: nil, 53 | }, 54 | // invalid mode 55 | { 56 | Option: WithMode(Mode(10)), 57 | Error: fmt.Errorf("invalid mode"), 58 | }, 59 | } 60 | 61 | for idx, test := range cases { 62 | err := test.Option(&config) 63 | if !errorContains(err, test.Error) { 64 | t.Errorf("[Case %d] expected %+v, got %+v", idx, test.Error, err) 65 | } 66 | } 67 | } 68 | 69 | // errorContains checks if the error message in out contains the text in 70 | // want. 71 | // 72 | // This is safe when out is nil. Use an empty string for want if you want to 73 | // test that err is nil. 74 | func errorContains(out error, want error) bool { 75 | if want == nil && out == nil { 76 | return true 77 | } else if want == nil || out == nil { 78 | return false 79 | } 80 | return strings.Contains(out.Error(), want.Error()) 81 | } 82 | -------------------------------------------------------------------------------- /pkg/proxy/manager.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net" 7 | "sync" 8 | 9 | "github.com/mothership/rds-auth-proxy/pkg/log" 10 | "go.uber.org/zap" 11 | ) 12 | 13 | // errorWrapper wraps an error from a particular proxy 14 | type errorWrapper struct { 15 | ConnectionID uint64 16 | Error error 17 | } 18 | 19 | // Manager watches a group of proxies 20 | type Manager struct { 21 | ActiveSessions sync.Map 22 | errorCh chan errorWrapper 23 | cfg *Config 24 | } 25 | 26 | // NewManager returns an instance of Manager 27 | func NewManager(opts ...Option) (*Manager, error) { 28 | cfg := &Config{} 29 | for _, opt := range opts { 30 | err := opt(cfg) 31 | if err != nil { 32 | return nil, err 33 | } 34 | } 35 | return &Manager{ 36 | ActiveSessions: sync.Map{}, 37 | // Should probably have a similar buffer size to active sessions? 38 | errorCh: make(chan errorWrapper, 10), 39 | cfg: cfg, 40 | }, nil 41 | } 42 | 43 | // Start starts the proxy server 44 | func (m *Manager) Start(ctx context.Context) error { 45 | go m.errorHandler(ctx) 46 | listener, err := net.ListenTCP("tcp", m.cfg.ListenAddress) 47 | if err != nil { 48 | return err 49 | } 50 | 51 | for { 52 | conn, err := listener.AcceptTCP() 53 | if err != nil { 54 | log.Error("error accepting connection from client", zap.Error(err)) 55 | continue 56 | } 57 | log.Info( 58 | "accepted connection from client", 59 | zap.String("client_address", conn.RemoteAddr().String()), 60 | ) 61 | p := newProxy(conn, m.errorCh, m.cfg) 62 | m.ActiveSessions.Store(p.ID, p) 63 | //nolint:errcheck // Errors are handled in m.errorCh 64 | go p.Start() 65 | } 66 | } 67 | 68 | func (m *Manager) errorHandler(ctx context.Context) { 69 | log.Info("starting error handler") 70 | defer log.Debug("shut down error handler") 71 | for { 72 | select { 73 | case <-ctx.Done(): 74 | return 75 | case err := <-m.errorCh: 76 | if err.Error != nil && err.Error != io.ErrUnexpectedEOF { 77 | log.Error("proxy caught error", 78 | zap.Uint64("connectionID", err.ConnectionID), 79 | zap.Error(err.Error), 80 | ) 81 | } 82 | if p, loaded := m.ActiveSessions.LoadAndDelete(err.ConnectionID); loaded { 83 | proxy, _ := p.(*Proxy) 84 | log.Info("stopping proxy", zap.Uint64("connectionID", err.ConnectionID)) 85 | proxy.Stop() 86 | log.Info("proxy stopped", zap.Uint64("connectionID", err.ConnectionID)) 87 | } 88 | } 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /cmd/completion.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/spf13/cobra" 7 | ) 8 | 9 | // completionCmd represents the completion command 10 | var completionCmd = &cobra.Command{ 11 | Use: "completion", 12 | Short: "Generates CLI completion scripts", 13 | Long: `To load completion in bash run 14 | 15 | echo '. <(rds-auth-proxy completion bash)' >> ~/.bash_profile 16 | source ~/.bash_profile 17 | 18 | Run 'rds-auth-proxy completion bash --help' for more information. 19 | 20 | 21 | To load completion in zsh, run the following command to generate a completion file: 22 | 23 | rds-auth-proxy completion zsh > _rds-auth-proxy 24 | 25 | 26 | Then move this file somewhere along your $fpath and source your ~/.zshrc again. Run 'rds-auth-proxy completion zsh --help' for more information. 27 | `, 28 | RunE: func(cmd *cobra.Command, args []string) error { 29 | err := rootCmd.GenBashCompletion(os.Stdout) 30 | if err != nil { 31 | return err 32 | } 33 | return rootCmd.GenZshCompletion(os.Stdout) 34 | }, 35 | } 36 | 37 | var bashCompletionCmd = &cobra.Command{ 38 | Use: "bash", 39 | Short: "Generates bash completion scripts", 40 | Long: `To load completion in bash run 41 | 42 | echo '. <(rds-auth-proxy completion bash)' >> ~/.bash_profile 43 | source ~/.bash_profile 44 | 45 | If this gives you trouble, ensure you're not using the version of bash bundled 46 | with OSX (OSX bundles 3.2, Homebrew bundles version 5). 47 | `, 48 | RunE: func(cmd *cobra.Command, args []string) error { 49 | return rootCmd.GenBashCompletion(os.Stdout) 50 | }, 51 | } 52 | 53 | var zshCompletionCmd = &cobra.Command{ 54 | Use: "zsh", 55 | Short: "Generates zsh completion scripts", 56 | Long: `To load completion in ZSH, run the following command to generate a completion file: 57 | 58 | rds-auth-proxy completion zsh > _rds-auth-proxy 59 | 60 | Then move this file somewhere along your $fpath and source your ~/.zshrc again. If you're 61 | still not getting completion behavior, ensure you have the following in your ~/.zshrc 62 | 63 | autoload -U compaudit && compinit 64 | 65 | If you do and it's still not working, try removing the completion cache: 66 | 67 | rm ~/.zcompdump* && source ~/.zshrc 68 | `, 69 | RunE: func(cmd *cobra.Command, args []string) error { 70 | return rootCmd.GenZshCompletion(os.Stdout) 71 | }, 72 | } 73 | 74 | func init() { 75 | completionCmd.AddCommand(bashCompletionCmd) 76 | completionCmd.AddCommand(zshCompletionCmd) 77 | rootCmd.AddCommand(completionCmd) 78 | } 79 | -------------------------------------------------------------------------------- /pkg/cert/cert.go: -------------------------------------------------------------------------------- 1 | package cert 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/x509" 7 | "crypto/x509/pkix" 8 | "encoding/pem" 9 | "math/big" 10 | "net" 11 | "strings" 12 | "time" 13 | 14 | "github.com/mothership/rds-auth-proxy/pkg/file" 15 | ) 16 | 17 | // GenerateSelfSignedCert creates a self signed RSA cert and returns it 18 | // along with the private key or any errors that occurred while generating 19 | // it. 20 | func GenerateSelfSignedCert(host string, isCA bool) ([]byte, []byte, error) { 21 | notBefore := time.Now() 22 | // 5 years later 23 | notAfter := notBefore.Add(5 * 365 * 24 * time.Hour) 24 | 25 | var priv *rsa.PrivateKey 26 | var err error 27 | priv, err = rsa.GenerateKey(rand.Reader, 2048) 28 | if err != nil { 29 | return nil, nil, err 30 | } 31 | 32 | serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) 33 | serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) 34 | if err != nil { 35 | return nil, nil, err 36 | } 37 | 38 | template := x509.Certificate{ 39 | NotBefore: notBefore, 40 | NotAfter: notAfter, 41 | KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, 42 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 43 | BasicConstraintsValid: true, 44 | SerialNumber: serialNumber, 45 | Subject: pkix.Name{ 46 | Organization: []string{"Mothership"}, 47 | }, 48 | } 49 | 50 | hosts := strings.Split(host, ",") 51 | for _, h := range hosts { 52 | if ip := net.ParseIP(h); ip != nil { 53 | template.IPAddresses = append(template.IPAddresses, ip) 54 | } else { 55 | template.DNSNames = append(template.DNSNames, h) 56 | } 57 | } 58 | 59 | if isCA { 60 | template.IsCA = true 61 | template.KeyUsage |= x509.KeyUsageCertSign 62 | } 63 | 64 | certBytes, err := x509.CreateCertificate( 65 | rand.Reader, 66 | &template, 67 | &template, 68 | &priv.PublicKey, 69 | priv, 70 | ) 71 | if err != nil { 72 | return nil, nil, err 73 | } 74 | pemKeyBytes := pem.EncodeToMemory( 75 | &pem.Block{ 76 | Type: "RSA PRIVATE KEY", 77 | Bytes: x509.MarshalPKCS1PrivateKey(priv), 78 | }, 79 | ) 80 | pemCertBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) 81 | return pemCertBytes, pemKeyBytes, err 82 | 83 | } 84 | 85 | // Save saves a file to disk 86 | func Save(path string, pemBytes []byte) error { 87 | fileWriter := file.NewFileWriter() 88 | _, err := fileWriter.Write(pemBytes) 89 | if err != nil { 90 | return err 91 | } 92 | return fileWriter.Save(path) 93 | } 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 | 4 |

RDS Auth Proxy

5 |
6 |
7 | 8 |

9 |

10 | 11 |

12 |

13 | 14 | ![GitHub tag (latest SemVer)](https://img.shields.io/github/v/tag/mothership/rds-auth-proxy) 15 | ![GitHub branch checks state](https://img.shields.io/github/checks-status/mothership/rds-auth-proxy/main) 16 | [![Go Report Card](https://goreportcard.com/badge/github.com/mothership/rds-auth-proxy)](https://goreportcard.com/report/github.com/mothership/rds-auth-proxy) 17 | 18 | A two-layer proxy for connecting into RDS postgres databases 19 | based on IAM authentication. 20 | 21 | This tool allows you to keep your databases firewalled off, 22 | manage database access through IAM policies, and no developer 23 | will ever have to share or type a password. 24 | 25 | ![Running the proxy](./docs/images/proxy.gif) 26 | 27 | ![Connecting with psql](./docs/images/psql.gif) 28 | 29 | This pairs extremely well with a tool like [saml2aws](https://github.com/Versent/saml2aws) 30 | to ensure all AWS/database access uses temporary credentials. 31 | 32 | ## Documentation 33 | 34 | End user documentation is available on our [project site](https://mothership.github.io/rds-auth-proxy/). 35 | 36 | ## Design 37 | 38 | One proxy is run in your VPC subnet that can reach your RDS instances, 39 | the other on your client machine (dev laptop, etc.) with access to 40 | aws credentials. 41 | 42 | The client proxy is responsible for picking a host (RDS instance), and 43 | generating a temporary password based on the local IAM identity. The 44 | client proxy injects the host and password into the postgres startup 45 | message as additional parameters. 46 | 47 | ![Client startup flow](./docs/images/rds-proxy-client-startup-flow.png) 48 | 49 | The server proxy accepts a connection from the client proxy, and 50 | unpacks the host and password parameters. It then opens a connection 51 | to the RDS database and intercepts the authentication request. It then 52 | passes along the password it received from the client, and forwards the 53 | result to the client. 54 | 55 | ![Auth overview](./docs/images/rds-proxy-auth-flow.png) 56 | 57 | 58 | ## Releasing 59 | 60 | CI handles building binaries and images on tag events. 61 | 62 | To create a release, start with a dry-run on the main branch: 63 | 64 | ```bash 65 | git checkout main 66 | ./build/release.sh --dry-run 67 | ``` 68 | 69 | Ensure that the changelog looks as expected, then run it for real: 70 | 71 | ```bash 72 | ./build/release.sh 73 | ``` 74 | 75 | -------------------------------------------------------------------------------- /pkg/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/mothership/rds-auth-proxy/pkg/pg" 5 | "github.com/spf13/viper" 6 | ) 7 | 8 | const ( 9 | defaultKubeConfigPath = "$HOME/.kube/config" 10 | defaultListenAddr = "0.0.0.0:8000" 11 | ) 12 | 13 | type ConfigFile struct { 14 | Proxy Proxy `mapstructure:"proxy"` 15 | Targets map[string]*Target `mapstructure:"targets"` 16 | ProxyTargets map[string]*ProxyTarget `mapstructure:"upstream_proxies"` 17 | } 18 | 19 | type Proxy struct { 20 | ListenAddr string `mapstructure:"listen_addr"` 21 | SSL ServerSSL `mapstructure:"ssl"` 22 | ACL ACL `mapstructure:"target_acl"` 23 | } 24 | 25 | func LoadConfig(filepath string) (ConfigFile, error) { 26 | var config ConfigFile 27 | if filepath != "" { 28 | viper.SetConfigFile(filepath) 29 | } else { 30 | viper.SetConfigName("config") 31 | viper.AddConfigPath(".") 32 | viper.AddConfigPath("$XDG_CONFIG_HOME/rds-auth-proxy") 33 | viper.AddConfigPath("$HOME/.config/rds-auth-proxy") 34 | } 35 | if err := viper.ReadInConfig(); err != nil { 36 | return config, err 37 | } 38 | if err := viper.Unmarshal(&config); err != nil { 39 | return config, err 40 | } 41 | config.Init() 42 | return config, nil 43 | } 44 | 45 | // Init sets up defaults for the config file 46 | func (c *ConfigFile) Init() { 47 | if c.Targets == nil { 48 | c.Targets = map[string]*Target{} 49 | } 50 | 51 | if c.ProxyTargets == nil { 52 | c.ProxyTargets = map[string]*ProxyTarget{} 53 | } 54 | 55 | if c.Proxy.ListenAddr == "" { 56 | c.Proxy.ListenAddr = defaultListenAddr 57 | } 58 | 59 | c.Proxy.ACL.Init() 60 | // Set up SSL defaults for all targets if not set 61 | for key, target := range c.Targets { 62 | target.Name = key 63 | // if no SSL keys 64 | if target.SSL.Mode == "" { 65 | target.SSL.Mode = pg.SSLRequired 66 | } 67 | if target.SSL.Mode != pg.SSLDisabled && target.SSL.ClientCertificatePath == nil { 68 | target.SSL.ClientCertificatePath = c.Proxy.SSL.ClientCertificatePath 69 | target.SSL.ClientPrivateKeyPath = c.Proxy.SSL.ClientPrivateKeyPath 70 | } 71 | } 72 | 73 | // Set up SSL defaults for all proxies if not set 74 | for key, target := range c.ProxyTargets { 75 | target.Name = key 76 | if target.PortForward != nil && target.PortForward.KubeConfigFilePath == "" { 77 | target.PortForward.KubeConfigFilePath = defaultKubeConfigPath 78 | if target.SSL.Mode == "" { 79 | target.SSL.Mode = pg.SSLDisabled 80 | } 81 | } 82 | 83 | if target.SSL.Mode == "" { 84 | target.SSL.Mode = pg.SSLRequired 85 | } 86 | // if no SSL keys 87 | if target.SSL.Mode != pg.SSLDisabled && target.SSL.ClientCertificatePath == nil { 88 | target.SSL.ClientCertificatePath = c.Proxy.SSL.ClientCertificatePath 89 | target.SSL.ClientPrivateKeyPath = c.Proxy.SSL.ClientPrivateKeyPath 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /pkg/aws/rds.go: -------------------------------------------------------------------------------- 1 | package aws 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/aws/aws-sdk-go-v2/aws" 7 | "github.com/aws/aws-sdk-go-v2/aws/arn" 8 | "github.com/aws/aws-sdk-go-v2/config" 9 | "github.com/aws/aws-sdk-go-v2/feature/rds/auth" 10 | "github.com/aws/aws-sdk-go-v2/service/rds" 11 | "github.com/aws/aws-sdk-go-v2/service/rds/types" 12 | ) 13 | 14 | const ( 15 | filterEngine = "engine" 16 | enginePostgres = "postgres" 17 | engineAuroraPostgres = "aurora-postgresql" 18 | ) 19 | 20 | // DBInstanceResult is wrapper around a DBInstance or error 21 | // as a result of listing RDS Instances 22 | type DBInstanceResult struct { 23 | Instance types.DBInstance 24 | Error error 25 | } 26 | 27 | // RDSClient is our wrapper around the RDS library, allows us to 28 | // mock this for testing 29 | type RDSClient interface { 30 | GetPostgresInstances(ctx context.Context) <-chan DBInstanceResult 31 | NewAuthToken(ctx context.Context, host, region, user string) (string, error) 32 | RegionForInstance(inst types.DBInstance) (string, error) 33 | } 34 | 35 | type rdsClient struct { 36 | cfg aws.Config 37 | svc *rds.Client 38 | } 39 | 40 | // NewRDSClient loads AWS Config and creds, and returns an RDS client 41 | func NewRDSClient(ctx context.Context) (RDSClient, error) { 42 | cfg, err := config.LoadDefaultConfig(ctx) 43 | if err != nil { 44 | return nil, err 45 | } 46 | return &rdsClient{cfg: cfg, svc: rds.NewFromConfig(cfg)}, nil 47 | } 48 | 49 | // GetPostgresInstances grabs all db instances filtered by engine "postgres" and publishes 50 | // them to the result channel 51 | func (r *rdsClient) GetPostgresInstances(ctx context.Context) <-chan DBInstanceResult { 52 | resChan := make(chan DBInstanceResult, 1) 53 | go func() { 54 | defer close(resChan) 55 | paginator := r.rdsPaginator([]types.Filter{ 56 | { 57 | Name: strPtr(filterEngine), 58 | Values: []string{enginePostgres, engineAuroraPostgres}, 59 | }, 60 | }) 61 | for paginator.HasMorePages() { 62 | page, err := paginator.NextPage(ctx) 63 | if err != nil { 64 | resChan <- DBInstanceResult{Error: err} 65 | return 66 | } 67 | for _, d := range page.DBInstances { 68 | resChan <- DBInstanceResult{Instance: d} 69 | } 70 | } 71 | }() 72 | return resChan 73 | } 74 | 75 | func (r *rdsClient) rdsPaginator(filters []types.Filter) (paginator *rds.DescribeDBInstancesPaginator) { 76 | paginator = rds.NewDescribeDBInstancesPaginator(r.svc, &rds.DescribeDBInstancesInput{ 77 | Filters: filters, 78 | }, func(o *rds.DescribeDBInstancesPaginatorOptions) { 79 | o.Limit = 100 80 | }) 81 | return 82 | } 83 | 84 | func (r *rdsClient) NewAuthToken(ctx context.Context, host, region, user string) (string, error) { 85 | return auth.BuildAuthToken(ctx, host, region, user, r.cfg.Credentials) 86 | } 87 | 88 | func (r *rdsClient) RegionForInstance(inst types.DBInstance) (string, error) { 89 | arn, err := arn.Parse(*inst.DBInstanceArn) 90 | if err != nil { 91 | return "", err 92 | } 93 | return arn.Region, nil 94 | } 95 | 96 | func strPtr(val string) *string { 97 | return &val 98 | } 99 | -------------------------------------------------------------------------------- /cmd/proxy_server.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/mothership/rds-auth-proxy/pkg/aws" 9 | "github.com/mothership/rds-auth-proxy/pkg/config" 10 | "github.com/mothership/rds-auth-proxy/pkg/discovery" 11 | discoveryFactory "github.com/mothership/rds-auth-proxy/pkg/discovery/factory" 12 | "github.com/mothership/rds-auth-proxy/pkg/log" 13 | "github.com/mothership/rds-auth-proxy/pkg/proxy" 14 | "github.com/spf13/cobra" 15 | "go.uber.org/zap" 16 | ) 17 | 18 | var proxyServerCommand = &cobra.Command{ 19 | Use: "server", 20 | Short: "Launches the server proxy", 21 | Long: `Runs a proxy service in-cluster for connecting to RDS.`, 22 | RunE: func(cmd *cobra.Command, args []string) error { 23 | // TODO: make this gracefully shutdown on sigterm / sigint 24 | ctx, cancel := context.WithCancel(context.Background()) 25 | defer cancel() 26 | 27 | logger := log.NewLogger() 28 | filepath, err := cmd.Flags().GetString("configfile") 29 | if err != nil { 30 | return err 31 | } 32 | rdsClient, err := aws.NewRDSClient(ctx) 33 | if err != nil { 34 | return err 35 | } 36 | cfg, err := config.LoadConfig(filepath) 37 | if err != nil { 38 | return err 39 | } 40 | discoveryClient := discoveryFactory.FromConfig(rdsClient, &cfg) 41 | if err := discoveryClient.Refresh(ctx); err != nil { 42 | return err 43 | } 44 | 45 | opts, err := proxySSLOptions(cfg.Proxy.SSL) 46 | if err != nil { 47 | return err 48 | } 49 | logger.Info("starting server", zap.String("listen_addr", cfg.Proxy.ListenAddr)) 50 | manager, err := proxy.NewManager(proxy.MergeOptions(opts, []proxy.Option{ 51 | proxy.WithListenAddress(cfg.Proxy.ListenAddr), 52 | proxy.WithMode(proxy.ServerSide), 53 | proxy.WithCredentialInterceptor(func(creds *proxy.Credentials) error { 54 | hostConfig, err := discoveryClient.LookupTargetByHost(creds.Host) 55 | if err != nil { 56 | logger.Warn("client attempted to login to unknown host", zap.String("host", creds.Host)) 57 | return fmt.Errorf("host not allowed by ACL, or not configured for this proxy") 58 | } 59 | return overrideSSLConfig(creds, hostConfig.SSL) 60 | })})..., 61 | ) 62 | if err != nil { 63 | return err 64 | } 65 | // TODO: periodic refresh of discovery client 66 | RefreshTargets(ctx, discoveryClient, 1*time.Minute) 67 | err = manager.Start(ctx) 68 | return err 69 | }, 70 | } 71 | 72 | func RefreshTargets(ctx context.Context, client discovery.Client, period time.Duration) { 73 | go func() { 74 | t := time.NewTicker(period) 75 | for { 76 | select { 77 | case <-ctx.Done(): 78 | t.Stop() 79 | return 80 | case <-t.C: 81 | log.Info("starting target refresh", zap.Strings("targets", targetNames(client.GetTargets()))) 82 | if err := client.Refresh(ctx); err != nil { 83 | log.Warn("refresh failed", zap.Error(err), zap.Strings("targets", targetNames(client.GetTargets()))) 84 | } else { 85 | log.Info("refresh done", zap.Strings("targets", targetNames(client.GetTargets()))) 86 | } 87 | } 88 | } 89 | }() 90 | } 91 | 92 | func targetNames(targets []config.Target) []string { 93 | instances := make([]string, 0, len(targets)) 94 | for _, target := range targets { 95 | instances = append(instances, target.Name) 96 | } 97 | return instances 98 | } 99 | 100 | func init() { 101 | proxyServerCommand.PersistentFlags().String("configfile", "", "Filepath for proxy config file") 102 | rootCmd.AddCommand(proxyServerCommand) 103 | } 104 | -------------------------------------------------------------------------------- /pkg/kubernetes/port_forward.go: -------------------------------------------------------------------------------- 1 | package kubernetes 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "math/rand" 8 | "net/http" 9 | "time" 10 | 11 | "github.com/mothership/rds-auth-proxy/pkg/file" 12 | v1 "k8s.io/apimachinery/pkg/apis/meta/v1" 13 | "k8s.io/client-go/kubernetes" 14 | corev1client "k8s.io/client-go/kubernetes/typed/core/v1" 15 | restclient "k8s.io/client-go/rest" 16 | clientcmd "k8s.io/client-go/tools/clientcmd" 17 | "k8s.io/client-go/tools/portforward" 18 | "k8s.io/client-go/transport/spdy" 19 | ) 20 | 21 | type PortForwardCommand struct { 22 | Namespace string 23 | PodName string 24 | Config *restclient.Config 25 | Client restclient.Interface 26 | PodClient corev1client.PodsGetter 27 | Ports []string 28 | Address []string 29 | Out, ErrOut *bytes.Buffer 30 | PortForwarder *portforward.PortForwarder 31 | StopChannel chan struct{} 32 | ReadyChannel chan struct{} 33 | } 34 | 35 | type PortForwardOptions struct { 36 | Namespace string 37 | Deployment string 38 | Ports []string 39 | Context string 40 | } 41 | 42 | func loadConfig(path, context string) (*restclient.Config, error) { 43 | loadingRules := &clientcmd.ClientConfigLoadingRules{ExplicitPath: path} 44 | overrides := &clientcmd.ConfigOverrides{} 45 | if context != "" { 46 | overrides.CurrentContext = context 47 | } 48 | return clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, overrides).ClientConfig() 49 | } 50 | 51 | func BuildPortForwardCommand(ctx context.Context, kubeConfigPath string, opts PortForwardOptions) (*PortForwardCommand, error) { 52 | cmd := &PortForwardCommand{ 53 | Namespace: opts.Namespace, 54 | Ports: opts.Ports, 55 | Address: []string{"localhost"}, 56 | Out: new(bytes.Buffer), 57 | ErrOut: new(bytes.Buffer), 58 | StopChannel: make(chan struct{}, 1), 59 | ReadyChannel: make(chan struct{}), 60 | } 61 | path, err := file.ExpandPath(kubeConfigPath) 62 | if err != nil { 63 | return nil, err 64 | } 65 | config, err := loadConfig(path, opts.Context) 66 | if err != nil { 67 | return nil, err 68 | } 69 | cmd.Config = config 70 | // create the clientset 71 | clientset, err := kubernetes.NewForConfig(config) 72 | if err != nil { 73 | return nil, err 74 | } 75 | 76 | cmd.Client = clientset.CoreV1().RESTClient() 77 | cmd.PodClient = clientset.CoreV1() 78 | pods, err := cmd.PodClient.Pods(opts.Namespace).List(ctx, v1.ListOptions{ 79 | LabelSelector: fmt.Sprintf("app.kubernetes.io/name=%s", opts.Deployment), 80 | FieldSelector: "status.phase=Running", 81 | Limit: 3, 82 | }) 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | rand.Seed(time.Now().Unix()) 88 | cmd.PodName = pods.Items[rand.Intn(len(pods.Items))].Name 89 | return cmd, nil 90 | } 91 | 92 | // ForwardPort forwards a port until context is canceled 93 | func ForwardPort(ctx context.Context, cmd *PortForwardCommand) error { 94 | go func() { 95 | <-ctx.Done() 96 | if cmd.StopChannel != nil { 97 | close(cmd.StopChannel) 98 | } 99 | }() 100 | 101 | req := cmd.Client.Post(). 102 | Resource("pods"). 103 | Namespace(cmd.Namespace). 104 | Name(cmd.PodName). 105 | SubResource("portforward") 106 | 107 | transport, upgrader, err := spdy.RoundTripperFor(cmd.Config) 108 | if err != nil { 109 | return err 110 | } 111 | dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", req.URL()) 112 | fw, err := portforward.NewOnAddresses(dialer, cmd.Address, cmd.Ports, cmd.StopChannel, cmd.ReadyChannel, cmd.Out, cmd.ErrOut) 113 | if err != nil { 114 | return err 115 | } 116 | cmd.PortForwarder = fw 117 | return fw.ForwardPorts() 118 | } 119 | -------------------------------------------------------------------------------- /docs/security.md: -------------------------------------------------------------------------------- 1 | # Security 2 | 3 | Found an issue? Send us an email at [security@mothership.com](mailto:security@mothership.com). 4 | 5 | As this software is in early development, there are no bounties, but you'll get credit and 6 | have our eternal gratitude! 7 | 8 | # Security Model 9 | 10 | The security of `rds-auth-proxy` depends on the following assumptions: 11 | 12 | * Users do not share laptops / leave them unlocked. 13 | * No untrusted process is running on the client machine. 14 | * No untrusted process can read the memory of the proxy process 15 | * You have a secure connection between the client and server proxies 16 | * You have IAM policies restricting which roles/databases a developer may use. 17 | 18 | ### Users do not share laptops / leave them unlocked 19 | 20 | `rds-auth-proxy` has no means of securing access to the proxy based on who is 21 | using the laptop. 22 | 23 | ### No untrusted process is running on the client machine 24 | 25 | Any process running locally, with network access, can connect to the proxy. 26 | 27 | ### No untrusted process can read the memory of the proxy process 28 | 29 | In client mode, `rds-auth-proxy` generates database passwords for the currently 30 | logged in AWS user. In server mode, `rds-auth-proxy` has to pass along the 31 | password to RDS. 32 | 33 | If an untrusted process can read memory of the proxy, it can read the generated 34 | password. Additionally, as Go is a garbage collected language, it is difficult 35 | to ensure all copies of the password is cleared from memory. 36 | 37 | ### You have a secure connection between the client and server proxies 38 | 39 | The postgres protocol transports passwords in plaintext and as an md5 hash. 40 | Neither format is suitable for transport over the public internet. We recommend 41 | tunneling the protocol over a kubernetes port-forward, or setting up SSL on the 42 | server-side proxy. 43 | 44 | ### You have IAM policies restricting which roles/databases a developer may use 45 | 46 | The server proxy, by default, will allow connections into any RDS postgres 47 | database with IAM authentication enabled. Developer access to the databases are 48 | controlled by their ability to generate temporary passwords on the client, which 49 | in turn is controlled by IAM policies and the AWS credentials they have access to. 50 | 51 | # Options for TLS 52 | 53 | ### Protecting the connection between the client and client proxy 54 | 55 | We support TLS between the client and client proxy, however, the client proxy is 56 | designed to be run on the same machine as the client and we don't recommend 57 | setting up TLS in that scenario for a few reasons: 58 | 59 | 1. If an attacker can already listen to local sockets, they can 60 | are already in a privileged position, TLS will not help. 61 | 2. Self-signed certificates are the only way to do this, but they are harder 62 | to manage and distribute securely. 63 | 64 | ### Protecting the connection between the client and server proxies 65 | 66 | The client proxy connection can be tunneled over a Kubernetes port-forward, and/or 67 | protected with TLS via the postgres protocol. 68 | 69 | While the server proxy supports the postgres over TLS, the postgres protocol has 70 | it's own handshake prior to upgrading to TLS, making it difficult to work with 71 | reverse proxies, or ingress resources. 72 | 73 | By using a port-forward, we can piggyback off of the encrypted tunnel Kubernetes 74 | provides. You can still enable TLS between the client and server over the 75 | port-forward if desired. 76 | 77 | ### Protecting the connection between the server and database 78 | 79 | For RDS instances, we require full verification of the RDS certificate. Our docker 80 | images are built with the RDS root CA certificates pre-installed. You may bring your 81 | own client certificate, but `rds-auth-proxy` will generate a self-signed client 82 | certificate if one isn't provided. 83 | -------------------------------------------------------------------------------- /pkg/discovery/rds/rds.go: -------------------------------------------------------------------------------- 1 | package rds 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strconv" 7 | "sync" 8 | 9 | "github.com/mothership/rds-auth-proxy/pkg/aws" 10 | "github.com/mothership/rds-auth-proxy/pkg/config" 11 | "github.com/mothership/rds-auth-proxy/pkg/discovery" 12 | "github.com/mothership/rds-auth-proxy/pkg/log" 13 | "github.com/mothership/rds-auth-proxy/pkg/pg" 14 | "go.uber.org/zap" 15 | ) 16 | 17 | const ( 18 | defaultDatabaseTag = "rds-auth-proxy:db-name" 19 | localPortTag = "rds-auth-proxy:local-port" 20 | ) 21 | 22 | type RdsDiscoveryClient struct { 23 | targetLock *sync.RWMutex 24 | config *config.ConfigFile 25 | client aws.RDSClient 26 | rdsTargets map[string]config.Target 27 | } 28 | 29 | var _ discovery.Client = (*RdsDiscoveryClient)(nil) 30 | 31 | func NewRdsDiscoveryClient(client aws.RDSClient, cfg *config.ConfigFile) *RdsDiscoveryClient { 32 | return &RdsDiscoveryClient{ 33 | targetLock: &sync.RWMutex{}, 34 | config: cfg, 35 | client: client, 36 | rdsTargets: map[string]config.Target{}, 37 | } 38 | } 39 | 40 | func (r *RdsDiscoveryClient) LookupTargetByHost(host string) (config.Target, error) { 41 | r.targetLock.RLock() 42 | defer r.targetLock.RUnlock() 43 | if target, ok := r.rdsTargets[host]; ok { 44 | return target, nil 45 | } 46 | return config.Target{}, discovery.ErrTargetNotFound 47 | 48 | } 49 | 50 | func (r *RdsDiscoveryClient) LookupTargetByName(name string) (config.Target, error) { 51 | r.targetLock.RLock() 52 | defer r.targetLock.RUnlock() 53 | for _, target := range r.rdsTargets { 54 | if target.Name == name { 55 | return target, nil 56 | } 57 | } 58 | return config.Target{}, discovery.ErrTargetNotFound 59 | } 60 | 61 | func (r *RdsDiscoveryClient) GetTargets() []config.Target { 62 | r.targetLock.RLock() 63 | defer r.targetLock.RUnlock() 64 | targetList := make([]config.Target, 0, len(r.rdsTargets)) 65 | for _, target := range r.rdsTargets { 66 | targetList = append(targetList, target) 67 | } 68 | return targetList 69 | } 70 | 71 | // RefreshRDSTargets searches AWS for allowed dbs updates the target list 72 | func (r *RdsDiscoveryClient) Refresh(ctx context.Context) (err error) { 73 | // XXX: Must consume ALL of these, else I think we leak the channel 74 | resChan := r.client.GetPostgresInstances(ctx) 75 | rdsTargets := map[string]config.Target{} 76 | for result := range resChan { 77 | if result.Error != nil { 78 | err = result.Error 79 | continue 80 | } 81 | d := result.Instance 82 | if d.Endpoint == nil { 83 | log.Warn("db instance missing endpoint, skipping", zap.String("name", *d.DBInstanceIdentifier)) 84 | continue 85 | } 86 | 87 | if tmpErr := r.config.Proxy.ACL.IsAllowed(d.TagList); tmpErr != nil { 88 | log.Debug("db instance not allowed by acl", zap.String("name", *d.DBInstanceIdentifier)) 89 | continue 90 | } 91 | 92 | region, regionErr := r.client.RegionForInstance(d) 93 | if regionErr != nil { 94 | log.Error("failed to detect db region, skipping", zap.Error(regionErr), zap.String("name", *d.DBInstanceIdentifier)) 95 | continue 96 | } 97 | 98 | if !d.IAMDatabaseAuthenticationEnabled { 99 | log.Warn("db instance does not have IAM auth enabled, skipping", zap.String("name", *d.DBInstanceIdentifier)) 100 | continue 101 | } 102 | 103 | target := config.Target{ 104 | Name: *d.DBInstanceIdentifier, 105 | Host: fmt.Sprintf("%+v:%+v", *d.Endpoint.Address, strconv.FormatInt(int64(d.Endpoint.Port), 10)), 106 | DefaultDatabase: d.DBName, 107 | SSL: config.SSL{ 108 | Mode: pg.SSLVerifyFull, 109 | ClientCertificatePath: r.config.Proxy.SSL.ClientCertificatePath, 110 | ClientPrivateKeyPath: r.config.Proxy.SSL.ClientPrivateKeyPath, 111 | }, 112 | Region: region, 113 | IsRDS: true, 114 | } 115 | for _, tag := range d.TagList { 116 | if *tag.Key == defaultDatabaseTag { 117 | target.DefaultDatabase = tag.Value 118 | } else if *tag.Key == localPortTag { 119 | target.LocalPort = tag.Value 120 | } 121 | } 122 | rdsTargets[target.Host] = target 123 | } 124 | 125 | if err == nil { 126 | r.targetLock.Lock() 127 | defer r.targetLock.Unlock() 128 | r.rdsTargets = rdsTargets 129 | } 130 | return err 131 | } 132 | -------------------------------------------------------------------------------- /pkg/pg/backend.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "crypto/tls" 5 | "io" 6 | "net" 7 | "sync" 8 | "time" 9 | 10 | pgproto3 "github.com/jackc/pgproto3/v2" 11 | ) 12 | 13 | const GSSENCecNotAllowed byte = 'N' 14 | const SSLNotAllowed byte = 'N' 15 | const SSLAllowed byte = 'S' 16 | 17 | // Backend acts as the postgres front-end client (ex: psql) 18 | type Backend interface { 19 | io.Closer 20 | Send(msg pgproto3.BackendMessage) error 21 | SendRaw([]byte) error 22 | Receive() (pgproto3.FrontendMessage, error) 23 | ReceiveRaw() ([]byte, error) 24 | } 25 | 26 | // SendOnlyBackend allows only the send operation to be accessed for network safety 27 | type SendOnlyBackend interface { 28 | Send(msg pgproto3.BackendMessage) error 29 | } 30 | 31 | // PostgresBackend implements a postgres backend client 32 | type PostgresBackend struct { 33 | backend *pgproto3.Backend 34 | connection net.Conn 35 | IdleTimeout time.Duration 36 | mutex sync.Mutex 37 | } 38 | 39 | // BackendOption allows us to specify options 40 | type BackendOption func(f *PostgresBackend) error 41 | 42 | // NewBackend returns a new postgres backend 43 | func NewBackend(conn net.Conn, opts ...BackendOption) (*PostgresBackend, error) { 44 | f := &PostgresBackend{ 45 | backend: pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn), 46 | connection: conn, 47 | IdleTimeout: readTimeout, 48 | mutex: sync.Mutex{}, 49 | } 50 | 51 | for _, opt := range opts { 52 | err := opt(f) 53 | if err != nil { 54 | return nil, err 55 | } 56 | } 57 | return f, nil 58 | } 59 | 60 | // Send sends a backend message to the backend 61 | func (b *PostgresBackend) Send(msg pgproto3.BackendMessage) error { 62 | b.mutex.Lock() 63 | defer b.mutex.Unlock() 64 | return b.backend.Send(msg) 65 | } 66 | 67 | // SendRaw sends arbitrary bytes to a backend 68 | func (b *PostgresBackend) SendRaw(msg []byte) error { 69 | _, err := b.connection.Write(msg) 70 | return err 71 | } 72 | 73 | // Receive accepts a message from the backend, or errors if nothing is read 74 | // within the idle timeout. Returns io.ErrUnexpectedEOF if the connection has 75 | // been closed. 76 | func (b *PostgresBackend) Receive() (pgproto3.FrontendMessage, error) { 77 | _ = b.connection.SetReadDeadline(time.Now().Add(b.IdleTimeout)) 78 | return b.backend.Receive() 79 | } 80 | 81 | // ReceiveRaw accepts a message from the backend, or errors if nothing 82 | // is read within the idle timeout. Returns io.ErrUnexpectedEOF if the 83 | // connection has been closed. 84 | func (b *PostgresBackend) ReceiveRaw() ([]byte, error) { 85 | // Postgres send buffers are at least this large 86 | response := make([]byte, 8192) 87 | _ = b.connection.SetReadDeadline(time.Now().Add(b.IdleTimeout)) 88 | readBytes, err := b.connection.Read(response) 89 | if err != nil && err == io.EOF { 90 | err = io.ErrUnexpectedEOF 91 | } 92 | return response[:readBytes], err 93 | } 94 | 95 | // Close closes the underlying connection 96 | func (b *PostgresBackend) Close() error { 97 | return b.connection.Close() 98 | } 99 | 100 | // SetupConnection sets up an inbound connection and extracts the login information 101 | // This will always return the existing connection, unless it had to upgrade to an SSL 102 | // connection. 103 | func (b *PostgresBackend) SetupConnection(cert *tls.Certificate) (map[string]string, error) { 104 | for { 105 | message, err := b.backend.ReceiveStartupMessage() 106 | if err != nil { 107 | return nil, err 108 | } 109 | switch msg := message.(type) { 110 | case *pgproto3.StartupMessage: 111 | return msg.Parameters, nil 112 | case *pgproto3.SSLRequest: 113 | if cert == nil { 114 | err = b.SendRaw([]byte{SSLNotAllowed}) 115 | if err != nil { 116 | return nil, err 117 | } 118 | continue 119 | } 120 | err = b.SendRaw([]byte{SSLAllowed}) 121 | if err != nil { 122 | return nil, err 123 | } 124 | b.connection = UpgradeServer(b.connection, cert) 125 | b.backend = pgproto3.NewBackend(pgproto3.NewChunkReader(b.connection), b.connection) 126 | continue 127 | case *pgproto3.GSSEncRequest: 128 | // Would need more research to offer GSS enc. 129 | err = b.SendRaw([]byte{GSSENCecNotAllowed}) 130 | if err != nil { 131 | return nil, err 132 | } 133 | continue 134 | } 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /pkg/config/config_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | . "github.com/mothership/rds-auth-proxy/pkg/config" 9 | "github.com/mothership/rds-auth-proxy/pkg/pg" 10 | ) 11 | 12 | func TestProxyConfigLoad(t *testing.T) { 13 | cases := []struct { 14 | FileName string 15 | Error error 16 | }{ 17 | { 18 | FileName: "foo", 19 | Error: fmt.Errorf("Unsupported Config Type"), 20 | }, 21 | { 22 | FileName: "foo.yaml", 23 | Error: fmt.Errorf("no such file or directory"), 24 | }, 25 | } 26 | for idx, test := range cases { 27 | _, err := LoadConfig(test.FileName) 28 | if !errorContains(err, test.Error) { 29 | t.Errorf("[Case %d] expected %+v, got %+v", idx, test.Error, err) 30 | } 31 | } 32 | } 33 | 34 | func TestConfigInit(t *testing.T) { 35 | var cfg ConfigFile 36 | cfg.Init() 37 | 38 | if cfg.Targets == nil { 39 | t.Errorf("expected targets to be initialized") 40 | } 41 | if cfg.ProxyTargets == nil { 42 | t.Errorf("expected proxy targets to be initialized") 43 | } 44 | } 45 | 46 | func TestTargetsGetDefaults(t *testing.T) { 47 | var cfg ConfigFile = ConfigFile{ 48 | Proxy: Proxy{ 49 | SSL: ServerSSL{ 50 | ClientCertificatePath: strPtr("/app/cert.pem"), 51 | ClientPrivateKeyPath: strPtr("/app/key.pem"), 52 | }, 53 | }, 54 | Targets: map[string]*Target{ 55 | "empty": {Host: "0"}, 56 | "override": { 57 | Host: "1", 58 | SSL: SSL{ 59 | ClientCertificatePath: strPtr("/tls/cert.pem"), 60 | ClientPrivateKeyPath: strPtr("/tls/key.pem"), 61 | }, 62 | }, 63 | }, 64 | } 65 | cfg.Init() 66 | if cfg.Targets["empty"].Name != "empty" { 67 | t.Errorf("Expected empty to have name populated") 68 | } 69 | 70 | if cfg.Targets["override"].Name != "override" { 71 | t.Errorf("Expected override to have name populated") 72 | } 73 | 74 | if cfg.Targets["empty"].SSL.Mode != pg.SSLRequired { 75 | t.Errorf("Expected empty to require SSL") 76 | } 77 | 78 | if cfg.Targets["empty"].SSL.ClientCertificatePath != cfg.Proxy.SSL.ClientCertificatePath { 79 | t.Errorf("Expected SSL cert to have taken value from parent") 80 | } 81 | 82 | if cfg.Targets["empty"].SSL.ClientPrivateKeyPath != cfg.Proxy.SSL.ClientPrivateKeyPath { 83 | t.Errorf("Expected SSL key to have taken value from parent") 84 | } 85 | } 86 | 87 | func TestProxyTargetsGetDefault(t *testing.T) { 88 | var cfg ConfigFile = ConfigFile{ 89 | Proxy: Proxy{ 90 | SSL: ServerSSL{ 91 | ClientCertificatePath: strPtr("/app/cert.pem"), 92 | ClientPrivateKeyPath: strPtr("/app/key.pem"), 93 | }, 94 | }, 95 | ProxyTargets: map[string]*ProxyTarget{ 96 | "empty": {Host: "0"}, 97 | "portforward": { 98 | PortForward: &PortForward{}, 99 | }, 100 | "override": { 101 | Host: "1", 102 | SSL: SSL{ 103 | ClientCertificatePath: strPtr("/tls/cert.pem"), 104 | ClientPrivateKeyPath: strPtr("/tls/key.pem"), 105 | }, 106 | }, 107 | }, 108 | } 109 | cfg.Init() 110 | 111 | if cfg.ProxyTargets["empty"].Name != "empty" { 112 | t.Errorf("Expected empty to have name populated") 113 | } 114 | 115 | if cfg.ProxyTargets["override"].Name != "override" { 116 | t.Errorf("Expected override to have name populated") 117 | } 118 | 119 | if cfg.ProxyTargets["empty"].SSL.Mode != pg.SSLRequired { 120 | t.Errorf("Expected empty to require SSL") 121 | } 122 | 123 | if cfg.ProxyTargets["empty"].SSL.ClientCertificatePath != cfg.Proxy.SSL.ClientCertificatePath { 124 | t.Errorf("Expected SSL cert to have taken value from parent") 125 | } 126 | 127 | if cfg.ProxyTargets["empty"].SSL.ClientPrivateKeyPath != cfg.Proxy.SSL.ClientPrivateKeyPath { 128 | t.Errorf("Expected SSL key to have taken value from parent") 129 | } 130 | 131 | if cfg.ProxyTargets["portforward"].SSL.Mode != pg.SSLDisabled { 132 | t.Errorf("Expected SSL to be disabled on portforward if not set") 133 | } 134 | } 135 | 136 | // errorContains checks if the error message in out contains the text in 137 | // want. 138 | // 139 | // This is safe when out is nil. Use an empty string for want if you want to 140 | // test that err is nil. 141 | func errorContains(out error, want error) bool { 142 | if want == nil && out == nil { 143 | return true 144 | } else if want == nil || out == nil { 145 | return false 146 | } 147 | return strings.Contains(out.Error(), want.Error()) 148 | } 149 | 150 | func strPtr(val string) *string { 151 | return &val 152 | } 153 | -------------------------------------------------------------------------------- /pkg/discovery/static/client_test.go: -------------------------------------------------------------------------------- 1 | package static_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/mothership/rds-auth-proxy/pkg/config" 7 | "github.com/mothership/rds-auth-proxy/pkg/discovery" 8 | . "github.com/mothership/rds-auth-proxy/pkg/discovery/static" 9 | ) 10 | 11 | func TestStaticDiscoveryClientHostLookupFailures(t *testing.T) { 12 | staticTargets := map[string]config.Target{ 13 | "test.com:5432": makeTarget("test", "test.com:5432"), 14 | } 15 | cases := []struct { 16 | Targets map[string]config.Target 17 | Host string 18 | Expected error 19 | }{ 20 | { 21 | Targets: nil, 22 | Host: "test.com:5432", 23 | Expected: discovery.ErrTargetNotFound, 24 | }, 25 | { 26 | Targets: staticTargets, 27 | Host: "missing", 28 | Expected: discovery.ErrTargetNotFound, 29 | }, 30 | } 31 | for idx, test := range cases { 32 | client := NewStaticDiscoveryClient(test.Targets) 33 | _, err := client.LookupTargetByHost(test.Host) 34 | if test.Expected != err { 35 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, err) 36 | } 37 | } 38 | } 39 | 40 | func TestStaticDiscoveryClientHostLookupSuccess(t *testing.T) { 41 | staticTargets := map[string]config.Target{ 42 | "test.com:5432": makeTarget("test", "test.com:5432"), 43 | } 44 | cases := []struct { 45 | Targets map[string]config.Target 46 | Host string 47 | Expected config.Target 48 | }{ 49 | { 50 | Targets: staticTargets, 51 | Host: "test.com:5432", 52 | Expected: staticTargets["test.com:5432"], 53 | }, 54 | } 55 | for idx, test := range cases { 56 | client := NewStaticDiscoveryClient(test.Targets) 57 | target, err := client.LookupTargetByHost(test.Host) 58 | if err != nil { 59 | t.Fatalf("[Case %d] unexpected error: %s", idx, err) 60 | } 61 | if test.Expected != target { 62 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, target) 63 | } 64 | } 65 | } 66 | 67 | func TestStaticDiscoveryClientNameLookupSuccess(t *testing.T) { 68 | staticTargets := map[string]config.Target{ 69 | "test.com:5432": makeTarget("test", "test.com:5432"), 70 | } 71 | cases := []struct { 72 | Targets map[string]config.Target 73 | Name string 74 | Expected config.Target 75 | }{ 76 | { 77 | Targets: staticTargets, 78 | Name: "test", 79 | Expected: staticTargets["test.com:5432"], 80 | }, 81 | } 82 | for idx, test := range cases { 83 | client := NewStaticDiscoveryClient(test.Targets) 84 | target, err := client.LookupTargetByName(test.Name) 85 | if err != nil { 86 | t.Fatalf("[Case %d] unexpected error: %s", idx, err) 87 | } 88 | if test.Expected != target { 89 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, target) 90 | } 91 | } 92 | } 93 | 94 | func TestStaticDiscoveryClientNameLookupFailures(t *testing.T) { 95 | staticTargets := map[string]config.Target{ 96 | "test.com:5432": makeTarget("test", "test.com:5432"), 97 | } 98 | cases := []struct { 99 | Targets map[string]config.Target 100 | Name string 101 | Expected error 102 | }{ 103 | { 104 | Targets: nil, 105 | Name: "test", 106 | Expected: discovery.ErrTargetNotFound, 107 | }, 108 | { 109 | Targets: staticTargets, 110 | Name: "missing", 111 | Expected: discovery.ErrTargetNotFound, 112 | }, 113 | } 114 | for idx, test := range cases { 115 | client := NewStaticDiscoveryClient(test.Targets) 116 | _, err := client.LookupTargetByName(test.Name) 117 | if test.Expected != err { 118 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, err) 119 | } 120 | } 121 | } 122 | 123 | func TestStaticDiscoveryClientGetTargets(t *testing.T) { 124 | staticTargets := map[string]config.Target{ 125 | "test.com:5432": makeTarget("test", "test.com:5432"), 126 | } 127 | client := NewStaticDiscoveryClient(staticTargets) 128 | targets := client.GetTargets() 129 | if len(targets) != len(staticTargets) { 130 | t.Fatalf("missing targets") 131 | } 132 | 133 | for _, target := range targets { 134 | found, ok := staticTargets[target.Host] 135 | if !ok { 136 | t.Fatalf("failed to find target %q", target.Host) 137 | } 138 | if found.Host != target.Host || target.Name != found.Name { 139 | t.Fatalf("found wrong target %q, expected %q", found.Host, target.Host) 140 | } 141 | } 142 | 143 | } 144 | 145 | func makeTarget(name string, host string) config.Target { 146 | return config.Target{ 147 | Host: host, 148 | Name: name, 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /pkg/pg/frontend.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "crypto/md5" 5 | "fmt" 6 | "io" 7 | "net" 8 | "sync" 9 | "time" 10 | 11 | pgproto3 "github.com/jackc/pgproto3/v2" 12 | ) 13 | 14 | const ( 15 | // TODO: should this be configurable? 16 | readTimeout = 3 * time.Second 17 | ) 18 | 19 | // Frontend acts as the postgres front-end client (ex: psql) 20 | type Frontend interface { 21 | io.Closer 22 | Send(msg pgproto3.FrontendMessage) error 23 | SendRaw([]byte) error 24 | Receive() (pgproto3.BackendMessage, error) 25 | ReceiveRaw() ([]byte, error) 26 | } 27 | 28 | // SendOnlyFrontend allows only the send operation to be accessed for network safety 29 | type SendOnlyFrontend interface { 30 | Send(msg pgproto3.FrontendMessage) error 31 | } 32 | 33 | type AuthFailedError struct { 34 | ErrMsg *pgproto3.ErrorResponse 35 | } 36 | 37 | func (a *AuthFailedError) Error() string { 38 | return "auth failed" 39 | } 40 | 41 | // PostgresFrontend implements a postgres frontend client 42 | type PostgresFrontend struct { 43 | frontend *pgproto3.Frontend 44 | connection net.Conn 45 | IdleTimeout time.Duration 46 | mutex sync.Mutex 47 | } 48 | 49 | // FrontendOption allows us to specify options 50 | type FrontendOption func(f *PostgresFrontend) error 51 | 52 | // NewFrontend returns a new postgres frontend 53 | func NewFrontend(conn net.Conn, opts ...FrontendOption) (*PostgresFrontend, error) { 54 | f := &PostgresFrontend{ 55 | frontend: pgproto3.NewFrontend(pgproto3.NewChunkReader(conn), conn), 56 | connection: conn, 57 | IdleTimeout: readTimeout, 58 | mutex: sync.Mutex{}, 59 | } 60 | 61 | for _, opt := range opts { 62 | err := opt(f) 63 | if err != nil { 64 | return nil, err 65 | } 66 | } 67 | return f, nil 68 | } 69 | 70 | // Send sends a frontend message to the backend 71 | func (f *PostgresFrontend) Send(msg pgproto3.FrontendMessage) error { 72 | f.mutex.Lock() 73 | defer f.mutex.Unlock() 74 | return f.frontend.Send(msg) 75 | } 76 | 77 | // SendRaw sends arbitrary bytes to a backend 78 | func (f *PostgresFrontend) SendRaw(b []byte) error { 79 | _, err := f.connection.Write(b) 80 | return err 81 | } 82 | 83 | // Receive accepts a message from the backend, or errors if nothing is read 84 | // within the idle timeout. Returns io.ErrUnexpectedEOF if the connection has 85 | // been closed. 86 | func (f *PostgresFrontend) Receive() (pgproto3.BackendMessage, error) { 87 | _ = f.connection.SetReadDeadline(time.Now().Add(f.IdleTimeout)) 88 | return f.frontend.Receive() 89 | } 90 | 91 | // ReceiveRaw accepts a message from the backend, or errors if nothing 92 | // is read within the idle timeout. Returns io.ErrUnexpectedEOF if the 93 | // connection has been closed. 94 | func (f *PostgresFrontend) ReceiveRaw() ([]byte, error) { 95 | // Postgres send buffers are at least this large 96 | response := make([]byte, 8192) 97 | _ = f.connection.SetReadDeadline(time.Now().Add(f.IdleTimeout)) 98 | readBytes, err := f.connection.Read(response) 99 | if err != nil && err == io.EOF { 100 | err = io.ErrUnexpectedEOF 101 | } 102 | return response[:readBytes], err 103 | } 104 | 105 | // Close closes the underlying connection 106 | func (b *PostgresFrontend) Close() error { 107 | return b.connection.Close() 108 | } 109 | 110 | func (f *PostgresFrontend) HandleAuthenticationRequest(username, password string) error { 111 | // TODO: max tries / exit condition? 112 | for { 113 | message, err := f.frontend.Receive() 114 | if err != nil { 115 | return err 116 | } 117 | 118 | switch msg := message.(type) { 119 | case *pgproto3.AuthenticationOk: 120 | return nil 121 | case *pgproto3.ReadyForQuery: 122 | return nil 123 | case *pgproto3.AuthenticationMD5Password: 124 | if err = f.Send(createMd5(msg, username, password)); err != nil { 125 | return err 126 | } 127 | continue 128 | case *pgproto3.AuthenticationCleartextPassword: 129 | if err := f.Send(createCleartext(msg, username, password)); err != nil { 130 | return err 131 | } 132 | continue 133 | case *pgproto3.ErrorResponse: 134 | return &AuthFailedError{ErrMsg: msg} 135 | case *pgproto3.AuthenticationSASL: 136 | return fmt.Errorf("SASL auth not supported") 137 | default: 138 | return fmt.Errorf("unsupported auth request, or unexpected message") 139 | } 140 | } 141 | } 142 | 143 | func createMD5Password(username string, password string, salt string) string { 144 | // Concatenate the password and the username together. 145 | passwordString := fmt.Sprintf("%s%s", password, username) 146 | 147 | // Compute the MD5 sum of the password+username string. 148 | passwordString = fmt.Sprintf("%x", md5.Sum([]byte(passwordString))) 149 | 150 | // Compute the MD5 sum of the password hash and the salt 151 | passwordString = fmt.Sprintf("%s%s", passwordString, salt) 152 | return fmt.Sprintf("md5%x", md5.Sum([]byte(passwordString))) 153 | } 154 | 155 | func createMd5(msg *pgproto3.AuthenticationMD5Password, username, password string) *pgproto3.PasswordMessage { 156 | return &pgproto3.PasswordMessage{ 157 | Password: createMD5Password(username, password, string(msg.Salt[:])), 158 | } 159 | } 160 | 161 | func createCleartext(msg *pgproto3.AuthenticationCleartextPassword, username, password string) *pgproto3.PasswordMessage { 162 | return &pgproto3.PasswordMessage{ 163 | Password: password, 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /pkg/pg/ssl.go: -------------------------------------------------------------------------------- 1 | package pg 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "net" 7 | 8 | "github.com/jackc/pgproto3/v2" 9 | ) 10 | 11 | // SSLMode is the type of SSL required 12 | // https://www.postgresql.org/docs/8.4/libpq-connect.html#LIBPQ-CONNECT-SSLMODE 13 | type SSLMode string 14 | 15 | const ( 16 | // SSLDisabled only tries a non-SSL connection 17 | SSLDisabled SSLMode = "disable" 18 | // SSLAllow first try a non-SSL connection, if that fails, tries an SSL connection 19 | // XXX: Not allowed at this time 20 | SSLAllow = "allow" 21 | // SSLPreferred is like allow, but tries an SSL connection first -- default behavior of psql 22 | SSLPreferred = "preferred" 23 | // SSLRequired only tries an SSL connection. If a root CA file is present, verify the certificate in the same way as if verify-ca was specified 24 | SSLRequired = "require" 25 | // SSLVerifyCA only tries an SSL connection, and verifies that the server certificate is issued by a trusted CA. 26 | SSLVerifyCA = "verify-ca" 27 | // SSLVerifyFull only tries an SSL connection, verifies that the server certificate is issued by a trusted CA and that 28 | // the server hostname matches that in the certificate. 29 | SSLVerifyFull = "verify-full" 30 | ) 31 | 32 | // Connect connects to an upstream database 33 | func Connect(host string, mode SSLMode, cert *tls.Certificate, rootCert *x509.Certificate) (net.Conn, error) { 34 | connection, err := net.Dial("tcp", host) 35 | if err != nil { 36 | return nil, err 37 | } 38 | 39 | backend := pgproto3.NewFrontend(pgproto3.NewChunkReader(connection), connection) 40 | if mode != SSLDisabled { 41 | // log.Info("SSL connections are enabled.") 42 | 43 | /* 44 | * First determine if SSL is allowed by the backend. To do this, send an 45 | * SSL request. The response from the backend will be a single byte 46 | * message. If the value is 'S', then SSL connections are allowed and an 47 | * upgrade to the connection should be attempted. If the value is 'N', 48 | * then the backend does not support SSL connections. 49 | */ 50 | sslRequest := &pgproto3.SSLRequest{} 51 | err = backend.Send(sslRequest) 52 | if err != nil { 53 | return nil, err 54 | } 55 | 56 | response := make([]byte, 4096) 57 | _, err = connection.Read(response) 58 | if err != nil { 59 | return nil, err 60 | } 61 | 62 | if len(response) > 0 && response[0] == SSLAllowed { 63 | // TODO: should probably decide whether or not to error based on SSL mode 64 | // but we'll pass the error back anyhow 65 | connection, err = UpgradeClient(host, connection, mode, cert, rootCert) 66 | } else if mode != SSLPreferred { 67 | // Close the connection only if we wanted required or higher 68 | connection.Close() 69 | } 70 | } 71 | 72 | return connection, err 73 | } 74 | 75 | // UpgradeServer upgrades a server connection with SSL 76 | func UpgradeServer(client net.Conn, cert *tls.Certificate) net.Conn { 77 | if cert == nil { 78 | return client 79 | } 80 | tlsConfig := tls.Config{} 81 | tlsConfig.Certificates = []tls.Certificate{*cert} 82 | return tls.Server(client, &tlsConfig) 83 | } 84 | 85 | // UpgradeClient upgrades a client connection with SSL 86 | func UpgradeClient(hostPort string, connection net.Conn, mode SSLMode, cert *tls.Certificate, rootCert *x509.Certificate) (net.Conn, error) { 87 | if mode == SSLDisabled { 88 | return connection, nil 89 | } 90 | 91 | tlsConfig := tls.Config{} 92 | if mode == SSLPreferred || mode == SSLRequired || mode == SSLVerifyCA { 93 | tlsConfig.InsecureSkipVerify = true 94 | } 95 | 96 | if mode == SSLVerifyFull { 97 | hostname, _, err := net.SplitHostPort(hostPort) 98 | if err != nil { 99 | return connection, err 100 | } 101 | tlsConfig.ServerName = hostname 102 | } 103 | 104 | var err error 105 | tlsConfig.Certificates = []tls.Certificate{*cert} 106 | tlsConfig.RootCAs, err = x509.SystemCertPool() 107 | if err != nil { 108 | return connection, err 109 | } 110 | 111 | if rootCert != nil { 112 | tlsConfig.RootCAs.AddCert(rootCert) 113 | } 114 | 115 | // do the upgrade 116 | client := tls.Client(connection, &tlsConfig) 117 | if mode == SSLVerifyCA || (mode == SSLRequired && rootCert != nil) { 118 | err := verifyCA(client, &tlsConfig) 119 | if err != nil { 120 | return connection, err 121 | } 122 | } 123 | 124 | return client, nil 125 | } 126 | 127 | // verifyCA explicitly does the handshake and certificate chain validation in the case that we need to validate 128 | // the CA, or we have a CA cert to validate against and the mode is require. 129 | func verifyCA(client *tls.Conn, tlsConf *tls.Config) error { 130 | err := client.Handshake() 131 | if err != nil { 132 | return err 133 | } 134 | 135 | // Get the peer/CA certificates from the connection state. 136 | peerCerts := client.ConnectionState().PeerCertificates 137 | caCert := peerCerts[0] 138 | peerCerts = peerCerts[1:] 139 | 140 | options := x509.VerifyOptions{ 141 | DNSName: client.ConnectionState().ServerName, 142 | Intermediates: x509.NewCertPool(), 143 | Roots: tlsConf.RootCAs, 144 | } 145 | // build the intermediate chain for verification 146 | for _, certificate := range peerCerts { 147 | options.Intermediates.AddCert(certificate) 148 | } 149 | 150 | // verify the CA cert is legitimate by building a path between it and the root 151 | // certificates we have, using the intermediates provided by the peer certificates. 152 | _, err = caCert.Verify(options) 153 | return err 154 | } 155 | -------------------------------------------------------------------------------- /pkg/proxy/config.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "fmt" 7 | "net" 8 | 9 | pgproto3 "github.com/jackc/pgproto3/v2" 10 | "github.com/mothership/rds-auth-proxy/pkg/cert" 11 | "github.com/mothership/rds-auth-proxy/pkg/pg" 12 | ) 13 | 14 | // Credentials represents connection details to an upstream database or proxy 15 | type Credentials struct { 16 | Host string 17 | Database string 18 | Username string 19 | Password string 20 | // Misc connection parameters to be passed along 21 | Options map[string]string 22 | // SSL Settings for the outbound connection 23 | SSLMode pg.SSLMode 24 | ClientCertificate *tls.Certificate 25 | RootCertificate *x509.Certificate 26 | } 27 | 28 | // CredentialInterceptor provides a way to update credentials being forwarded 29 | // to the server proxy 30 | type CredentialInterceptor func(creds *Credentials) error 31 | 32 | // Mode indicates what kind of mode the proxy is in 33 | type Mode int 34 | 35 | const ( 36 | // ClientSide proxy mode is for running on the end-user laptop 37 | ClientSide Mode = iota 38 | // ServerSide proxy mode is for the in-cluster 39 | ServerSide 40 | ) 41 | 42 | // Config contains the various options for setting up the proxy 43 | type Config struct { 44 | ServerCertificate *tls.Certificate 45 | DefaultClientCertificate *tls.Certificate 46 | ListenAddress *net.TCPAddr 47 | CredentialInterceptor CredentialInterceptor 48 | QueryInterceptor QueryInterceptor 49 | Mode Mode 50 | } 51 | 52 | // QueryInterceptor provides a way to define custom behavior for handling messages 53 | type QueryInterceptor func(frontend pg.SendOnlyFrontend, backend pg.SendOnlyBackend, msg *pgproto3.Query) error 54 | 55 | // WillSendManually lets the proxy know that QueryInterceptor will handle sending the message 56 | var WillSendManually = fmt.Errorf("sending manually") 57 | 58 | // Option lets you set a config option 59 | type Option func(*Config) error 60 | 61 | // WithMode sets the mode of the proxy 62 | func WithMode(mode Mode) Option { 63 | return func(c *Config) error { 64 | if mode != ServerSide && mode != ClientSide { 65 | return fmt.Errorf("invalid mode: %d", mode) 66 | } 67 | c.Mode = mode 68 | return nil 69 | } 70 | } 71 | 72 | // WithListenAddress sets the IP/port that the proxy will accept connections on 73 | func WithListenAddress(addr string) Option { 74 | return func(c *Config) error { 75 | listenAddr, err := net.ResolveTCPAddr("tcp", addr) 76 | if err != nil { 77 | return err 78 | } 79 | 80 | c.ListenAddress = listenAddr 81 | return nil 82 | } 83 | } 84 | 85 | // WithCredentialInterceptor sets the credential retrieval strategy 86 | func WithCredentialInterceptor(credFactory CredentialInterceptor) Option { 87 | return func(c *Config) error { 88 | c.CredentialInterceptor = credFactory 89 | return nil 90 | } 91 | } 92 | 93 | // WithServerCertificate sets the SSL settings for the proxy 94 | func WithServerCertificate(certPath, keyPath string) Option { 95 | return func(c *Config) (err error) { 96 | if certPath == "" { 97 | return fmt.Errorf("certificate path not set") 98 | } 99 | if keyPath == "" { 100 | return fmt.Errorf("private key path not set") 101 | } 102 | cert, err := tls.LoadX509KeyPair(certPath, keyPath) 103 | if err != nil { 104 | return err 105 | } 106 | c.ServerCertificate = &cert 107 | return nil 108 | } 109 | } 110 | 111 | // WithGeneratedServerCertificate generates a self-signed server certificate for the proxy 112 | func WithGeneratedServerCertificate() Option { 113 | return func(c *Config) (err error) { 114 | certBytes, keyBytes, err := cert.GenerateSelfSignedCert("localhost,127.0.0.1", false) 115 | if err != nil { 116 | return err 117 | } 118 | cert, err := tls.X509KeyPair(certBytes, keyBytes) 119 | if err != nil { 120 | return err 121 | } 122 | c.ServerCertificate = &cert 123 | return nil 124 | } 125 | } 126 | 127 | // WithClientCertificate sets up the default client certificates 128 | func WithClientCertificate(certPath, keyPath string) Option { 129 | return func(c *Config) (err error) { 130 | if certPath == "" { 131 | return fmt.Errorf("client certificate path not set") 132 | } 133 | if keyPath == "" { 134 | return fmt.Errorf("client private key path not set") 135 | } 136 | cert, err := tls.LoadX509KeyPair(certPath, keyPath) 137 | if err != nil { 138 | return err 139 | } 140 | c.DefaultClientCertificate = &cert 141 | return nil 142 | } 143 | } 144 | 145 | // WithGeneratedClientCertificate generates the default client certificates 146 | func WithGeneratedClientCertificate() Option { 147 | return func(c *Config) (err error) { 148 | certBytes, keyBytes, err := cert.GenerateSelfSignedCert("localhost,127.0.0.1", false) 149 | if err != nil { 150 | return err 151 | } 152 | cert, err := tls.X509KeyPair(certBytes, keyBytes) 153 | if err != nil { 154 | return err 155 | } 156 | c.DefaultClientCertificate = &cert 157 | return nil 158 | } 159 | } 160 | 161 | // WithQueryInterceptor adds a function for custom message handling 162 | func WithQueryInterceptor(interceptor QueryInterceptor) Option { 163 | return func(c *Config) (err error) { 164 | c.QueryInterceptor = interceptor 165 | return nil 166 | } 167 | } 168 | 169 | // MergeOptions is a helper to merge an option list 170 | func MergeOptions(lists ...[]Option) []Option { 171 | opts := []Option{} 172 | for _, l := range lists { 173 | opts = append(opts, l...) 174 | } 175 | return opts 176 | } 177 | -------------------------------------------------------------------------------- /pkg/discovery/combined/client_test.go: -------------------------------------------------------------------------------- 1 | package combined_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/mothership/rds-auth-proxy/pkg/config" 7 | "github.com/mothership/rds-auth-proxy/pkg/discovery" 8 | . "github.com/mothership/rds-auth-proxy/pkg/discovery/combined" 9 | "github.com/mothership/rds-auth-proxy/pkg/discovery/static" 10 | ) 11 | 12 | func TestCombinedDiscoveryClientHostLookupFailures(t *testing.T) { 13 | staticOne := makeStatic(makeTarget("db-1", "db-1:5432")) 14 | staticTwo := makeStatic(makeTarget("db-2", "db-2:5432"), makeTarget("db-3", "db-3:5432")) 15 | cases := []struct { 16 | Clients []discovery.Client 17 | Host string 18 | Expected error 19 | }{ 20 | { 21 | Clients: nil, 22 | Host: "db-1:5432", 23 | Expected: discovery.ErrTargetNotFound, 24 | }, 25 | { 26 | Clients: []discovery.Client{staticOne, staticTwo}, 27 | Host: "db-4:5432", 28 | Expected: discovery.ErrTargetNotFound, 29 | }, 30 | } 31 | for idx, test := range cases { 32 | client := NewCombinedDiscoveryClient(test.Clients) 33 | _, err := client.LookupTargetByHost(test.Host) 34 | if test.Expected != err { 35 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, err) 36 | } 37 | } 38 | } 39 | 40 | func TestCombinedDiscoveryClientHostLookupSuccess(t *testing.T) { 41 | staticOne := makeStatic(makeTarget("db-1", "db-1:5432")) 42 | staticTwo := makeStatic(makeTarget("db-2", "db-2:5432"), makeTarget("db-3", "db-3:5432")) 43 | cases := []struct { 44 | Clients []discovery.Client 45 | Host string 46 | Expected config.Target 47 | }{ 48 | { 49 | Clients: []discovery.Client{staticOne, staticTwo}, 50 | Host: "db-1:5432", 51 | Expected: ensureTarget(staticOne.LookupTargetByHost("db-1:5432")), 52 | }, 53 | { 54 | Clients: []discovery.Client{staticOne, staticTwo}, 55 | Host: "db-2:5432", 56 | Expected: ensureTarget(staticTwo.LookupTargetByHost("db-2:5432")), 57 | }, 58 | } 59 | for idx, test := range cases { 60 | client := NewCombinedDiscoveryClient(test.Clients) 61 | target, err := client.LookupTargetByHost(test.Host) 62 | if err != nil { 63 | t.Fatalf("[Case %d] unexpected error: %s", idx, err) 64 | } 65 | if test.Expected != target { 66 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, target) 67 | } 68 | } 69 | } 70 | 71 | func TestCombinedDiscoveryClientNameLookupSuccess(t *testing.T) { 72 | staticOne := makeStatic(makeTarget("db-1", "db-1:5432")) 73 | staticTwo := makeStatic(makeTarget("db-2", "db-2:5432"), makeTarget("db-3", "db-3:5432")) 74 | cases := []struct { 75 | Clients []discovery.Client 76 | Name string 77 | Expected config.Target 78 | }{ 79 | { 80 | Clients: []discovery.Client{staticOne, staticTwo}, 81 | Name: "db-1", 82 | Expected: ensureTarget(staticOne.LookupTargetByHost("db-1:5432")), 83 | }, 84 | { 85 | Clients: []discovery.Client{staticOne, staticTwo}, 86 | Name: "db-3", 87 | Expected: ensureTarget(staticTwo.LookupTargetByHost("db-3:5432")), 88 | }, 89 | } 90 | for idx, test := range cases { 91 | client := NewCombinedDiscoveryClient(test.Clients) 92 | target, err := client.LookupTargetByName(test.Name) 93 | if err != nil { 94 | t.Fatalf("[Case %d] unexpected error: %s", idx, err) 95 | } 96 | if test.Expected != target { 97 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, target) 98 | } 99 | } 100 | } 101 | 102 | func TestCombinedDiscoveryClientNameLookupFailures(t *testing.T) { 103 | staticOne := makeStatic(makeTarget("db-1", "db-1:5432")) 104 | staticTwo := makeStatic(makeTarget("db-2", "db-2:5432"), makeTarget("db-3", "db-3:5432")) 105 | cases := []struct { 106 | Clients []discovery.Client 107 | Name string 108 | Expected error 109 | }{ 110 | { 111 | Clients: nil, 112 | Name: "db-1", 113 | Expected: discovery.ErrTargetNotFound, 114 | }, 115 | { 116 | Clients: []discovery.Client{staticOne, staticTwo}, 117 | Name: "db-4", 118 | Expected: discovery.ErrTargetNotFound, 119 | }, 120 | } 121 | for idx, test := range cases { 122 | client := NewCombinedDiscoveryClient(test.Clients) 123 | _, err := client.LookupTargetByName(test.Name) 124 | if test.Expected != err { 125 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, err) 126 | } 127 | } 128 | } 129 | 130 | func TestCombinedDiscoveryClientGetTargets(t *testing.T) { 131 | staticOne := makeStatic(makeTarget("db-1", "db-1:5432")) 132 | staticTwo := makeStatic(makeTarget("db-2", "db-2:5432"), makeTarget("db-3", "db-3:5432")) 133 | client := NewCombinedDiscoveryClient([]discovery.Client{staticOne, staticTwo}) 134 | targets := client.GetTargets() 135 | if len(targets) != len(staticOne.GetTargets())+len(staticTwo.GetTargets()) { 136 | t.Fatalf("missing targets") 137 | } 138 | 139 | for _, target := range targets { 140 | found, err := client.LookupTargetByHost(target.Host) 141 | if err != nil { 142 | t.Fatalf("unexpected error %+v", err) 143 | } 144 | if found != target { 145 | t.Fatalf("found wrong target: %+v, expected %+v", found, target) 146 | } 147 | } 148 | } 149 | 150 | func makeStatic(targets ...config.Target) discovery.Client { 151 | hostMap := map[string]config.Target{} 152 | for _, target := range targets { 153 | hostMap[target.Host] = target 154 | } 155 | return static.NewStaticDiscoveryClient(hostMap) 156 | } 157 | 158 | func makeTarget(name string, host string) config.Target { 159 | return config.Target{ 160 | Host: host, 161 | Name: name, 162 | } 163 | } 164 | 165 | func ensureTarget(target config.Target, err error) config.Target { 166 | if err != nil { 167 | panic(err) 168 | } 169 | return target 170 | } 171 | -------------------------------------------------------------------------------- /docs/images/shield-icon-gradient.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | !> At the moment, `rds-auth-proxy` only support PostgreSQL-flavored RDS instances. 4 | 5 | `rds-auth-proxy` is a binary that contains two major components, a server-side proxy, 6 | and a client-side proxy. This guide takes you through deploying the server-side proxy in 7 | your cluster, and connecting to it using the client-side proxy. 8 | 9 | For more information about the design, see the [architecture](./architecture.md) docs. 10 | 11 | ## Deploying the Server 12 | 13 | In order to deploy the server proxy successfully, you must ensure the following: 14 | 15 | 1. The proxy will be deployed into a subnet with security group rules that allow access to the RDS instances 16 | 2. The proxy has AWS credentials for database discovery 17 | 18 | The recommended way to install the proxy is to use our [helm chart](https://github.com/mothership/helm-charts/tree/master/charts/rds-auth-proxy). 19 | 20 | ### Setting up AWS Permissions 21 | 22 | The server-side proxy needs to be able to look up database instances to validate that it's allowed 23 | to complete the connection. 24 | 25 | In order to do this, it must be able to list RDS instances. An example IAM policy may look like this: 26 | 27 | ```json 28 | { 29 | "Version":"2012-10-17", 30 | "Statement":[ 31 | { 32 | "Sid":"AllowRDSDescribe", 33 | "Effect":"Allow", 34 | "Action": [ 35 | "rds:DescribeDBInstances", 36 | "rds:ListTagsForResource" 37 | ], 38 | "Resource": [ 39 | "arn:aws:rds:*:*:db:*" 40 | ] 41 | } 42 | ] 43 | } 44 | ``` 45 | 46 | You can get more granular with this policy by only allowing certain tags, AWS accounts, etc. Attach 47 | this policy to the user or role that will be used by the server-side proxy. 48 | 49 | ### Adding our chart repository 50 | 51 | Use Helm 3 to add the mothership repository: 52 | 53 | ```bash 54 | helm repo add mothership https://mothership.github.io/helm-charts/ 55 | helm repo update 56 | ``` 57 | 58 | ### Installing 59 | 60 | Start by creating a values file for the helm chart. An example file using [IRSA](https://aws.amazon.com/blogs/opensource/introducing-fine-grained-iam-roles-service-accounts/) 61 | is provided below. 62 | 63 | A similar approach should also work for [kube2iam](https://github.com/jtblin/kube2iam), or [kiam](https://github.com/uswitch/kiam) 64 | as well, but the annotations would need to be added to the deployment instead of the service account. 65 | 66 | ```yaml 67 | # values.yaml 68 | fullnameOverride: "rds-auth-proxy" 69 | deployment: 70 | # deploy at least two pods 71 | replicaCount: 2 72 | 73 | # create a service account, and pass AWS credentials using IRSA 74 | serviceAccount: 75 | create: true 76 | # IRSA controller will pick up this annotation and inject AWS credentials into the pod. 77 | annotations: 78 | eks.amazonaws.com/role-arn: arn:aws:iam::123456789012:role/rds-auth-proxy 79 | 80 | rbac: 81 | # This creates a role with permissions to get/list pods + portforward to them in 82 | # the release namespace. 83 | create: true 84 | # Add your developer group, or list of users here to give them permissions 85 | # to port-forward to the server proxy 86 | portforwardSubjects: 87 | - name: my-developer-group 88 | kind: Group 89 | apiGroup: rbac.authorization.k8s.io 90 | 91 | 92 | proxy: 93 | # Disable SSL/TLS to the proxy itself, we'll tunnel the connection over a port-forward 94 | ssl: 95 | enabled: false 96 | # Disable cert manager integration, the proxy will generate a self-signed client 97 | # certificate and key 98 | certManager: 99 | enabled: false 100 | allowedRDSTags: 101 | - name: "rds_proxy_enabled" 102 | value: "true" 103 | ``` 104 | 105 | Apply the helm chart with this configuration to the cluster: 106 | 107 | ```bash 108 | kubectl create namespace rds-auth-proxy 109 | helm install rds-auth-proxy --namespace rds-auth-proxy mothership/rds-auth-proxy -f values.yaml 110 | ``` 111 | 112 | ## Preparing your database 113 | 114 | Enable RDS IAM authentication for one of your databases that the server proxy can reach. 115 | 116 | ```bash 117 | aws rds modify-db-instance \ 118 | --db-instance-identifier {my-db-instance} \ 119 | --apply-immediately \ 120 | --enable-iam-database-authentication 121 | ``` 122 | 123 | Add the tag `rds_proxy_enabled:true` to your database instance. 124 | 125 | ```bash 126 | aws rds add-tags-to-resource \ 127 | --resource-name {my-db-arn} \ 128 | --tags "[{\"Key\": \"rds_proxy_enabled\",\"Value\": \"true\"}]" 129 | ``` 130 | 131 | ### Granting the IAM Role 132 | 133 | Log in to your database and grant the `rds_iam` role to a user that you want to be accessible over 134 | the proxy. 135 | 136 | ```sql 137 | GRANT rds_iam TO postgres; 138 | ``` 139 | 140 | ### Granting IAM Permissions 141 | 142 | Ensure that you have access to AWS credentials with permissions to log in as that user. Your IAM 143 | policy would need to include a clause like this: 144 | 145 | ```json 146 | { 147 | "Version": "2012-10-17", 148 | "Statement": [ 149 | { 150 | "Effect": "Allow", 151 | "Action": [ 152 | "rds-db:connect" 153 | ], 154 | "Resource": [ 155 | "arn:aws:rds-db:{region}:{account}:dbuser:{db-resource-id}/{db-user}" 156 | ] 157 | } 158 | ] 159 | } 160 | ``` 161 | 162 | ## Setting up the client 163 | 164 | ### Download the client binary 165 | 166 | Check the [release page](https://github.com/mothership/rds-auth-proxy/releases) for the latest 167 | binaries. Download and install the one for your platform and architecture. 168 | 169 | ### Create your local config file 170 | 171 | Now we need to tell our client proxy about the server proxy. Drop this file at 172 | `~/.config/rds-auth-proxy/config.yaml`. 173 | 174 | ```yaml 175 | # ~/.config/rds-auth-proxy/config.yaml 176 | proxy: 177 | # this is the host/port that psql should connect with 178 | listen_addr: "0.0.0.0:8001" 179 | # don't use SSL between local proxy and psql 180 | ssl: 181 | enabled: false 182 | # only look at rds instances the server proxy can connect to 183 | target_acl: 184 | allowed_rds_tags: 185 | - name: rds_proxy_enabled 186 | value: "true" 187 | blocked_rds_tags: [] 188 | 189 | upstream_proxies: 190 | default: 191 | # configure a kubernetes port-forward tunnel to the in-cluster proxy 192 | port_forward: 193 | # context: some-other-kube-context 194 | # kube_config: /path/to/alternate_config_file 195 | deployment: rds-auth-proxy 196 | namespace: rds-auth-proxy 197 | local_port: "8000" 198 | remote_port: "8000" 199 | ssl: 200 | # since we disabled SSL on the in-cluster proxy, don't try SSL between 201 | # the client proxy and server proxy 202 | mode: "disable" 203 | ``` 204 | 205 | ## Testing your installation 206 | 207 | Run the following to start the client proxy: 208 | 209 | ```bash 210 | rds-auth-proxy client --target {dbinstanceidentifier} 211 | ``` 212 | 213 | In another shell, you should be able to connect: 214 | 215 | ``` 216 | psql -h localhost -p 8001 -U {db-user-with-iam-auth} 217 | ``` 218 | -------------------------------------------------------------------------------- /docs/reference.md: -------------------------------------------------------------------------------- 1 | # Config File Reference 2 | 3 | ## Terms 4 | 5 | **Targets** - Databases 6 | 7 | **Proxy** - The host launched by the rds-auth-proxy binary, can refer to either the client or server proxy. 8 | 9 | **Upstream Proxy** - The server proxy. 10 | 11 | ## Database Tags 12 | 13 | There are a few database tags that can change the behavior of 14 | `rds-auth-proxy` on the client proxy. 15 | 16 | | Tag | Behavior | 17 | | --- | -------- | 18 | | `rds-auth-proxy:db-name` | Provides the end user a hint about the default database name | 19 | | `rds-auth-proxy:local-port` | Sets the local port used by the client proxy for that database. Having a static local port per database allows developers to share connection configurations for various database tools | 20 | 21 | ## Client Config 22 | 23 | A full example of every option available: 24 | 25 | ```yaml 26 | # ~/.config/rds-auth-proxy/config.yaml 27 | 28 | # Options for the local proxy server 29 | proxy: 30 | # The listen address of this proxy 31 | listen_addr: 0.0.0.0:8001 32 | # SSL/TLS config for the proxy itself. 33 | ssl: 34 | # If set to true, without specifying a server 35 | # certificate/private key, will generate a self-signed 36 | # certificate for localhost 37 | enabled: false 38 | 39 | # Path to a pem-encoded certificate for the proxy 40 | certificate: ~/.config/rds-auth-proxy/server-cert.pem 41 | # Path to a pem-encoded private key for the certificate 42 | private_key: ~/.config/rds-auth-proxy/server-key.pem 43 | 44 | # Path to a pem-encoded certificate for upstream connections 45 | # 46 | # If not set, the proxy will generate a self-signed cert for 47 | # targets that require TLS/SSL. This cert can be overridden 48 | # on a per-host basis in the target config block below. 49 | # 50 | # This can be set regardless of whether or not you enable 51 | # TLS/SSL. 52 | client_certificate: ~/.config/rds-auth-proxy/client-cert.pem 53 | # Path to a pem-encoded key for the client certificate 54 | # 55 | # Leave unset, unless you're also providing the 56 | # client_certificate. 57 | client_private_key: ~/.config/rds-auth-proxy/client-key.pem 58 | 59 | # Effectively service-discovery for the proxy. These should 60 | # match the configuration for the upstream proxy. 61 | # 62 | # In the case that you have multiple upstream proxies, this can 63 | # be more permissive, but this has the current limitation that 64 | # your user must know which upstream handles which database 65 | # instances. 66 | target_acl: 67 | # RDS instances must have ALL of these tags to be connectable 68 | # An empty list means ALL instances that the proxy can see are 69 | # connectable 70 | allowed_rds_tags: 71 | - name: "rds_proxy_enabled" 72 | value: "true" # currently, must be an exact match 73 | # RDS instances must not have ANY of these tags to be connectable 74 | # An empty list means ALL instances that the proxy can see are 75 | # connectable 76 | blocked_rds_tags: 77 | - name: "rds_proxy_disabled" 78 | value: "true" # currently, must be an exact match 79 | 80 | # This is where you can specify upstream proxy settings 81 | upstream_proxies: 82 | # The 'default' proxy is used when no --proxy-target flag is 83 | # passed to the CLI 84 | default: 85 | # You can set up a kubernetes port-forward to the deployment 86 | # using a port-forward config block. 87 | # 88 | # In this case, the host will be set to 0.0.0.0 89 | port_forward: 90 | # The kubernetes config file to use when establishing the port-forward 91 | # If unset, uses ~/.kube/config 92 | kube_config: ~/.config/kube/kube_config 93 | # The context to use within the kube config file 94 | context: development 95 | # The name of your server proxy deployment 96 | deployment: rds-auth-proxy 97 | # The namespace of your server proxy 98 | namespace: rds-auth-proxy 99 | # Optional, the local port for the port-forward tunnel 100 | # if not specified, a random unused port will be used. If you have 101 | # multiple upstream proxies, leave this unset! 102 | local_port: 8000 103 | # The remote port of the proxy 104 | remote_port: 8000 105 | ssl: 106 | # You can enable SSL over a port-forward, but it's not required 107 | # as the port-forward is over a TLS connection. 108 | mode: "disable" # options are "disable", "verify-full", "verify-ca", or "require" 109 | # Additional upstream proxies can be specified as arbitrary keys in 110 | # the block 111 | with_ssl: 112 | host: example.com:8000 113 | ssl: 114 | # Expects the server certificate to be signed by a CA in the system trust store 115 | # and that the common name of the certificate matches the hostname 116 | mode: "verify-full" 117 | # Optionally provide a root CA that the certifiate chain must validate up to, 118 | # rather than the system trust store. 119 | root_certificate: ~/.config/rds-auth-proxy/root-ca.pem 120 | # Path to a pem encoded client certificate that should be used instead of the 121 | # proxies default client certificate for this host 122 | client_cert: ~/.config/rds-auth-proxy/my-client-cert.pem 123 | # Path to the pem encoded private key for the certificate 124 | client_private_key: ~/.config/rds-auth-proxy/my-client-key.pem 125 | 126 | # This is where you can specify SSL settings for the upstream 127 | # (non-RDS) databases 128 | # 129 | # RDS databases are discovered/added automatically at runtime. 130 | targets: 131 | in-cluster-postgres: 132 | # This should be the in-cluster hostname / port that the server-proxy 133 | # will use. 134 | host: postgres:5432 135 | ``` 136 | 137 | ## Server Config 138 | 139 | A full example of every option available: 140 | 141 | ```yaml 142 | # /etc/rds-auth-proxy/config.yaml 143 | 144 | # Options for the local proxy server 145 | proxy: 146 | # The listen address of this proxy 147 | listen_addr: 0.0.0.0:8000 148 | # SSL/TLS config for the proxy itself. 149 | ssl: 150 | # If set to true, without specifying a server 151 | # certificate/private key, will generate a self-signed 152 | # certificate for localhost 153 | enabled: false 154 | 155 | # Path to a pem-encoded certificate for the proxy 156 | certificate: /etc/rds-auth-proxy/server-cert.pem 157 | # Path to a pem-encoded private key for the certificate 158 | private_key: /etc/rds-auth-proxy/server-key.pem 159 | 160 | # Path to a pem-encoded certificate for upstream connections 161 | # 162 | # If not set, the proxy will generate a self-signed cert for 163 | # targets that require TLS/SSL. This cert can be overridden 164 | # on a per-host basis in the target config block below. 165 | # 166 | # This can be set regardless of whether or not you enable 167 | # TLS/SSL. 168 | client_certificate: /etc/rds-auth-proxy/client-cert.pem 169 | # Path to a pem-encoded key for the client certificate 170 | # 171 | # Leave unset, unless you're also providing the 172 | # client_certificate. 173 | client_private_key: /etc/rds-auth-proxy/client-key.pem 174 | 175 | # Effectively service-discovery for the proxy. Before making an 176 | # outbound connection, the proxy will check and verify that it 177 | # knows the host has one of these tags, or is specified in the 178 | # target list below. 179 | target_acl: 180 | # RDS instances must have ALL of these tags to be connectable 181 | # An empty list means ALL instances that the proxy can see are 182 | # connectable 183 | allowed_rds_tags: 184 | - name: "rds_proxy_enabled" 185 | value: "true" # currently, must be an exact match 186 | # RDS instances must not have ANY of these tags to be connectable 187 | # An empty list means ALL instances that the proxy can see are 188 | # connectable 189 | blocked_rds_tags: 190 | - name: "rds_proxy_disabled" 191 | value: "true" # currently, must be an exact match 192 | 193 | # This is where you can specify SSL settings for the upstream 194 | # databases 195 | # 196 | # RDS databases are discovered/added automatically at runtime, if you 197 | # add them yourself, you can override SSL settings. 198 | targets: 199 | in-cluster-postgres: 200 | # This should be the in-cluster hostname / port that the server-proxy 201 | # will use. 202 | host: postgres:5432 203 | overriden-rds-ssl: 204 | host: test-rds.aws.com:5432 205 | ssl: 206 | mode: "verify-full" 207 | # Path to a pem encoded client certificate that should be used instead of the 208 | # proxies default client certificate for this host 209 | client_cert: /etc/rds-auth-proxy/my-client-cert.pem 210 | # Path to the pem encoded private key for the certificate 211 | client_private_key: /etc/rds-auth-proxy/my-client-key.pem 212 | ``` 213 | -------------------------------------------------------------------------------- /pkg/proxy/proxy.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "os" 7 | "sync" 8 | "time" 9 | 10 | pgproto3 "github.com/jackc/pgproto3/v2" 11 | "github.com/mothership/rds-auth-proxy/pkg/log" 12 | "github.com/mothership/rds-auth-proxy/pkg/pg" 13 | "go.uber.org/zap" 14 | ) 15 | 16 | var connectionID = uint64(0) 17 | 18 | // Proxy - Manages a Proxy connection, piping data between proxy and remote. 19 | type Proxy struct { 20 | ID uint64 21 | logger *zap.Logger 22 | backend *pg.PostgresBackend 23 | frontend *pg.PostgresFrontend 24 | waiter sync.WaitGroup 25 | errChan chan errorWrapper 26 | shutdownChan chan bool 27 | config *Config 28 | } 29 | 30 | // newProxy returns a new Proxy that will handle a client connection and open 31 | // a downstream connection to the Postgres server 32 | func newProxy(clientConn net.Conn, errChan chan errorWrapper, config *Config) *Proxy { 33 | // XXX: can't error if no options are passed 34 | backend, _ := pg.NewBackend(clientConn) 35 | shutdownChan := make(chan bool, 1) 36 | connectionID++ 37 | return &Proxy{ 38 | ID: connectionID, 39 | shutdownChan: shutdownChan, 40 | backend: backend, 41 | logger: log.With(zap.Uint64("connectionID", connectionID)), 42 | errChan: errChan, 43 | waiter: sync.WaitGroup{}, 44 | config: config, 45 | } 46 | } 47 | 48 | func (p *Proxy) notifyError(err error) error { 49 | msg := &pgproto3.ErrorResponse{Severity: "FATAL", Message: err.Error()} 50 | if authErr, ok := err.(*pg.AuthFailedError); ok { 51 | msg = authErr.ErrMsg 52 | } 53 | _ = p.backend.Send(msg) 54 | p.errChan <- errorWrapper{ConnectionID: p.ID, Error: err} 55 | return err 56 | } 57 | 58 | func (p *Proxy) notifyStopped() { 59 | p.errChan <- errorWrapper{ConnectionID: p.ID, Error: nil} 60 | } 61 | 62 | // Stop shuts the proxy down and cleans up the connections 63 | func (p *Proxy) Stop() { 64 | close(p.shutdownChan) 65 | p.waiter.Wait() 66 | } 67 | 68 | // Start boots the proxy 69 | func (p *Proxy) Start() error { 70 | defer p.backend.Close() 71 | p.logger.Info("starting connection") 72 | // First, set up the connection with our client (ex: psql) 73 | // and extract the connection parameters from the startup message 74 | connectParams, err := p.backend.SetupConnection(p.config.ServerCertificate) 75 | if err != nil { 76 | return p.notifyError(err) 77 | } 78 | // Get credentials 79 | creds := p.ParseCredentials(connectParams) 80 | if err := p.config.CredentialInterceptor(&creds); err != nil { 81 | return p.notifyError(err) 82 | } 83 | // Next, establish a connection with the upstream database 84 | p.logger.Info("connecting to upstream postgres server", zap.String("postgres_server", creds.Host)) 85 | connection, err := pg.Connect(creds.Host, creds.SSLMode, creds.ClientCertificate, creds.RootCertificate) 86 | if err != nil { 87 | return p.notifyError(err) 88 | } 89 | 90 | // XXX: can't error without options 91 | frontend, _ := pg.NewFrontend(connection) 92 | p.frontend = frontend 93 | defer p.frontend.Close() 94 | 95 | p.logger.Info("connected to upstream postgres server", zap.String("postgres_server", creds.Host)) 96 | 97 | p.logger.Debug("sending startup message", 98 | zap.String("postgres_server", creds.Host), 99 | zap.String("user", creds.Username), 100 | zap.Any("options", creds.Options), 101 | ) 102 | 103 | // If we're in client proxy mode, forward the password in the StartupMessage 104 | if p.config.Mode == ClientSide && creds.Password != "" { 105 | creds.Options["password"] = creds.Password 106 | } 107 | // Now send our own StartupMessage, and pass thru any remaining connection parameters. 108 | startupMessage := createStartupMessage(creds.Username, creds.Database, creds.Options) 109 | if err = frontend.SendRaw(startupMessage.Encode(nil)); err != nil { 110 | return p.notifyError(err) 111 | } 112 | 113 | // Even if we're in server mode, don't bother intercepting the startup message response 114 | // UNLESS we have the password/auth credentials to handle it. This lets the user use 115 | // the proxy normally, for instance, if they are using it without IAM auth 116 | if p.config.Mode == ServerSide && creds.Password != "" { 117 | // Fetch the response to our startup message, most likely this is going to be a request 118 | // for us to authenticate. Assuming it is, forward the password we collected. 119 | p.logger.Debug("handling upstream authentication", zap.String("postgres_server", creds.Host)) 120 | err = p.frontend.HandleAuthenticationRequest(creds.Username, creds.Password) 121 | if err != nil { 122 | return p.notifyError(err) 123 | } 124 | p.logger.Debug("authed successfully with upstream", zap.String("postgres_server", creds.Host)) 125 | // Send the auth result down to the client (ex: psql) 126 | err = p.backend.Send(&pgproto3.AuthenticationOk{}) 127 | if err != nil { 128 | return p.notifyError(errors.New("failed to send auth")) 129 | } 130 | p.logger.Debug("notified client of auth result", zap.String("postgres_server", creds.Host)) 131 | } 132 | 133 | // Now move to generic TLS/TCP proxy 134 | p.logger.Info("startup success, starting full proxy", zap.String("postgres_server", creds.Host)) 135 | p.waiter.Add(2) 136 | go p.proxyToServer() 137 | go p.proxyToClient() 138 | // wait for close... 139 | p.waiter.Wait() 140 | return nil 141 | } 142 | 143 | func (p *Proxy) proxyToServer() { 144 | idleTimeout := 5 * time.Minute 145 | maxTimeouts := int64(int64(idleTimeout) / int64(p.backend.IdleTimeout)) 146 | timeouts := int64(0) 147 | defer p.waiter.Done() 148 | for { 149 | select { 150 | case <-p.shutdownChan: 151 | return 152 | default: 153 | msg, err := p.backend.Receive() 154 | if err != nil { 155 | if isRetryableError(err) { 156 | timeouts++ 157 | if timeouts < maxTimeouts { 158 | continue 159 | } 160 | } 161 | _ = p.notifyError(err) 162 | return 163 | } 164 | timeouts = 0 165 | 166 | switch castedMsg := msg.(type) { 167 | case *pgproto3.Terminate: 168 | err := p.frontend.Send(castedMsg) 169 | if err != nil { 170 | _ = p.notifyError(err) 171 | return 172 | } 173 | p.logger.Debug("got disconnected message") 174 | p.notifyStopped() 175 | return 176 | case *pgproto3.Query: 177 | p.logger.Debug("got query message from client") 178 | if p.config.QueryInterceptor != nil { 179 | if err := p.config.QueryInterceptor(p.frontend, p.backend, castedMsg); err != nil { 180 | if err != WillSendManually { 181 | _ = p.notifyError(err) 182 | return 183 | } 184 | continue 185 | } 186 | } 187 | err := p.frontend.Send(castedMsg) 188 | if err != nil { 189 | _ = p.notifyError(err) 190 | return 191 | } 192 | default: 193 | p.logger.Debug("got message from client") 194 | err := p.frontend.Send(castedMsg) 195 | if err != nil { 196 | _ = p.notifyError(err) 197 | return 198 | } 199 | } 200 | } 201 | } 202 | } 203 | 204 | func (p *Proxy) proxyToClient() { 205 | idleTimeout := 5 * time.Minute 206 | maxTimeouts := int64(int64(idleTimeout) / int64(p.frontend.IdleTimeout)) 207 | timeouts := int64(0) 208 | defer p.waiter.Done() 209 | for { 210 | select { 211 | case <-p.shutdownChan: 212 | return 213 | default: 214 | msg, err := p.frontend.Receive() 215 | if err != nil { 216 | if isRetryableError(err) { 217 | timeouts++ 218 | if timeouts < maxTimeouts { 219 | continue 220 | } 221 | } 222 | _ = p.notifyError(err) 223 | return 224 | } 225 | timeouts = 0 226 | p.logger.Debug("got message from server") 227 | err = p.backend.Send(msg) 228 | if err != nil { 229 | _ = p.notifyError(err) 230 | return 231 | } 232 | } 233 | } 234 | } 235 | 236 | func createStartupMessage(username string, database string, options map[string]string) pgproto3.StartupMessage { 237 | params := map[string]string{ 238 | "user": username, 239 | "database": database, 240 | } 241 | for key, value := range options { 242 | params[key] = value 243 | } 244 | 245 | return pgproto3.StartupMessage{ 246 | ProtocolVersion: pgproto3.ProtocolVersionNumber, 247 | Parameters: params, 248 | } 249 | } 250 | 251 | func isRetryableError(err error) bool { 252 | // These errors are expected in periods of no query activity. 253 | if errors.Is(err, os.ErrDeadlineExceeded) { 254 | return true 255 | } 256 | if netErr, ok := err.(net.Error); ok { 257 | return netErr.Timeout() || netErr.Temporary() //nolint 258 | } 259 | return false 260 | } 261 | 262 | // ParseCredentials takes connection parameters and turns them into Credentials 263 | func (p *Proxy) ParseCredentials(connectionParams map[string]string) Credentials { 264 | extracted := []string{"host", "password", "user", "database"} 265 | creds := Credentials{ 266 | Host: connectionParams["host"], 267 | Password: connectionParams["password"], 268 | Username: connectionParams["user"], 269 | Database: connectionParams["database"], 270 | SSLMode: pg.SSLRequired, 271 | ClientCertificate: p.config.DefaultClientCertificate, 272 | } 273 | for _, key := range extracted { 274 | delete(connectionParams, key) 275 | } 276 | creds.Options = connectionParams 277 | return creds 278 | } 279 | -------------------------------------------------------------------------------- /cmd/proxy_client.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "crypto/x509" 7 | "encoding/pem" 8 | "fmt" 9 | "io/ioutil" 10 | "net" 11 | "os" 12 | "os/signal" 13 | 14 | "github.com/AlecAivazis/survey/v2" 15 | "github.com/mothership/rds-auth-proxy/pkg/aws" 16 | "github.com/mothership/rds-auth-proxy/pkg/config" 17 | "github.com/mothership/rds-auth-proxy/pkg/discovery" 18 | discoveryFactory "github.com/mothership/rds-auth-proxy/pkg/discovery/factory" 19 | "github.com/mothership/rds-auth-proxy/pkg/kubernetes" 20 | "github.com/mothership/rds-auth-proxy/pkg/log" 21 | "github.com/mothership/rds-auth-proxy/pkg/proxy" 22 | "github.com/spf13/cobra" 23 | "go.uber.org/zap" 24 | "go.uber.org/zap/zapcore" 25 | ) 26 | 27 | var proxyClientCommand = &cobra.Command{ 28 | Use: "client", 29 | Short: "Launches the localhost proxy", 30 | Long: `Runs a localhost proxy service in-cluster for connecting to RDS.`, 31 | RunE: func(cmd *cobra.Command, args []string) error { 32 | logCfg := zap.NewDevelopmentConfig() 33 | logCfg.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel) 34 | logCfg.Development = false 35 | logger, err := logCfg.Build(zap.WithCaller(false)) 36 | if err != nil { 37 | return err 38 | } 39 | log.SetLogger(logger) 40 | ctx, cancel := context.WithCancel(context.Background()) 41 | defer cancel() 42 | rdsClient, err := aws.NewRDSClient(ctx) 43 | if err != nil { 44 | return err 45 | } 46 | filepath, err := cmd.Flags().GetString("configfile") 47 | if err != nil { 48 | return err 49 | } 50 | cfg, err := config.LoadConfig(filepath) 51 | if err != nil { 52 | return err 53 | } 54 | discoveryClient := discoveryFactory.FromConfig(rdsClient, &cfg) 55 | if err := discoveryClient.Refresh(ctx); err != nil { 56 | return err 57 | } 58 | 59 | // Look up the proxy target 60 | proxyTarget, err := getProxyTarget(cmd, cfg.ProxyTargets) 61 | if err != nil { 62 | return err 63 | } 64 | 65 | // Look up the real target name in the target list 66 | target, err := getTarget(cmd, discoveryClient) 67 | if err != nil { 68 | return err 69 | } 70 | // Override local port if needed 71 | if target.LocalPort != nil { 72 | addr, err := net.ResolveTCPAddr("tcp", cfg.Proxy.ListenAddr) 73 | if err != nil { 74 | return err 75 | } 76 | cfg.Proxy.ListenAddr = fmt.Sprintf("%s:%s", addr.IP, *target.LocalPort) 77 | } 78 | 79 | // Optionally grab the password 80 | pass, err := cmd.Flags().GetString("password") 81 | if err != nil { 82 | return err 83 | } 84 | 85 | err = printConnectionString(cfg.Proxy.ListenAddr, target) 86 | if err != nil { 87 | return err 88 | } 89 | 90 | if proxyTarget.PortForward != nil { 91 | // setup port-forward 92 | prtCmd, err := kubernetes.BuildPortForwardCommand(ctx, proxyTarget.PortForward.KubeConfigFilePath, kubernetes.PortForwardOptions{ 93 | Namespace: proxyTarget.PortForward.Namespace, 94 | Deployment: proxyTarget.PortForward.DeploymentName, 95 | Ports: []string{fmt.Sprintf("%s:%s", proxyTarget.PortForward.GetLocalPort(), proxyTarget.PortForward.RemotePort)}, 96 | Context: proxyTarget.PortForward.Context, 97 | }) 98 | if err != nil { 99 | return err 100 | } 101 | 102 | go func() { 103 | if err := kubernetes.ForwardPort(ctx, prtCmd); err != nil { 104 | // TODO: blow this up gracefully 105 | log.Error("k8s port-forward caught error", zap.Error(err), zap.String("listen_addr", proxyTarget.GetHost())) 106 | panic(err) 107 | } 108 | log.Info("k8s port-forward exited", zap.String("listen_addr", proxyTarget.GetHost())) 109 | }() 110 | <-prtCmd.ReadyChannel 111 | ports, err := prtCmd.PortForwarder.GetPorts() 112 | if err != nil { 113 | return err 114 | } 115 | portUsed := fmt.Sprintf("%d", ports[0].Local) 116 | proxyTarget.PortForward.LocalPort = &portUsed 117 | log.Info("started k8s port-forward", zap.String("listen_addr", proxyTarget.GetHost())) 118 | } 119 | 120 | log.Info("starting client proxy", zap.String("listen_addr", cfg.Proxy.ListenAddr)) 121 | opts, err := proxySSLOptions(cfg.Proxy.SSL) 122 | if err != nil { 123 | return err 124 | } 125 | 126 | manager, err := proxy.NewManager(proxy.MergeOptions(opts, []proxy.Option{ 127 | proxy.WithListenAddress(cfg.Proxy.ListenAddr), 128 | proxy.WithMode(proxy.ClientSide), 129 | proxy.WithCredentialInterceptor(func(creds *proxy.Credentials) error { 130 | // Send this connection to the proxy host 131 | creds.Host = proxyTarget.GetHost() 132 | // But tell the server proxy to forward to the target host 133 | creds.Options["host"] = target.Host 134 | 135 | // Use provided password, or generate an RDS password to forward through 136 | if pass != "" { 137 | creds.Password = pass 138 | } else if target.IsRDS { 139 | authToken, err := rdsClient.NewAuthToken(ctx, target.Host, target.Region, creds.Username) 140 | if err != nil { 141 | return err 142 | } 143 | creds.Password = authToken 144 | } 145 | 146 | return overrideSSLConfig(creds, proxyTarget.SSL) 147 | })})..., 148 | ) 149 | if err != nil { 150 | return err 151 | } 152 | 153 | // Shutdown app on SIGINT/SIGTERM 154 | signals := make(chan os.Signal, 1) 155 | go func() { 156 | _ = manager.Start(ctx) 157 | close(signals) 158 | }() 159 | signal.Notify(signals, os.Interrupt) 160 | <-signals 161 | cancel() 162 | return nil 163 | }, 164 | } 165 | 166 | func printConnectionString(listenAddr string, target config.Target) error { 167 | addr, err := net.ResolveTCPAddr("tcp", listenAddr) 168 | if err != nil { 169 | return err 170 | } 171 | start := fmt.Sprintf("psql -h %s -p %d", addr.IP, addr.Port) 172 | if target.DefaultDatabase != nil && *target.DefaultDatabase != "" { 173 | start += fmt.Sprintf(" -d %s", *target.DefaultDatabase) 174 | } else { 175 | start += " -d {your_database}" 176 | } 177 | start += " -U {your user}" 178 | fmt.Printf("Setting up a tunnel to %s\n\nGive this a second, then in a new shell, connect with:\n\n\t%s\n\n", target.Name, start) 179 | return nil 180 | } 181 | 182 | func getProxyTarget(cmd *cobra.Command, targets map[string]*config.ProxyTarget) (*config.ProxyTarget, error) { 183 | // Look up the proxy target 184 | proxyName, err := cmd.Flags().GetString("proxy-target") 185 | if err != nil { 186 | return nil, err 187 | } 188 | proxyTarget, ok := targets[proxyName] 189 | if ok { 190 | return proxyTarget, nil 191 | } 192 | 193 | opts := make([]string, 0, len(targets)) 194 | for name := range targets { 195 | opts = append(opts, name) 196 | } 197 | 198 | prompt := &survey.Select{ 199 | Message: "Select an upstream proxy", 200 | Options: opts, 201 | } 202 | 203 | err = survey.AskOne(prompt, &proxyName) 204 | if err != nil { 205 | return nil, err 206 | } 207 | proxyTarget, ok = targets[proxyName] 208 | if ok { 209 | return proxyTarget, nil 210 | } 211 | return nil, fmt.Errorf("couldn't find a proxy target") 212 | } 213 | 214 | func getTarget(cmd *cobra.Command, discoveryClient discovery.Client) (config.Target, error) { 215 | targetName, err := cmd.Flags().GetString("target") 216 | if err != nil { 217 | return config.Target{}, err 218 | } 219 | 220 | if targetName == "" { 221 | targets := discoveryClient.GetTargets() 222 | opts := make([]string, 0, len(targets)) 223 | for _, target := range targets { 224 | opts = append(opts, target.Name) 225 | } 226 | prompt := &survey.Select{ 227 | Message: "Select a database", 228 | Options: opts, 229 | } 230 | 231 | if err := survey.AskOne(prompt, &targetName); err != nil { 232 | return config.Target{}, err 233 | } 234 | } 235 | return discoveryClient.LookupTargetByName(targetName) 236 | } 237 | 238 | func overrideSSLConfig(creds *proxy.Credentials, ssl config.SSL) error { 239 | creds.SSLMode = ssl.Mode 240 | // If the config wants us to use a specific SSL client cert, load it 241 | if ssl.ClientCertificatePath != nil { 242 | // TODO: load sooner / cache 243 | cert, err := tls.LoadX509KeyPair(*ssl.ClientCertificatePath, *ssl.ClientPrivateKeyPath) 244 | if err != nil { 245 | return err 246 | } 247 | creds.ClientCertificate = &cert 248 | } 249 | 250 | // If the config wants us to validate the cert chain goes to a specific root cert for the server proxy 251 | // load it, and set it 252 | if ssl.RootCertificatePath != nil { 253 | rootCABytes, err := ioutil.ReadFile(*ssl.RootCertificatePath) 254 | if err != nil { 255 | return err 256 | } 257 | decoded, _ := pem.Decode(rootCABytes) 258 | cert, err := x509.ParseCertificate(decoded.Bytes) 259 | if err != nil { 260 | return err 261 | } 262 | creds.RootCertificate = cert 263 | } 264 | return nil 265 | } 266 | 267 | func proxySSLOptions(ssl config.ServerSSL) ([]proxy.Option, error) { 268 | opts := make([]proxy.Option, 0, 2) 269 | if ssl.Enabled { 270 | if ssl.CertificatePath == nil && ssl.PrivateKeyPath == nil { 271 | opts = append(opts, proxy.WithGeneratedServerCertificate()) 272 | } else if ssl.CertificatePath != nil && ssl.PrivateKeyPath != nil { 273 | opts = append(opts, proxy.WithServerCertificate(*ssl.CertificatePath, *ssl.PrivateKeyPath)) 274 | } else { 275 | return opts, fmt.Errorf("bad options: when ssl is enabled, either both a certificate and key must be provided, or neither provided") 276 | } 277 | } 278 | 279 | if ssl.ClientCertificatePath == nil && ssl.ClientPrivateKeyPath == nil { 280 | opts = append(opts, proxy.WithGeneratedClientCertificate()) 281 | } else if ssl.ClientCertificatePath != nil && ssl.ClientPrivateKeyPath != nil { 282 | opts = append(opts, proxy.WithClientCertificate(*ssl.ClientCertificatePath, *ssl.ClientPrivateKeyPath)) 283 | } else { 284 | return opts, fmt.Errorf("bad options: either both a client certificate and key must be provided, or neither provided") 285 | } 286 | return opts, nil 287 | 288 | } 289 | 290 | func init() { 291 | proxyClientCommand.PersistentFlags().String("proxy-target", "default", "Name of the proxy target in the configfile") 292 | proxyClientCommand.PersistentFlags().String("target", "", "Name of the target, or db instance identifier that you wish to connect to") 293 | proxyClientCommand.PersistentFlags().String("configfile", "", "Path to the proxy config file") 294 | _ = proxyClientCommand.MarkPersistentFlagDirname("configfile") 295 | proxyClientCommand.PersistentFlags().String("password", "", "Password for the user if IAM auth is not set up") 296 | rootCmd.AddCommand(proxyClientCommand) 297 | } 298 | -------------------------------------------------------------------------------- /pkg/discovery/rds/rds_test.go: -------------------------------------------------------------------------------- 1 | package rds_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/aws/aws-sdk-go-v2/service/rds/types" 9 | "github.com/mothership/rds-auth-proxy/pkg/aws" 10 | "github.com/mothership/rds-auth-proxy/pkg/config" 11 | "github.com/mothership/rds-auth-proxy/pkg/discovery" 12 | . "github.com/mothership/rds-auth-proxy/pkg/discovery/rds" 13 | ) 14 | 15 | func TestRefreshBehaviorWithACLFailures(t *testing.T) { 16 | instances := []aws.DBInstanceResult{ 17 | instance(types.DBInstance{ 18 | DBInstanceIdentifier: strPtr("db-1"), 19 | Endpoint: endpoint("db-1", 5000), 20 | TagList: rdsTags("enabled", "false"), 21 | }), 22 | instance(types.DBInstance{ 23 | DBInstanceIdentifier: strPtr("db-2"), 24 | Endpoint: endpoint("db-2", 5000), 25 | TagList: rdsTags("enabled", "true"), 26 | }), 27 | instance(types.DBInstance{ 28 | DBInstanceIdentifier: strPtr("db-3"), 29 | Endpoint: endpoint("db-3", 5000), 30 | TagList: rdsTags("region", "east", "enabled", "true"), 31 | }), 32 | instance(types.DBInstance{ 33 | DBInstanceIdentifier: strPtr("db-4"), 34 | Endpoint: endpoint("db-4", 5000), 35 | TagList: rdsTags("region", "west", "enabled", "true"), 36 | }), 37 | } 38 | 39 | cases := []struct { 40 | Config config.ConfigFile 41 | Host string 42 | Expected error 43 | }{ 44 | // Case 0: Test no allow/block list 45 | { 46 | Config: configFromACL(nil, nil), 47 | Host: "db-2:5000", 48 | Expected: nil, 49 | }, 50 | // Case 1: Test allow list, no block list 51 | { 52 | Config: configFromACL(tags("enabled", "true"), nil), 53 | Host: "db-1:5000", 54 | Expected: discovery.ErrTargetNotFound, 55 | }, 56 | // Case 2: Test no allow list, block list 57 | { 58 | Config: configFromACL(nil, tags("enabled", "true")), 59 | Host: "db-1:5000", 60 | Expected: nil, 61 | }, 62 | // Case 3: Test blocklist overrides allow list 63 | { 64 | Config: configFromACL(tags("enabled", "true"), tags("enabled", "true")), 65 | Host: "db-1:5000", 66 | Expected: discovery.ErrTargetNotFound, 67 | }, 68 | // Case 4: Test multi-tags in allow list require all to be met 69 | { 70 | Config: configFromACL(tags("enabled", "true", "region", "west"), nil), 71 | Host: "db-3:5000", 72 | Expected: discovery.ErrTargetNotFound, 73 | }, 74 | // Case 5: Test multi-tags in block list require any to be met 75 | { 76 | Config: configFromACL(nil, tags("enabled", "false", "region", "west")), 77 | Host: "db-4:5000", 78 | Expected: discovery.ErrTargetNotFound, 79 | }, 80 | } 81 | 82 | for idx, test := range cases { 83 | client := NewRdsDiscoveryClient(&mockRDSClient{Return: instances}, &test.Config) 84 | if err := client.Refresh(context.Background()); err != nil { 85 | t.Fatalf("[Case %d] expected no error, got: %+v", idx, err) 86 | } 87 | _, err := client.LookupTargetByHost(test.Host) 88 | if err != test.Expected { 89 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, err) 90 | } 91 | } 92 | } 93 | 94 | func TestRefreshBehaviorWithACLSuccesses(t *testing.T) { 95 | instances := []aws.DBInstanceResult{ 96 | instance(types.DBInstance{ 97 | DBInstanceIdentifier: strPtr("db-1"), 98 | Endpoint: endpoint("db-1", 5000), 99 | TagList: rdsTags("enabled", "false"), 100 | }), 101 | instance(types.DBInstance{ 102 | DBInstanceIdentifier: strPtr("db-2"), 103 | Endpoint: endpoint("db-2", 5000), 104 | TagList: rdsTags("enabled", "true"), 105 | }), 106 | instance(types.DBInstance{ 107 | DBInstanceIdentifier: strPtr("db-3"), 108 | Endpoint: endpoint("db-3", 5000), 109 | TagList: rdsTags("region", "east", "enabled", "true"), 110 | }), 111 | instance(types.DBInstance{ 112 | DBInstanceIdentifier: strPtr("db-4"), 113 | Endpoint: endpoint("db-4", 5000), 114 | TagList: rdsTags("region", "west", "enabled", "true"), 115 | }), 116 | } 117 | cases := []struct { 118 | Config config.ConfigFile 119 | Host string 120 | Expected config.Target 121 | }{ 122 | // Case 0: Test no allow/block list 123 | { 124 | Config: configFromACL(nil, nil), 125 | Host: "db-2:5000", 126 | Expected: config.Target{Name: "db-2"}, 127 | }, 128 | // Case 1: Test block list still allows non-blocked dbs 129 | { 130 | Config: configFromACL(nil, tags("enabled", "true")), 131 | Host: "db-1:5000", 132 | Expected: config.Target{Name: "db-1"}, 133 | }, 134 | // Case 2: Test multi-tags in block list doesn't affect other dbs 135 | { 136 | Config: configFromACL(nil, tags("enabled", "false", "region", "west")), 137 | Host: "db-3:5000", 138 | Expected: config.Target{Name: "db-3"}, 139 | }, 140 | } 141 | for idx, test := range cases { 142 | client := NewRdsDiscoveryClient(&mockRDSClient{Return: instances}, &test.Config) 143 | if err := client.Refresh(context.Background()); err != nil { 144 | t.Fatalf("[Case %d] expected no error, got: %+v", idx, err) 145 | } 146 | target, err := client.LookupTargetByHost(test.Host) 147 | if err != nil { 148 | t.Fatalf("[Case %d] got unexpected error: %s", idx, err) 149 | } 150 | if target.Name != test.Expected.Name { 151 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, target) 152 | } 153 | } 154 | } 155 | 156 | func TestGetTargetByNameSuccesses(t *testing.T) { 157 | instances := []aws.DBInstanceResult{ 158 | instance(types.DBInstance{ 159 | DBInstanceIdentifier: strPtr("db-1"), 160 | Endpoint: endpoint("db-1", 5000), 161 | TagList: rdsTags("enabled", "false"), 162 | }), 163 | instance(types.DBInstance{ 164 | DBInstanceIdentifier: strPtr("db-2"), 165 | Endpoint: endpoint("db-2", 5000), 166 | TagList: rdsTags("enabled", "true"), 167 | }), 168 | } 169 | cases := []struct { 170 | Name string 171 | Expected config.Target 172 | }{ 173 | { 174 | Name: "db-2", 175 | Expected: config.Target{Name: "db-2"}, 176 | }, 177 | { 178 | Name: "db-1", 179 | Expected: config.Target{Name: "db-1"}, 180 | }, 181 | } 182 | for idx, test := range cases { 183 | config := configFromACL(nil, nil) 184 | client := NewRdsDiscoveryClient(&mockRDSClient{Return: instances}, &config) 185 | if err := client.Refresh(context.Background()); err != nil { 186 | t.Fatalf("[Case %d] expected no error, got: %+v", idx, err) 187 | } 188 | target, err := client.LookupTargetByName(test.Name) 189 | if err != nil { 190 | t.Fatalf("[Case %d] got unexpected error: %s", idx, err) 191 | } 192 | if target.Name != test.Expected.Name { 193 | t.Errorf("[Case %d] expected %+v. Got %+v.", idx, test.Expected, target) 194 | } 195 | } 196 | } 197 | 198 | func TestGetTargetByNameFailures(t *testing.T) { 199 | instances := []aws.DBInstanceResult{ 200 | instance(types.DBInstance{ 201 | DBInstanceIdentifier: strPtr("db-1"), 202 | Endpoint: endpoint("db-1", 5000), 203 | TagList: rdsTags("enabled", "false"), 204 | }), 205 | instance(types.DBInstance{ 206 | DBInstanceIdentifier: strPtr("db-2"), 207 | Endpoint: endpoint("db-2", 5000), 208 | TagList: rdsTags("enabled", "true"), 209 | }), 210 | } 211 | cases := []struct { 212 | Name string 213 | Expected error 214 | }{ 215 | { 216 | Name: "db-3", 217 | Expected: discovery.ErrTargetNotFound, 218 | }, 219 | } 220 | for idx, test := range cases { 221 | config := configFromACL(nil, nil) 222 | client := NewRdsDiscoveryClient(&mockRDSClient{Return: instances}, &config) 223 | if err := client.Refresh(context.Background()); err != nil { 224 | t.Fatalf("[Case %d] expected no error, got: %+v", idx, err) 225 | } 226 | _, err := client.LookupTargetByName(test.Name) 227 | if err != test.Expected { 228 | t.Fatalf("[Case %d] got %+v, expected error: %s", idx, err, test.Expected) 229 | } 230 | } 231 | } 232 | 233 | func TestRdsDiscoveryClientGetTargets(t *testing.T) { 234 | instances := []aws.DBInstanceResult{ 235 | instance(types.DBInstance{ 236 | DBInstanceIdentifier: strPtr("db-1"), 237 | Endpoint: endpoint("db-1", 5000), 238 | TagList: rdsTags("enabled", "false"), 239 | }), 240 | instance(types.DBInstance{ 241 | DBInstanceIdentifier: strPtr("db-2"), 242 | Endpoint: endpoint("db-2", 5000), 243 | TagList: rdsTags("enabled", "true"), 244 | }), 245 | } 246 | config := configFromACL(nil, nil) 247 | client := NewRdsDiscoveryClient(&mockRDSClient{Return: instances}, &config) 248 | if err := client.Refresh(context.Background()); err != nil { 249 | t.Fatalf("expected no error, got: %+v", err) 250 | } 251 | targets := client.GetTargets() 252 | if len(targets) != len(instances) { 253 | t.Fatalf("missing targets") 254 | } 255 | 256 | // TODO: sort and test that each instance was present, currently a low quality test 257 | } 258 | 259 | type mockRDSClient struct { 260 | Return []aws.DBInstanceResult 261 | } 262 | 263 | var _ aws.RDSClient = (*mockRDSClient)(nil) 264 | 265 | func (m *mockRDSClient) GetPostgresInstances(ctx context.Context) <-chan aws.DBInstanceResult { 266 | retChan := make(chan aws.DBInstanceResult, 1) 267 | go func() { 268 | defer close(retChan) 269 | for _, r := range m.Return { 270 | retChan <- r 271 | if r.Error != nil { 272 | return 273 | } 274 | } 275 | }() 276 | return retChan 277 | } 278 | 279 | func (m *mockRDSClient) NewAuthToken(ctx context.Context, host, region, user string) (string, error) { 280 | return "", nil 281 | } 282 | 283 | func (m *mockRDSClient) RegionForInstance(d types.DBInstance) (string, error) { 284 | return "us-west-2", nil 285 | } 286 | 287 | func instance(inst types.DBInstance) aws.DBInstanceResult { 288 | inst.IAMDatabaseAuthenticationEnabled = true 289 | return aws.DBInstanceResult{Instance: inst} 290 | } 291 | 292 | func rdsTags(pairs ...string) []types.Tag { 293 | if len(pairs)%2 != 0 { 294 | panic(fmt.Errorf("must pass key value pairs to rdsTags")) 295 | } 296 | tags := make([]types.Tag, 0, len(pairs)/2) 297 | for i := 0; i < len(pairs)/2+1; i += 2 { 298 | tags = append(tags, types.Tag{Key: &pairs[i], Value: &pairs[i+1]}) 299 | } 300 | return tags 301 | } 302 | 303 | func tags(pairs ...string) []*config.Tag { 304 | if len(pairs)%2 != 0 { 305 | panic(fmt.Errorf("must pass key value pairs to tags")) 306 | } 307 | tags := make([]*config.Tag, 0, len(pairs)/2) 308 | for i := 0; i < len(pairs)/2+1; i += 2 { 309 | tags = append(tags, &config.Tag{Name: pairs[i], Value: pairs[i+1]}) 310 | } 311 | return tags 312 | } 313 | 314 | func endpoint(host string, port int32) *types.Endpoint { 315 | return &types.Endpoint{Address: &host, Port: port} 316 | } 317 | 318 | func configFromACL(allowed []*config.Tag, blocked []*config.Tag) config.ConfigFile { 319 | return config.ConfigFile{ 320 | Targets: map[string]*config.Target{}, 321 | Proxy: config.Proxy{ 322 | SSL: config.ServerSSL{}, 323 | ACL: config.ACL{ 324 | AllowedRDSTags: config.TagList(allowed), 325 | BlockedRDSTags: config.TagList(blocked), 326 | }, 327 | }, 328 | } 329 | } 330 | 331 | func strPtr(val string) *string { 332 | return &val 333 | } 334 | --------------------------------------------------------------------------------