├── .github ├── SECURITY.md └── workflows │ ├── lint.yaml │ └── test.yaml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── Makefile ├── README.md ├── auth ├── audit │ ├── audit.go │ ├── audit_test.go │ ├── mocks │ │ └── repository.go │ ├── model.go │ └── repositories │ │ ├── dockertest_test.go │ │ ├── postgres.go │ │ └── postgres_test.go └── oidc │ ├── _example │ └── main.go │ ├── cobra.go │ ├── redirect.html │ ├── source_gsa.go │ ├── source_oidc.go │ └── utils.go ├── cli ├── commander │ ├── codex.go │ ├── completion.go │ ├── hooks.go │ ├── layout.go │ ├── manager.go │ ├── reference.go │ └── topics.go ├── printer │ ├── colors.go │ ├── markdown.go │ ├── progress.go │ ├── spinner.go │ ├── structured.go │ ├── table.go │ └── text.go ├── prompter │ └── prompt.go ├── releaser │ └── release.go └── terminator │ ├── brew.go │ ├── browser.go │ ├── pager.go │ └── term.go ├── config ├── config.go ├── config_test.go ├── doc.go └── helpers.go ├── db ├── config.go ├── db.go ├── db_test.go ├── migrate.go ├── migrate_test.go └── migrations │ ├── 1481574547_create_users_table.down.sql │ └── 1481574547_create_users_table.up.sql ├── go.mod ├── go.sum ├── log ├── logger.go ├── logrus.go ├── logrus_test.go ├── noop.go ├── zap.go └── zap_test.go ├── rql ├── README.md ├── parser.go └── parser_test.go ├── server ├── mux │ ├── README.md │ ├── mux.go │ ├── option.go │ └── serve_target.go └── spa │ ├── doc.go │ ├── handler.go │ └── router.go ├── telemetry ├── opentelemetry.go ├── otelgrpc │ ├── otelgrpc.go │ └── otelgrpc_test.go ├── otelhhtpclient │ ├── annotations.go │ ├── http_transport.go │ └── http_transport_test.go └── telemetry.go ├── testing └── dockertestx │ ├── README.md │ ├── configs │ ├── cortex │ │ └── single_process_cortex.yaml │ └── nginx │ │ └── cortex_nginx.conf │ ├── cortex.go │ ├── dockertestx.go │ ├── minio.go │ ├── minio_migrate.go │ ├── nginx.go │ ├── postgres.go │ ├── spicedb.go │ └── spicedb_migrate.go └── utils ├── error_status.go └── error_status_test.go /.github/SECURITY.md: -------------------------------------------------------------------------------- 1 | Raystack takes the security of our software products and services seriously. 2 | 3 | If you believe you have found a security vulnerability in this project, you can report it to us directly using [private vulnerability reporting][]. 4 | 5 | - Include a description of your investigation of this project's codebase and why you believe an exploit is possible. 6 | - POCs and links to code are greatly encouraged. 7 | - Such reports are not eligible for a bounty reward. 8 | 9 | **Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.** 10 | 11 | Thanks for helping make GitHub safe for everyone. 12 | 13 | [private vulnerability reporting]: https://github.com/raystack/meteor/security/advisories -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: 3 | push: 4 | paths: 5 | - "**.go" 6 | - go.mod 7 | - go.sum 8 | pull_request: 9 | paths: 10 | - "**.go" 11 | - go.mod 12 | - go.sum 13 | 14 | jobs: 15 | checks: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v4 20 | with: 21 | fetch-depth: 0 22 | - name: Setup Go 23 | uses: actions/setup-go@v5 24 | with: 25 | go-version-file: "go.mod" 26 | - name: Run linter 27 | uses: golangci/golangci-lint-action@v6 28 | with: 29 | version: v1.60 -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v4 15 | with: 16 | fetch-depth: 0 17 | - name: Setup Go 18 | uses: actions/setup-go@v5 19 | with: 20 | go-version-file: "go.mod" 21 | - name: Run unit tests 22 | run: make test 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -cfg` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | vendor/ 16 | 17 | .idea 18 | .vscode 19 | expt/ 20 | temp.env 21 | .DS_Store 22 | temp.env 23 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | output: 2 | formats: 3 | - format: line-number 4 | linters: 5 | enable-all: false 6 | disable-all: true 7 | enable: 8 | - govet 9 | - goimports 10 | - thelper 11 | - tparallel 12 | - unconvert 13 | - wastedassign 14 | - revive 15 | - unused 16 | - gofmt 17 | - whitespace 18 | - misspell 19 | linters-settings: 20 | revive: 21 | ignore-generated-header: true 22 | severity: warning 23 | severity: 24 | default-severity: error 25 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: check fmt lint test vet help 2 | .DEFAULT_GOAL := help 3 | 4 | check: test lint ## Run tests and linters 5 | 6 | test: ## Run tests 7 | go test ./... -race 8 | 9 | lint: ## Run linter 10 | golangci-lint run 11 | 12 | help: 13 | @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # salt 2 | 3 | [![GoDoc reference](https://img.shields.io/badge/godoc-reference-5272B4.svg)](https://godoc.org/github.com/raystack/salt) 4 | ![test workflow](https://github.com/raystack/salt/actions/workflows/test.yaml/badge.svg) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/raystack/salt)](https://goreportcard.com/report/github.com/raystack/salt) 6 | 7 | Salt is a Golang utility library offering a variety of packages to simplify and enhance application development. It provides modular and reusable components for common tasks, including configuration management, CLI utilities, authentication, logging, and more. 8 | 9 | ## Installation 10 | 11 | To use, run the following command: 12 | 13 | ``` 14 | go get github.com/raystack/salt 15 | ``` 16 | 17 | ## Pacakages 18 | 19 | ### Configuration and Environment 20 | - **`config`** 21 | Utilities for managing application configurations using environment variables, files, or defaults. 22 | 23 | ### CLI Utilities 24 | - **`cli/cmdx`** 25 | Command execution and management tools. 26 | 27 | - **`cli/printer`** 28 | Utilities for formatting and printing output to the terminal. 29 | 30 | - **`cli/prompt`** 31 | Interactive CLI prompts for user input. 32 | 33 | - **`cli/terminal`** 34 | Terminal utilities for colors, cursor management, and formatting. 35 | 36 | - **`cli/version`** 37 | Utilities for displaying and managing CLI tool versions. 38 | 39 | ### Authentication and Security 40 | - **`auth/oidc`** 41 | Helpers for integrating OpenID Connect authentication flows. 42 | 43 | - **`auth/audit`** 44 | Auditing tools for tracking security events and compliance. 45 | 46 | ### Server and Infrastructure 47 | - **`server`** 48 | Utilities for setting up and managing HTTP or RPC servers. 49 | 50 | - **`db`** 51 | Helpers for database connections, migrations, and query execution. 52 | 53 | - **`telemetry`** 54 | Observability tools for capturing application metrics and traces. 55 | 56 | ### Development and Testing 57 | - **`dockertestx`** 58 | Tools for creating and managing Docker-based testing environments. 59 | 60 | ### Utilities 61 | - **`log`** 62 | Simplified logging utilities for structured and unstructured log messages. 63 | 64 | - **`utils`** 65 | General-purpose utility functions for common programming tasks. 66 | -------------------------------------------------------------------------------- /auth/audit/audit.go: -------------------------------------------------------------------------------- 1 | //go:generate mockery --name=repository --exported 2 | 3 | package audit 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "fmt" 9 | "time" 10 | ) 11 | 12 | var ( 13 | TimeNow = time.Now 14 | 15 | ErrInvalidMetadata = errors.New("failed to cast existing metadata to map[string]interface{} type") 16 | ) 17 | 18 | type actorContextKey struct{} 19 | type metadataContextKey struct{} 20 | 21 | func WithActor(ctx context.Context, actor string) context.Context { 22 | return context.WithValue(ctx, actorContextKey{}, actor) 23 | } 24 | 25 | func WithMetadata(ctx context.Context, md map[string]interface{}) (context.Context, error) { 26 | existingMetadata := ctx.Value(metadataContextKey{}) 27 | if existingMetadata == nil { 28 | return context.WithValue(ctx, metadataContextKey{}, md), nil 29 | } 30 | 31 | // append new metadata 32 | mapMd, ok := existingMetadata.(map[string]interface{}) 33 | if !ok { 34 | return nil, ErrInvalidMetadata 35 | } 36 | for k, v := range md { 37 | mapMd[k] = v 38 | } 39 | 40 | return context.WithValue(ctx, metadataContextKey{}, mapMd), nil 41 | } 42 | 43 | type repository interface { 44 | Init(context.Context) error 45 | Insert(context.Context, *Log) error 46 | } 47 | 48 | type AuditOption func(*Service) 49 | 50 | func WithRepository(r repository) AuditOption { 51 | return func(s *Service) { 52 | s.repository = r 53 | } 54 | } 55 | 56 | func WithMetadataExtractor(fn func(context.Context) map[string]interface{}) AuditOption { 57 | return func(s *Service) { 58 | s.withMetadata = func(ctx context.Context) (context.Context, error) { 59 | md := fn(ctx) 60 | return WithMetadata(ctx, md) 61 | } 62 | } 63 | } 64 | 65 | func WithActorExtractor(fn func(context.Context) (string, error)) AuditOption { 66 | return func(s *Service) { 67 | s.actorExtractor = fn 68 | } 69 | } 70 | 71 | func defaultActorExtractor(ctx context.Context) (string, error) { 72 | if actor, ok := ctx.Value(actorContextKey{}).(string); ok { 73 | return actor, nil 74 | } 75 | return "", nil 76 | } 77 | 78 | type Service struct { 79 | repository repository 80 | actorExtractor func(context.Context) (string, error) 81 | withMetadata func(context.Context) (context.Context, error) 82 | } 83 | 84 | func New(opts ...AuditOption) *Service { 85 | svc := &Service{ 86 | actorExtractor: defaultActorExtractor, 87 | } 88 | for _, o := range opts { 89 | o(svc) 90 | } 91 | 92 | return svc 93 | } 94 | 95 | func (s *Service) Log(ctx context.Context, action string, data interface{}) error { 96 | if s.withMetadata != nil { 97 | var err error 98 | if ctx, err = s.withMetadata(ctx); err != nil { 99 | return err 100 | } 101 | } 102 | 103 | l := &Log{ 104 | Timestamp: TimeNow(), 105 | Action: action, 106 | Data: data, 107 | } 108 | 109 | if md, ok := ctx.Value(metadataContextKey{}).(map[string]interface{}); ok { 110 | l.Metadata = md 111 | } 112 | 113 | if s.actorExtractor != nil { 114 | actor, err := s.actorExtractor(ctx) 115 | if err != nil { 116 | return fmt.Errorf("extracting actor: %w", err) 117 | } 118 | l.Actor = actor 119 | } 120 | 121 | return s.repository.Insert(ctx, l) 122 | } 123 | -------------------------------------------------------------------------------- /auth/audit/audit_test.go: -------------------------------------------------------------------------------- 1 | package audit_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | "time" 8 | 9 | "github.com/raystack/salt/auth/audit" 10 | "github.com/raystack/salt/auth/audit/mocks" 11 | 12 | "github.com/stretchr/testify/mock" 13 | "github.com/stretchr/testify/suite" 14 | ) 15 | 16 | type AuditTestSuite struct { 17 | suite.Suite 18 | 19 | now time.Time 20 | 21 | mockRepository *mocks.Repository 22 | service *audit.Service 23 | } 24 | 25 | func (s *AuditTestSuite) setupTest() { 26 | s.mockRepository = new(mocks.Repository) 27 | s.service = audit.New( 28 | audit.WithMetadataExtractor(func(context.Context) map[string]interface{} { 29 | return map[string]interface{}{ 30 | "trace_id": "test-trace-id", 31 | "app_name": "guardian_test", 32 | "app_version": 1, 33 | } 34 | }), 35 | audit.WithRepository(s.mockRepository), 36 | ) 37 | 38 | s.now = time.Now() 39 | audit.TimeNow = func() time.Time { 40 | return s.now 41 | } 42 | } 43 | 44 | func TestAudit(t *testing.T) { 45 | suite.Run(t, new(AuditTestSuite)) 46 | } 47 | 48 | func (s *AuditTestSuite) TestLog() { 49 | s.Run("should insert to repository", func() { 50 | s.setupTest() 51 | 52 | s.mockRepository.On("Insert", mock.Anything, &audit.Log{ 53 | Timestamp: s.now, 54 | Action: "action", 55 | Actor: "user@example.com", 56 | Data: map[string]interface{}{"foo": "bar"}, 57 | Metadata: map[string]interface{}{ 58 | "trace_id": "test-trace-id", 59 | "app_name": "guardian_test", 60 | "app_version": 1, 61 | }, 62 | }).Return(nil) 63 | 64 | ctx := context.Background() 65 | ctx = audit.WithActor(ctx, "user@example.com") 66 | err := s.service.Log(ctx, "action", map[string]interface{}{"foo": "bar"}) 67 | s.NoError(err) 68 | }) 69 | 70 | s.Run("actor extractor", func() { 71 | s.Run("should use actor extractor if option given", func() { 72 | expectedActor := "test-actor" 73 | s.service = audit.New( 74 | audit.WithActorExtractor(func(ctx context.Context) (string, error) { 75 | return expectedActor, nil 76 | }), 77 | audit.WithRepository(s.mockRepository), 78 | ) 79 | 80 | s.mockRepository.On("Insert", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { 81 | log := args.Get(1).(*audit.Log) 82 | s.Equal(expectedActor, log.Actor) 83 | }).Return(nil).Once() 84 | 85 | err := s.service.Log(context.Background(), "", nil) 86 | s.NoError(err) 87 | }) 88 | 89 | s.Run("should return error if extractor returns error", func() { 90 | expectedError := errors.New("test error") 91 | s.service = audit.New( 92 | audit.WithActorExtractor(func(ctx context.Context) (string, error) { 93 | return "", expectedError 94 | }), 95 | ) 96 | 97 | err := s.service.Log(context.Background(), "", nil) 98 | s.ErrorIs(err, expectedError) 99 | }) 100 | }) 101 | 102 | s.Run("metadata", func() { 103 | s.Run("should pass empty trace id if extractor not found", func() { 104 | s.service = audit.New( 105 | audit.WithMetadataExtractor(func(ctx context.Context) map[string]interface{} { 106 | return map[string]interface{}{ 107 | "app_name": "guardian_test", 108 | "app_version": 1, 109 | } 110 | }), 111 | audit.WithRepository(s.mockRepository), 112 | ) 113 | 114 | s.mockRepository.On("Insert", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { 115 | l := args.Get(1).(*audit.Log) 116 | s.IsType(map[string]interface{}{}, l.Metadata) 117 | 118 | md := l.Metadata.(map[string]interface{}) 119 | s.Empty(md["trace_id"]) 120 | s.NotEmpty(md["app_name"]) 121 | s.NotEmpty(md["app_version"]) 122 | }).Return(nil).Once() 123 | 124 | err := s.service.Log(context.Background(), "", nil) 125 | s.NoError(err) 126 | }) 127 | 128 | s.Run("should append new metadata to existing one", func() { 129 | s.service = audit.New( 130 | audit.WithMetadataExtractor(func(ctx context.Context) map[string]interface{} { 131 | return map[string]interface{}{ 132 | "existing": "foobar", 133 | } 134 | }), 135 | audit.WithRepository(s.mockRepository), 136 | ) 137 | 138 | expectedMetadata := map[string]interface{}{ 139 | "existing": "foobar", 140 | "new": "foobar", 141 | } 142 | s.mockRepository.On("Insert", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { 143 | log := args.Get(1).(*audit.Log) 144 | s.Equal(expectedMetadata, log.Metadata) 145 | }).Return(nil).Once() 146 | 147 | ctx, err := audit.WithMetadata(context.Background(), map[string]interface{}{ 148 | "new": "foobar", 149 | }) 150 | s.Require().NoError(err) 151 | 152 | err = s.service.Log(ctx, "", nil) 153 | s.NoError(err) 154 | }) 155 | }) 156 | 157 | s.Run("should return error if repository.Insert fails", func() { 158 | s.setupTest() 159 | 160 | expectedError := errors.New("test error") 161 | s.mockRepository.On("Insert", mock.Anything, mock.Anything).Return(expectedError) 162 | 163 | err := s.service.Log(context.Background(), "", nil) 164 | s.ErrorIs(err, expectedError) 165 | }) 166 | } 167 | -------------------------------------------------------------------------------- /auth/audit/mocks/repository.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v2.10.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | context "context" 7 | "github.com/raystack/salt/auth/audit" 8 | 9 | mock "github.com/stretchr/testify/mock" 10 | ) 11 | 12 | // Repository is an autogenerated mock type for the repository type 13 | type Repository struct { 14 | mock.Mock 15 | } 16 | 17 | // Init provides a mock function with given fields: _a0 18 | func (_m *Repository) Init(_a0 context.Context) error { 19 | ret := _m.Called(_a0) 20 | 21 | var r0 error 22 | if rf, ok := ret.Get(0).(func(context.Context) error); ok { 23 | r0 = rf(_a0) 24 | } else { 25 | r0 = ret.Error(0) 26 | } 27 | 28 | return r0 29 | } 30 | 31 | // Insert provides a mock function with given fields: _a0, _a1 32 | func (_m *Repository) Insert(_a0 context.Context, _a1 *audit.Log) error { 33 | ret := _m.Called(_a0, _a1) 34 | 35 | var r0 error 36 | if rf, ok := ret.Get(0).(func(context.Context, *audit.Log) error); ok { 37 | r0 = rf(_a0, _a1) 38 | } else { 39 | r0 = ret.Error(0) 40 | } 41 | 42 | return r0 43 | } 44 | -------------------------------------------------------------------------------- /auth/audit/model.go: -------------------------------------------------------------------------------- 1 | package audit 2 | 3 | import "time" 4 | 5 | type Log struct { 6 | Timestamp time.Time `json:"timestamp"` 7 | Action string `json:"action"` 8 | Actor string `json:"actor"` 9 | Data interface{} `json:"data"` 10 | Metadata interface{} `json:"metadata"` 11 | } 12 | -------------------------------------------------------------------------------- /auth/audit/repositories/dockertest_test.go: -------------------------------------------------------------------------------- 1 | package repositories_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/raystack/salt/auth/audit/repositories" 10 | 11 | _ "github.com/lib/pq" 12 | "github.com/ory/dockertest/v3" 13 | "github.com/ory/dockertest/v3/docker" 14 | "github.com/raystack/salt/log" 15 | ) 16 | 17 | func newTestRepository(logger log.Logger) (*repositories.PostgresRepository, *dockertest.Pool, *dockertest.Resource, error) { 18 | host := "localhost" 19 | port := "5433" 20 | user := "test_user" 21 | password := "test_pass" 22 | dbName := "test_db" 23 | sslMode := "disable" 24 | 25 | opts := &dockertest.RunOptions{ 26 | Repository: "postgres", 27 | Tag: "13", 28 | Env: []string{ 29 | "POSTGRES_PASSWORD=" + password, 30 | "POSTGRES_USER=" + user, 31 | "POSTGRES_DB=" + dbName, 32 | }, 33 | PortBindings: map[docker.Port][]docker.PortBinding{ 34 | "5432": { 35 | {HostIP: "0.0.0.0", HostPort: port}, 36 | }, 37 | }, 38 | } 39 | 40 | // uses a sensible default on windows (tcp/http) and linux/osx (socket) 41 | pool, err := dockertest.NewPool("") 42 | if err != nil { 43 | return nil, nil, nil, fmt.Errorf("could not create dockertest pool: %w", err) 44 | } 45 | 46 | resource, err := pool.RunWithOptions(opts, func(config *docker.HostConfig) { 47 | config.AutoRemove = true 48 | config.RestartPolicy = docker.RestartPolicy{Name: "no"} 49 | }) 50 | if err != nil { 51 | return nil, nil, nil, fmt.Errorf("could not start resource: %w", err) 52 | } 53 | 54 | port = resource.GetPort("5432/tcp") 55 | 56 | // attach terminal logger to container if exists 57 | // for debugging purpose 58 | if logger.Level() == "debug" { 59 | logWaiter, err := pool.Client.AttachToContainerNonBlocking(docker.AttachToContainerOptions{ 60 | Container: resource.Container.ID, 61 | OutputStream: logger.Writer(), 62 | ErrorStream: logger.Writer(), 63 | Stderr: true, 64 | Stdout: true, 65 | Stream: true, 66 | }) 67 | if err != nil { 68 | logger.Fatal("could not connect to postgres container log output", "error", err) 69 | } 70 | defer func() { 71 | if err = logWaiter.Close(); err != nil { 72 | logger.Fatal("could not close container log", "error", err) 73 | } 74 | 75 | if err = logWaiter.Wait(); err != nil { 76 | logger.Fatal("could not wait for container log to close", "error", err) 77 | } 78 | }() 79 | } 80 | 81 | // Tell docker to hard kill the container in 120 seconds 82 | if err := resource.Expire(120); err != nil { 83 | return nil, nil, nil, err 84 | } 85 | 86 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet 87 | pool.MaxWait = 60 * time.Second 88 | 89 | var repo *repositories.PostgresRepository 90 | time.Sleep(5 * time.Second) 91 | if err := pool.Retry(func() error { 92 | db, err := sql.Open("postgres", fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", host, port, user, password, dbName, sslMode)) 93 | if err != nil { 94 | return err 95 | } 96 | repo = repositories.NewPostgresRepository(db) 97 | 98 | return db.Ping() 99 | }); err != nil { 100 | return nil, nil, nil, fmt.Errorf("could not connect to docker: %w", err) 101 | } 102 | 103 | if err := setup(repo); err != nil { 104 | logger.Fatal("failed to setup and migrate DB", "error", err) 105 | } 106 | return repo, pool, resource, nil 107 | } 108 | 109 | func setup(repo *repositories.PostgresRepository) error { 110 | var queries = []string{ 111 | "DROP SCHEMA public CASCADE", 112 | "CREATE SCHEMA public", 113 | } 114 | for _, query := range queries { 115 | repo.DB().Exec(query) 116 | } 117 | 118 | if err := repo.Init(context.Background()); err != nil { 119 | return err 120 | } 121 | 122 | return nil 123 | } 124 | 125 | func purgeTestDocker(pool *dockertest.Pool, resource *dockertest.Resource) error { 126 | if err := pool.Purge(resource); err != nil { 127 | return fmt.Errorf("could not purge resource: %w", err) 128 | } 129 | return nil 130 | } 131 | -------------------------------------------------------------------------------- /auth/audit/repositories/postgres.go: -------------------------------------------------------------------------------- 1 | package repositories 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "encoding/json" 7 | "fmt" 8 | "time" 9 | 10 | "github.com/raystack/salt/auth/audit" 11 | 12 | "github.com/jmoiron/sqlx/types" 13 | ) 14 | 15 | type AuditModel struct { 16 | Timestamp time.Time `db:"timestamp"` 17 | Action string `db:"action"` 18 | Actor string `db:"actor"` 19 | Data types.NullJSONText `db:"data"` 20 | Metadata types.NullJSONText `db:"metadata"` 21 | } 22 | 23 | type PostgresRepository struct { 24 | db *sql.DB 25 | } 26 | 27 | func NewPostgresRepository(db *sql.DB) *PostgresRepository { 28 | return &PostgresRepository{db} 29 | } 30 | 31 | func (r *PostgresRepository) DB() *sql.DB { 32 | return r.db 33 | } 34 | 35 | func (r *PostgresRepository) Init(ctx context.Context) error { 36 | sql := ` 37 | CREATE TABLE IF NOT EXISTS audit_logs ( 38 | timestamp TIMESTAMP WITH TIME ZONE NOT NULL, 39 | action TEXT NOT NULL, 40 | actor TEXT NOT NULL, 41 | data JSONB NOT NULL, 42 | metadata JSONB NOT NULL 43 | ); 44 | 45 | CREATE INDEX IF NOT EXISTS audit_logs_timestamp_idx ON audit_logs (timestamp); 46 | CREATE INDEX IF NOT EXISTS audit_logs_action_idx ON audit_logs (action); 47 | CREATE INDEX IF NOT EXISTS audit_logs_actor_idx ON audit_logs (actor); 48 | ` 49 | if _, err := r.db.ExecContext(ctx, sql); err != nil { 50 | return fmt.Errorf("migrating audit model to postgres db: %w", err) 51 | } 52 | return nil 53 | } 54 | 55 | func (r *PostgresRepository) Insert(ctx context.Context, l *audit.Log) error { 56 | m := &AuditModel{ 57 | Timestamp: l.Timestamp, 58 | Action: l.Action, 59 | Actor: l.Actor, 60 | } 61 | 62 | if l.Data != nil { 63 | data, err := json.Marshal(l.Data) 64 | if err != nil { 65 | return fmt.Errorf("marshalling data: %w", err) 66 | } 67 | m.Data = types.NullJSONText{JSONText: data, Valid: true} 68 | } 69 | 70 | if l.Metadata != nil { 71 | metadata, err := json.Marshal(l.Metadata) 72 | if err != nil { 73 | return fmt.Errorf("marshalling metadata: %w", err) 74 | } 75 | m.Metadata = types.NullJSONText{JSONText: metadata, Valid: true} 76 | } 77 | 78 | if _, err := r.db.ExecContext(ctx, "INSERT INTO audit_logs (timestamp, action, actor, data, metadata) VALUES ($1, $2, $3, $4, $5)", m.Timestamp, m.Action, m.Actor, m.Data, m.Metadata); err != nil { 79 | return fmt.Errorf("inserting to db: %w", err) 80 | } 81 | 82 | return nil 83 | } 84 | -------------------------------------------------------------------------------- /auth/audit/repositories/postgres_test.go: -------------------------------------------------------------------------------- 1 | package repositories_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/raystack/salt/auth/audit" 9 | "github.com/raystack/salt/auth/audit/repositories" 10 | 11 | "github.com/google/go-cmp/cmp" 12 | "github.com/google/go-cmp/cmp/cmpopts" 13 | "github.com/jmoiron/sqlx/types" 14 | "github.com/raystack/salt/log" 15 | "github.com/stretchr/testify/suite" 16 | ) 17 | 18 | type PostgresRepositoryTestSuite struct { 19 | suite.Suite 20 | 21 | repository *repositories.PostgresRepository 22 | } 23 | 24 | func TestPostgresRepository(t *testing.T) { 25 | suite.Run(t, new(PostgresRepositoryTestSuite)) 26 | } 27 | 28 | func (s *PostgresRepositoryTestSuite) SetupSuite() { 29 | var err error 30 | repository, pool, dockerResource, err := newTestRepository(log.NewLogrus()) 31 | if err != nil { 32 | s.T().Fatal(err) 33 | } 34 | s.repository = repository 35 | 36 | s.T().Cleanup(func() { 37 | if err := s.repository.DB().Close(); err != nil { 38 | s.T().Fatal(err) 39 | } 40 | if err := purgeTestDocker(pool, dockerResource); err != nil { 41 | s.T().Fatal(err) 42 | } 43 | }) 44 | } 45 | 46 | func (s *PostgresRepositoryTestSuite) TestInsert() { 47 | s.Run("should insert record to db", func() { 48 | l := &audit.Log{ 49 | Timestamp: time.Now(), 50 | Action: "test-action", 51 | Actor: "user@example.com", 52 | Data: types.NullJSONText{ 53 | JSONText: []byte(`{"test": "data"}`), 54 | Valid: true, 55 | }, 56 | Metadata: types.NullJSONText{ 57 | JSONText: []byte(`{"test": "metadata"}`), 58 | Valid: true, 59 | }, 60 | } 61 | 62 | err := s.repository.Insert(context.Background(), l) 63 | s.Require().NoError(err) 64 | 65 | rows, err := s.repository.DB().Query("SELECT * FROM audit_logs") 66 | var actualResult repositories.AuditModel 67 | for rows.Next() { 68 | err := rows.Scan(&actualResult.Timestamp, &actualResult.Action, &actualResult.Actor, &actualResult.Data, &actualResult.Metadata) 69 | s.Require().NoError(err) 70 | } 71 | 72 | s.NoError(err) 73 | s.NotNil(actualResult) 74 | if diff := cmp.Diff(l.Timestamp, actualResult.Timestamp, cmpopts.EquateApproxTime(time.Microsecond)); diff != "" { 75 | s.T().Errorf("result not match, diff: %v", diff) 76 | } 77 | s.Equal(l.Action, actualResult.Action) 78 | s.Equal(l.Actor, actualResult.Actor) 79 | s.Equal(l.Data, actualResult.Data) 80 | s.Equal(l.Metadata, actualResult.Metadata) 81 | }) 82 | 83 | s.Run("should return error if data marshalling returns error", func() { 84 | l := &audit.Log{ 85 | Data: make(chan int), 86 | } 87 | 88 | err := s.repository.Insert(context.Background(), l) 89 | s.EqualError(err, "marshalling data: json: unsupported type: chan int") 90 | }) 91 | 92 | s.Run("should return error if metadata marshalling returns error", func() { 93 | l := &audit.Log{ 94 | Metadata: map[string]interface{}{ 95 | "foo": make(chan int), 96 | }, 97 | } 98 | 99 | err := s.repository.Insert(context.Background(), l) 100 | s.EqualError(err, "marshalling metadata: json: unsupported type: chan int") 101 | }) 102 | } 103 | -------------------------------------------------------------------------------- /auth/oidc/_example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "github.com/raystack/salt/auth/oidc" 6 | "log" 7 | "os" 8 | "strings" 9 | 10 | "golang.org/x/oauth2" 11 | "golang.org/x/oauth2/google" 12 | ) 13 | 14 | func main() { 15 | cfg := &oauth2.Config{ 16 | ClientID: os.Getenv("CLIENT_ID"), 17 | ClientSecret: os.Getenv("CLIENT_SECRET"), 18 | Endpoint: google.Endpoint, 19 | RedirectURL: "http://localhost:5454", 20 | Scopes: strings.Split(os.Getenv("OIDC_SCOPES"), ","), 21 | } 22 | aud := os.Getenv("OIDC_AUDIENCE") 23 | keyFile := os.Getenv("GOOGLE_SERVICE_ACCOUNT") 24 | 25 | onTokenOrErr := func(t *oauth2.Token, err error) { 26 | if err != nil { 27 | log.Fatalf("oidc login failed: %v", err) 28 | } 29 | 30 | _ = json.NewEncoder(os.Stdout).Encode(map[string]interface{}{ 31 | "token_type": t.TokenType, 32 | "access_token": t.AccessToken, 33 | "expiry": t.Expiry, 34 | "refresh_token": t.RefreshToken, 35 | "id_token": t.Extra("id_token"), 36 | }) 37 | } 38 | 39 | _ = oidc.LoginCmd(cfg, aud, keyFile, onTokenOrErr).Execute() 40 | } 41 | -------------------------------------------------------------------------------- /auth/oidc/cobra.go: -------------------------------------------------------------------------------- 1 | package oidc 2 | 3 | import ( 4 | "context" 5 | "os/signal" 6 | "syscall" 7 | 8 | "github.com/spf13/cobra" 9 | "golang.org/x/oauth2" 10 | ) 11 | 12 | func LoginCmd(cfg *oauth2.Config, aud, keyFilePath string, onTokenOrErr func(t *oauth2.Token, err error)) *cobra.Command { 13 | cmd := &cobra.Command{ 14 | Use: "login", 15 | Short: "Login with your Google account.", 16 | Run: func(cmd *cobra.Command, args []string) { 17 | ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 18 | defer cancel() 19 | 20 | var ts oauth2.TokenSource 21 | if keyFilePath != "" { 22 | var err error 23 | ts, err = NewGoogleServiceAccountTokenSource(ctx, keyFilePath, aud) 24 | if err != nil { 25 | onTokenOrErr(nil, err) 26 | return 27 | } 28 | } else { 29 | ts = NewTokenSource(ctx, cfg, aud) 30 | } 31 | onTokenOrErr(ts.Token()) 32 | }, 33 | } 34 | 35 | return cmd 36 | } 37 | -------------------------------------------------------------------------------- /auth/oidc/redirect.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Login 4 | 19 | 20 | 21 |
22 |

✅ Done

23 |

It is safe to close this window now.

24 |

Go back to the console to continue.

25 |
26 | 27 | 28 | -------------------------------------------------------------------------------- /auth/oidc/source_gsa.go: -------------------------------------------------------------------------------- 1 | package oidc 2 | 3 | import ( 4 | "context" 5 | 6 | "golang.org/x/oauth2" 7 | "google.golang.org/api/idtoken" 8 | ) 9 | 10 | func NewGoogleServiceAccountTokenSource(ctx context.Context, keyFile, aud string) (oauth2.TokenSource, error) { 11 | return idtoken.NewTokenSource(ctx, aud, idtoken.WithCredentialsFile(keyFile)) 12 | } 13 | -------------------------------------------------------------------------------- /auth/oidc/source_oidc.go: -------------------------------------------------------------------------------- 1 | package oidc 2 | 3 | import ( 4 | "context" 5 | "crypto/sha256" 6 | "errors" 7 | "fmt" 8 | 9 | "golang.org/x/oauth2" 10 | ) 11 | 12 | const ( 13 | // Values from OpenID Connect. 14 | scopeOpenID = "openid" 15 | audienceKey = "audience" 16 | 17 | // Values used in PKCE implementation. 18 | // Refer https://www.rfc-editor.org/rfc/rfc7636 19 | pkceS256 = "S256" 20 | codeVerifierLen = 32 21 | codeChallengeKey = "code_challenge" 22 | codeVerifierKey = "code_verifier" 23 | codeChallengeMethodKey = "code_challenge_method" 24 | ) 25 | 26 | func NewTokenSource(ctx context.Context, conf *oauth2.Config, audience string) oauth2.TokenSource { 27 | conf.Scopes = append(conf.Scopes, scopeOpenID) 28 | return &authHandlerSource{ 29 | ctx: ctx, 30 | config: conf, 31 | audience: audience, 32 | } 33 | } 34 | 35 | type authHandlerSource struct { 36 | ctx context.Context 37 | config *oauth2.Config 38 | audience string 39 | } 40 | 41 | func (source *authHandlerSource) Token() (*oauth2.Token, error) { 42 | stateBytes, err := randomBytes(10) 43 | if err != nil { 44 | return nil, err 45 | } 46 | actualState := string(stateBytes) 47 | 48 | codeVerifier, codeChallenge, challengeMethod, err := newPKCEParams() 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | // Step 1. Send user to authorization page for obtaining consent. 54 | url := source.config.AuthCodeURL(actualState, 55 | oauth2.SetAuthURLParam(audienceKey, source.audience), 56 | oauth2.SetAuthURLParam(codeChallengeKey, codeChallenge), 57 | oauth2.SetAuthURLParam(codeChallengeMethodKey, challengeMethod), 58 | ) 59 | 60 | code, receivedState, err := browserAuthzHandler(source.ctx, source.config.RedirectURL, url) 61 | if err != nil { 62 | return nil, err 63 | } else if receivedState != actualState { 64 | return nil, errors.New("state received in redirection does not match") 65 | } 66 | 67 | // Step 2. Exchange code-grant for tokens (access_token, refresh_token, id_token). 68 | tok, err := source.config.Exchange(source.ctx, code, 69 | oauth2.SetAuthURLParam(audienceKey, source.audience), 70 | oauth2.SetAuthURLParam(codeVerifierKey, codeVerifier), 71 | ) 72 | if err != nil { 73 | return nil, err 74 | } 75 | 76 | idToken, ok := tok.Extra("id_token").(string) 77 | if !ok { 78 | return nil, errors.New("id_token not found in token response") 79 | } 80 | tok.AccessToken = idToken 81 | 82 | return tok, nil 83 | } 84 | 85 | // newPKCEParams generates parameters for 'Proof Key for Code Exchange'. 86 | // Refer https://www.rfc-editor.org/rfc/rfc7636#section-4.2 87 | func newPKCEParams() (verifier, challenge, method string, err error) { 88 | // generate 'verifier' string. 89 | verifierBytes, err := randomBytes(codeVerifierLen) 90 | if err != nil { 91 | return "", "", "", fmt.Errorf("failed to generate random bytes: %v", err) 92 | } 93 | verifier = encode(verifierBytes) 94 | 95 | // generate S256 challenge. 96 | h := sha256.New() 97 | h.Write([]byte(verifier)) 98 | challenge = encode(h.Sum(nil)) 99 | 100 | return verifier, challenge, pkceS256, nil 101 | } 102 | -------------------------------------------------------------------------------- /auth/oidc/utils.go: -------------------------------------------------------------------------------- 1 | package oidc 2 | 3 | import ( 4 | "context" 5 | "crypto/rand" 6 | _ "embed" // for embedded html 7 | "encoding/base64" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net/http" 12 | "net/url" 13 | "os/exec" 14 | "runtime" 15 | "strings" 16 | ) 17 | 18 | const ( 19 | codeParam = "code" 20 | stateParam = "state" 21 | 22 | errParam = "error" 23 | errDescParam = "error_description" 24 | ) 25 | 26 | //go:embed redirect.html 27 | var callbackResponsePage string 28 | 29 | func encode(msg []byte) string { 30 | encoded := base64.StdEncoding.EncodeToString(msg) 31 | encoded = strings.Replace(encoded, "+", "-", -1) 32 | encoded = strings.Replace(encoded, "/", "_", -1) 33 | encoded = strings.Replace(encoded, "=", "", -1) 34 | return encoded 35 | } 36 | 37 | // https://tools.ietf.org/html/rfc7636#section-4.1) 38 | func randomBytes(length int) ([]byte, error) { 39 | const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" 40 | const csLen = byte(len(charset)) 41 | output := make([]byte, 0, length) 42 | for { 43 | buf := make([]byte, length) 44 | if _, err := io.ReadFull(rand.Reader, buf); err != nil { 45 | return nil, fmt.Errorf("failed to read random bytes: %v", err) 46 | } 47 | for _, b := range buf { 48 | // Avoid bias by using a value range that's a multiple of 62 49 | if b < (csLen * 4) { 50 | output = append(output, charset[b%csLen]) 51 | 52 | if len(output) == length { 53 | return output, nil 54 | } 55 | } 56 | } 57 | } 58 | } 59 | 60 | func browserAuthzHandler(ctx context.Context, redirectURL, authCodeURL string) (code string, state string, err error) { 61 | if err := openURL(authCodeURL); err != nil { 62 | return "", "", err 63 | } 64 | 65 | u, err := url.Parse(redirectURL) 66 | if err != nil { 67 | return "", "", err 68 | } 69 | 70 | code, state, err = waitForCallback(ctx, fmt.Sprintf(":%s", u.Port())) 71 | if err != nil { 72 | return "", "", err 73 | } 74 | return code, state, nil 75 | } 76 | 77 | func waitForCallback(ctx context.Context, addr string) (code, state string, err error) { 78 | var cb struct { 79 | code string 80 | state string 81 | err error 82 | } 83 | 84 | stopCh := make(chan struct{}) 85 | srv := &http.Server{ 86 | Addr: addr, 87 | Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 88 | cb.code, cb.state, cb.err = parseCallbackRequest(r) 89 | 90 | w.WriteHeader(http.StatusOK) 91 | w.Header().Set("content-type", "text/html") 92 | _, _ = w.Write([]byte(callbackResponsePage)) 93 | 94 | // try to flush to ensure the page is shown to user before we close 95 | // the server. 96 | if fl, ok := w.(http.Flusher); ok { 97 | fl.Flush() 98 | } 99 | 100 | close(stopCh) 101 | }), 102 | } 103 | 104 | go func() { 105 | select { 106 | case <-stopCh: 107 | _ = srv.Close() 108 | 109 | case <-ctx.Done(): 110 | cb.err = ctx.Err() 111 | _ = srv.Close() 112 | } 113 | }() 114 | 115 | if serveErr := srv.ListenAndServe(); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) { 116 | return "", "", serveErr 117 | } 118 | return cb.code, cb.state, cb.err 119 | } 120 | 121 | func parseCallbackRequest(r *http.Request) (code string, state string, err error) { 122 | if err = r.ParseForm(); err != nil { 123 | return "", "", err 124 | } 125 | 126 | state = r.Form.Get(stateParam) 127 | if state == "" { 128 | return "", "", errors.New("missing state parameter") 129 | } 130 | 131 | if errorCode := r.Form.Get(errParam); errorCode != "" { 132 | // Got error from provider. Passing through. 133 | return "", "", fmt.Errorf("%s: %s", errorCode, r.Form.Get(errDescParam)) 134 | } 135 | 136 | code = r.Form.Get(codeParam) 137 | if code == "" { 138 | return "", "", errors.New("missing code parameter") 139 | } 140 | 141 | return code, state, nil 142 | } 143 | 144 | // openURL opens the specified URL in the default application registered for 145 | // the URL scheme. 146 | func openURL(url string) error { 147 | var cmd string 148 | var args []string 149 | 150 | switch runtime.GOOS { 151 | case "windows": 152 | cmd = "cmd" 153 | args = []string{"/c", "start"} 154 | // If we don't escape &, cmd will ignore everything after the first &. 155 | url = strings.Replace(url, "&", "^&", -1) 156 | 157 | case "darwin": 158 | cmd = "open" 159 | 160 | default: // "linux", "freebsd", "openbsd", "netbsd" 161 | cmd = "xdg-open" 162 | } 163 | 164 | args = append(args, url) 165 | return exec.Command(cmd, args...).Start() 166 | } 167 | -------------------------------------------------------------------------------- /cli/commander/codex.go: -------------------------------------------------------------------------------- 1 | package commander 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | 8 | "github.com/spf13/cobra" 9 | "github.com/spf13/cobra/doc" 10 | ) 11 | 12 | // addMarkdownCommand integrates a hidden `markdown` command into the root command. 13 | // This command generates a Markdown documentation tree for all commands in the hierarchy. 14 | func (m *Manager) addMarkdownCommand(outputPath string) { 15 | markdownCmd := &cobra.Command{ 16 | Use: "markdown", 17 | Short: "Generate Markdown documentation for all commands", 18 | Hidden: true, 19 | Annotations: map[string]string{ 20 | "group": "help", 21 | }, 22 | RunE: func(cmd *cobra.Command, args []string) error { 23 | return m.generateMarkdownTree(outputPath, m.RootCmd) 24 | }, 25 | } 26 | 27 | m.RootCmd.AddCommand(markdownCmd) 28 | } 29 | 30 | // generateMarkdownTree generates a Markdown documentation tree for the given command hierarchy. 31 | // 32 | // Parameters: 33 | // - rootOutputPath: The root directory where the Markdown files will be generated. 34 | // - cmd: The root Cobra command whose hierarchy will be documented. 35 | // 36 | // Returns: 37 | // - An error if any part of the process (file creation, directory creation) fails. 38 | func (m *Manager) generateMarkdownTree(rootOutputPath string, cmd *cobra.Command) error { 39 | dirFilePath := filepath.Join(rootOutputPath, cmd.Name()) 40 | 41 | // Handle subcommands by creating a directory and iterating through subcommands. 42 | if len(cmd.Commands()) > 0 { 43 | if err := ensureDir(dirFilePath); err != nil { 44 | return fmt.Errorf("failed to create directory for command %q: %w", cmd.Name(), err) 45 | } 46 | for _, subCmd := range cmd.Commands() { 47 | if err := m.generateMarkdownTree(dirFilePath, subCmd); err != nil { 48 | return err 49 | } 50 | } 51 | } else { 52 | // Generate a Markdown file for leaf commands. 53 | outFilePath := filepath.Join(rootOutputPath, cmd.Name()+".md") 54 | 55 | f, err := os.Create(outFilePath) 56 | if err != nil { 57 | return err 58 | } 59 | defer f.Close() 60 | 61 | // Generate Markdown with a custom link handler. 62 | return doc.GenMarkdownCustom(cmd, f, func(s string) string { 63 | return filepath.Join(dirFilePath, s) 64 | }) 65 | } 66 | 67 | return nil 68 | } 69 | 70 | // ensureDir ensures that the given directory exists, creating it if necessary. 71 | func ensureDir(path string) error { 72 | if _, err := os.Stat(path); os.IsNotExist(err) { 73 | if err := os.MkdirAll(path, os.ModePerm); err != nil { 74 | return err 75 | } 76 | } 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /cli/commander/completion.go: -------------------------------------------------------------------------------- 1 | package commander 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/MakeNowJust/heredoc" 7 | "github.com/spf13/cobra" 8 | ) 9 | 10 | // addCompletionCommand adds a `completion` command to the CLI. 11 | // The `completion` command generates shell completion scripts 12 | // for Bash, Zsh, Fish, and PowerShell. 13 | // Usage: 14 | // 15 | // $ mycli completion bash 16 | // $ mycli completion zsh 17 | func (m *Manager) addCompletionCommand() { 18 | summary := m.generateCompletionSummary(m.RootCmd.Use) 19 | 20 | completionCmd := &cobra.Command{ 21 | Use: "completion [bash|zsh|fish|powershell]", 22 | Short: "Generate shell completion scripts", 23 | Long: summary, 24 | DisableFlagsInUseLine: true, 25 | ValidArgs: []string{"bash", "zsh", "fish", "powershell"}, 26 | Args: cobra.ExactValidArgs(1), 27 | Run: func(cmd *cobra.Command, args []string) { 28 | switch args[0] { 29 | case "bash": 30 | cmd.Root().GenBashCompletion(os.Stdout) 31 | case "zsh": 32 | cmd.Root().GenZshCompletion(os.Stdout) 33 | case "fish": 34 | cmd.Root().GenFishCompletion(os.Stdout, true) 35 | case "powershell": 36 | cmd.Root().GenPowerShellCompletionWithDesc(os.Stdout) 37 | } 38 | }, 39 | } 40 | 41 | m.RootCmd.AddCommand(completionCmd) 42 | } 43 | 44 | // generateCompletionSummary creates the long description for the `completion` command. 45 | func (m *Manager) generateCompletionSummary(exec string) string { 46 | var execs []interface{} 47 | for i := 0; i < 12; i++ { 48 | execs = append(execs, exec) 49 | } 50 | return heredoc.Docf(`To load completions: 51 | `+"```"+` 52 | Bash: 53 | 54 | $ source <(%s completion bash) 55 | 56 | # To load completions for each session, execute once: 57 | # Linux: 58 | $ %s completion bash > /etc/bash_completion.d/%s 59 | # macOS: 60 | $ %s completion bash > /usr/local/etc/bash_completion.d/%s 61 | 62 | Zsh: 63 | 64 | # If shell completion is not already enabled in your environment, 65 | # you will need to enable it. You can execute the following once: 66 | 67 | $ echo "autoload -U compinit; compinit" >> ~/.zshrc 68 | 69 | # To load completions for each session, execute once: 70 | $ %s completion zsh > "${fpath[1]}/_yourprogram" 71 | 72 | # You will need to start a new shell for this setup to take effect. 73 | 74 | Fish: 75 | 76 | $ %s completion fish | source 77 | 78 | # To load completions for each session, execute once: 79 | $ %s completion fish > ~/.config/fish/completions/%s.fish 80 | 81 | PowerShell: 82 | 83 | PS> %s completion powershell | Out-String | Invoke-Expression 84 | 85 | # To load completions for every new session, run: 86 | PS> %s completion powershell > %s.ps1 87 | # and source this file from your PowerShell profile. 88 | `+"```"+` 89 | `, execs...) 90 | } 91 | -------------------------------------------------------------------------------- /cli/commander/hooks.go: -------------------------------------------------------------------------------- 1 | package commander 2 | 3 | // addClientHooks applies all configured hooks to commands annotated with `client:true`. 4 | func (m *Manager) addClientHooks() { 5 | for _, cmd := range m.RootCmd.Commands() { 6 | for _, hook := range m.Hooks { 7 | if cmd.Annotations["client"] == "true" { 8 | hook.Behavior(cmd) 9 | } 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /cli/commander/manager.go: -------------------------------------------------------------------------------- 1 | package commander 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/spf13/cobra" 7 | ) 8 | 9 | // Manager manages and configures features for a CLI tool. 10 | type Manager struct { 11 | RootCmd *cobra.Command 12 | Help bool // Enable custom help. 13 | Reference bool // Enable reference command. 14 | Completion bool // Enable shell completion. 15 | Config bool // Enable configuration management. 16 | Docs bool // Enable markdown documentation 17 | Hooks []HookBehavior // Hook behaviors to apply to commands 18 | Topics []HelpTopic // Help topics with their details. 19 | } 20 | 21 | // HelpTopic defines a single help topic with its details. 22 | type HelpTopic struct { 23 | Name string 24 | Short string 25 | Long string 26 | Example string 27 | } 28 | 29 | // HookBehavior defines a specific behavior applied to commands. 30 | type HookBehavior struct { 31 | Name string // Name of the hook (e.g., "setup", "auth"). 32 | Behavior func(cmd *cobra.Command) // Function to apply to commands. 33 | } 34 | 35 | // New creates a new CLI Manager using the provided root command and optional configurations. 36 | // 37 | // Parameters: 38 | // - rootCmd: The root Cobra command for the CLI. 39 | // - options: Functional options for configuring the Manager. 40 | // 41 | // Example: 42 | // 43 | // rootCmd := &cobra.Command{Use: "mycli"} 44 | // manager := cmdx.NewCommander(rootCmd, cmdx.WithTopics(...), cmdx.WithHooks(...)) 45 | func New(rootCmd *cobra.Command, options ...func(*Manager)) *Manager { 46 | // Create Manager with defaults 47 | manager := &Manager{ 48 | RootCmd: rootCmd, 49 | Help: true, // Default enabled 50 | Reference: true, // Default enabled 51 | Completion: true, // Default enabled 52 | Docs: false, // Default disabled 53 | Topics: []HelpTopic{}, 54 | Hooks: []HookBehavior{}, 55 | } 56 | 57 | // Apply functional options 58 | for _, opt := range options { 59 | opt(manager) 60 | } 61 | 62 | return manager 63 | } 64 | 65 | // Init sets up the CLI features based on the Manager's configuration. 66 | // It enables or disables features like custom help, reference documentation, 67 | // shell completion, help topics, and client hooks based on the Manager's settings. 68 | func (m *Manager) Init() { 69 | if m.Help { 70 | m.setCustomHelp() 71 | } 72 | if m.Reference { 73 | m.addReferenceCommand() 74 | } 75 | if m.Completion { 76 | m.addCompletionCommand() 77 | } 78 | if m.Docs { 79 | m.addMarkdownCommand("./docs") 80 | } 81 | if len(m.Topics) > 0 { 82 | m.addHelpTopics() 83 | } 84 | 85 | if len(m.Hooks) > 0 { 86 | m.addClientHooks() 87 | } 88 | } 89 | 90 | // WithTopics sets the help topics for the Manager. 91 | func WithTopics(topics []HelpTopic) func(*Manager) { 92 | return func(m *Manager) { 93 | m.Topics = topics 94 | } 95 | } 96 | 97 | // WithHooks sets the hook behaviors for the Manager. 98 | func WithHooks(hooks []HookBehavior) func(*Manager) { 99 | return func(m *Manager) { 100 | m.Hooks = hooks 101 | } 102 | } 103 | 104 | // IsCommandErr checks if the given error is related to a Cobra command error. 105 | // This is useful for distinguishing between user errors (e.g., incorrect commands or flags) 106 | // and program errors, allowing the application to display appropriate messages. 107 | func IsCommandErr(err error) bool { 108 | if err == nil { 109 | return false 110 | } 111 | 112 | // Known Cobra command error keywords 113 | cmdErrorKeywords := []string{ 114 | "unknown command", 115 | "unknown flag", 116 | "unknown shorthand flag", 117 | } 118 | 119 | errMessage := err.Error() 120 | for _, keyword := range cmdErrorKeywords { 121 | if strings.Contains(errMessage, keyword) { 122 | return true 123 | } 124 | } 125 | return false 126 | } 127 | -------------------------------------------------------------------------------- /cli/commander/reference.go: -------------------------------------------------------------------------------- 1 | package commander 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "strings" 8 | 9 | "github.com/raystack/salt/cli/printer" 10 | 11 | "github.com/spf13/cobra" 12 | ) 13 | 14 | // addReferenceCommand adds a `reference` command to the CLI. 15 | // The `reference` command generates markdown documentation for all commands 16 | // in the CLI command tree. 17 | func (m *Manager) addReferenceCommand() { 18 | var isPlain bool 19 | refCmd := &cobra.Command{ 20 | Use: "reference", 21 | Short: "Comprehensive reference of all commands", 22 | Long: m.generateReferenceMarkdown(), 23 | Run: m.runReferenceCommand(&isPlain), 24 | Annotations: map[string]string{ 25 | "group": "help", 26 | }, 27 | } 28 | refCmd.SetHelpFunc(m.runReferenceCommand(&isPlain)) 29 | refCmd.Flags().BoolVarP(&isPlain, "plain", "p", true, "output in plain markdown (without ANSI color)") 30 | 31 | m.RootCmd.AddCommand(refCmd) 32 | } 33 | 34 | // runReferenceCommand handles the output generation for the `reference` command. 35 | // It renders the documentation either as plain markdown or with ANSI color. 36 | func (m *Manager) runReferenceCommand(isPlain *bool) func(cmd *cobra.Command, args []string) { 37 | return func(cmd *cobra.Command, args []string) { 38 | var ( 39 | output string 40 | err error 41 | ) 42 | 43 | if *isPlain { 44 | output = cmd.Long 45 | } else { 46 | output, err = printer.Markdown(cmd.Long) 47 | if err != nil { 48 | fmt.Println("Error generating markdown:", err) 49 | return 50 | } 51 | } 52 | 53 | fmt.Print(output) 54 | } 55 | } 56 | 57 | // generateReferenceMarkdown generates a complete markdown representation 58 | // of the command tree for the `reference` command. 59 | func (m *Manager) generateReferenceMarkdown() string { 60 | buf := bytes.NewBufferString(fmt.Sprintf("# %s reference\n\n", m.RootCmd.Name())) 61 | for _, c := range m.RootCmd.Commands() { 62 | if c.Hidden { 63 | continue 64 | } 65 | m.generateCommandReference(buf, c, 2) 66 | } 67 | return buf.String() 68 | } 69 | 70 | // generateCommandReference recursively generates markdown for a given command 71 | // and its subcommands. 72 | func (m *Manager) generateCommandReference(w io.Writer, cmd *cobra.Command, depth int) { 73 | // Name + Description 74 | fmt.Fprintf(w, "%s `%s`\n\n", strings.Repeat("#", depth), cmd.UseLine()) 75 | fmt.Fprintf(w, "%s\n\n", cmd.Short) 76 | 77 | // Flags 78 | if flagUsages := cmd.Flags().FlagUsages(); flagUsages != "" { 79 | fmt.Fprintf(w, "```\n%s```\n\n", dedent(flagUsages)) 80 | } 81 | 82 | // Subcommands 83 | for _, c := range cmd.Commands() { 84 | if c.Hidden { 85 | continue 86 | } 87 | m.generateCommandReference(w, c, depth+1) 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /cli/commander/topics.go: -------------------------------------------------------------------------------- 1 | package commander 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/spf13/cobra" 7 | ) 8 | 9 | // addHelpTopics adds all configured help topics to the CLI. 10 | // 11 | // Help topics provide detailed information about specific subjects, 12 | // such as environment variables or configuration. 13 | func (m *Manager) addHelpTopics() { 14 | for _, topic := range m.Topics { 15 | m.addHelpTopicCommand(topic) 16 | } 17 | } 18 | 19 | // addHelpTopicCommand adds a single help topic command to the CLI. 20 | func (m *Manager) addHelpTopicCommand(topic HelpTopic) { 21 | helpCmd := &cobra.Command{ 22 | Use: topic.Name, 23 | Short: topic.Short, 24 | Long: topic.Long, 25 | Example: topic.Example, 26 | Hidden: false, 27 | Annotations: map[string]string{ 28 | "group": "help", 29 | }, 30 | } 31 | 32 | helpCmd.SetHelpFunc(helpTopicHelpFunc) 33 | helpCmd.SetUsageFunc(helpTopicUsageFunc) 34 | 35 | m.RootCmd.AddCommand(helpCmd) 36 | } 37 | 38 | // helpTopicHelpFunc customizes the help message for a help topic command. 39 | func helpTopicHelpFunc(cmd *cobra.Command, args []string) { 40 | fmt.Fprintln(cmd.OutOrStdout(), cmd.Long) 41 | if cmd.Example != "" { 42 | fmt.Fprintln(cmd.OutOrStdout(), "\nEXAMPLES") 43 | fmt.Fprintln(cmd.OutOrStdout(), indent(cmd.Example, " ")) 44 | } 45 | } 46 | 47 | // helpTopicUsageFunc customizes the usage message for a help topic command. 48 | func helpTopicUsageFunc(cmd *cobra.Command) error { 49 | fmt.Fprintf(cmd.OutOrStdout(), "Usage: %s help %s\n", cmd.Root().Name(), cmd.Use) 50 | return nil 51 | } 52 | -------------------------------------------------------------------------------- /cli/printer/colors.go: -------------------------------------------------------------------------------- 1 | package printer 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/muesli/termenv" 7 | ) 8 | 9 | var tp = termenv.EnvColorProfile() 10 | 11 | // Theme defines a collection of colors for terminal outputs. 12 | type Theme struct { 13 | Green termenv.Color 14 | Yellow termenv.Color 15 | Cyan termenv.Color 16 | Red termenv.Color 17 | Grey termenv.Color 18 | Blue termenv.Color 19 | Magenta termenv.Color 20 | } 21 | 22 | var themes = map[string]Theme{ 23 | "light": { 24 | Green: tp.Color("#005F00"), 25 | Yellow: tp.Color("#FFAF00"), 26 | Cyan: tp.Color("#0087FF"), 27 | Red: tp.Color("#D70000"), 28 | Grey: tp.Color("#303030"), 29 | Blue: tp.Color("#000087"), 30 | Magenta: tp.Color("#AF00FF"), 31 | }, 32 | "dark": { 33 | Green: tp.Color("#A8CC8C"), 34 | Yellow: tp.Color("#DBAB79"), 35 | Cyan: tp.Color("#66C2CD"), 36 | Red: tp.Color("#E88388"), 37 | Grey: tp.Color("#B9BFCA"), 38 | Blue: tp.Color("#71BEF2"), 39 | Magenta: tp.Color("#D290E4"), 40 | }, 41 | } 42 | 43 | // NewTheme initializes a Theme based on the terminal background (light or dark). 44 | func NewTheme() Theme { 45 | if !termenv.HasDarkBackground() { 46 | return themes["light"] 47 | } 48 | return themes["dark"] 49 | } 50 | 51 | var theme = NewTheme() 52 | 53 | // formatColorize applies the given color to the formatted text. 54 | func formatColorize(color termenv.Color, t string, args ...interface{}) string { 55 | return colorize(color, fmt.Sprintf(t, args...)) 56 | } 57 | 58 | func Green(t ...string) string { 59 | return colorize(theme.Green, t...) 60 | } 61 | 62 | func Greenf(t string, args ...interface{}) string { 63 | return formatColorize(theme.Green, t, args...) 64 | } 65 | 66 | func Yellow(t ...string) string { 67 | return colorize(theme.Yellow, t...) 68 | } 69 | 70 | func Yellowf(t string, args ...interface{}) string { 71 | return formatColorize(theme.Yellow, t, args...) 72 | } 73 | 74 | func Cyan(t ...string) string { 75 | return colorize(theme.Cyan, t...) 76 | } 77 | 78 | func Cyanf(t string, args ...interface{}) string { 79 | return formatColorize(theme.Cyan, t, args...) 80 | } 81 | 82 | func Red(t ...string) string { 83 | return colorize(theme.Red, t...) 84 | } 85 | 86 | func Redf(t string, args ...interface{}) string { 87 | return formatColorize(theme.Red, t, args...) 88 | } 89 | 90 | func Grey(t ...string) string { 91 | return colorize(theme.Grey, t...) 92 | } 93 | 94 | func Greyf(t string, args ...interface{}) string { 95 | return formatColorize(theme.Grey, t, args...) 96 | } 97 | 98 | func Blue(t ...string) string { 99 | return colorize(theme.Blue, t...) 100 | } 101 | 102 | func Bluef(t string, args ...interface{}) string { 103 | return formatColorize(theme.Blue, t, args...) 104 | } 105 | 106 | func Magenta(t ...string) string { 107 | return colorize(theme.Magenta, t...) 108 | } 109 | 110 | func Magentaf(t string, args ...interface{}) string { 111 | return formatColorize(theme.Magenta, t, args...) 112 | } 113 | 114 | func Icon(name string) string { 115 | icons := map[string]string{"failure": "✘", "success": "✔", "info": "ℹ", "warning": "⚠"} 116 | if icon, exists := icons[name]; exists { 117 | return icon 118 | } 119 | return "" 120 | } 121 | 122 | // colorize applies the given color to the text. 123 | func colorize(color termenv.Color, t ...string) string { 124 | return termenv.String(t...).Foreground(color).String() 125 | } 126 | -------------------------------------------------------------------------------- /cli/printer/markdown.go: -------------------------------------------------------------------------------- 1 | package printer 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/charmbracelet/glamour" 7 | ) 8 | 9 | // RenderOpts is a type alias for a slice of glamour.TermRendererOption, 10 | // representing the rendering options for the markdown renderer. 11 | type RenderOpts []glamour.TermRendererOption 12 | 13 | // This ensures the rendered markdown has no extra indentation or margins, providing a compact view. 14 | func withoutIndentation() glamour.TermRendererOption { 15 | overrides := []byte(` 16 | { 17 | "document": { 18 | "margin": 0 19 | }, 20 | "code_block": { 21 | "margin": 0 22 | } 23 | }`) 24 | 25 | return glamour.WithStylesFromJSONBytes(overrides) 26 | } 27 | 28 | // withoutWrap ensures the rendered markdown does not wrap lines, useful for wide terminals. 29 | func withoutWrap() glamour.TermRendererOption { 30 | return glamour.WithWordWrap(0) 31 | } 32 | 33 | // render applies the given rendering options to the provided markdown text. 34 | func render(text string, opts RenderOpts) (string, error) { 35 | // Ensure input text uses consistent line endings. 36 | text = strings.ReplaceAll(text, "\r\n", "\n") 37 | 38 | tr, err := glamour.NewTermRenderer(opts...) 39 | if err != nil { 40 | return "", err 41 | } 42 | 43 | return tr.Render(text) 44 | } 45 | 46 | // Markdown renders the given markdown text with default options. 47 | func Markdown(text string) (string, error) { 48 | opts := RenderOpts{ 49 | glamour.WithAutoStyle(), // Automatically determine styling based on terminal settings. 50 | glamour.WithEmoji(), // Enable emoji rendering. 51 | withoutIndentation(), // Disable indentation for a compact view. 52 | withoutWrap(), // Disable word wrapping. 53 | } 54 | 55 | return render(text, opts) 56 | } 57 | 58 | // MarkdownWithWrap renders the given markdown text with a specified word wrapping width. 59 | func MarkdownWithWrap(text string, wrap int) (string, error) { 60 | opts := RenderOpts{ 61 | glamour.WithAutoStyle(), // Automatically determine styling based on terminal settings. 62 | glamour.WithEmoji(), // Enable emoji rendering. 63 | glamour.WithWordWrap(wrap), // Enable word wrapping with the specified width. 64 | withoutIndentation(), // Disable indentation for a compact view. 65 | } 66 | 67 | return render(text, opts) 68 | } 69 | -------------------------------------------------------------------------------- /cli/printer/progress.go: -------------------------------------------------------------------------------- 1 | package printer 2 | 3 | import ( 4 | "github.com/schollz/progressbar/v3" 5 | ) 6 | 7 | // Progress creates and returns a progress bar for tracking the progress of an operation. 8 | // 9 | // Parameters: 10 | // - max: The maximum value of the progress bar, indicating 100% completion. 11 | // - description: A brief description of the task associated with the progress bar. 12 | 13 | // Example Usage: 14 | // 15 | // bar := printer.Progress(100, "Downloading files") 16 | // for i := 0; i < 100; i++ { 17 | // bar.Add(1) // Increment progress by 1. 18 | // } 19 | func Progress(max int, description string) *progressbar.ProgressBar { 20 | bar := progressbar.NewOptions( 21 | max, 22 | progressbar.OptionEnableColorCodes(true), 23 | progressbar.OptionSetDescription(description), 24 | progressbar.OptionShowCount(), 25 | ) 26 | return bar 27 | } 28 | -------------------------------------------------------------------------------- /cli/printer/spinner.go: -------------------------------------------------------------------------------- 1 | package printer 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/briandowns/spinner" 7 | "github.com/raystack/salt/cli/terminator" 8 | ) 9 | 10 | // Indicator represents a terminal spinner used for indicating progress or ongoing operations. 11 | type Indicator struct { 12 | spinner *spinner.Spinner // The spinner instance. 13 | } 14 | 15 | // Stop halts the spinner animation. 16 | // 17 | // This method ensures the spinner is stopped gracefully. If the spinner is nil (e.g., when the 18 | // terminal does not support TTY), the method does nothing. 19 | // 20 | // Example Usage: 21 | // 22 | // indicator := printer.Spin("Loading") 23 | // // Perform some operation... 24 | // indicator.Stop() 25 | func (s *Indicator) Stop() { 26 | if s.spinner == nil { 27 | return 28 | } 29 | s.spinner.Stop() 30 | } 31 | 32 | // Spin creates and starts a terminal spinner to indicate an ongoing operation. 33 | // 34 | // The spinner uses a predefined character set and updates at a fixed interval. It automatically 35 | // disables itself if the terminal does not support TTY. 36 | // 37 | // Parameters: 38 | // - label: A string to prefix the spinner (e.g., "Loading"). 39 | // 40 | // Returns: 41 | // - An *Indicator instance that manages the spinner lifecycle. 42 | // 43 | // Example Usage: 44 | // 45 | // indicator := printer.Spin("Processing data") 46 | // // Perform some long-running operation... 47 | // indicator.Stop() 48 | func Spin(label string) *Indicator { 49 | // Predefined spinner character set (dots style). 50 | set := spinner.CharSets[11] 51 | 52 | // Check if the terminal supports TTY; if not, return a no-op Indicator. 53 | if !terminator.IsTTY() { 54 | return &Indicator{} 55 | } 56 | 57 | // Create a new spinner instance with a 120ms update interval and cyan color. 58 | s := spinner.New(set, 120*time.Millisecond, spinner.WithColor("fgCyan")) 59 | 60 | // Add a label prefix if provided. 61 | if label != "" { 62 | s.Prefix = label + " " 63 | } 64 | 65 | // Start the spinner animation. 66 | s.Start() 67 | 68 | // Return the Indicator wrapping the spinner instance. 69 | return &Indicator{s} 70 | } 71 | -------------------------------------------------------------------------------- /cli/printer/structured.go: -------------------------------------------------------------------------------- 1 | package printer 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "gopkg.in/yaml.v3" 8 | ) 9 | 10 | // YAML prints the given data in YAML format. 11 | func YAML(data interface{}) error { 12 | return File(data, "yaml") 13 | } 14 | 15 | // JSON prints the given data in JSON format. 16 | func JSON(data interface{}) error { 17 | return File(data, "json") 18 | } 19 | 20 | // PrettyJSON prints the given data in pretty-printed JSON format. 21 | func PrettyJSON(data interface{}) error { 22 | return File(data, "prettyjson") 23 | } 24 | 25 | // File marshals and prints the given data in the specified format. 26 | func File(data interface{}, format string) (err error) { 27 | var output []byte 28 | switch format { 29 | case "yaml": 30 | output, err = yaml.Marshal(data) 31 | case "json": 32 | output, err = json.Marshal(data) 33 | case "prettyjson": 34 | output, err = json.MarshalIndent(data, "", "\t") 35 | default: 36 | return fmt.Errorf("unknown format: %v", format) 37 | } 38 | 39 | if err != nil { 40 | return err 41 | } 42 | 43 | fmt.Println(string(output)) 44 | return nil 45 | } 46 | -------------------------------------------------------------------------------- /cli/printer/table.go: -------------------------------------------------------------------------------- 1 | package printer 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/olekukonko/tablewriter" 7 | ) 8 | 9 | // Table renders a terminal-friendly table to the provided writer. 10 | // 11 | // Create a table with customized formatting and styles, 12 | // suitable for displaying data in CLI applications. 13 | // 14 | // Parameters: 15 | // - target: The `io.Writer` where the table will be written (e.g., os.Stdout). 16 | // - rows: A 2D slice of strings representing the table rows and columns. 17 | // Each inner slice represents a single row, with its elements as column values. 18 | // 19 | // Example Usage: 20 | // 21 | // rows := [][]string{ 22 | // {"ID", "Name", "Age"}, 23 | // {"1", "Alice", "30"}, 24 | // {"2", "Bob", "25"}, 25 | // } 26 | // printer.Table(os.Stdout, rows) 27 | // 28 | // Behavior: 29 | // - Disables text wrapping for better terminal rendering. 30 | // - Aligns headers and rows to the left. 31 | // - Removes borders and separators for a clean look. 32 | // - Formats the table using tab padding for better alignment in terminals. 33 | func Table(target io.Writer, rows [][]string) { 34 | table := tablewriter.NewWriter(target) 35 | table.SetAutoWrapText(false) 36 | table.SetAutoFormatHeaders(true) 37 | table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) 38 | table.SetAlignment(tablewriter.ALIGN_LEFT) 39 | table.SetCenterSeparator("") 40 | table.SetColumnSeparator("") 41 | table.SetRowSeparator("") 42 | table.SetHeaderLine(false) 43 | table.SetBorder(false) 44 | table.SetTablePadding("\t") 45 | table.SetNoWhiteSpace(true) 46 | table.AppendBulk(rows) 47 | table.Render() 48 | } 49 | -------------------------------------------------------------------------------- /cli/printer/text.go: -------------------------------------------------------------------------------- 1 | package printer 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/muesli/termenv" 7 | ) 8 | 9 | // Success prints the given message(s) in green to indicate success. 10 | func Success(t ...string) { 11 | printWithColor(Green, t...) 12 | } 13 | 14 | // Successln prints the given message(s) in green with a newline. 15 | func Successln(t ...string) { 16 | printWithColorln(Green, t...) 17 | } 18 | 19 | // Successf formats and prints the success message in green. 20 | func Successf(t string, args ...interface{}) { 21 | printWithColorf(Greenf, t, args...) 22 | } 23 | 24 | // Warning prints the given message(s) in yellow to indicate a warning. 25 | func Warning(t ...string) { 26 | printWithColor(Yellow, t...) 27 | } 28 | 29 | // Warningln prints the given message(s) in yellow with a newline. 30 | func Warningln(t ...string) { 31 | printWithColorln(Yellow, t...) 32 | } 33 | 34 | // Warningf formats and prints the warning message in yellow. 35 | func Warningf(t string, args ...interface{}) { 36 | printWithColorf(Yellowf, t, args...) 37 | } 38 | 39 | // Error prints the given message(s) in red to indicate an error. 40 | func Error(t ...string) { 41 | printWithColor(Red, t...) 42 | } 43 | 44 | // Errorln prints the given message(s) in red with a newline. 45 | func Errorln(t ...string) { 46 | printWithColorln(Red, t...) 47 | } 48 | 49 | // Errorf formats and prints the error message in red. 50 | func Errorf(t string, args ...interface{}) { 51 | printWithColorf(Redf, t, args...) 52 | } 53 | 54 | // Info prints the given message(s) in cyan to indicate informational messages. 55 | func Info(t ...string) { 56 | printWithColor(Cyan, t...) 57 | } 58 | 59 | // Infoln prints the given message(s) in cyan with a newline. 60 | func Infoln(t ...string) { 61 | printWithColorln(Cyan, t...) 62 | } 63 | 64 | // Infof formats and prints the informational message in cyan. 65 | func Infof(t string, args ...interface{}) { 66 | printWithColorf(Cyanf, t, args...) 67 | } 68 | 69 | // Bold prints the given message(s) in bold style. 70 | func Bold(t ...string) string { 71 | return termenv.String(t...).Bold().String() 72 | } 73 | 74 | // Boldln prints the given message(s) in bold style with a newline. 75 | func Boldln(t ...string) { 76 | fmt.Println(Bold(t...)) 77 | } 78 | 79 | // Boldf formats and prints the message in bold style. 80 | func Boldf(t string, args ...interface{}) string { 81 | return Bold(fmt.Sprintf(t, args...)) 82 | } 83 | 84 | // Italic prints the given message(s) in italic style. 85 | func Italic(t ...string) string { 86 | return termenv.String(t...).Italic().String() 87 | } 88 | 89 | // Italicln prints the given message(s) in italic style with a newline. 90 | func Italicln(t ...string) { 91 | fmt.Println(Italic(t...)) 92 | } 93 | 94 | // Italicf formats and prints the message in italic style. 95 | func Italicf(t string, args ...interface{}) string { 96 | return Italic(fmt.Sprintf(t, args...)) 97 | } 98 | 99 | // Space prints a single space to the output. 100 | func Space() { 101 | fmt.Print(" ") 102 | } 103 | 104 | // printWithColor prints the given message(s) with the specified color function. 105 | func printWithColor(colorFunc func(...string) string, t ...string) { 106 | fmt.Print(colorFunc(t...)) 107 | } 108 | 109 | // printWithColorln prints the given message(s) with the specified color function and a newline. 110 | func printWithColorln(colorFunc func(...string) string, t ...string) { 111 | fmt.Println(colorFunc(t...)) 112 | } 113 | 114 | // printWithColorf formats and prints the message with the specified color function. 115 | func printWithColorf(colorFunc func(string, ...interface{}) string, t string, args ...interface{}) { 116 | fmt.Print(colorFunc(t, args...)) 117 | } 118 | -------------------------------------------------------------------------------- /cli/prompter/prompt.go: -------------------------------------------------------------------------------- 1 | package prompter 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/AlecAivazis/survey/v2" 7 | ) 8 | 9 | // Prompter defines an interface for user input interactions. 10 | type Prompter interface { 11 | Select(message, defaultValue string, options []string) (int, error) 12 | MultiSelect(message, defaultValue string, options []string) ([]int, error) 13 | Input(message, defaultValue string) (string, error) 14 | Confirm(message string, defaultValue bool) (bool, error) 15 | } 16 | 17 | // New creates and returns a new Prompter instance. 18 | func New() Prompter { 19 | return &surveyPrompter{} 20 | } 21 | 22 | type surveyPrompter struct { 23 | } 24 | 25 | // ask is a helper function to prompt the user and capture the response. 26 | func (p *surveyPrompter) ask(q survey.Prompt, response interface{}) error { 27 | err := survey.AskOne(q, response) 28 | if err != nil { 29 | return fmt.Errorf("prompt error: %w", err) 30 | } 31 | return nil 32 | } 33 | 34 | // Select prompts the user to select one option from a list. 35 | // 36 | // Parameters: 37 | // - message: The prompt message to display. 38 | // - defaultValue: The default selected value. 39 | // - options: The list of options to display. 40 | // 41 | // Returns: 42 | // - The index of the selected option. 43 | // - An error, if any. 44 | func (p *surveyPrompter) Select(message, defaultValue string, options []string) (int, error) { 45 | var result int 46 | err := p.ask(&survey.Select{ 47 | Message: message, 48 | Default: defaultValue, 49 | Options: options, 50 | PageSize: 20, 51 | }, &result) 52 | return result, err 53 | } 54 | 55 | // MultiSelect prompts the user to select multiple options from a list. 56 | // 57 | // Parameters: 58 | // - message: The prompt message to display. 59 | // - defaultValue: The default selected values. 60 | // - options: The list of options to display. 61 | // 62 | // Returns: 63 | // - A slice of indices representing the selected options. 64 | // - An error, if any. 65 | func (p *surveyPrompter) MultiSelect(message, defaultValue string, options []string) ([]int, error) { 66 | var result []int 67 | err := p.ask(&survey.MultiSelect{ 68 | Message: message, 69 | Default: defaultValue, 70 | Options: options, 71 | PageSize: 20, 72 | }, &result) 73 | return result, err 74 | } 75 | 76 | // Input prompts the user for a text input. 77 | // 78 | // Parameters: 79 | // - message: The prompt message to display. 80 | // - defaultValue: The default input value. 81 | // 82 | // Returns: 83 | // - The user's input as a string. 84 | // - An error, if any. 85 | func (p *surveyPrompter) Input(message, defaultValue string) (string, error) { 86 | var result string 87 | err := p.ask(&survey.Input{ 88 | Message: message, 89 | Default: defaultValue, 90 | }, &result) 91 | return result, err 92 | } 93 | 94 | // Confirm prompts the user for a yes/no confirmation. 95 | // 96 | // Parameters: 97 | // - message: The prompt message to display. 98 | // - defaultValue: The default confirmation value. 99 | // 100 | // Returns: 101 | // - A boolean indicating the user's choice. 102 | // - An error, if any. 103 | func (p *surveyPrompter) Confirm(message string, defaultValue bool) (bool, error) { 104 | var result bool 105 | err := p.ask(&survey.Confirm{ 106 | Message: message, 107 | Default: defaultValue, 108 | }, &result) 109 | return result, err 110 | } 111 | -------------------------------------------------------------------------------- /cli/releaser/release.go: -------------------------------------------------------------------------------- 1 | package releaser 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/hashicorp/go-version" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | var ( 15 | // Timeout sets the HTTP client timeout for fetching release info. 16 | Timeout = time.Second * 1 17 | 18 | // APIFormat is the GitHub API URL template to fetch the latest release of a repository. 19 | APIFormat = "https://api.github.com/repos/%s/releases/latest" 20 | ) 21 | 22 | // Info holds information about a software release. 23 | type Info struct { 24 | Version string // Version of the release 25 | TarURL string // Tarball URL of the release 26 | } 27 | 28 | // FetchInfo fetches details related to the latest release from the provided URL. 29 | // 30 | // Parameters: 31 | // - releaseURL: The URL to fetch the latest release information from. 32 | // Example: "https://api.github.com/repos/raystack/optimus/releases/latest" 33 | // 34 | // Returns: 35 | // - An *Info struct containing the release and tarball URL. 36 | // - An error if the HTTP request or response parsing fails. 37 | func FetchInfo(url string) (*Info, error) { 38 | httpClient := http.Client{Timeout: Timeout} 39 | req, err := http.NewRequest(http.MethodGet, url, nil) 40 | if err != nil { 41 | return nil, errors.Wrap(err, "failed to create HTTP request") 42 | } 43 | req.Header.Set("User-Agent", "raystack/salt") 44 | 45 | resp, err := httpClient.Do(req) 46 | if err != nil { 47 | return nil, errors.Wrapf(err, "failed to fetch release information from URL: %s", url) 48 | } 49 | defer func() { 50 | if resp.Body != nil { 51 | resp.Body.Close() 52 | } 53 | }() 54 | if resp.StatusCode != http.StatusOK { 55 | return nil, fmt.Errorf("unexpected status code %d from URL: %s", resp.StatusCode, url) 56 | } 57 | 58 | body, err := io.ReadAll(resp.Body) 59 | if err != nil { 60 | return nil, errors.Wrap(err, "failed to read response body") 61 | } 62 | 63 | var data struct { 64 | TagName string `json:"tag_name"` 65 | Tarball string `json:"tarball_url"` 66 | } 67 | if err = json.Unmarshal(body, &data); err != nil { 68 | return nil, errors.Wrapf(err, "failed to parse JSON response") 69 | } 70 | 71 | return &Info{ 72 | Version: data.TagName, 73 | TarURL: data.Tarball, 74 | }, nil 75 | } 76 | 77 | // CompareVersions compares the current release with the latest release. 78 | // 79 | // Parameters: 80 | // - currVersion: The current release string. 81 | // - latestVersion: The latest release string. 82 | // 83 | // Returns: 84 | // - true if the current release is greater than or equal to the latest release. 85 | // - An error if release parsing fails. 86 | func CompareVersions(current, latest string) (bool, error) { 87 | currentVersion, err := version.NewVersion(current) 88 | if err != nil { 89 | return false, errors.Wrap(err, "invalid current version format") 90 | } 91 | 92 | latestVersion, err := version.NewVersion(latest) 93 | if err != nil { 94 | return false, errors.Wrap(err, "invalid latest version format") 95 | } 96 | 97 | return currentVersion.GreaterThanOrEqual(latestVersion), nil 98 | } 99 | 100 | // CheckForUpdate generates a message indicating if an update is available. 101 | // 102 | // Parameters: 103 | // - currentVersion: The current version string (e.g., "v1.0.0"). 104 | // - repo: The GitHub repository in the format "owner/repo". 105 | // 106 | // Returns: 107 | // - A string containing the update message if a newer version is available. 108 | // - An empty string if the current version is up-to-date or if an error occurs. 109 | func CheckForUpdate(currentVersion, repo string) string { 110 | releaseURL := fmt.Sprintf(APIFormat, repo) 111 | info, err := FetchInfo(releaseURL) 112 | if err != nil { 113 | return "" 114 | } 115 | 116 | isLatest, err := CompareVersions(currentVersion, info.Version) 117 | if err != nil || isLatest { 118 | return "" 119 | } 120 | 121 | return fmt.Sprintf("A new release (%s) is available. consider updating to latest version.", info.Version) 122 | } 123 | -------------------------------------------------------------------------------- /cli/terminator/brew.go: -------------------------------------------------------------------------------- 1 | package terminator 2 | 3 | import ( 4 | "os/exec" 5 | "path/filepath" 6 | "strings" 7 | ) 8 | 9 | // IsUnderHomebrew checks if a given binary path is managed under the Homebrew path. 10 | // This function is useful to verify if a binary is installed via Homebrew 11 | // by comparing its location to the Homebrew binary directory. 12 | func IsUnderHomebrew(path string) bool { 13 | if path == "" { 14 | return false 15 | } 16 | 17 | brewExe, err := exec.LookPath("brew") 18 | if err != nil { 19 | return false 20 | } 21 | 22 | brewPrefixBytes, err := exec.Command(brewExe, "--prefix").Output() 23 | if err != nil { 24 | return false 25 | } 26 | 27 | brewBinPrefix := filepath.Join(strings.TrimSpace(string(brewPrefixBytes)), "bin") + string(filepath.Separator) 28 | return strings.HasPrefix(path, brewBinPrefix) 29 | } 30 | 31 | // HasHomebrew checks if Homebrew is installed on the user's system. 32 | // This function determines the presence of Homebrew by looking for the "brew" 33 | // executable in the system's PATH. It is useful to ensure Homebrew dependencies 34 | // can be managed before executing related commands. 35 | func HasHomebrew() bool { 36 | _, err := exec.LookPath("brew") 37 | return err == nil 38 | } 39 | -------------------------------------------------------------------------------- /cli/terminator/browser.go: -------------------------------------------------------------------------------- 1 | package terminator 2 | 3 | import ( 4 | "os" 5 | "os/exec" 6 | "strings" 7 | ) 8 | 9 | // OpenBrowser opens the default web browser at the specified URL. 10 | // 11 | // Parameters: 12 | // - goos: The operating system name (e.g., "darwin", "windows", or "linux"). 13 | // - url: The URL to open in the web browser. 14 | // 15 | // Returns: 16 | // - An *exec.Cmd configured to open the URL. Note that you must call `cmd.Run()` 17 | // or `cmd.Start()` on the returned command to execute it. 18 | // 19 | // Panics: 20 | // - This function will panic if called without a TTY (e.g., not running in a terminal). 21 | func OpenBrowser(goos, url string) *exec.Cmd { 22 | if !IsTTY() { 23 | panic("OpenBrowser called without a TTY") 24 | } 25 | 26 | exe := "open" 27 | var args []string 28 | 29 | switch goos { 30 | case "darwin": 31 | // macOS: Use the "open" command to open the URL. 32 | args = append(args, url) 33 | case "windows": 34 | // Windows: Use "cmd /c start" to open the URL. 35 | exe, _ = exec.LookPath("cmd") 36 | replacer := strings.NewReplacer("&", "^&") 37 | args = append(args, "/c", "start", replacer.Replace(url)) 38 | default: 39 | // Linux: Use "xdg-open" or fallback to "wslview" for WSL environments. 40 | exe = linuxExe() 41 | args = append(args, url) 42 | } 43 | 44 | // Create the command to open the browser and set stderr for error reporting. 45 | cmd := exec.Command(exe, args...) 46 | cmd.Stderr = os.Stderr 47 | return cmd 48 | } 49 | 50 | // linuxExe determines the appropriate command to open a web browser on Linux. 51 | func linuxExe() string { 52 | exe := "xdg-open" 53 | 54 | _, err := exec.LookPath(exe) 55 | if err != nil { 56 | _, err := exec.LookPath("wslview") 57 | if err == nil { 58 | exe = "wslview" 59 | } 60 | } 61 | 62 | return exe 63 | } 64 | -------------------------------------------------------------------------------- /cli/terminator/pager.go: -------------------------------------------------------------------------------- 1 | package terminator 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "os" 7 | "os/exec" 8 | "strings" 9 | "syscall" 10 | 11 | "github.com/cli/safeexec" 12 | "github.com/google/shlex" 13 | ) 14 | 15 | // Pager manages a pager process for displaying output in a paginated format. 16 | // 17 | // It supports configuring the pager command, starting the pager process, 18 | // and ensuring proper cleanup when the pager is no longer needed. 19 | type Pager struct { 20 | Out io.Writer // The writer to send output to the pager. 21 | ErrOut io.Writer // The writer to send error output to. 22 | pagerCommand string // The command to run the pager (e.g., "less", "more"). 23 | pagerProcess *os.Process // The running pager process, if any. 24 | } 25 | 26 | // NewPager creates a new Pager instance with default settings. 27 | // 28 | // If the "PAGER" environment variable is not set, the default command is "more". 29 | func NewPager() *Pager { 30 | pagerCmd := os.Getenv("PAGER") 31 | if pagerCmd == "" { 32 | pagerCmd = "more" 33 | } 34 | 35 | return &Pager{ 36 | pagerCommand: pagerCmd, 37 | Out: os.Stdout, 38 | ErrOut: os.Stderr, 39 | } 40 | } 41 | 42 | // Set updates the pager command used to display output. 43 | // 44 | // Parameters: 45 | // - cmd: The pager command (e.g., "less", "more"). 46 | func (p *Pager) Set(cmd string) { 47 | p.pagerCommand = cmd 48 | } 49 | 50 | // Get returns the current pager command. 51 | // 52 | // Returns: 53 | // - The pager command as a string. 54 | func (p *Pager) Get() string { 55 | return p.pagerCommand 56 | } 57 | 58 | // Start begins the pager process to display output. 59 | // 60 | // If the pager command is "cat" or empty, it does nothing. 61 | // The function also sets environment variables to optimize the behavior of 62 | // certain pagers, like "less" and "lv". 63 | // 64 | // Returns: 65 | // - An error if the pager command fails to start or if arguments cannot be parsed. 66 | func (p *Pager) Start() error { 67 | if p.pagerCommand == "" || p.pagerCommand == "cat" { 68 | return nil 69 | } 70 | 71 | pagerArgs, err := shlex.Split(p.pagerCommand) 72 | if err != nil { 73 | return err 74 | } 75 | 76 | // Prepare the environment variables for the pager process. 77 | pagerEnv := os.Environ() 78 | for i := len(pagerEnv) - 1; i >= 0; i-- { 79 | if strings.HasPrefix(pagerEnv[i], "PAGER=") { 80 | pagerEnv = append(pagerEnv[0:i], pagerEnv[i+1:]...) 81 | } 82 | } 83 | if _, ok := os.LookupEnv("LESS"); !ok { 84 | pagerEnv = append(pagerEnv, "LESS=FRX") 85 | } 86 | if _, ok := os.LookupEnv("LV"); !ok { 87 | pagerEnv = append(pagerEnv, "LV=-c") 88 | } 89 | 90 | // Locate the pager executable using safeexec for added security. 91 | pagerExe, err := safeexec.LookPath(pagerArgs[0]) 92 | if err != nil { 93 | return err 94 | } 95 | 96 | pagerCmd := exec.Command(pagerExe, pagerArgs[1:]...) 97 | pagerCmd.Env = pagerEnv 98 | pagerCmd.Stdout = p.Out 99 | pagerCmd.Stderr = p.ErrOut 100 | pagedOut, err := pagerCmd.StdinPipe() 101 | if err != nil { 102 | return err 103 | } 104 | p.Out = &pagerWriter{pagedOut} 105 | 106 | // Start the pager process. 107 | err = pagerCmd.Start() 108 | if err != nil { 109 | return err 110 | } 111 | p.pagerProcess = pagerCmd.Process 112 | return nil 113 | } 114 | 115 | // Stop terminates the running pager process and cleans up resources. 116 | func (p *Pager) Stop() { 117 | if p.pagerProcess == nil { 118 | return 119 | } 120 | 121 | // Close the output writer and wait for the process to exit. 122 | _ = p.Out.(io.WriteCloser).Close() 123 | _, _ = p.pagerProcess.Wait() 124 | p.pagerProcess = nil 125 | } 126 | 127 | // pagerWriter is a custom writer that wraps WriteCloser and handles EPIPE errors. 128 | // 129 | // If a write fails due to a closed pipe, it returns an ErrClosedPagerPipe error. 130 | type pagerWriter struct { 131 | io.WriteCloser 132 | } 133 | 134 | // Write writes data to the underlying WriteCloser and handles EPIPE errors. 135 | // 136 | // Parameters: 137 | // - d: The data to write. 138 | // 139 | // Returns: 140 | // - The number of bytes written and an error if the write fails. 141 | func (w *pagerWriter) Write(d []byte) (int, error) { 142 | n, err := w.WriteCloser.Write(d) 143 | if err != nil && (errors.Is(err, io.ErrClosedPipe) || isEpipeError(err)) { 144 | return n, &ErrClosedPagerPipe{err} 145 | } 146 | return n, err 147 | } 148 | 149 | // isEpipeError checks if an error is a broken pipe (EPIPE) error. 150 | // 151 | // Returns: 152 | // - A boolean indicating whether the error is an EPIPE error. 153 | func isEpipeError(err error) bool { 154 | return errors.Is(err, syscall.EPIPE) 155 | } 156 | 157 | // ErrClosedPagerPipe is an error type returned when writing to a closed pager pipe. 158 | type ErrClosedPagerPipe struct { 159 | error 160 | } 161 | -------------------------------------------------------------------------------- /cli/terminator/term.go: -------------------------------------------------------------------------------- 1 | package terminator 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/mattn/go-isatty" 7 | "github.com/muesli/termenv" 8 | ) 9 | 10 | // IsTTY checks if the current output is a TTY (teletypewriter) or a Cygwin terminal. 11 | // This function is useful for determining if the program is running in a terminal 12 | // environment, which is important for features like colored output or interactive prompts. 13 | func IsTTY() bool { 14 | return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd()) 15 | } 16 | 17 | // IsColorDisabled checks if color output is disabled based on the environment settings. 18 | // This function uses the `termenv` library to determine if the NO_COLOR environment 19 | // variable is set, which is a common way to disable colored output. 20 | func IsColorDisabled() bool { 21 | return termenv.EnvNoColor() 22 | } 23 | 24 | // IsCI checks if the code is running in a Continuous Integration (CI) environment. 25 | // This function checks for common environment variables used by popular CI systems 26 | // like GitHub Actions, Travis CI, CircleCI, Jenkins, TeamCity, and others. 27 | func IsCI() bool { 28 | return os.Getenv("CI") != "" || // GitHub Actions, Travis CI, CircleCI, Cirrus CI, GitLab CI, AppVeyor, CodeShip, dsari 29 | os.Getenv("BUILD_NUMBER") != "" || // Jenkins, TeamCity 30 | os.Getenv("RUN_ID") != "" // TaskCluster, dsari 31 | } 32 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "os" 8 | "path/filepath" 9 | "reflect" 10 | "strings" 11 | 12 | "github.com/go-playground/validator" 13 | "github.com/mcuadros/go-defaults" 14 | "github.com/spf13/pflag" 15 | "github.com/spf13/viper" 16 | "gopkg.in/yaml.v3" 17 | ) 18 | 19 | // Loader is responsible for managing configuration 20 | type Loader struct { 21 | v *viper.Viper 22 | flags *pflag.FlagSet 23 | } 24 | 25 | // Option defines a functional option for configuring the Loader. 26 | type Option func(c *Loader) 27 | 28 | // NewLoader creates a new Loader instance with the provided options. 29 | // It initializes Viper with defaults for YAML configuration files and environment variable handling. 30 | // 31 | // Example: 32 | // 33 | // loader := config.NewLoader( 34 | // config.WithFile("./config.yaml"), 35 | // config.WithEnvPrefix("MYAPP"), 36 | // ) 37 | func NewLoader(options ...Option) *Loader { 38 | v := viper.New() 39 | 40 | v.SetConfigName("config") 41 | v.SetConfigType("yaml") 42 | v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) 43 | v.AutomaticEnv() 44 | 45 | loader := &Loader{v: v} 46 | for _, opt := range options { 47 | opt(loader) 48 | } 49 | return loader 50 | } 51 | 52 | // WithFile specifies the configuration file to use. 53 | func WithFile(configFilePath string) Option { 54 | return func(l *Loader) { 55 | l.v.SetConfigFile(configFilePath) 56 | } 57 | } 58 | 59 | // WithEnvPrefix specifies a prefix for ENV variables. 60 | func WithEnvPrefix(prefix string) Option { 61 | return func(l *Loader) { 62 | l.v.SetEnvPrefix(prefix) 63 | } 64 | } 65 | 66 | // WithFlags specifies a command-line flag set to bind dynamically based on `cmdx` tags. 67 | func WithFlags(flags *pflag.FlagSet) Option { 68 | return func(l *Loader) { 69 | l.flags = flags 70 | } 71 | } 72 | 73 | // WithAppConfig sets up application-specific configuration file handling. 74 | func WithAppConfig(app string) Option { 75 | return func(l *Loader) { 76 | filePath, err := getConfigFilePath(app) 77 | if err != nil { 78 | panic(fmt.Errorf("failed to determine config file path: %w", err)) 79 | } 80 | l.v.SetConfigFile(filePath) 81 | } 82 | } 83 | 84 | // Load reads the configuration from the file, environment variables, and command-line flags, 85 | // and merges them into the provided configuration struct. It validates the configuration 86 | // using struct tags. 87 | // 88 | // The priority order is: 89 | // 1. Command-line flags 90 | // 2. Environment variables 91 | // 3. Configuration file 92 | // 4. Default values 93 | func (l *Loader) Load(config interface{}) error { 94 | if err := validateStructPtr(config); err != nil { 95 | return err 96 | } 97 | 98 | // Apply default values before reading configuration 99 | defaults.SetDefaults(config) 100 | 101 | // Bind flags dynamically using reflection on `cmdx` tags if a flag set is provided 102 | if l.flags != nil { 103 | if err := bindFlags(l.v, l.flags, reflect.TypeOf(config).Elem(), ""); err != nil { 104 | return fmt.Errorf("failed to bind flags: %w", err) 105 | } 106 | } 107 | 108 | // Bind environment variables for all keys in the config 109 | keys, err := extractFlattenedKeys(config) 110 | if err != nil { 111 | return fmt.Errorf("failed to extract config keys: %w", err) 112 | } 113 | for _, key := range keys { 114 | if err := l.v.BindEnv(key); err != nil { 115 | return fmt.Errorf("failed to bind environment variable for key %q: %w", key, err) 116 | } 117 | } 118 | 119 | // Attempt to read the configuration file 120 | if err := l.v.ReadInConfig(); err != nil { 121 | var configFileNotFoundError viper.ConfigFileNotFoundError 122 | if errors.As(err, &configFileNotFoundError) { 123 | fmt.Println("Warning: Config file not found. Falling back to defaults and environment variables.") 124 | } 125 | } 126 | 127 | // Unmarshal the merged configuration into the provided struct 128 | if err := l.v.Unmarshal(config); err != nil { 129 | return fmt.Errorf("failed to unmarshal config: %w", err) 130 | } 131 | 132 | // Validate the resulting configuration 133 | if err := validator.New().Struct(config); err != nil { 134 | return fmt.Errorf("invalid configuration: %w", err) 135 | } 136 | 137 | return nil 138 | } 139 | 140 | // Init initializes the configuration file with default values. 141 | func (l *Loader) Init(config interface{}) error { 142 | defaults.SetDefaults(config) 143 | 144 | path := l.v.ConfigFileUsed() 145 | if fileExists(path) { 146 | return errors.New("configuration file already exists") 147 | } 148 | 149 | data, err := yaml.Marshal(config) 150 | if err != nil { 151 | return fmt.Errorf("failed to marshal configuration: %w", err) 152 | } 153 | 154 | if err := ensureDir(filepath.Dir(path)); err != nil { 155 | return fmt.Errorf("failed to create directory: %w", err) 156 | } 157 | 158 | if err := os.WriteFile(path, data, 0644); err != nil { 159 | return fmt.Errorf("failed to write configuration file: %w", err) 160 | } 161 | return nil 162 | } 163 | 164 | // Get retrieves a configuration value by key. 165 | func (l *Loader) Get(key string) interface{} { 166 | return l.v.Get(key) 167 | } 168 | 169 | // Set updates a configuration value in memory (not persisted to file). 170 | func (l *Loader) Set(key string, value interface{}) { 171 | l.v.Set(key, value) 172 | } 173 | 174 | // Save writes the current configuration to the file specified during initialization. 175 | func (l *Loader) Save() error { 176 | configFile := l.v.ConfigFileUsed() 177 | if configFile == "" { 178 | return errors.New("no configuration file specified for saving") 179 | } 180 | 181 | settings := l.v.AllSettings() 182 | content, err := yaml.Marshal(settings) 183 | if err != nil { 184 | return fmt.Errorf("failed to marshal configuration: %w", err) 185 | } 186 | 187 | if err := os.WriteFile(configFile, content, 0644); err != nil { 188 | return fmt.Errorf("failed to write configuration to file: %w", err) 189 | } 190 | return nil 191 | } 192 | 193 | // View returns the current configuration as a formatted JSON string. 194 | func (l *Loader) View() (string, error) { 195 | settings := l.v.AllSettings() 196 | data, err := json.MarshalIndent(settings, "", " ") 197 | if err != nil { 198 | return "", fmt.Errorf("failed to format configuration as JSON: %w", err) 199 | } 200 | return string(data), nil 201 | } 202 | -------------------------------------------------------------------------------- /config/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package config provides a flexible and extensible configuration management solution for Go applications. 3 | 4 | It integrates configuration files, environment variables, command-line flags, and default values to populate 5 | and validate user-defined structs. 6 | 7 | Configuration Precedence: 8 | The `Loader` merges configuration values from multiple sources in the following order of precedence (highest to lowest): 9 | 1. Command-line flags: Defined using `pflag.FlagSet` and dynamically bound via `cmdx` tags. 10 | 2. Environment variables: Dynamically bound to configuration keys, optionally prefixed using `WithEnvPrefix`. 11 | 3. Configuration file: YAML configuration files specified via `WithFile`. 12 | 4. Default values: Struct fields annotated with `default` tags are populated if no other source provides a value. 13 | 14 | Defaults: 15 | Default values are specified using the `default` struct tag. Fields annotated with `default` are populated 16 | before any other source (flags, environment variables, or files). 17 | 18 | Example: 19 | 20 | type Config struct { 21 | ServerPort int `mapstructure:"server.port" default:"8080"` 22 | LogLevel string `mapstructure:"log.level" default:"info"` 23 | } 24 | 25 | In the absence of higher-priority sources, `ServerPort` will default to `8080` and `LogLevel` to `info`. 26 | 27 | Validation: 28 | Validation is performed using the `go-playground/validator` package. Fields annotated with `validate` tags 29 | are validated after merging all configuration sources. 30 | 31 | Example: 32 | 33 | type Config struct { 34 | ServerPort int `mapstructure:"server.port" validate:"required,min=1"` 35 | LogLevel string `mapstructure:"log.level" validate:"required,oneof=debug info warn error"` 36 | } 37 | 38 | If validation fails, the `Load` method returns a detailed error indicating the invalid fields. 39 | 40 | Annotations: 41 | Configuration structs use the following struct tags to define behavior: 42 | - `mapstructure`: Maps YAML or environment variables to struct fields. 43 | - `default`: Provides fallback values for fields when no source overrides them. 44 | - `validate`: Ensures the final configuration meets application-specific requirements. 45 | 46 | Example: 47 | 48 | type Config struct { 49 | Server struct { 50 | Port int `mapstructure:"server.port" default:"8080" validate:"required,min=1"` 51 | Host string `mapstructure:"server.host" default:"localhost" validate:"required"` 52 | } `mapstructure:"server"` 53 | 54 | LogLevel string `mapstructure:"log.level" default:"info" validate:"required,oneof=debug info warn error"` 55 | } 56 | 57 | The `Loader` will merge all sources, apply defaults, and validate the result in a single call to `Load`. 58 | 59 | Features: 60 | - Merges configurations from multiple sources: flags, environment variables, files, and defaults. 61 | - Supports nested structs with dynamic field mapping using `cmdx` tags. 62 | - Validates fields with constraints defined in `validate` tags. 63 | - Saves and views the final configuration in YAML or JSON formats. 64 | 65 | Example Usage: 66 | 67 | type Config struct { 68 | ServerPort int `mapstructure:"server.port" cmdx:"server.port" default:"8080" validate:"required,min=1"` 69 | LogLevel string `mapstructure:"log.level" cmdx:"log.level" default:"info" validate:"required,oneof=debug info warn error"` 70 | } 71 | 72 | func main() { 73 | flags := pflag.NewFlagSet("example", pflag.ExitOnError) 74 | flags.Int("server.port", 8080, "Server port") 75 | flags.String("log.level", "info", "Log level") 76 | 77 | loader := config.NewLoader( 78 | config.WithFile("./config.yaml"), 79 | config.WithEnvPrefix("MYAPP"), 80 | config.WithFlags(flags), 81 | ) 82 | 83 | flags.Parse(os.Args[1:]) 84 | 85 | cfg := &Config{} 86 | if err := loader.Load(cfg); err != nil { 87 | log.Fatalf("Failed to load configuration: %v", err) 88 | } 89 | 90 | fmt.Printf("Configuration: %+v\n", cfg) 91 | } 92 | */ 93 | package config 94 | -------------------------------------------------------------------------------- /config/helpers.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "reflect" 9 | "runtime" 10 | 11 | "github.com/jeremywohl/flatten" 12 | "github.com/mitchellh/mapstructure" 13 | "github.com/spf13/pflag" 14 | "github.com/spf13/viper" 15 | ) 16 | 17 | // bindFlags dynamically binds flags to configuration fields based on `cmdx` tags. 18 | func bindFlags(v *viper.Viper, flagSet *pflag.FlagSet, structType reflect.Type, parentKey string) error { 19 | for i := 0; i < structType.NumField(); i++ { 20 | field := structType.Field(i) 21 | tag := field.Tag.Get("cmdx") 22 | if tag == "" { 23 | continue 24 | } 25 | 26 | if parentKey != "" { 27 | tag = parentKey + "." + tag 28 | } 29 | 30 | if field.Type.Kind() == reflect.Struct { 31 | // Recurse into nested structs 32 | if err := bindFlags(v, flagSet, field.Type, tag); err != nil { 33 | return err 34 | } 35 | } else { 36 | flag := flagSet.Lookup(tag) 37 | if flag == nil { 38 | return fmt.Errorf("missing flag for tag: %s", tag) 39 | } 40 | if err := v.BindPFlag(tag, flag); err != nil { 41 | return fmt.Errorf("failed to bind flag for tag: %s, error: %w", tag, err) 42 | } 43 | } 44 | } 45 | return nil 46 | } 47 | 48 | // validateStructPtr ensures the provided value is a pointer to a struct. 49 | func validateStructPtr(value interface{}) error { 50 | val := reflect.ValueOf(value) 51 | if val.Kind() != reflect.Ptr || val.Elem().Kind() != reflect.Struct { 52 | return errors.New("load requires a pointer to a struct") 53 | } 54 | return nil 55 | } 56 | 57 | // extractFlattenedKeys retrieves all keys from the struct in a flattened format. 58 | func extractFlattenedKeys(config interface{}) ([]string, error) { 59 | var structMap map[string]interface{} 60 | if err := mapstructure.Decode(config, &structMap); err != nil { 61 | return nil, err 62 | } 63 | flatMap, err := flatten.Flatten(structMap, "", flatten.DotStyle) 64 | if err != nil { 65 | return nil, err 66 | } 67 | keys := make([]string, 0, len(flatMap)) 68 | for k := range flatMap { 69 | keys = append(keys, k) 70 | } 71 | return keys, nil 72 | } 73 | 74 | // Utilities for app-specific configuration paths 75 | func getConfigFilePath(app string) (string, error) { 76 | dirPath := getConfigDir("raystack") 77 | if err := ensureDir(dirPath); err != nil { 78 | return "", err 79 | } 80 | return filepath.Join(dirPath, app+".yml"), nil 81 | } 82 | 83 | func getConfigDir(root string) string { 84 | switch { 85 | case envSet("RAYSTACK_CONFIG_DIR"): 86 | return filepath.Join(os.Getenv("RAYSTACK_CONFIG_DIR"), root) 87 | case envSet("XDG_CONFIG_HOME"): 88 | return filepath.Join(os.Getenv("XDG_CONFIG_HOME"), root) 89 | case runtime.GOOS == "windows" && envSet("APPDATA"): 90 | return filepath.Join(os.Getenv("APPDATA"), root) 91 | default: 92 | home, _ := os.UserHomeDir() 93 | return filepath.Join(home, ".config", root) 94 | } 95 | } 96 | 97 | func ensureDir(dir string) error { 98 | if err := os.MkdirAll(dir, 0755); err != nil { 99 | return fmt.Errorf("failed to create directory: %w", err) 100 | } 101 | return nil 102 | } 103 | 104 | func fileExists(filename string) bool { 105 | _, err := os.Stat(filename) 106 | return err == nil 107 | } 108 | 109 | func envSet(key string) bool { 110 | return os.Getenv(key) != "" 111 | } 112 | -------------------------------------------------------------------------------- /db/config.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type Config struct { 8 | Driver string `yaml:"driver" mapstructure:"driver"` 9 | URL string `yaml:"url" mapstructure:"url"` 10 | MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" default:"10"` 11 | MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" default:"10"` 12 | ConnMaxLifeTime time.Duration `yaml:"conn_max_life_time" mapstructure:"conn_max_life_time" default:"10ms"` 13 | MaxQueryTimeout time.Duration `yaml:"max_query_timeout" mapstructure:"max_query_timeout" default:"100ms"` 14 | } 15 | -------------------------------------------------------------------------------- /db/db.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "net/url" 8 | "time" 9 | 10 | "github.com/jmoiron/sqlx" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | type Client struct { 15 | *sqlx.DB 16 | queryTimeOut time.Duration 17 | cfg Config 18 | host string 19 | } 20 | 21 | // NewClient creates a new sqlx database client 22 | func New(cfg Config) (*Client, error) { 23 | dbURL, err := url.Parse(cfg.URL) 24 | if err != nil { 25 | return nil, err 26 | } 27 | host := dbURL.Host 28 | 29 | db, err := sqlx.Connect(cfg.Driver, cfg.URL) 30 | if err != nil { 31 | return nil, err 32 | } 33 | 34 | db.SetMaxIdleConns(cfg.MaxIdleConns) 35 | db.SetMaxOpenConns(cfg.MaxOpenConns) 36 | db.SetConnMaxLifetime(cfg.ConnMaxLifeTime) 37 | 38 | return &Client{DB: db, queryTimeOut: cfg.MaxQueryTimeout, cfg: cfg, host: host}, err 39 | } 40 | 41 | func (c Client) WithTimeout(ctx context.Context, op func(ctx context.Context) error) (err error) { 42 | ctxWithTimeout, cancel := context.WithTimeout(ctx, c.queryTimeOut) 43 | defer cancel() 44 | 45 | return op(ctxWithTimeout) 46 | } 47 | 48 | func (c Client) WithTxn(ctx context.Context, txnOptions sql.TxOptions, txFunc func(*sqlx.Tx) error) (err error) { 49 | txn, err := c.BeginTxx(ctx, &txnOptions) 50 | if err != nil { 51 | return err 52 | } 53 | 54 | defer func() { 55 | if p := recover(); p != nil { 56 | switch p := p.(type) { 57 | case error: 58 | err = p 59 | default: 60 | err = errors.Errorf("%s", p) 61 | } 62 | err = txn.Rollback() 63 | panic(p) 64 | } else if err != nil { 65 | if rlbErr := txn.Rollback(); rlbErr != nil { 66 | err = fmt.Errorf("rollback error: %s while executing: %w", rlbErr, err) 67 | } else { 68 | err = fmt.Errorf("rollback: %w", err) 69 | } 70 | } else { 71 | err = txn.Commit() 72 | } 73 | }() 74 | 75 | err = txFunc(txn) 76 | return err 77 | } 78 | 79 | // ConnectionURL fetch the database connection url 80 | func (c *Client) ConnectionURL() string { 81 | return c.cfg.URL 82 | } 83 | 84 | // Host fetch the database host information 85 | func (c *Client) Host() string { 86 | return c.host 87 | } 88 | 89 | // Close closes the database connection 90 | func (c *Client) Close() error { 91 | return c.DB.Close() 92 | } 93 | -------------------------------------------------------------------------------- /db/db_test.go: -------------------------------------------------------------------------------- 1 | package db_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | "os" 9 | "testing" 10 | "time" 11 | 12 | "github.com/jmoiron/sqlx" 13 | "github.com/ory/dockertest/v3" 14 | "github.com/ory/dockertest/v3/docker" 15 | "github.com/raystack/salt/db" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | const ( 20 | dialect = "postgres" 21 | user = "root" 22 | password = "pass" 23 | database = "postgres" 24 | host = "localhost" 25 | port = "5432" 26 | dsn = "postgres://%s:%s@localhost:%s/%s?sslmode=disable" 27 | ) 28 | 29 | var ( 30 | createTableQuery = "CREATE TABLE IF NOT EXISTS users (id VARCHAR(36) PRIMARY KEY, name VARCHAR(50))" 31 | dropTableQuery = "DROP TABLE IF EXISTS users" 32 | checkTableQuery = "SELECT EXISTS(SELECT * FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users');" 33 | ) 34 | 35 | var client *db.Client 36 | 37 | func TestMain(m *testing.M) { 38 | pool, err := dockertest.NewPool("") 39 | if err != nil { 40 | log.Fatalf("Could not connect to docker: %s", err) 41 | } 42 | 43 | opts := dockertest.RunOptions{ 44 | Repository: "postgres", 45 | Tag: "14", 46 | Env: []string{ 47 | "POSTGRES_USER=" + user, 48 | "POSTGRES_PASSWORD=" + password, 49 | "POSTGRES_DB=" + database, 50 | }, 51 | ExposedPorts: []string{"5432"}, 52 | PortBindings: map[docker.Port][]docker.PortBinding{ 53 | "5432": { 54 | {HostIP: "0.0.0.0", HostPort: port}, 55 | }, 56 | }, 57 | } 58 | 59 | resource, err := pool.RunWithOptions(&opts, func(config *docker.HostConfig) { 60 | config.AutoRemove = true 61 | config.RestartPolicy = docker.RestartPolicy{Name: "no"} 62 | }) 63 | if err != nil { 64 | log.Fatalf("Could not start resource: %s", err.Error()) 65 | } 66 | 67 | fmt.Println(resource.GetPort("5432/tcp")) 68 | 69 | if err := resource.Expire(120); err != nil { 70 | log.Fatalf("Could not expire resource: %s", err.Error()) 71 | } 72 | 73 | pool.MaxWait = 60 * time.Second 74 | 75 | dsn := fmt.Sprintf(dsn, user, password, port, database) 76 | var ( 77 | pgConfig = db.Config{ 78 | Driver: "postgres", 79 | URL: dsn, 80 | } 81 | ) 82 | 83 | if err = pool.Retry(func() error { 84 | client, err = db.New(pgConfig) 85 | return err 86 | }); err != nil { 87 | log.Fatalf("Could not connect to docker: %s", err.Error()) 88 | } 89 | 90 | defer func() { 91 | client.Close() 92 | }() 93 | 94 | code := m.Run() 95 | 96 | if err := pool.Purge(resource); err != nil { 97 | log.Fatalf("Could not purge resource: %s", err) 98 | } 99 | 100 | os.Exit(code) 101 | } 102 | 103 | func TestWithTxn(t *testing.T) { 104 | if _, err := client.Exec(dropTableQuery); err != nil { 105 | log.Fatalf("Could not cleanup: %s", err) 106 | } 107 | err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error { 108 | if _, err := tx.Exec(createTableQuery); err != nil { 109 | return err 110 | } 111 | if _, err := tx.Exec(dropTableQuery); err != nil { 112 | return err 113 | } 114 | 115 | return nil 116 | }) 117 | assert.NoError(t, err) 118 | 119 | // Table should be dropped 120 | var tableExist bool 121 | result := client.QueryRow(checkTableQuery) 122 | result.Scan(&tableExist) 123 | assert.Equal(t, false, tableExist) 124 | } 125 | 126 | func TestWithTxnCommit(t *testing.T) { 127 | if _, err := client.Exec(dropTableQuery); err != nil { 128 | log.Fatalf("Could not cleanup: %s", err) 129 | } 130 | query2 := "SELECT 1" 131 | 132 | err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error { 133 | if _, err := tx.Exec(createTableQuery); err != nil { 134 | return err 135 | } 136 | if _, err := tx.Exec(query2); err != nil { 137 | return err 138 | } 139 | 140 | return nil 141 | }) 142 | // WithTx should not return an error 143 | assert.NoError(t, err) 144 | 145 | // User table should exist 146 | var tableExist bool 147 | result := client.QueryRow(checkTableQuery) 148 | result.Scan(&tableExist) 149 | assert.Equal(t, true, tableExist) 150 | } 151 | 152 | func TestWithTxnRollback(t *testing.T) { 153 | if _, err := client.Exec(dropTableQuery); err != nil { 154 | log.Fatalf("Could not cleanup: %s", err) 155 | } 156 | query2 := "WRONG QUERY" 157 | 158 | err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error { 159 | if _, err := tx.Exec(createTableQuery); err != nil { 160 | return err 161 | } 162 | if _, err := tx.Exec(query2); err != nil { 163 | return err 164 | } 165 | 166 | return nil 167 | }) 168 | // WithTx should return an error 169 | assert.Error(t, err) 170 | 171 | // Table should not be created 172 | var tableExist bool 173 | result := client.QueryRow(checkTableQuery) 174 | result.Scan(&tableExist) 175 | assert.Equal(t, false, tableExist) 176 | } 177 | -------------------------------------------------------------------------------- /db/migrate.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "fmt" 5 | "io/fs" 6 | 7 | "github.com/golang-migrate/migrate/v4" 8 | _ "github.com/golang-migrate/migrate/v4/database" 9 | _ "github.com/golang-migrate/migrate/v4/database/mysql" 10 | _ "github.com/golang-migrate/migrate/v4/database/postgres" 11 | _ "github.com/golang-migrate/migrate/v4/source/file" 12 | "github.com/golang-migrate/migrate/v4/source/iofs" 13 | ) 14 | 15 | func RunMigrations(config Config, embeddedMigrations fs.FS, resourcePath string) error { 16 | m, err := getMigrationInstance(config, embeddedMigrations, resourcePath) 17 | if err != nil { 18 | return err 19 | } 20 | 21 | err = m.Up() 22 | if err == migrate.ErrNoChange || err == nil { 23 | return nil 24 | } 25 | 26 | return err 27 | } 28 | 29 | func RunRollback(config Config, embeddedMigrations fs.FS, resourcePath string) error { 30 | m, err := getMigrationInstance(config, embeddedMigrations, resourcePath) 31 | if err != nil { 32 | return err 33 | } 34 | 35 | err = m.Steps(-1) 36 | if err == migrate.ErrNoChange || err == nil { 37 | return nil 38 | } 39 | 40 | return err 41 | } 42 | 43 | func getMigrationInstance(config Config, embeddedMigrations fs.FS, resourcePath string) (*migrate.Migrate, error) { 44 | src, err := iofs.New(embeddedMigrations, resourcePath) 45 | if err != nil { 46 | return nil, fmt.Errorf("db migrator: %v", err) 47 | } 48 | return migrate.NewWithSourceInstance("iofs", src, config.URL) 49 | } 50 | -------------------------------------------------------------------------------- /db/migrate_test.go: -------------------------------------------------------------------------------- 1 | package db_test 2 | 3 | import ( 4 | "embed" 5 | "fmt" 6 | "log" 7 | "testing" 8 | 9 | "github.com/raystack/salt/db" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | //go:embed migrations/*.sql 14 | var migrationFs embed.FS 15 | 16 | func TestRunMigrations(t *testing.T) { 17 | if _, err := client.Exec(dropTableQuery); err != nil { 18 | log.Fatalf("Could not cleanup: %s", err) 19 | } 20 | 21 | dsn := fmt.Sprintf(dsn, user, password, port, database) 22 | var ( 23 | pgConfig = db.Config{ 24 | Driver: "postgres", 25 | URL: dsn, 26 | } 27 | ) 28 | 29 | err := db.RunMigrations(pgConfig, migrationFs, "migrations") 30 | assert.NoError(t, err) 31 | 32 | // User table should exist 33 | var tableExist bool 34 | result := client.QueryRow(checkTableQuery) 35 | result.Scan(&tableExist) 36 | assert.Equal(t, true, tableExist) 37 | } 38 | 39 | func TestRunRollback(t *testing.T) { 40 | if _, err := client.Exec(dropTableQuery); err != nil { 41 | log.Fatalf("Could not cleanup: %s", err) 42 | } 43 | 44 | dsn := fmt.Sprintf(dsn, user, password, port, database) 45 | var ( 46 | pgConfig = db.Config{ 47 | Driver: "postgres", 48 | URL: dsn, 49 | } 50 | ) 51 | 52 | err := db.RunRollback(pgConfig, migrationFs, "migrations") 53 | assert.NoError(t, err) 54 | 55 | // User table should not exist 56 | var tableExist bool 57 | result := client.QueryRow(checkTableQuery) 58 | result.Scan(&tableExist) 59 | assert.Equal(t, false, tableExist) 60 | } 61 | -------------------------------------------------------------------------------- /db/migrations/1481574547_create_users_table.down.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS users -------------------------------------------------------------------------------- /db/migrations/1481574547_create_users_table.up.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS users (id VARCHAR(36) PRIMARY KEY, name VARCHAR(50)) -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/raystack/salt 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/AlecAivazis/survey/v2 v2.3.6 7 | github.com/MakeNowJust/heredoc v1.0.0 8 | github.com/NYTimes/gziphandler v1.1.1 9 | github.com/authzed/authzed-go v0.7.0 10 | github.com/authzed/grpcutil v0.0.0-20230908193239-4286bb1d6403 11 | github.com/briandowns/spinner v1.18.0 12 | github.com/charmbracelet/glamour v0.3.0 13 | github.com/cli/safeexec v1.0.0 14 | github.com/go-playground/validator v9.31.0+incompatible 15 | github.com/golang-migrate/migrate/v4 v4.16.0 16 | github.com/google/go-cmp v0.6.0 17 | github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 18 | github.com/google/uuid v1.6.0 19 | github.com/hashicorp/go-version v1.3.0 20 | github.com/jeremywohl/flatten v1.0.1 21 | github.com/jmoiron/sqlx v1.3.5 22 | github.com/lib/pq v1.10.4 23 | github.com/mattn/go-isatty v0.0.19 24 | github.com/mcuadros/go-defaults v1.2.0 25 | github.com/mitchellh/mapstructure v1.5.0 26 | github.com/muesli/termenv v0.11.1-0.20220212125758-44cd13922739 27 | github.com/oklog/run v1.1.0 28 | github.com/olekukonko/tablewriter v0.0.5 29 | github.com/ory/dockertest/v3 v3.9.1 30 | github.com/pkg/errors v0.9.1 31 | github.com/schollz/progressbar/v3 v3.8.5 32 | github.com/sirupsen/logrus v1.9.2 33 | github.com/spf13/cobra v1.8.1 34 | github.com/spf13/pflag v1.0.5 35 | github.com/spf13/viper v1.19.0 36 | github.com/stretchr/testify v1.9.0 37 | go.opentelemetry.io/contrib/instrumentation/host v0.56.0 38 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.52.0 39 | go.opentelemetry.io/contrib/instrumentation/runtime v0.56.0 40 | go.opentelemetry.io/contrib/samplers/probability/consistent v0.25.0 41 | go.opentelemetry.io/otel v1.31.0 42 | go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0 43 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 44 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0 45 | go.opentelemetry.io/otel/metric v1.31.0 46 | go.opentelemetry.io/otel/sdk v1.31.0 47 | go.opentelemetry.io/otel/sdk/metric v1.31.0 48 | go.uber.org/zap v1.21.0 49 | golang.org/x/oauth2 v0.22.0 50 | golang.org/x/text v0.19.0 51 | google.golang.org/api v0.171.0 52 | google.golang.org/grpc v1.67.1 53 | google.golang.org/protobuf v1.35.1 54 | gopkg.in/yaml.v3 v3.0.1 55 | ) 56 | 57 | require ( 58 | cloud.google.com/go/compute/metadata v0.5.0 // indirect 59 | github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect 60 | github.com/Microsoft/go-winio v0.6.1 // indirect 61 | github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect 62 | github.com/alecthomas/chroma v0.8.2 // indirect 63 | github.com/alecthomas/repr v0.2.0 // indirect 64 | github.com/aymerick/douceur v0.2.0 // indirect 65 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect 66 | github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d // indirect 67 | github.com/containerd/continuity v0.3.0 // indirect 68 | github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect 69 | github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect 70 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 71 | github.com/dlclark/regexp2 v1.2.0 // indirect 72 | github.com/docker/cli v20.10.14+incompatible // indirect 73 | github.com/docker/docker v20.10.24+incompatible // indirect 74 | github.com/docker/go-connections v0.4.0 // indirect 75 | github.com/docker/go-units v0.5.0 // indirect 76 | github.com/ebitengine/purego v0.8.0 // indirect 77 | github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect 78 | github.com/fatih/color v1.15.0 // indirect 79 | github.com/felixge/httpsnoop v1.0.4 // indirect 80 | github.com/fsnotify/fsnotify v1.7.0 // indirect 81 | github.com/go-logr/logr v1.4.2 // indirect 82 | github.com/go-logr/stdr v1.2.2 // indirect 83 | github.com/go-ole/go-ole v1.3.0 // indirect 84 | github.com/go-playground/locales v0.14.1 // indirect 85 | github.com/go-playground/universal-translator v0.18.1 // indirect 86 | github.com/go-sql-driver/mysql v1.6.0 // indirect 87 | github.com/gogo/protobuf v1.3.2 // indirect 88 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 89 | github.com/golang/protobuf v1.5.4 // indirect 90 | github.com/google/s2a-go v0.1.7 // indirect 91 | github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect 92 | github.com/gorilla/css v1.0.0 // indirect 93 | github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect 94 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect 95 | github.com/hashicorp/errwrap v1.1.0 // indirect 96 | github.com/hashicorp/go-multierror v1.1.1 // indirect 97 | github.com/hashicorp/hcl v1.0.0 // indirect 98 | github.com/imdario/mergo v0.3.12 // indirect 99 | github.com/inconshreveable/mousetrap v1.1.0 // indirect 100 | github.com/jzelinskie/stringz v0.0.0-20210414224931-d6a8ce844a70 // indirect 101 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect 102 | github.com/leodido/go-urn v1.4.0 // indirect 103 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect 104 | github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect 105 | github.com/magiconair/properties v1.8.7 // indirect 106 | github.com/mattn/go-colorable v0.1.13 // indirect 107 | github.com/mattn/go-runewidth v0.0.13 // indirect 108 | github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect 109 | github.com/microcosm-cc/bluemonday v1.0.6 // indirect 110 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect 111 | github.com/moby/term v0.5.0 // indirect 112 | github.com/muesli/reflow v0.3.0 // indirect 113 | github.com/opencontainers/go-digest v1.0.0 // indirect 114 | github.com/opencontainers/image-spec v1.0.2 // indirect 115 | github.com/opencontainers/runc v1.1.2 // indirect 116 | github.com/pelletier/go-toml/v2 v2.2.2 // indirect 117 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 118 | github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect 119 | github.com/rivo/uniseg v0.2.0 // indirect 120 | github.com/russross/blackfriday/v2 v2.1.0 // indirect 121 | github.com/sagikazarmark/locafero v0.4.0 // indirect 122 | github.com/sagikazarmark/slog-shim v0.1.0 // indirect 123 | github.com/shirou/gopsutil/v4 v4.24.9 // indirect 124 | github.com/sourcegraph/conc v0.3.0 // indirect 125 | github.com/spf13/afero v1.11.0 // indirect 126 | github.com/spf13/cast v1.6.0 // indirect 127 | github.com/stretchr/objx v0.5.2 // indirect 128 | github.com/subosito/gotenv v1.6.0 // indirect 129 | github.com/tklauser/go-sysconf v0.3.14 // indirect 130 | github.com/tklauser/numcpus v0.9.0 // indirect 131 | github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect 132 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect 133 | github.com/xeipuuv/gojsonschema v1.2.0 // indirect 134 | github.com/yuin/goldmark v1.4.13 // indirect 135 | github.com/yuin/goldmark-emoji v1.0.1 // indirect 136 | github.com/yusufpapurcu/wmi v1.2.4 // indirect 137 | go.opencensus.io v0.24.0 // indirect 138 | go.opentelemetry.io/otel/trace v1.31.0 // indirect 139 | go.opentelemetry.io/proto/otlp v1.3.1 // indirect 140 | go.uber.org/atomic v1.10.0 // indirect 141 | go.uber.org/multierr v1.9.0 // indirect 142 | golang.org/x/crypto v0.28.0 // indirect 143 | golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect 144 | golang.org/x/mod v0.18.0 // indirect 145 | golang.org/x/net v0.30.0 // indirect 146 | golang.org/x/sync v0.8.0 // indirect 147 | golang.org/x/sys v0.26.0 // indirect 148 | golang.org/x/term v0.25.0 // indirect 149 | golang.org/x/tools v0.22.0 // indirect 150 | google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect 151 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 // indirect 152 | gopkg.in/go-playground/assert.v1 v1.2.1 // indirect 153 | gopkg.in/ini.v1 v1.67.0 // indirect 154 | gopkg.in/yaml.v2 v2.4.0 // indirect 155 | gotest.tools/v3 v3.5.1 // indirect 156 | ) 157 | -------------------------------------------------------------------------------- /log/logger.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | // Option modifies the logger behavior 8 | type Option func(interface{}) 9 | 10 | // Logger is a convenient interface to use provided loggers 11 | // either use it as it is or implement your own interface where 12 | // the logging implementations are used 13 | // Each log method must take first string as message and then one or 14 | // more key,value arguments. 15 | // For example: 16 | // 17 | // timeTaken := time.Duration(time.Second * 1) 18 | // l.Debug("processed request", "time taken", timeTaken) 19 | // 20 | // here key should always be a `string` and value could be of any type as 21 | // long as it is printable. 22 | // 23 | // l.Info("processed request", "time taken", timeTaken, "started at", startedAt) 24 | type Logger interface { 25 | 26 | // Debug level message with alternating key/value pairs 27 | // key should be string, value could be anything printable 28 | Debug(msg string, args ...interface{}) 29 | 30 | // Info level message with alternating key/value pairs 31 | // key should be string, value could be anything printable 32 | Info(msg string, args ...interface{}) 33 | 34 | // Warn level message with alternating key/value pairs 35 | // key should be string, value could be anything printable 36 | Warn(msg string, args ...interface{}) 37 | 38 | // Error level message with alternating key/value pairs 39 | // key should be string, value could be anything printable 40 | Error(msg string, args ...interface{}) 41 | 42 | // Fatal level message with alternating key/value pairs 43 | // key should be string, value could be anything printable 44 | Fatal(msg string, args ...interface{}) 45 | 46 | // Level returns priority level for which this logger will filter logs 47 | Level() string 48 | 49 | // Writer used to print logs 50 | Writer() io.Writer 51 | } 52 | -------------------------------------------------------------------------------- /log/logrus.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "io" 5 | 6 | "github.com/sirupsen/logrus" 7 | ) 8 | 9 | type Logrus struct { 10 | log *logrus.Logger 11 | } 12 | 13 | func (l Logrus) getFields(args ...interface{}) map[string]interface{} { 14 | fieldMap := map[string]interface{}{} 15 | if len(args) > 1 && len(args)%2 == 0 { 16 | for i := 1; i < len(args); i += 2 { 17 | fieldMap[args[i-1].(string)] = args[i] 18 | } 19 | } 20 | return fieldMap 21 | } 22 | 23 | func (l *Logrus) Info(msg string, args ...interface{}) { 24 | l.log.WithFields(l.getFields(args...)).Info(msg) 25 | } 26 | 27 | func (l *Logrus) Debug(msg string, args ...interface{}) { 28 | l.log.WithFields(l.getFields(args...)).Debug(msg) 29 | } 30 | 31 | func (l *Logrus) Warn(msg string, args ...interface{}) { 32 | l.log.WithFields(l.getFields(args...)).Warn(msg) 33 | } 34 | 35 | func (l *Logrus) Error(msg string, args ...interface{}) { 36 | l.log.WithFields(l.getFields(args...)).Error(msg) 37 | } 38 | 39 | func (l *Logrus) Fatal(msg string, args ...interface{}) { 40 | l.log.WithFields(l.getFields(args...)).Fatal(msg) 41 | } 42 | 43 | func (l *Logrus) Level() string { 44 | return l.log.Level.String() 45 | } 46 | 47 | func (l *Logrus) Writer() io.Writer { 48 | return l.log.Writer() 49 | } 50 | 51 | func (l *Logrus) Entry(args ...interface{}) *logrus.Entry { 52 | return l.log.WithFields(l.getFields(args...)) 53 | } 54 | 55 | func LogrusWithLevel(level string) Option { 56 | return func(logger interface{}) { 57 | logLevel, err := logrus.ParseLevel(level) 58 | if err != nil { 59 | panic(err) 60 | } 61 | logger.(*Logrus).log.SetLevel(logLevel) 62 | } 63 | } 64 | 65 | func LogrusWithWriter(writer io.Writer) Option { 66 | return func(logger interface{}) { 67 | logger.(*Logrus).log.SetOutput(writer) 68 | } 69 | } 70 | 71 | // LogrusWithFormatter can be used to change default formatting 72 | // by implementing logrus.Formatter 73 | // For example: 74 | // 75 | // type PlainFormatter struct{} 76 | // func (p *PlainFormatter) Format(entry *logrus.Entry) ([]byte, error) { 77 | // return []byte(entry.Message), nil 78 | // } 79 | // l := log.NewLogrus(log.LogrusWithFormatter(&PlainFormatter{})) 80 | func LogrusWithFormatter(f logrus.Formatter) Option { 81 | return func(logger interface{}) { 82 | logger.(*Logrus).log.SetFormatter(f) 83 | } 84 | } 85 | 86 | // NewLogrus returns a logrus logger instance with info level as default log level 87 | func NewLogrus(opts ...Option) *Logrus { 88 | logger := &Logrus{ 89 | log: logrus.New(), 90 | } 91 | logger.log.Level = logrus.InfoLevel 92 | for _, opt := range opts { 93 | opt(logger) 94 | } 95 | return logger 96 | } 97 | -------------------------------------------------------------------------------- /log/logrus_test.go: -------------------------------------------------------------------------------- 1 | package log_test 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "testing" 8 | 9 | "github.com/sirupsen/logrus" 10 | 11 | "github.com/raystack/salt/log" 12 | 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestLogrus(t *testing.T) { 17 | t.Run("should parse info messages at debug level correctly", func(t *testing.T) { 18 | var b bytes.Buffer 19 | foo := bufio.NewWriter(&b) 20 | 21 | logger := log.NewLogrus(log.LogrusWithLevel("debug"), log.LogrusWithWriter(foo), log.LogrusWithFormatter(&logrus.TextFormatter{ 22 | DisableTimestamp: true, 23 | })) 24 | logger.Info("hello world") 25 | foo.Flush() 26 | 27 | assert.Equal(t, "level=info msg=\"hello world\"\n", b.String()) 28 | }) 29 | t.Run("should not parse debug messages at info level correctly", func(t *testing.T) { 30 | var b bytes.Buffer 31 | foo := bufio.NewWriter(&b) 32 | 33 | logger := log.NewLogrus(log.LogrusWithLevel("info"), log.LogrusWithWriter(foo), log.LogrusWithFormatter(&logrus.TextFormatter{ 34 | DisableTimestamp: true, 35 | })) 36 | logger.Debug("hello world") 37 | foo.Flush() 38 | 39 | assert.Equal(t, "", b.String()) 40 | }) 41 | t.Run("should parse field maps correctly", func(t *testing.T) { 42 | var b bytes.Buffer 43 | foo := bufio.NewWriter(&b) 44 | 45 | logger := log.NewLogrus(log.LogrusWithLevel("debug"), log.LogrusWithWriter(foo), log.LogrusWithFormatter(&logrus.TextFormatter{ 46 | DisableTimestamp: true, 47 | })) 48 | logger.Debug("current values", "day", 11, "month", "aug") 49 | foo.Flush() 50 | 51 | assert.Equal(t, "level=debug msg=\"current values\" day=11 month=aug\n", b.String()) 52 | }) 53 | t.Run("should handle errors correctly", func(t *testing.T) { 54 | var b bytes.Buffer 55 | foo := bufio.NewWriter(&b) 56 | 57 | logger := log.NewLogrus(log.LogrusWithLevel("info"), log.LogrusWithWriter(foo), log.LogrusWithFormatter(&logrus.TextFormatter{ 58 | DisableTimestamp: true, 59 | })) 60 | var err = fmt.Errorf("request failed") 61 | logger.Error(err.Error(), "hello", "world") 62 | foo.Flush() 63 | assert.Equal(t, "level=error msg=\"request failed\" hello=world\n", b.String()) 64 | }) 65 | t.Run("should ignore params if malformed", func(t *testing.T) { 66 | var b bytes.Buffer 67 | foo := bufio.NewWriter(&b) 68 | 69 | logger := log.NewLogrus(log.LogrusWithLevel("info"), log.LogrusWithWriter(foo), log.LogrusWithFormatter(&logrus.TextFormatter{ 70 | DisableTimestamp: true, 71 | })) 72 | var err = fmt.Errorf("request failed") 73 | logger.Error(err.Error(), "hello", "world", "!") 74 | foo.Flush() 75 | assert.Equal(t, "level=error msg=\"request failed\"\n", b.String()) 76 | }) 77 | } 78 | -------------------------------------------------------------------------------- /log/noop.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "io" 5 | "io/ioutil" 6 | ) 7 | 8 | type Noop struct{} 9 | 10 | func (n *Noop) Info(msg string, args ...interface{}) {} 11 | func (n *Noop) Debug(msg string, args ...interface{}) {} 12 | func (n *Noop) Warn(msg string, args ...interface{}) {} 13 | func (n *Noop) Error(msg string, args ...interface{}) {} 14 | func (n *Noop) Fatal(msg string, args ...interface{}) {} 15 | 16 | func (n *Noop) Level() string { 17 | return "unsupported" 18 | } 19 | func (n *Noop) Writer() io.Writer { 20 | return ioutil.Discard 21 | } 22 | 23 | // NewNoop returns a no operation logger, useful in tests 24 | func NewNoop(opts ...Option) *Noop { 25 | return &Noop{} 26 | } 27 | -------------------------------------------------------------------------------- /log/zap.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "context" 5 | "io" 6 | 7 | "go.uber.org/zap" 8 | ) 9 | 10 | type Zap struct { 11 | log *zap.SugaredLogger 12 | conf zap.Config 13 | } 14 | 15 | type ctxKey string 16 | 17 | var loggerCtxKey = ctxKey("zapLoggerCtxKey") 18 | 19 | func (z Zap) Debug(msg string, args ...interface{}) { 20 | z.log.With(args...).Debug(msg) 21 | } 22 | 23 | func (z Zap) Info(msg string, args ...interface{}) { 24 | z.log.With(args...).Info(msg) 25 | } 26 | 27 | func (z Zap) Warn(msg string, args ...interface{}) { 28 | z.log.With(args...).Warn(msg, args) 29 | } 30 | 31 | func (z Zap) Error(msg string, args ...interface{}) { 32 | z.log.With(args...).Error(msg, args) 33 | } 34 | 35 | func (z Zap) Fatal(msg string, args ...interface{}) { 36 | z.log.With(args...).Fatal(msg, args) 37 | } 38 | 39 | func (z Zap) Level() string { 40 | return z.conf.Level.String() 41 | } 42 | 43 | func (z Zap) Writer() io.Writer { 44 | panic("not supported") 45 | } 46 | 47 | func ZapWithConfig(conf zap.Config, opts ...zap.Option) Option { 48 | return func(z interface{}) { 49 | z.(*Zap).conf = conf 50 | prodLogger, err := z.(*Zap).conf.Build(opts...) 51 | if err != nil { 52 | panic(err) 53 | } 54 | z.(*Zap).log = prodLogger.Sugar() 55 | } 56 | } 57 | 58 | // GetInternalZapLogger Gets internal SugaredLogger instance 59 | func (z Zap) GetInternalZapLogger() *zap.SugaredLogger { 60 | return z.log 61 | } 62 | 63 | // NewContext will add Zap inside context 64 | func (z Zap) NewContext(ctx context.Context) context.Context { 65 | return context.WithValue(ctx, loggerCtxKey, z) 66 | } 67 | 68 | // ZapContextWithFields will add Zap Fields to logger in Context 69 | func ZapContextWithFields(ctx context.Context, fields ...zap.Field) context.Context { 70 | return context.WithValue(ctx, loggerCtxKey, Zap{ 71 | // Error when not Desugaring when adding fields: github.com/ipfs/go-log/issues/85 72 | log: ZapFromContext(ctx).GetInternalZapLogger().Desugar().With(fields...).Sugar(), 73 | conf: ZapFromContext(ctx).conf, 74 | }) 75 | } 76 | 77 | // ZapFromContext will help in fetching back zap logger from context 78 | func ZapFromContext(ctx context.Context) Zap { 79 | if ctxLogger, ok := ctx.Value(loggerCtxKey).(Zap); ok { 80 | return ctxLogger 81 | } 82 | 83 | return Zap{} 84 | } 85 | 86 | func ZapWithNoop() Option { 87 | return func(z interface{}) { 88 | z.(*Zap).log = zap.NewNop().Sugar() 89 | z.(*Zap).conf = zap.Config{} 90 | } 91 | } 92 | 93 | // NewZap returns a zap logger instance with info level as default log level 94 | func NewZap(opts ...Option) *Zap { 95 | defaultConfig := zap.NewProductionConfig() 96 | defaultConfig.Level.SetLevel(zap.InfoLevel) 97 | logger, err := defaultConfig.Build() 98 | if err != nil { 99 | panic(err) 100 | } 101 | 102 | zapper := &Zap{ 103 | log: logger.Sugar(), 104 | conf: defaultConfig, 105 | } 106 | for _, opt := range opts { 107 | opt(zapper) 108 | } 109 | return zapper 110 | } 111 | -------------------------------------------------------------------------------- /log/zap_test.go: -------------------------------------------------------------------------------- 1 | package log_test 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "crypto/rand" 8 | "fmt" 9 | "io" 10 | "net/url" 11 | "testing" 12 | "time" 13 | 14 | "github.com/stretchr/testify/assert" 15 | 16 | "go.uber.org/zap" 17 | 18 | "github.com/raystack/salt/log" 19 | ) 20 | 21 | type zapBufWriter struct { 22 | io.Writer 23 | } 24 | 25 | func (cw zapBufWriter) Close() error { 26 | return nil 27 | } 28 | func (cw zapBufWriter) Sync() error { 29 | return nil 30 | } 31 | 32 | type zapClock struct { 33 | t time.Time 34 | } 35 | 36 | func (m zapClock) Now() time.Time { 37 | return m.t 38 | } 39 | 40 | func (m zapClock) NewTicker(duration time.Duration) *time.Ticker { 41 | return time.NewTicker(duration) 42 | } 43 | 44 | func buildBufferedZapOption(writer io.Writer, t time.Time, bufWriterKey string) log.Option { 45 | config := zap.NewDevelopmentConfig() 46 | config.DisableCaller = true 47 | // register mock writer 48 | _ = zap.RegisterSink(bufWriterKey, func(u *url.URL) (zap.Sink, error) { 49 | return zapBufWriter{writer}, nil 50 | }) 51 | // build a valid custom path 52 | customPath := fmt.Sprintf("%s:", bufWriterKey) 53 | config.OutputPaths = []string{customPath} 54 | 55 | return log.ZapWithConfig(config, zap.WithClock(&zapClock{ 56 | t: t, 57 | })) 58 | } 59 | 60 | func TestZap(t *testing.T) { 61 | mockedTime := time.Date(2021, 6, 10, 11, 55, 0, 0, time.UTC) 62 | 63 | t.Run("should successfully print at info level", func(t *testing.T) { 64 | var b bytes.Buffer 65 | bWriter := bufio.NewWriter(&b) 66 | 67 | zapper := log.NewZap(buildBufferedZapOption(bWriter, mockedTime, randomString(10))) 68 | zapper.Info("hello", "wor", "ld") 69 | bWriter.Flush() 70 | 71 | assert.Equal(t, mockedTime.Format("2006-01-02T15:04:05.000Z0700")+"\tINFO\thello\t{\"wor\": \"ld\"}\n", b.String()) 72 | }) 73 | 74 | t.Run("should successfully print log from context", func(t *testing.T) { 75 | var b bytes.Buffer 76 | bWriter := bufio.NewWriter(&b) 77 | 78 | zapper := log.NewZap(buildBufferedZapOption(bWriter, mockedTime, randomString(10))) 79 | ctx := zapper.NewContext(context.Background()) 80 | contextualLog := log.ZapFromContext(ctx) 81 | contextualLog.Info("hello", "wor", "ld") 82 | bWriter.Flush() 83 | 84 | assert.Equal(t, mockedTime.Format("2006-01-02T15:04:05.000Z0700")+"\tINFO\thello\t{\"wor\": \"ld\"}\n", b.String()) 85 | }) 86 | 87 | t.Run("should successfully print log from context with fields", func(t *testing.T) { 88 | var b bytes.Buffer 89 | bWriter := bufio.NewWriter(&b) 90 | 91 | zapper := log.NewZap(buildBufferedZapOption(bWriter, mockedTime, randomString(10))) 92 | ctx := zapper.NewContext(context.Background()) 93 | ctx = log.ZapContextWithFields(ctx, zap.Int("one", 1)) 94 | ctx = log.ZapContextWithFields(ctx, zap.String("two", "two")) 95 | log.ZapFromContext(ctx).Info("hello", "wor", "ld") 96 | bWriter.Flush() 97 | 98 | assert.Equal(t, mockedTime.Format("2006-01-02T15:04:05.000Z0700")+"\tINFO\thello\t{\"one\": 1, \"two\": \"two\", \"wor\": \"ld\"}\n", b.String()) 99 | }) 100 | } 101 | 102 | func randomString(n int) string { 103 | const alphabets = "abcdefghijklmnopqrstuvwxyz" 104 | var alphaBytes = make([]byte, n) 105 | rand.Read(alphaBytes) 106 | for i, b := range alphaBytes { 107 | alphaBytes[i] = alphabets[b%byte(len(alphabets))] 108 | } 109 | return string(alphaBytes) 110 | } 111 | -------------------------------------------------------------------------------- /rql/README.md: -------------------------------------------------------------------------------- 1 | # RQL (Rest Query Language) 2 | 3 | A library to parse support advanced REST API query parameters like (filter, pagination, sort, group, search) and logical operators on the keys (like eq, neq, like, gt, lt etc) 4 | 5 | It takes a Golang struct and a json string as input and returns a Golang object that can be used to prepare SQL Statements (using raw sql or ORM Query builders). 6 | 7 | ### Usage 8 | 9 | Frontend should send the parameters and operator like this schema to the backend service on some route with `POST` HTTP Method 10 | 11 | ```json 12 | { 13 | "filters": [ 14 | { "name": "id", "operator": "neq", "value": 20 }, 15 | { "name": "title", "operator": "neq", "value": "nasa" }, 16 | { "name": "enabled", "operator": "eq", "value": false }, 17 | { 18 | "name": "created_at", 19 | "operator": "gte", 20 | "value": "2025-02-05T11:25:37.957Z" 21 | }, 22 | { "name": "title", "operator": "like", "value": "xyz" } 23 | ], 24 | "group_by": ["plan_name"], 25 | "offset": 20, 26 | "limit": 50, 27 | "search": "abcd", 28 | "sort": [ 29 | { "name": "title", "order": "desc" }, 30 | { "name": "created_at", "order": "asc" } 31 | ] 32 | } 33 | ``` 34 | 35 | The `rql` library can be used to parse this json, validate it and returns a Struct containing all the info to generate the operations and values for SQL. 36 | 37 | The validation happens via stuct tags defined on your model. Example: 38 | 39 | ```golang 40 | type Organization struct { 41 | Id int `rql:"name=id,type=number,min=10,max=200"` 42 | BillingPlanName string `rql:"name=plan_name,type=string"` 43 | CreatedAt time.Time `rql:"name=created_at,type=datetime"` 44 | MemberCount int `rql:"name=member_count,type=number"` 45 | Title string `rql:"name=title,type=string"` 46 | Enabled bool `rql:"name=enabled,type=bool"` 47 | } 48 | 49 | ``` 50 | 51 | **Supported data types:** 52 | 53 | 1. number 54 | 2. string 55 | 3. datetime 56 | 4. bool 57 | 58 | Check `main.go` for more info on usage. 59 | 60 | Using this struct, a SQL query can be generated. Here is an example using `goqu` SQL Builder 61 | 62 | ```go 63 | //init the library's "Query" object with input json bytes 64 | userInput := &parser.Query{} 65 | 66 | //assuming jsonBytes is defined earlier 67 | err = json.Unmarshal(jsonBytes, userInput) 68 | if err != nil { 69 | panic(fmt.Sprintf("failed to unmarshal query string to parser query struct, err:%s", err.Error())) 70 | } 71 | 72 | //validate the json input 73 | err = parser.ValidateQuery(userInput, Organization{}) 74 | if err != nil { 75 | panic(err) 76 | } 77 | 78 | //userInput object can be utilized to prepare SQL statement 79 | query := goqu.From("organizations") 80 | 81 | fuzzySearchColumns := []string{"id", "billing_plan_name", "title"} 82 | 83 | for _, filter_item := range userInput.Filters { 84 | query = query.Where(goqu.Ex{ 85 | filter_item.Name: goqu.Op{filter_item.Operator: filter_item.Value}, 86 | }) 87 | } 88 | 89 | listOfExpressions := make([]goqu.Expression, 0) 90 | 91 | if userInput.Search != "" { 92 | for _, col := range fuzzySearchColumns { 93 | listOfExpressions = append(listOfExpressions, goqu.Ex{ 94 | col: goqu.Op{"LIKE": userInput.Search}, 95 | }) 96 | } 97 | } 98 | 99 | query = query.Where(goqu.Or(listOfExpressions...)) 100 | 101 | query = query.Offset(uint(userInput.Offset)) 102 | for _, sort_item := range userInput.Sort { 103 | switch sort_item.Order { 104 | case "asc": 105 | query = query.OrderAppend(goqu.C(sort_item.Name).Asc()) 106 | case "desc": 107 | query = query.OrderAppend(goqu.C(sort_item.Name).Desc()) 108 | default: 109 | } 110 | } 111 | query = query.Limit(uint(userInput.Limit)) 112 | sql, _, _ := query.ToSQL() 113 | fmt.Println(sql) 114 | 115 | 116 | ``` 117 | 118 | giving output as 119 | 120 | ```sql 121 | SELECT * FROM "organizations" WHERE (("id" != 20) AND ("title" != 'nasa') AND ("enabled" IS FALSE) AND ("createdAt" >= '2025-02-05T11:25:37.957Z') AND ("title" LIKE 'xyz') AND (("id" LIKE 'abcd') OR ("billing_plan_name" LIKE 'abcd') OR ("title" LIKE 'abcd'))) ORDER BY "title" DESC, "createdAt" ASC LIMIT 50 OFFSET 20 122 | ``` 123 | 124 | ### Improvements 125 | 126 | 1. The operators need to mapped with SQL operators like (`eq` should be converted to `=` etc). Right now we are relying on GoQU to do that, but we can make it SQL ORL lib agnostic. 127 | 128 | 2. Support validation on the range or values of the data. Like `min`, `max` on number etc. 129 | -------------------------------------------------------------------------------- /rql/parser.go: -------------------------------------------------------------------------------- 1 | package rql 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "slices" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | var validNumberOperations = []string{"eq", "neq", "gt", "lt", "gte", "lte"} 12 | var validStringOperations = []string{"eq", "neq", "like", "in", "notin", "notlike", "empty", "notempty"} 13 | var validBoolOperations = []string{"eq", "neq"} 14 | var validDatetimeOperations = []string{"eq", "neq", "gt", "lt", "gte", "lte"} 15 | 16 | const TAG = "rql" 17 | const DATATYPE_NUMBER = "number" 18 | const DATATYPE_DATETIME = "datetime" 19 | const DATATYPE_STRING = "string" 20 | const DATATYPE_BOOL = "bool" 21 | const SORT_ORDER_ASC = "asc" 22 | const SORT_ORDER_DESC = "desc" 23 | 24 | var validSortOrder = []string{SORT_ORDER_ASC, SORT_ORDER_DESC} 25 | 26 | type Query struct { 27 | Filters []Filter `json:"filters"` 28 | GroupBy []string `json:"group_by"` 29 | Offset int `json:"offset"` 30 | Limit int `json:"limit"` 31 | Search string `json:"search"` 32 | Sort []Sort `json:"sort"` 33 | } 34 | 35 | type Filter struct { 36 | Name string `json:"name"` 37 | Operator string `json:"operator"` 38 | dataType string 39 | Value any `json:"value"` 40 | } 41 | 42 | type Sort struct { 43 | Name string `json:"name"` 44 | Order string `json:"order"` 45 | } 46 | 47 | func ValidateQuery(q *Query, checkStruct interface{}) error { 48 | val := reflect.ValueOf(checkStruct) 49 | 50 | // validate filters 51 | for _, filterItem := range q.Filters { 52 | //validate filter key name 53 | filterIdx := searchKeyInsideStruct(filterItem.Name, val) 54 | if filterIdx < 0 { 55 | return fmt.Errorf("'%s' is not a valid filter key", filterItem.Name) 56 | } 57 | structKeyTag := val.Type().Field(filterIdx).Tag.Get(TAG) 58 | 59 | // validate filter key data type 60 | allowedDataType := getDataTypeOfField(structKeyTag) 61 | filterItem.dataType = allowedDataType 62 | switch allowedDataType { 63 | case DATATYPE_NUMBER: 64 | err := validateNumberType(filterItem) 65 | if err != nil { 66 | return err 67 | } 68 | case DATATYPE_BOOL: 69 | err := validateBoolType(filterItem) 70 | if err != nil { 71 | return err 72 | } 73 | case DATATYPE_DATETIME: 74 | err := validateDatetimeType(filterItem) 75 | if err != nil { 76 | return err 77 | } 78 | case DATATYPE_STRING: 79 | err := validateStringType(filterItem) 80 | if err != nil { 81 | return err 82 | } 83 | default: 84 | return fmt.Errorf("type '%s' is not recognized", allowedDataType) 85 | } 86 | 87 | if !isValidOperator(filterItem) { 88 | return fmt.Errorf("value '%s' for key '%s' is valid string", filterItem.Operator, filterItem.Name) 89 | } 90 | } 91 | 92 | err := validateGroupByKeys(q, val) 93 | if err != nil { 94 | return err 95 | } 96 | return validateSortKey(q, val) 97 | } 98 | 99 | func validateNumberType(filterItem Filter) error { 100 | // check if the type is any of Golang numeric types 101 | // if not, return error 102 | switch filterItem.Value.(type) { 103 | case uint8, uint16, uint32, uint64, int8, int16, int32, int64, float32, float64, int, uint: 104 | return nil 105 | default: 106 | return fmt.Errorf("value %v for key '%s' is not int type", filterItem.Value, filterItem.Name) 107 | } 108 | } 109 | 110 | func validateDatetimeType(filterItem Filter) error { 111 | // cast the value to datetime 112 | // if failed, return error 113 | castedVal, ok := filterItem.Value.(string) 114 | if !ok { 115 | return fmt.Errorf("value %s for key '%s' is not a valid ISO datetime string", filterItem.Value, filterItem.Name) 116 | } 117 | _, err := time.Parse(time.RFC3339, castedVal) 118 | if err != nil { 119 | return fmt.Errorf("value %s for key '%s' is not a valid ISO datetime string", filterItem.Value, filterItem.Name) 120 | } 121 | return nil 122 | } 123 | 124 | func validateBoolType(filterItem Filter) error { 125 | // cast the value to bool 126 | // if failed, return error 127 | _, ok := filterItem.Value.(bool) 128 | if !ok { 129 | return fmt.Errorf("value %v for key '%s' is not bool type", filterItem.Value, filterItem.Name) 130 | } 131 | return nil 132 | } 133 | 134 | func validateStringType(filterItem Filter) error { 135 | // cast the value to string 136 | // if failed, return error 137 | _, ok := filterItem.Value.(string) 138 | if !ok { 139 | return fmt.Errorf("value %s for key '%s' is valid string type", filterItem.Value, filterItem.Name) 140 | } 141 | return nil 142 | } 143 | 144 | func searchKeyInsideStruct(keyName string, val reflect.Value) int { 145 | normalizedKey := strings.ToLower(keyName) 146 | 147 | for i := 0; i < val.NumField(); i++ { 148 | field := val.Type().Field(i) 149 | 150 | // Check field name 151 | if strings.ToLower(field.Name) == normalizedKey { 152 | return i 153 | } 154 | 155 | // Check rql tag 156 | if tag, ok := field.Tag.Lookup("rql"); ok { 157 | // Parse the tag string 158 | tagParts := strings.Split(tag, ",") 159 | for _, part := range tagParts { 160 | if strings.HasPrefix(part, "name=") { 161 | tagName := strings.TrimPrefix(part, "name=") 162 | if strings.ToLower(tagName) == normalizedKey { 163 | return i 164 | } 165 | } 166 | } 167 | } 168 | } 169 | 170 | return -1 171 | } 172 | 173 | // parse the tag schema which is of the format 174 | // type=int,min=10,max=200 175 | // to extract type else fallback to string 176 | func getDataTypeOfField(tagString string) string { 177 | res := DATATYPE_STRING 178 | splitted := strings.Split(tagString, ",") 179 | for _, item := range splitted { 180 | kvSplitted := strings.Split(item, "=") 181 | if len(kvSplitted) == 2 { 182 | if kvSplitted[0] == "type" { 183 | return kvSplitted[1] 184 | } 185 | } 186 | } 187 | //fallback to string if type not found in tag value 188 | return res 189 | } 190 | 191 | func GetDataTypeOfField(fieldName string, checkStruct interface{}) (string, error) { 192 | val := reflect.ValueOf(checkStruct) 193 | filterIdx := searchKeyInsideStruct(fieldName, val) 194 | if filterIdx < 0 { 195 | return "", fmt.Errorf("'%s' is not a valid field", fieldName) 196 | } 197 | structKeyTag := val.Type().Field(filterIdx).Tag.Get(TAG) 198 | dataType := getDataTypeOfField(structKeyTag) 199 | if !slices.Contains([]string{DATATYPE_STRING, DATATYPE_BOOL, DATATYPE_NUMBER, DATATYPE_DATETIME}, dataType) { 200 | return "", fmt.Errorf("invalid datatype '%s' is for field %s", dataType, fieldName) 201 | } 202 | return dataType, nil 203 | } 204 | 205 | func isValidOperator(filterItem Filter) bool { 206 | switch filterItem.dataType { 207 | case DATATYPE_NUMBER: 208 | return slices.Contains(validNumberOperations, filterItem.Operator) 209 | case DATATYPE_DATETIME: 210 | return slices.Contains(validDatetimeOperations, filterItem.Operator) 211 | case DATATYPE_STRING: 212 | return slices.Contains(validStringOperations, filterItem.Operator) 213 | case DATATYPE_BOOL: 214 | return slices.Contains(validBoolOperations, filterItem.Operator) 215 | default: 216 | return false 217 | } 218 | } 219 | 220 | func validateSortKey(q *Query, val reflect.Value) error { 221 | for _, item := range q.Sort { 222 | filterIdx := searchKeyInsideStruct(item.Name, val) 223 | if filterIdx < 0 { 224 | return fmt.Errorf("'%s' is not a valid sort key", item.Name) 225 | } 226 | if !slices.Contains(validSortOrder, item.Order) { 227 | return fmt.Errorf("'%s' is not a valid sort key", item.Name) 228 | } 229 | } 230 | return nil 231 | } 232 | 233 | func validateGroupByKeys(q *Query, val reflect.Value) error { 234 | for _, item := range q.GroupBy { 235 | filterIdx := searchKeyInsideStruct(item, val) 236 | if filterIdx < 0 { 237 | return fmt.Errorf("'%s' is not a valid sort key", item) 238 | } 239 | } 240 | return nil 241 | } 242 | -------------------------------------------------------------------------------- /rql/parser_test.go: -------------------------------------------------------------------------------- 1 | package rql 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestValidateQuery(t *testing.T) { 10 | type TestStruct struct { 11 | ID int32 `rql:"name=id,type=number"` 12 | Name string `rql:"name=name,type=string"` 13 | IsActive bool `rql:"name=is_active,type=bool"` 14 | CreatedAt time.Time `rql:"name=created_at,type=datetime"` 15 | } 16 | 17 | tests := []struct { 18 | name string 19 | query Query 20 | checkStruct TestStruct 21 | expectErr bool 22 | }{ 23 | { 24 | name: "Valid filters and sort", 25 | query: Query{ 26 | Filters: []Filter{ 27 | {Name: "ID", Operator: "eq", Value: 123}, 28 | {Name: "Name", Operator: "like", Value: "test"}, 29 | {Name: "is_active", Operator: "eq", Value: true}, 30 | {Name: "created_at", Operator: "eq", Value: "2021-09-15T15:53:00Z"}, 31 | }, 32 | Sort: []Sort{ 33 | {Name: "ID", Order: "asc"}, 34 | }, 35 | }, 36 | checkStruct: TestStruct{}, 37 | expectErr: false, 38 | }, 39 | { 40 | name: "Invalid filter key", 41 | query: Query{ 42 | Filters: []Filter{ 43 | {Name: "NonExistentKey", Operator: "eq", Value: "test"}, 44 | }, 45 | }, 46 | checkStruct: TestStruct{}, 47 | expectErr: true, 48 | }, 49 | { 50 | name: "Invalid filter operator", 51 | query: Query{ 52 | Filters: []Filter{ 53 | {Name: "ID", Operator: "invalid", Value: 123}, 54 | }, 55 | }, 56 | checkStruct: TestStruct{}, 57 | expectErr: true, 58 | }, 59 | { 60 | name: "Invalid filter value type", 61 | query: Query{ 62 | Filters: []Filter{ 63 | {Name: "ID", Operator: "eq", Value: "invalid"}, 64 | }, 65 | }, 66 | checkStruct: TestStruct{}, 67 | expectErr: true, 68 | }, 69 | { 70 | name: "Invalid sort key", 71 | query: Query{ 72 | Sort: []Sort{ 73 | {Name: "NonExistentKey", Order: "asc"}, 74 | }, 75 | }, 76 | checkStruct: TestStruct{}, 77 | expectErr: true, 78 | }, 79 | } 80 | 81 | for _, tt := range tests { 82 | t.Run(tt.name, func(t *testing.T) { 83 | err := ValidateQuery(&tt.query, tt.checkStruct) 84 | if (err != nil) != tt.expectErr { 85 | t.Errorf("ValidateQuery() error = %v, expectErr %v", err, tt.expectErr) 86 | } 87 | }) 88 | } 89 | } 90 | 91 | func TestGetDataTypeOfField(t *testing.T) { 92 | type TestStruct struct { 93 | StringField string `rql:"name=string_field,type=string"` 94 | NumberField int `rql:"name=number_field,type=number"` 95 | BoolField bool `rql:"name=bool_field,type=bool"` 96 | DateTimeField time.Time `rql:"name=datetime_field,type=datetime"` 97 | InvalidField string `rql:"name=invalid_field,type=invalid"` 98 | NoTypeField string `rql:"name=no_type_field"` // No type specified 99 | NoTagField string // No tag at all 100 | } 101 | 102 | tests := []struct { 103 | name string 104 | fieldName string 105 | expectedType string 106 | expectedError bool 107 | errorContains string 108 | }{ 109 | { 110 | name: "String field by struct name", 111 | fieldName: "StringField", 112 | expectedType: "string", 113 | expectedError: false, 114 | }, 115 | { 116 | name: "String field by tag name", 117 | fieldName: "string_field", 118 | expectedType: "string", 119 | expectedError: false, 120 | }, 121 | { 122 | name: "Number field by struct name", 123 | fieldName: "NumberField", 124 | expectedType: "number", 125 | expectedError: false, 126 | }, 127 | { 128 | name: "Number field by tag name", 129 | fieldName: "number_field", 130 | expectedType: "number", 131 | expectedError: false, 132 | }, 133 | { 134 | name: "Bool field by struct name", 135 | fieldName: "BoolField", 136 | expectedType: "bool", 137 | expectedError: false, 138 | }, 139 | { 140 | name: "DateTime field by struct name", 141 | fieldName: "DateTimeField", 142 | expectedType: "datetime", 143 | expectedError: false, 144 | }, 145 | { 146 | name: "Invalid field name", 147 | fieldName: "NonExistentField", 148 | expectedType: "", 149 | expectedError: true, 150 | errorContains: "is not a valid field", 151 | }, 152 | { 153 | name: "No type specified in tag", 154 | fieldName: "NoTypeField", 155 | expectedType: "string", // Should default to string 156 | expectedError: false, 157 | }, 158 | { 159 | name: "No tag field", 160 | fieldName: "NoTagField", 161 | expectedType: "string", // Should default to string 162 | expectedError: false, 163 | }, 164 | } 165 | 166 | testStruct := TestStruct{} 167 | 168 | for _, tt := range tests { 169 | t.Run(tt.name, func(t *testing.T) { 170 | dataType, err := GetDataTypeOfField(tt.fieldName, testStruct) 171 | 172 | // Check error cases 173 | if tt.expectedError { 174 | if err == nil { 175 | t.Errorf("Expected error but got none") 176 | return 177 | } 178 | if !strings.Contains(err.Error(), tt.errorContains) { 179 | t.Errorf("Expected error containing '%s', got '%s'", tt.errorContains, err.Error()) 180 | } 181 | return 182 | } 183 | 184 | // Check success cases 185 | if err != nil { 186 | t.Errorf("Unexpected error: %v", err) 187 | return 188 | } 189 | 190 | if dataType != tt.expectedType { 191 | t.Errorf("Expected type '%s', got '%s'", tt.expectedType, dataType) 192 | } 193 | }) 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /server/mux/README.md: -------------------------------------------------------------------------------- 1 | # Mux 2 | 3 | `mux` package provides helpers for starting multiple servers. HTTP and gRPC 4 | servers are supported currently. 5 | 6 | ## Usage 7 | 8 | ```go 9 | package main 10 | 11 | import ( 12 | "context" 13 | "log" 14 | "net/http" 15 | "os/signal" 16 | "syscall" 17 | "time" 18 | 19 | "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" 20 | "github.com/raystack/salt/common" 21 | "github.com/raystack/salt/mux" 22 | "google.golang.org/grpc" 23 | "google.golang.org/grpc/reflection" 24 | ) 25 | 26 | func main() { 27 | ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 28 | defer cancel() 29 | 30 | grpcServer := grpc.NewServer() 31 | 32 | reflection.Register(grpcServer) 33 | 34 | grpcGateway := runtime.NewServeMux() 35 | 36 | httpMux := http.NewServeMux() 37 | httpMux.Handle("/api/", http.StripPrefix("/api", grpcGateway)) 38 | 39 | log.Fatalf("server exited: %v", mux.Serve( 40 | ctx, 41 | mux.WithHTTPTarget(":8080", &http.Server{ 42 | Handler: httpMux, 43 | ReadTimeout: 120 * time.Second, 44 | WriteTimeout: 120 * time.Second, 45 | MaxHeaderBytes: 1 << 20, 46 | }), 47 | mux.WithGRPCTarget(":8081", grpcServer), 48 | mux.WithGracePeriod(5*time.Second), 49 | )) 50 | } 51 | 52 | type SlowCommonService struct { 53 | *common.CommonService 54 | } 55 | 56 | func (s SlowCommonService) GetVersion(ctx context.Context, req *commonv1.GetVersionRequest) (*commonv1.GetVersionResponse, error) { 57 | for i := 0; i < 5; i++ { 58 | log.Printf("dooing stuff") 59 | time.Sleep(1 * time.Second) 60 | } 61 | return s.CommonService.GetVersion(ctx, req) 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /server/mux/mux.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "net" 9 | "time" 10 | 11 | "github.com/oklog/run" 12 | ) 13 | 14 | const ( 15 | defaultGracePeriod = 10 * time.Second 16 | ) 17 | 18 | // Serve starts TCP listeners and serves the registered protocol servers of the 19 | // given serveTarget(s) and blocks until the servers exit. Context can be 20 | // cancelled to perform graceful shutdown. 21 | func Serve(ctx context.Context, opts ...Option) error { 22 | mux := muxServer{gracePeriod: defaultGracePeriod} 23 | for _, opt := range opts { 24 | if err := opt(&mux); err != nil { 25 | return err 26 | } 27 | } 28 | 29 | if len(mux.targets) == 0 { 30 | return errors.New("mux serve: at least one serve target must be set") 31 | } 32 | 33 | return mux.Serve(ctx) 34 | } 35 | 36 | type muxServer struct { 37 | targets []serveTarget 38 | gracePeriod time.Duration 39 | } 40 | 41 | func (mux *muxServer) Serve(ctx context.Context) error { 42 | var g run.Group 43 | for _, t := range mux.targets { 44 | l, err := net.Listen("tcp", t.Address()) 45 | if err != nil { 46 | return fmt.Errorf("mux serve: %w", err) 47 | } 48 | 49 | t := t // redeclare to avoid referring to updated value inside closures. 50 | g.Add(func() error { 51 | err := t.Serve(l) 52 | if err != nil { 53 | log.Print("[ERROR] Serve:", err) 54 | } 55 | return err 56 | }, func(error) { 57 | ctx, cancel := context.WithTimeout(context.Background(), mux.gracePeriod) 58 | defer cancel() 59 | 60 | if err := t.Shutdown(ctx); err != nil { 61 | log.Print("[ERROR] Shutdown server gracefully:", err) 62 | } 63 | }) 64 | } 65 | 66 | g.Add(func() error { 67 | <-ctx.Done() 68 | return ctx.Err() 69 | }, func(error) { 70 | }) 71 | 72 | return g.Run() 73 | } 74 | -------------------------------------------------------------------------------- /server/mux/option.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "google.golang.org/grpc" 8 | ) 9 | 10 | // Option values can be used with Serve() for customisation. 11 | type Option func(m *muxServer) error 12 | 13 | func WithHTTPTarget(addr string, srv *http.Server) Option { 14 | srv.Addr = addr 15 | return func(m *muxServer) error { 16 | m.targets = append(m.targets, httpServeTarget{Server: srv}) 17 | return nil 18 | } 19 | } 20 | 21 | func WithGRPCTarget(addr string, srv *grpc.Server) Option { 22 | return func(m *muxServer) error { 23 | m.targets = append(m.targets, gRPCServeTarget{ 24 | Addr: addr, 25 | Server: srv, 26 | }) 27 | return nil 28 | } 29 | } 30 | 31 | // WithGracePeriod sets the wait duration for graceful shutdown. 32 | func WithGracePeriod(d time.Duration) Option { 33 | return func(m *muxServer) error { 34 | if d <= 0 { 35 | d = defaultGracePeriod 36 | } 37 | m.gracePeriod = d 38 | return nil 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /server/mux/serve_target.go: -------------------------------------------------------------------------------- 1 | package mux 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "net/http" 8 | 9 | "google.golang.org/grpc" 10 | ) 11 | 12 | type serveTarget interface { 13 | Address() string 14 | Serve(l net.Listener) error 15 | Shutdown(ctx context.Context) error 16 | } 17 | 18 | type httpServeTarget struct { 19 | *http.Server 20 | } 21 | 22 | func (h httpServeTarget) Address() string { return h.Addr } 23 | 24 | func (h httpServeTarget) Serve(l net.Listener) error { 25 | if err := h.Server.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) { 26 | return err 27 | } 28 | return nil 29 | } 30 | 31 | type gRPCServeTarget struct { 32 | Addr string 33 | *grpc.Server 34 | } 35 | 36 | func (g gRPCServeTarget) Address() string { return g.Addr } 37 | 38 | func (g gRPCServeTarget) Shutdown(ctx context.Context) error { 39 | signal := make(chan struct{}) 40 | go func() { 41 | defer close(signal) 42 | 43 | g.GracefulStop() 44 | }() 45 | 46 | select { 47 | case <-ctx.Done(): 48 | g.Stop() 49 | return errors.New("graceful stop failed") 50 | 51 | case <-signal: 52 | } 53 | 54 | return nil 55 | } 56 | -------------------------------------------------------------------------------- /server/spa/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package spa provides a simple and efficient HTTP handler for serving 3 | Single Page Applications (SPAs). 4 | 5 | The handler serves static files from an embedded file system and falls 6 | back to serving an index file for client-side routing. Optionally, it 7 | supports gzip compression for optimizing responses. 8 | 9 | Features: 10 | - Serves static assets from an embedded file system. 11 | - Fallback to an index file for client-side routing. 12 | - Optional gzip compression for supported clients. 13 | 14 | Usage: 15 | 16 | To use this package, embed your SPA's build assets into your binary using 17 | the `embed` package. Then, create an SPA handler using the `Handler` function 18 | and register it with an HTTP server. 19 | 20 | Example: 21 | 22 | package main 23 | 24 | import ( 25 | "embed" 26 | "log" 27 | "net/http" 28 | 29 | "yourmodule/spa" 30 | ) 31 | 32 | //go:embed build/* 33 | var build embed.FS 34 | 35 | func main() { 36 | handler, err := spa.Handler(build, "build", "index.html", true) 37 | if err != nil { 38 | log.Fatalf("Failed to initialize SPA handler: %v", err) 39 | } 40 | 41 | log.Println("Serving SPA on http://localhost:8080") 42 | http.ListenAndServe(":8080", handler) 43 | } 44 | */ 45 | package spa 46 | -------------------------------------------------------------------------------- /server/spa/handler.go: -------------------------------------------------------------------------------- 1 | package spa 2 | 3 | import ( 4 | "embed" 5 | "errors" 6 | "fmt" 7 | "io/fs" 8 | "net/http" 9 | 10 | "github.com/NYTimes/gziphandler" 11 | ) 12 | 13 | // Handler returns an HTTP handler for serving a Single Page Application (SPA). 14 | // 15 | // The handler serves static files from the specified directory in the embedded 16 | // file system and falls back to serving the index file if a requested file is not found. 17 | // This is useful for client-side routing in SPAs. 18 | // 19 | // Parameters: 20 | // - build: An embedded file system containing the build assets. 21 | // - dir: The directory within the embedded file system where the static files are located. 22 | // - index: The name of the index file (usually "index.html"). 23 | // - gzip: If true, the response body will be compressed using gzip for clients that support it. 24 | // 25 | // Returns: 26 | // - An http.Handler that serves the SPA and optional gzip compression. 27 | // - An error if the file system or index file cannot be initialized. 28 | func Handler(build embed.FS, dir string, index string, gzip bool) (http.Handler, error) { 29 | fsys, err := fs.Sub(build, dir) 30 | if err != nil { 31 | return nil, fmt.Errorf("couldn't create sub filesystem: %w", err) 32 | } 33 | 34 | if _, err = fsys.Open(index); err != nil { 35 | if errors.Is(err, fs.ErrNotExist) { 36 | return nil, fmt.Errorf("ui is enabled but no index.html found: %w", err) 37 | } else { 38 | return nil, fmt.Errorf("ui assets error: %w", err) 39 | } 40 | } 41 | router := &router{index: index, fs: http.FS(fsys)} 42 | 43 | hlr := http.FileServer(router) 44 | 45 | if !gzip { 46 | return hlr, nil 47 | } 48 | return gziphandler.GzipHandler(hlr), nil 49 | } 50 | -------------------------------------------------------------------------------- /server/spa/router.go: -------------------------------------------------------------------------------- 1 | package spa 2 | 3 | import ( 4 | "errors" 5 | "io/fs" 6 | "net/http" 7 | ) 8 | 9 | // router is the http filesystem which only serves files 10 | // and prevent the directory traversal. 11 | type router struct { 12 | index string 13 | fs http.FileSystem 14 | } 15 | 16 | // Open inspects the URL path to locate a file within the static dir. 17 | // If a file is found, it will be served. If not, the file located at 18 | // the index path on the SPA handler will be served. 19 | func (r *router) Open(name string) (http.File, error) { 20 | file, err := r.fs.Open(name) 21 | 22 | if err == nil { 23 | return file, nil 24 | } 25 | // Serve index if file does not exist. 26 | if errors.Is(err, fs.ErrNotExist) { 27 | file, err := r.fs.Open(r.index) 28 | return file, err 29 | } 30 | 31 | return nil, err 32 | } 33 | -------------------------------------------------------------------------------- /telemetry/opentelemetry.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/raystack/salt/log" 9 | "go.opentelemetry.io/contrib/instrumentation/host" 10 | "go.opentelemetry.io/contrib/instrumentation/runtime" 11 | "go.opentelemetry.io/contrib/samplers/probability/consistent" 12 | "go.opentelemetry.io/otel" 13 | "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" 14 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace" 15 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" 16 | "go.opentelemetry.io/otel/propagation" 17 | sdkmetric "go.opentelemetry.io/otel/sdk/metric" 18 | "go.opentelemetry.io/otel/sdk/resource" 19 | sdktrace "go.opentelemetry.io/otel/sdk/trace" 20 | semconv "go.opentelemetry.io/otel/semconv/v1.20.0" 21 | "google.golang.org/grpc/encoding/gzip" 22 | ) 23 | 24 | type OpenTelemetryConfig struct { 25 | Enabled bool `yaml:"enabled" mapstructure:"enabled" default:"false"` 26 | CollectorAddr string `yaml:"collector_addr" mapstructure:"collector_addr" default:"localhost:4317"` 27 | PeriodicReadInterval time.Duration `yaml:"periodic_read_interval" mapstructure:"periodic_read_interval" default:"1s"` 28 | TraceSampleProbability float64 `yaml:"trace_sample_probability" mapstructure:"trace_sample_probability" default:"1"` 29 | VerboseResourceLabelsEnabled bool `yaml:"verbose_resource_labels_enabled" mapstructure:"verbose_resource_labels_enabled" default:"false"` 30 | } 31 | 32 | func initOTLP(ctx context.Context, cfg Config, logger log.Logger) (func(), error) { 33 | if !cfg.OpenTelemetry.Enabled { 34 | logger.Info("OpenTelemetry monitoring is disabled.") 35 | return noOp, nil 36 | } 37 | resourceOptions := []resource.Option{ 38 | resource.WithFromEnv(), 39 | resource.WithAttributes( 40 | semconv.ServiceName(cfg.AppName), 41 | semconv.ServiceVersion(cfg.AppVersion), 42 | ), 43 | } 44 | if cfg.OpenTelemetry.VerboseResourceLabelsEnabled { 45 | resourceOptions = append(resourceOptions, 46 | resource.WithTelemetrySDK(), 47 | resource.WithOS(), 48 | resource.WithHost(), 49 | resource.WithProcess(), 50 | resource.WithProcessRuntimeName(), 51 | resource.WithProcessRuntimeVersion(), 52 | ) 53 | } 54 | res, err := resource.New(ctx, resourceOptions...) 55 | if err != nil { 56 | return nil, fmt.Errorf("create resource: %w", err) 57 | } 58 | shutdownMetric, err := initGlobalMetrics(ctx, res, cfg.OpenTelemetry, logger) 59 | if err != nil { 60 | return nil, err 61 | } 62 | shutdownTracer, err := initGlobalTracer(ctx, res, cfg.OpenTelemetry, logger) 63 | if err != nil { 64 | shutdownMetric() 65 | return nil, err 66 | } 67 | shutdownProviders := func() { 68 | shutdownTracer() 69 | shutdownMetric() 70 | } 71 | if err := host.Start(); err != nil { 72 | shutdownProviders() 73 | return nil, err 74 | } 75 | if err := runtime.Start(); err != nil { 76 | shutdownProviders() 77 | return nil, err 78 | } 79 | return shutdownProviders, nil 80 | } 81 | func initGlobalMetrics(ctx context.Context, res *resource.Resource, cfg OpenTelemetryConfig, logger log.Logger) (func(), error) { 82 | exporter, err := otlpmetricgrpc.New(ctx, 83 | otlpmetricgrpc.WithEndpoint(cfg.CollectorAddr), 84 | otlpmetricgrpc.WithCompressor(gzip.Name), 85 | otlpmetricgrpc.WithInsecure(), 86 | ) 87 | if err != nil { 88 | return nil, fmt.Errorf("create metric exporter: %w", err) 89 | } 90 | reader := sdkmetric.NewPeriodicReader(exporter, sdkmetric.WithInterval(cfg.PeriodicReadInterval)) 91 | provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader), sdkmetric.WithResource(res)) 92 | otel.SetMeterProvider(provider) 93 | return func() { 94 | shutdownCtx, cancel := context.WithTimeout(context.Background(), gracePeriod) 95 | defer cancel() 96 | if err := provider.Shutdown(shutdownCtx); err != nil { 97 | logger.Error("otlp metric-provider failed to shutdown", "err", err) 98 | } 99 | }, nil 100 | } 101 | func initGlobalTracer(ctx context.Context, res *resource.Resource, cfg OpenTelemetryConfig, logger log.Logger) (func(), error) { 102 | exporter, err := otlptrace.New(ctx, otlptracegrpc.NewClient( 103 | otlptracegrpc.WithEndpoint(cfg.CollectorAddr), 104 | otlptracegrpc.WithInsecure(), 105 | otlptracegrpc.WithCompressor(gzip.Name), 106 | )) 107 | if err != nil { 108 | return nil, fmt.Errorf("create trace exporter: %w", err) 109 | } 110 | tracerProvider := sdktrace.NewTracerProvider( 111 | sdktrace.WithSampler(consistent.ProbabilityBased(cfg.TraceSampleProbability)), 112 | sdktrace.WithResource(res), 113 | sdktrace.WithSpanProcessor(sdktrace.NewBatchSpanProcessor(exporter)), 114 | ) 115 | otel.SetTracerProvider(tracerProvider) 116 | otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator( 117 | propagation.TraceContext{}, propagation.Baggage{}, 118 | )) 119 | return func() { 120 | shutdownCtx, cancel := context.WithTimeout(context.Background(), gracePeriod) 121 | defer cancel() 122 | if err := tracerProvider.Shutdown(shutdownCtx); err != nil { 123 | logger.Error("otlp trace-provider failed to shutdown", "err", err) 124 | } 125 | }, nil 126 | } 127 | func noOp() {} 128 | -------------------------------------------------------------------------------- /telemetry/otelgrpc/otelgrpc.go: -------------------------------------------------------------------------------- 1 | package otelgrpc 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "strings" 7 | "time" 8 | 9 | "github.com/raystack/salt/utils" 10 | "go.opentelemetry.io/otel" 11 | "go.opentelemetry.io/otel/attribute" 12 | "go.opentelemetry.io/otel/metric" 13 | semconv "go.opentelemetry.io/otel/semconv/v1.20.0" 14 | "google.golang.org/grpc" 15 | "google.golang.org/grpc/peer" 16 | "google.golang.org/protobuf/proto" 17 | ) 18 | 19 | type UnaryParams struct { 20 | Start time.Time 21 | Method string 22 | Req any 23 | Res any 24 | Err error 25 | } 26 | type Meter struct { 27 | duration metric.Int64Histogram 28 | requestSize metric.Int64Histogram 29 | responseSize metric.Int64Histogram 30 | attributes []attribute.KeyValue 31 | } 32 | type MeterOpts struct { 33 | meterName string `default:"github.com/raystack/salt/telemetry/otelgrpc"` 34 | } 35 | type Option func(*MeterOpts) 36 | 37 | func WithMeterName(meterName string) Option { 38 | return func(opts *MeterOpts) { 39 | opts.meterName = meterName 40 | } 41 | } 42 | func NewMeter(hostName string, opts ...Option) Meter { 43 | meterOpts := &MeterOpts{} 44 | for _, opt := range opts { 45 | opt(meterOpts) 46 | } 47 | meter := otel.Meter(meterOpts.meterName) 48 | duration, err := meter.Int64Histogram("rpc.client.duration", metric.WithUnit("ms")) 49 | handleOtelErr(err) 50 | requestSize, err := meter.Int64Histogram("rpc.client.request.size", metric.WithUnit("By")) 51 | handleOtelErr(err) 52 | responseSize, err := meter.Int64Histogram("rpc.client.response.size", metric.WithUnit("By")) 53 | handleOtelErr(err) 54 | addr, port := ExtractAddress(hostName) 55 | return Meter{ 56 | duration: duration, 57 | requestSize: requestSize, 58 | responseSize: responseSize, 59 | attributes: []attribute.KeyValue{ 60 | semconv.RPCSystemGRPC, 61 | attribute.String("network.transport", "tcp"), 62 | attribute.String("server.address", addr), 63 | attribute.String("server.port", port), 64 | }, 65 | } 66 | } 67 | func GetProtoSize(p any) int { 68 | if p == nil { 69 | return 0 70 | } 71 | if pm, ok := p.(proto.Message); ok { 72 | return proto.Size(pm) 73 | } 74 | return 0 75 | } 76 | func (m *Meter) RecordUnary(ctx context.Context, p UnaryParams) { 77 | reqSize := GetProtoSize(p.Req) 78 | resSize := GetProtoSize(p.Res) 79 | attrs := make([]attribute.KeyValue, len(m.attributes)) 80 | copy(attrs, m.attributes) 81 | attrs = append(attrs, attribute.String("rpc.grpc.status_text", utils.StatusText(p.Err))) 82 | attrs = append(attrs, attribute.String("network.type", netTypeFromCtx(ctx))) 83 | attrs = append(attrs, ParseFullMethod(p.Method)...) 84 | m.duration.Record(ctx, 85 | time.Since(p.Start).Milliseconds(), 86 | metric.WithAttributes(attrs...)) 87 | m.requestSize.Record(ctx, 88 | int64(reqSize), 89 | metric.WithAttributes(attrs...)) 90 | m.responseSize.Record(ctx, 91 | int64(resSize), 92 | metric.WithAttributes(attrs...)) 93 | } 94 | func (m *Meter) RecordStream(ctx context.Context, start time.Time, method string, err error) { 95 | attrs := make([]attribute.KeyValue, len(m.attributes)) 96 | copy(attrs, m.attributes) 97 | attrs = append(attrs, attribute.String("rpc.grpc.status_text", utils.StatusText(err))) 98 | attrs = append(attrs, attribute.String("network.type", netTypeFromCtx(ctx))) 99 | attrs = append(attrs, ParseFullMethod(method)...) 100 | m.duration.Record(ctx, 101 | time.Since(start).Milliseconds(), 102 | metric.WithAttributes(attrs...)) 103 | } 104 | func (m *Meter) UnaryClientInterceptor() grpc.UnaryClientInterceptor { 105 | return func(ctx context.Context, 106 | method string, 107 | req, reply interface{}, 108 | cc *grpc.ClientConn, 109 | invoker grpc.UnaryInvoker, 110 | opts ...grpc.CallOption, 111 | ) (err error) { 112 | defer func(start time.Time) { 113 | m.RecordUnary(ctx, UnaryParams{ 114 | Start: start, 115 | Req: req, 116 | Res: reply, 117 | Err: err, 118 | }) 119 | }(time.Now()) 120 | return invoker(ctx, method, req, reply, cc, opts...) 121 | } 122 | } 123 | func (m *Meter) StreamClientInterceptor() grpc.StreamClientInterceptor { 124 | return func(ctx context.Context, 125 | desc *grpc.StreamDesc, 126 | cc *grpc.ClientConn, 127 | method string, 128 | streamer grpc.Streamer, 129 | opts ...grpc.CallOption, 130 | ) (s grpc.ClientStream, err error) { 131 | defer func(start time.Time) { 132 | m.RecordStream(ctx, start, method, err) 133 | }(time.Now()) 134 | return streamer(ctx, desc, cc, method, opts...) 135 | } 136 | } 137 | func (m *Meter) GetAttributes() []attribute.KeyValue { 138 | return m.attributes 139 | } 140 | func ParseFullMethod(fullMethod string) []attribute.KeyValue { 141 | name := strings.TrimLeft(fullMethod, "/") 142 | service, method, found := strings.Cut(name, "/") 143 | if !found { 144 | return nil 145 | } 146 | var attrs []attribute.KeyValue 147 | if service != "" { 148 | attrs = append(attrs, semconv.RPCService(service)) 149 | } 150 | if method != "" { 151 | attrs = append(attrs, semconv.RPCMethod(method)) 152 | } 153 | return attrs 154 | } 155 | func handleOtelErr(err error) { 156 | if err != nil { 157 | otel.Handle(err) 158 | } 159 | } 160 | func ExtractAddress(addr string) (host, port string) { 161 | host, port, err := net.SplitHostPort(addr) 162 | if err != nil { 163 | return addr, "80" 164 | } 165 | return host, port 166 | } 167 | func netTypeFromCtx(ctx context.Context) (ipType string) { 168 | ipType = "unknown" 169 | p, ok := peer.FromContext(ctx) 170 | if !ok { 171 | return ipType 172 | } 173 | clientIP, _, err := net.SplitHostPort(p.Addr.String()) 174 | if err != nil { 175 | return ipType 176 | } 177 | ip := net.ParseIP(clientIP) 178 | if ip.To4() != nil { 179 | ipType = "ipv4" 180 | } else if ip.To16() != nil { 181 | ipType = "ipv6" 182 | } 183 | return ipType 184 | } 185 | -------------------------------------------------------------------------------- /telemetry/otelgrpc/otelgrpc_test.go: -------------------------------------------------------------------------------- 1 | package otelgrpc_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/raystack/salt/telemetry/otelgrpc" 8 | "github.com/stretchr/testify/assert" 9 | "go.opentelemetry.io/otel/attribute" 10 | semconv "go.opentelemetry.io/otel/semconv/v1.20.0" 11 | ) 12 | 13 | func Test_parseFullMethod(t *testing.T) { 14 | type args struct { 15 | fullMethod string 16 | } 17 | tests := []struct { 18 | name string 19 | args args 20 | want []attribute.KeyValue 21 | }{ 22 | {name: "should parse correct method", args: args{ 23 | fullMethod: "/test.service.name/MethodNameV1", 24 | }, want: []attribute.KeyValue{ 25 | semconv.RPCService("test.service.name"), 26 | semconv.RPCMethod("MethodNameV1"), 27 | }}, 28 | {name: "should return empty attributes on incorrect method", args: args{ 29 | fullMethod: "incorrectMethod", 30 | }, want: nil}, 31 | } 32 | for _, tt := range tests { 33 | t.Run(tt.name, func(t *testing.T) { 34 | if got := otelgrpc.ParseFullMethod(tt.args.fullMethod); !reflect.DeepEqual(got, tt.want) { 35 | t.Errorf("parseFullMethod() = %v, want %v", got, tt.want) 36 | } 37 | }) 38 | } 39 | } 40 | 41 | func TestExtractAddress(t *testing.T) { 42 | gotHost, gotPort := otelgrpc.ExtractAddress("localhost:1001") 43 | assert.Equal(t, "localhost", gotHost) 44 | assert.Equal(t, "1001", gotPort) 45 | gotHost, gotPort = otelgrpc.ExtractAddress("localhost") 46 | assert.Equal(t, "localhost", gotHost) 47 | assert.Equal(t, "80", gotPort) 48 | gotHost, gotPort = otelgrpc.ExtractAddress("some.address.golabs.io:15010") 49 | assert.Equal(t, "some.address.golabs.io", gotHost) 50 | assert.Equal(t, "15010", gotPort) 51 | } 52 | -------------------------------------------------------------------------------- /telemetry/otelhhtpclient/annotations.go: -------------------------------------------------------------------------------- 1 | package otelhttpclient 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" 8 | "go.opentelemetry.io/otel/attribute" 9 | ) 10 | 11 | type labelerContextKeyType int 12 | 13 | const lablelerContextKey labelerContextKeyType = 0 14 | 15 | // AnnotateRequest adds telemetry related annotations to request context and returns. 16 | // The request context on the returned request should be retained. 17 | // Ensure `route` is a route template and not actual URL to prevent high cardinality 18 | // on the metrics. 19 | func AnnotateRequest(req *http.Request, route string) *http.Request { 20 | ctx := req.Context() 21 | l := &otelhttp.Labeler{} 22 | l.Add(attribute.String(attributeHTTPRoute, route)) 23 | return req.WithContext(context.WithValue(ctx, lablelerContextKey, l)) 24 | } 25 | 26 | // LabelerFromContext returns the labeler annotation from the context if exists. 27 | func LabelerFromContext(ctx context.Context) (*otelhttp.Labeler, bool) { 28 | l, ok := ctx.Value(lablelerContextKey).(*otelhttp.Labeler) 29 | if !ok { 30 | l = &otelhttp.Labeler{} 31 | } 32 | return l, ok 33 | } 34 | -------------------------------------------------------------------------------- /telemetry/otelhhtpclient/http_transport.go: -------------------------------------------------------------------------------- 1 | package otelhttpclient 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "time" 8 | 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | "go.opentelemetry.io/otel/metric" 12 | ) 13 | 14 | // Refer OpenTelemetry Semantic Conventions for HTTP Client. 15 | // https://github.com/open-telemetry/semantic-conventions/blob/main/docs/http/http-metrics.md#http-client 16 | const ( 17 | metricClientDuration = "http.client.duration" 18 | metricClientRequestSize = "http.client.request.size" 19 | metricClientResponseSize = "http.client.response.size" 20 | attributeNetProtoName = "network.protocol.name" 21 | attributeNetProtoVersion = "network.protocol.release" 22 | attributeServerPort = "server.port" 23 | attributeServerAddress = "server.address" 24 | attributeHTTPRoute = "http.route" 25 | attributeRequestMethod = "http.request.method" 26 | attributeResponseStatusCode = "http.response.status_code" 27 | ) 28 | 29 | type httpTransport struct { 30 | roundTripper http.RoundTripper 31 | metricClientDuration metric.Float64Histogram 32 | metricClientRequestSize metric.Int64Counter 33 | metricClientResponseSize metric.Int64Counter 34 | } 35 | 36 | func NewHTTPTransport(baseTransport http.RoundTripper) http.RoundTripper { 37 | if _, ok := baseTransport.(*httpTransport); ok { 38 | return baseTransport 39 | } 40 | if baseTransport == nil { 41 | baseTransport = http.DefaultTransport 42 | } 43 | icl := &httpTransport{roundTripper: baseTransport} 44 | icl.createMeasures(otel.Meter("github.com/raystack/salt/telemetry/otehttpclient")) 45 | return icl 46 | } 47 | func (tr *httpTransport) RoundTrip(req *http.Request) (*http.Response, error) { 48 | ctx := req.Context() 49 | startAt := time.Now() 50 | labeler, _ := LabelerFromContext(req.Context()) 51 | var bw bodyWrapper 52 | if req.Body != nil && req.Body != http.NoBody { 53 | bw.ReadCloser = req.Body 54 | req.Body = &bw 55 | } 56 | port := req.URL.Port() 57 | if port == "" { 58 | port = "80" 59 | if req.URL.Scheme == "https" { 60 | port = "443" 61 | } 62 | } 63 | attribs := append(labeler.Get(), 64 | attribute.String(attributeNetProtoName, "http"), 65 | attribute.String(attributeRequestMethod, req.Method), 66 | attribute.String(attributeServerAddress, req.URL.Hostname()), 67 | attribute.String(attributeServerPort, port), 68 | ) 69 | resp, err := tr.roundTripper.RoundTrip(req) 70 | if err != nil { 71 | attribs = append(attribs, 72 | attribute.Int(attributeResponseStatusCode, 0), 73 | attribute.String(attributeNetProtoVersion, fmt.Sprintf("%d.%d", req.ProtoMajor, req.ProtoMinor)), 74 | ) 75 | } else { 76 | attribs = append(attribs, 77 | attribute.Int(attributeResponseStatusCode, resp.StatusCode), 78 | attribute.String(attributeNetProtoVersion, fmt.Sprintf("%d.%d", resp.ProtoMajor, resp.ProtoMinor)), 79 | ) 80 | } 81 | elapsedTime := float64(time.Since(startAt)) / float64(time.Millisecond) 82 | withAttribs := metric.WithAttributes(attribs...) 83 | tr.metricClientDuration.Record(ctx, elapsedTime, withAttribs) 84 | tr.metricClientRequestSize.Add(ctx, int64(bw.read), withAttribs) 85 | if resp != nil { 86 | tr.metricClientResponseSize.Add(ctx, resp.ContentLength, withAttribs) 87 | } 88 | return resp, err 89 | } 90 | func (tr *httpTransport) createMeasures(meter metric.Meter) { 91 | var err error 92 | tr.metricClientRequestSize, err = meter.Int64Counter(metricClientRequestSize) 93 | handleErr(err) 94 | tr.metricClientResponseSize, err = meter.Int64Counter(metricClientResponseSize) 95 | handleErr(err) 96 | tr.metricClientDuration, err = meter.Float64Histogram(metricClientDuration) 97 | handleErr(err) 98 | } 99 | func handleErr(err error) { 100 | if err != nil { 101 | otel.Handle(err) 102 | } 103 | } 104 | 105 | // bodyWrapper wraps a http.Request.Body (an io.ReadCloser) to track the number 106 | // of bytes read and the last error. 107 | type bodyWrapper struct { 108 | io.ReadCloser 109 | read int 110 | err error 111 | } 112 | 113 | func (w *bodyWrapper) Read(b []byte) (int, error) { 114 | n, err := w.ReadCloser.Read(b) 115 | w.read += n 116 | w.err = err 117 | return n, err 118 | } 119 | func (w *bodyWrapper) Close() error { 120 | return w.ReadCloser.Close() 121 | } 122 | -------------------------------------------------------------------------------- /telemetry/otelhhtpclient/http_transport_test.go: -------------------------------------------------------------------------------- 1 | package otelhttpclient_test 2 | 3 | import ( 4 | "testing" 5 | 6 | otelhttpclient "github.com/raystack/salt/telemetry/otelhhtpclient" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestNewHTTPTransport(t *testing.T) { 11 | tr := otelhttpclient.NewHTTPTransport(nil) 12 | assert.NotNil(t, tr) 13 | } 14 | -------------------------------------------------------------------------------- /telemetry/telemetry.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/raystack/salt/log" 8 | ) 9 | 10 | const gracePeriod = 5 * time.Second 11 | 12 | type Config struct { 13 | AppVersion string 14 | AppName string `yaml:"app_name" mapstructure:"app_name" default:"service"` 15 | OpenTelemetry OpenTelemetryConfig `yaml:"open_telemetry" mapstructure:"open_telemetry"` 16 | } 17 | 18 | func Init(ctx context.Context, cfg Config, logger log.Logger) (cleanUp func(), err error) { 19 | shutdown, err := initOTLP(ctx, cfg, logger) 20 | if err != nil { 21 | return noOp, err 22 | } 23 | return shutdown, nil 24 | } 25 | -------------------------------------------------------------------------------- /testing/dockertestx/README.md: -------------------------------------------------------------------------------- 1 | # dockertestx 2 | 3 | This package is an abstraction of several dockerized data storages using `ory/dockertest` to bootstrap a specific dockerized instance. 4 | 5 | Example postgres 6 | 7 | ```go 8 | // create postgres instance 9 | pgDocker, err := dockertest.CreatePostgres( 10 | dockertest.PostgresWithDetail( 11 | pgUser, pgPass, pgDBName, 12 | ), 13 | ) 14 | 15 | // get connection string 16 | connString := pgDocker.GetExternalConnString() 17 | 18 | // purge docker 19 | if err := pgDocker.GetPool().Purge(pgDocker.GetResouce()); err != nil { 20 | return fmt.Errorf("could not purge resource: %w", err) 21 | } 22 | ``` 23 | 24 | Example spice db 25 | 26 | - bootsrap spice db with postgres and wire them internally via network bridge 27 | 28 | ```go 29 | // create custom pool 30 | pool, err := dockertest.NewPool("") 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | // create a bridge network for testing 36 | network, err = pool.Client.CreateNetwork(docker.CreateNetworkOptions{ 37 | Name: fmt.Sprintf("bridge-%s", uuid.New().String()), 38 | }) 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | 44 | // create postgres instance 45 | pgDocker, err := dockertest.CreatePostgres( 46 | dockertest.PostgresWithDockerPool(pool), 47 | dockertest.PostgresWithDockertestNetwork(network), 48 | dockertest.PostgresWithDetail( 49 | pgUser, pgPass, pgDBName, 50 | ), 51 | ) 52 | 53 | // get connection string 54 | connString := pgDocker.GetInternalConnString() 55 | 56 | // create spice db instance 57 | spiceDocker, err := dockertest.CreateSpiceDB(connString, 58 | dockertest.SpiceDBWithDockerPool(pool), 59 | dockertest.SpiceDBWithDockertestNetwork(network), 60 | ) 61 | 62 | if err := dockertest.MigrateSpiceDB(connString, 63 | dockertest.MigrateSpiceDBWithDockerPool(pool), 64 | dockertest.MigrateSpiceDBWithDockertestNetwork(network), 65 | ); err != nil { 66 | return err 67 | } 68 | 69 | // purge docker resources 70 | if err := pool.Purge(spiceDocker.GetResouce()); err != nil { 71 | return fmt.Errorf("could not purge resource: %w", err) 72 | } 73 | if err := pool.Purge(pgDocker.GetResouce()); err != nil { 74 | return fmt.Errorf("could not purge resource: %w", err) 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /testing/dockertestx/configs/cortex/single_process_cortex.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for running Cortex in single-process mode. 2 | # This configuration should not be used in production. 3 | # It is only for getting started and development. 4 | 5 | # Disable the requirement that every request to Cortex has a 6 | # X-Scope-OrgID header. `fake` will be substituted in instead. 7 | auth_enabled: false 8 | 9 | server: 10 | http_listen_port: 9009 11 | 12 | # Configure the server to allow messages up to 100MB. 13 | grpc_server_max_recv_msg_size: 104857600 14 | grpc_server_max_send_msg_size: 104857600 15 | grpc_server_max_concurrent_streams: 1000 16 | 17 | distributor: 18 | shard_by_all_labels: true 19 | pool: 20 | health_check_ingesters: true 21 | 22 | ingester_client: 23 | grpc_client_config: 24 | # Configure the client to allow messages up to 100MB. 25 | max_recv_msg_size: 104857600 26 | max_send_msg_size: 104857600 27 | grpc_compression: gzip 28 | 29 | ingester: 30 | # We want our ingesters to flush chunks at the same time to optimise 31 | # deduplication opportunities. 32 | spread_flushes: true 33 | chunk_age_jitter: 0 34 | 35 | walconfig: 36 | wal_enabled: true 37 | recover_from_wal: true 38 | wal_dir: /tmp/cortex/wal 39 | 40 | lifecycler: 41 | # The address to advertise for this ingester. Will be autodiscovered by 42 | # looking up address on eth0 or en0; can be specified if this fails. 43 | # address: 127.0.0.1 44 | 45 | # We want to start immediately and flush on shutdown. 46 | join_after: 0 47 | min_ready_duration: 0s 48 | final_sleep: 0s 49 | num_tokens: 512 50 | tokens_file_path: /tmp/cortex/wal/tokens 51 | 52 | # Use an in memory ring store, so we don't need to launch a Consul. 53 | ring: 54 | kvstore: 55 | store: inmemory 56 | replication_factor: 1 57 | 58 | # Use local storage - BoltDB for the index, and the filesystem 59 | # for the chunks. 60 | schema: 61 | configs: 62 | - from: 2019-07-29 63 | store: boltdb 64 | object_store: filesystem 65 | schema: v10 66 | index: 67 | prefix: index_ 68 | period: 1w 69 | 70 | storage: 71 | boltdb: 72 | directory: /tmp/cortex/index 73 | 74 | filesystem: 75 | directory: /tmp/cortex/chunks 76 | 77 | delete_store: 78 | store: boltdb 79 | 80 | purger: 81 | object_store_type: filesystem 82 | 83 | frontend_worker: 84 | # Configure the frontend worker in the querier to match worker count 85 | # to max_concurrent on the queriers. 86 | match_max_concurrent: true 87 | 88 | # Configure the ruler to scan the /tmp/cortex/rules directory for prometheus 89 | # rules: https://prometheus.io/docs/prometheus/latest/configuration/recording_rules/#recording-rules 90 | ruler: 91 | enable_api: true 92 | enable_sharding: false 93 | # alertmanager_url: http://cortex-am:9009/api/prom/alertmanager/ 94 | rule_path: /tmp/cortex/rules 95 | storage: 96 | type: s3 97 | s3: 98 | # endpoint: http://minio1:9000 99 | bucketnames: cortex 100 | secret_access_key: minio123 101 | access_key_id: minio 102 | s3forcepathstyle: true 103 | 104 | alertmanager: 105 | enable_api: true 106 | sharding_enabled: false 107 | data_dir: data/ 108 | external_url: /api/prom/alertmanager 109 | storage: 110 | type: s3 111 | s3: 112 | # endpoint: http://minio1:9000 113 | bucketnames: cortex 114 | secret_access_key: minio123 115 | access_key_id: minio 116 | s3forcepathstyle: true 117 | 118 | alertmanager_storage: 119 | backend: local 120 | local: 121 | path: tmp/cortex/alertmanager 122 | -------------------------------------------------------------------------------- /testing/dockertestx/configs/nginx/cortex_nginx.conf: -------------------------------------------------------------------------------- 1 | worker_processes 1; 2 | error_log /dev/stderr; 3 | pid /tmp/nginx.pid; 4 | worker_rlimit_nofile 8192; 5 | 6 | events { 7 | worker_connections 1024; 8 | } 9 | 10 | 11 | http { 12 | client_max_body_size 5M; 13 | default_type application/octet-stream; 14 | log_format main '$remote_addr - $remote_user [$time_local] $status ' 15 | '"$request" $body_bytes_sent "$http_referer" ' 16 | '"$http_user_agent" "$http_x_forwarded_for" $http_x_scope_orgid'; 17 | access_log /dev/stderr main; 18 | sendfile on; 19 | tcp_nopush on; 20 | resolver 127.0.0.11 ipv6=off; 21 | 22 | server { 23 | listen {{.ExposedPort}}; 24 | proxy_connect_timeout 300s; 25 | proxy_send_timeout 300s; 26 | proxy_read_timeout 300s; 27 | proxy_http_version 1.1; 28 | 29 | location = /healthz { 30 | return 200 'alive'; 31 | } 32 | 33 | # Distributor Config 34 | location = /ring { 35 | proxy_pass http://{{.RulerHost}}$request_uri; 36 | } 37 | 38 | location = /all_user_stats { 39 | proxy_pass http://{{.RulerHost}}$request_uri; 40 | } 41 | 42 | location = /api/prom/push { 43 | proxy_pass http://{{.RulerHost}}$request_uri; 44 | } 45 | 46 | ## New Remote write API. Ref: https://cortexmetrics.io/docs/api/#remote-write 47 | location = /api/v1/push { 48 | proxy_pass http://{{.RulerHost}}$request_uri; 49 | } 50 | 51 | 52 | # Alertmanager Config 53 | location ~ /api/prom/alertmanager/.* { 54 | proxy_pass http://{{.AlertManagerHost}}$request_uri; 55 | } 56 | 57 | location ~ /api/v1/alerts { 58 | proxy_pass http://{{.AlertManagerHost}}$request_uri; 59 | } 60 | 61 | location ~ /multitenant_alertmanager/status { 62 | proxy_pass http://{{.AlertManagerHost}}$request_uri; 63 | } 64 | 65 | # Ruler Config 66 | location ~ /api/v1/rules { 67 | proxy_pass http://{{.RulerHost}}$request_uri; 68 | } 69 | 70 | location ~ /ruler/ring { 71 | proxy_pass http://{{.RulerHost}}$request_uri; 72 | } 73 | 74 | # Config Config 75 | location ~ /api/prom/configs/.* { 76 | proxy_pass http://{{.RulerHost}}$request_uri; 77 | } 78 | 79 | # Query Config 80 | location ~ /api/prom/.* { 81 | proxy_pass http://{{.RulerHost}}$request_uri; 82 | } 83 | 84 | ## New Query frontend APIs as per https://cortexmetrics.io/docs/api/#querier--query-frontend 85 | location ~ ^/prometheus/api/v1/(read|metadata|labels|series|query_range|query) { 86 | proxy_pass http://{{.RulerHost}}$request_uri; 87 | } 88 | 89 | location ~ /prometheus/api/v1/label/.* { 90 | proxy_pass http://{{.RulerHost}}$request_uri; 91 | } 92 | } 93 | } -------------------------------------------------------------------------------- /testing/dockertestx/cortex.go: -------------------------------------------------------------------------------- 1 | package dockertestx 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "os" 7 | "path" 8 | "runtime" 9 | "time" 10 | 11 | "github.com/google/uuid" 12 | "github.com/ory/dockertest/v3" 13 | "github.com/ory/dockertest/v3/docker" 14 | ) 15 | 16 | type dockerCortexOption func(dc *dockerCortex) 17 | 18 | // CortexWithDockertestNetwork is an option to assign docker network 19 | func CortexWithDockertestNetwork(network *dockertest.Network) dockerCortexOption { 20 | return func(dc *dockerCortex) { 21 | dc.network = network 22 | } 23 | } 24 | 25 | // CortexWithDockertestNetwork is an option to assign release tag 26 | // of a `quay.io/cortexproject/cortex` image 27 | func CortexWithVersionTag(versionTag string) dockerCortexOption { 28 | return func(dc *dockerCortex) { 29 | dc.versionTag = versionTag 30 | } 31 | } 32 | 33 | // CortexWithDockerPool is an option to assign docker pool 34 | func CortexWithDockerPool(pool *dockertest.Pool) dockerCortexOption { 35 | return func(dc *dockerCortex) { 36 | dc.pool = pool 37 | } 38 | } 39 | 40 | // CortexWithModule is an option to assign cortex module name 41 | // e.g. all, alertmanager, querier, etc 42 | func CortexWithModule(moduleName string) dockerCortexOption { 43 | return func(dc *dockerCortex) { 44 | dc.moduleName = moduleName 45 | } 46 | } 47 | 48 | // CortexWithAlertmanagerURL is an option to assign alertmanager url 49 | func CortexWithAlertmanagerURL(amURL string) dockerCortexOption { 50 | return func(dc *dockerCortex) { 51 | dc.alertManagerURL = amURL 52 | } 53 | } 54 | 55 | // CortexWithS3Endpoint is an option to assign external s3/minio storage 56 | func CortexWithS3Endpoint(s3URL string) dockerCortexOption { 57 | return func(dc *dockerCortex) { 58 | dc.s3URL = s3URL 59 | } 60 | } 61 | 62 | type dockerCortex struct { 63 | network *dockertest.Network 64 | pool *dockertest.Pool 65 | moduleName string 66 | alertManagerURL string 67 | s3URL string 68 | internalHost string 69 | externalHost string 70 | versionTag string 71 | dockertestResource *dockertest.Resource 72 | } 73 | 74 | // CreateCortex is a function to create a dockerized single-process cortex with 75 | // s3/minio as the backend storage 76 | func CreateCortex(opts ...dockerCortexOption) (*dockerCortex, error) { 77 | var ( 78 | err error 79 | dc = &dockerCortex{} 80 | ) 81 | 82 | for _, opt := range opts { 83 | opt(dc) 84 | } 85 | 86 | name := fmt.Sprintf("cortex-%s", uuid.New().String()) 87 | 88 | if dc.pool == nil { 89 | dc.pool, err = dockertest.NewPool("") 90 | if err != nil { 91 | return nil, fmt.Errorf("could not create dockertest pool: %w", err) 92 | } 93 | } 94 | 95 | if dc.versionTag == "" { 96 | dc.versionTag = "master-63703f5" 97 | } 98 | 99 | if dc.moduleName == "" { 100 | dc.moduleName = "all" 101 | } 102 | 103 | runOpts := &dockertest.RunOptions{ 104 | Name: name, 105 | Repository: "quay.io/cortexproject/cortex", 106 | Tag: dc.versionTag, 107 | Env: []string{ 108 | "minio_host=siren_nginx_1", 109 | }, 110 | Cmd: []string{ 111 | fmt.Sprintf("-target=%s", dc.moduleName), 112 | "-config.file=/etc/single-process-config.yaml", 113 | fmt.Sprintf("-ruler.storage.s3.endpoint=%s", dc.s3URL), 114 | fmt.Sprintf("-ruler.alertmanager-url=%s", dc.alertManagerURL), 115 | fmt.Sprintf("-alertmanager.storage.s3.endpoint=%s", dc.s3URL), 116 | }, 117 | ExposedPorts: []string{"9009/tcp"}, 118 | ExtraHosts: []string{ 119 | "cortex.siren_nginx_1:127.0.0.1", 120 | }, 121 | } 122 | 123 | if dc.network != nil { 124 | runOpts.NetworkID = dc.network.Network.ID 125 | } 126 | 127 | pwd, err := os.Getwd() 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | var ( 133 | rulesFolder = fmt.Sprintf("%s/tmp/dockertest-configs/cortex/rules", pwd) 134 | alertManagerFolder = fmt.Sprintf("%s/tmp/dockertest-configs/cortex/alertmanager", pwd) 135 | ) 136 | 137 | foldersPath := []string{rulesFolder, alertManagerFolder} 138 | for _, fp := range foldersPath { 139 | if _, err := os.Stat(fp); os.IsNotExist(err) { 140 | if err := os.MkdirAll(fp, 0777); err != nil { 141 | return nil, err 142 | } 143 | } 144 | } 145 | 146 | _, thisFileName, _, ok := runtime.Caller(0) 147 | if !ok { 148 | return nil, err 149 | } 150 | thisFileFolder := path.Dir(thisFileName) 151 | 152 | dc.dockertestResource, err = dc.pool.RunWithOptions( 153 | runOpts, 154 | func(config *docker.HostConfig) { 155 | config.RestartPolicy = docker.RestartPolicy{ 156 | Name: "no", 157 | } 158 | config.Mounts = []docker.HostMount{ 159 | { 160 | Target: "/etc/single-process-config.yaml", 161 | Source: fmt.Sprintf("%s/configs/cortex/single_process_cortex.yaml", thisFileFolder), 162 | Type: "bind", 163 | }, 164 | { 165 | Target: "/tmp/cortex/rules", 166 | Source: rulesFolder, 167 | Type: "bind", 168 | }, 169 | { 170 | Target: "/tmp/cortex/alertmanager", 171 | Source: alertManagerFolder, 172 | Type: "bind", 173 | }, 174 | } 175 | }, 176 | ) 177 | if err != nil { 178 | return nil, err 179 | } 180 | 181 | externalPort := dc.dockertestResource.GetPort("9009/tcp") 182 | dc.internalHost = fmt.Sprintf("%s:9009", name) 183 | dc.externalHost = fmt.Sprintf("localhost:%s", externalPort) 184 | 185 | if err = dc.dockertestResource.Expire(120); err != nil { 186 | return nil, err 187 | } 188 | 189 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet 190 | dc.pool.MaxWait = 60 * time.Second 191 | 192 | if err = dc.pool.Retry(func() error { 193 | httpClient := &http.Client{} 194 | res, err := httpClient.Get(fmt.Sprintf("http://localhost:%s/config", externalPort)) 195 | if err != nil { 196 | return err 197 | } 198 | 199 | if res.StatusCode != 200 { 200 | return fmt.Errorf("cortex server return status %d", res.StatusCode) 201 | } 202 | 203 | return nil 204 | }); err != nil { 205 | err = fmt.Errorf("could not connect to docker: %w", err) 206 | return nil, fmt.Errorf("could not connect to docker: %w", err) 207 | } 208 | 209 | return dc, nil 210 | } 211 | 212 | // GetInternalHost returns internal hostname and port 213 | // e.g. internal-xxxxxx:8080 214 | func (dc *dockerCortex) GetInternalHost() string { 215 | return dc.internalHost 216 | } 217 | 218 | // GetExternalHost returns localhost and port 219 | // e.g. localhost:51113 220 | func (dc *dockerCortex) GetExternalHost() string { 221 | return dc.externalHost 222 | } 223 | 224 | // GetPool returns docker pool 225 | func (dc *dockerCortex) GetPool() *dockertest.Pool { 226 | return dc.pool 227 | } 228 | 229 | // GetResource returns docker resource 230 | func (dc *dockerCortex) GetResource() *dockertest.Resource { 231 | return dc.dockertestResource 232 | } 233 | -------------------------------------------------------------------------------- /testing/dockertestx/dockertestx.go: -------------------------------------------------------------------------------- 1 | package dockertestx 2 | 3 | import "runtime" 4 | 5 | func DockerHostAddress() string { 6 | var dockerHostInternal = "host-gateway" // linux by default 7 | if runtime.GOOS == "darwin" { 8 | dockerHostInternal = "host.docker.internal" 9 | } 10 | return dockerHostInternal 11 | } 12 | -------------------------------------------------------------------------------- /testing/dockertestx/minio.go: -------------------------------------------------------------------------------- 1 | package dockertestx 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | "github.com/ory/dockertest/v3" 10 | "github.com/ory/dockertest/v3/docker" 11 | ) 12 | 13 | const ( 14 | defaultMinioRootUser = "minio" 15 | defaultMinioRootPassword = "minio123" 16 | defaultMinioDomain = "localhost" 17 | ) 18 | 19 | type dockerMinioOption func(dm *dockerMinio) 20 | 21 | // MinioWithDockertestNetwork is an option to assign docker network 22 | func MinioWithDockertestNetwork(network *dockertest.Network) dockerMinioOption { 23 | return func(dm *dockerMinio) { 24 | dm.network = network 25 | } 26 | } 27 | 28 | // MinioWithVersionTag is an option to assign release tag 29 | // of a `quay.io/minio/minio` image 30 | func MinioWithVersionTag(versionTag string) dockerMinioOption { 31 | return func(dm *dockerMinio) { 32 | dm.versionTag = versionTag 33 | } 34 | } 35 | 36 | // MinioWithDockerPool is an option to assign docker pool 37 | func MinioWithDockerPool(pool *dockertest.Pool) dockerMinioOption { 38 | return func(dm *dockerMinio) { 39 | dm.pool = pool 40 | } 41 | } 42 | 43 | type dockerMinio struct { 44 | network *dockertest.Network 45 | pool *dockertest.Pool 46 | rootUser string 47 | rootPassword string 48 | domain string 49 | versionTag string 50 | internalHost string 51 | externalHost string 52 | externalConsoleHost string 53 | dockertestResource *dockertest.Resource 54 | } 55 | 56 | // CreateMinio creates a minio instance with default configurations 57 | func CreateMinio(opts ...dockerMinioOption) (*dockerMinio, error) { 58 | var ( 59 | err error 60 | dm = &dockerMinio{} 61 | ) 62 | 63 | for _, opt := range opts { 64 | opt(dm) 65 | } 66 | 67 | name := fmt.Sprintf("minio-%s", uuid.New().String()) 68 | 69 | if dm.pool == nil { 70 | dm.pool, err = dockertest.NewPool("") 71 | if err != nil { 72 | return nil, fmt.Errorf("could not create dockertest pool: %w", err) 73 | } 74 | } 75 | 76 | if dm.rootUser == "" { 77 | dm.rootUser = defaultMinioRootUser 78 | } 79 | 80 | if dm.rootPassword == "" { 81 | dm.rootPassword = defaultMinioRootPassword 82 | } 83 | 84 | if dm.domain == "" { 85 | dm.domain = defaultMinioDomain 86 | } 87 | 88 | if dm.versionTag == "" { 89 | dm.versionTag = "RELEASE.2022-09-07T22-25-02Z" 90 | } 91 | 92 | runOpts := &dockertest.RunOptions{ 93 | Name: name, 94 | Repository: "quay.io/minio/minio", 95 | Tag: dm.versionTag, 96 | Env: []string{ 97 | "MINIO_ROOT_USER=" + dm.rootUser, 98 | "MINIO_ROOT_PASSWORD=" + dm.rootPassword, 99 | "MINIO_DOMAIN=" + dm.domain, 100 | }, 101 | Cmd: []string{"server", "/data1", "--console-address", ":9001"}, 102 | ExposedPorts: []string{"9000/tcp", "9001/tcp"}, 103 | } 104 | 105 | if dm.network != nil { 106 | runOpts.NetworkID = dm.network.Network.ID 107 | } 108 | 109 | dm.dockertestResource, err = dm.pool.RunWithOptions( 110 | runOpts, 111 | func(config *docker.HostConfig) { 112 | config.RestartPolicy = docker.RestartPolicy{ 113 | Name: "no", 114 | } 115 | }, 116 | ) 117 | if err != nil { 118 | return nil, err 119 | } 120 | 121 | minioPort := dm.dockertestResource.GetPort("9000/tcp") 122 | minioConsolePort := dm.dockertestResource.GetPort("9001/tcp") 123 | 124 | dm.internalHost = fmt.Sprintf("%s:%s", name, "9000") 125 | dm.externalHost = fmt.Sprintf("%s:%s", "localhost", minioPort) 126 | dm.externalConsoleHost = fmt.Sprintf("%s:%s", "localhost", minioConsolePort) 127 | 128 | if err = dm.dockertestResource.Expire(120); err != nil { 129 | return nil, err 130 | } 131 | 132 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet 133 | dm.pool.MaxWait = 60 * time.Second 134 | 135 | if err = dm.pool.Retry(func() error { 136 | httpClient := &http.Client{} 137 | res, err := httpClient.Get(fmt.Sprintf("http://localhost:%s/minio/health/live", minioPort)) 138 | if err != nil { 139 | return err 140 | } 141 | 142 | if res.StatusCode != 200 { 143 | return fmt.Errorf("minio server return status %d", res.StatusCode) 144 | } 145 | 146 | return nil 147 | }); err != nil { 148 | err = fmt.Errorf("could not connect to docker: %w", err) 149 | return nil, fmt.Errorf("could not connect to docker: %w", err) 150 | } 151 | 152 | return dm, nil 153 | } 154 | 155 | func (dm *dockerMinio) GetInternalHost() string { 156 | return dm.internalHost 157 | } 158 | 159 | func (dm *dockerMinio) GetExternalHost() string { 160 | return dm.externalHost 161 | } 162 | 163 | func (dm *dockerMinio) GetExternalConsoleHost() string { 164 | return dm.externalConsoleHost 165 | } 166 | 167 | // GetPool returns docker pool 168 | func (dm *dockerMinio) GetPool() *dockertest.Pool { 169 | return dm.pool 170 | } 171 | 172 | // GetResource returns docker resource 173 | func (dm *dockerMinio) GetResource() *dockertest.Resource { 174 | return dm.dockertestResource 175 | } 176 | -------------------------------------------------------------------------------- /testing/dockertestx/minio_migrate.go: -------------------------------------------------------------------------------- 1 | package dockertestx 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/ory/dockertest/v3" 10 | "github.com/ory/dockertest/v3/docker" 11 | ) 12 | 13 | const waitContainerTimeout = 60 * time.Second 14 | 15 | type dockerMigrateMinioOption func(dmm *dockerMigrateMinio) 16 | 17 | // MigrateMinioWithDockertestNetwork is an option to assign docker network 18 | func MigrateMinioWithDockertestNetwork(network *dockertest.Network) dockerMigrateMinioOption { 19 | return func(dm *dockerMigrateMinio) { 20 | dm.network = network 21 | } 22 | } 23 | 24 | // MigrateMinioWithVersionTag is an option to assign release tag 25 | // of a `minio/mc` image 26 | func MigrateMinioWithVersionTag(versionTag string) dockerMigrateMinioOption { 27 | return func(dm *dockerMigrateMinio) { 28 | dm.versionTag = versionTag 29 | } 30 | } 31 | 32 | // MigrateMinioWithDockerPool is an option to assign docker pool 33 | func MigrateMinioWithDockerPool(pool *dockertest.Pool) dockerMigrateMinioOption { 34 | return func(dm *dockerMigrateMinio) { 35 | dm.pool = pool 36 | } 37 | } 38 | 39 | type dockerMigrateMinio struct { 40 | network *dockertest.Network 41 | pool *dockertest.Pool 42 | versionTag string 43 | } 44 | 45 | // MigrateMinio does migration of a `bucketName` to a minio located in `minioHost` 46 | func MigrateMinio(minioHost string, bucketName string, opts ...dockerMigrateMinioOption) error { 47 | var ( 48 | err error 49 | dm = &dockerMigrateMinio{} 50 | ) 51 | 52 | for _, opt := range opts { 53 | opt(dm) 54 | } 55 | 56 | if dm.pool == nil { 57 | dm.pool, err = dockertest.NewPool("") 58 | if err != nil { 59 | return fmt.Errorf("could not create dockertest pool: %w", err) 60 | } 61 | } 62 | 63 | if dm.versionTag == "" { 64 | dm.versionTag = "RELEASE.2022-08-28T20-08-11Z" 65 | } 66 | 67 | runOpts := &dockertest.RunOptions{ 68 | Repository: "minio/mc", 69 | Tag: dm.versionTag, 70 | Entrypoint: []string{ 71 | "bin/sh", 72 | "-c", 73 | fmt.Sprintf(` 74 | /usr/bin/mc alias set myminio http://%s minio minio123; 75 | /usr/bin/mc rm -r --force %s; 76 | /usr/bin/mc mb myminio/%s; 77 | `, minioHost, bucketName, bucketName), 78 | }, 79 | } 80 | 81 | if dm.network != nil { 82 | runOpts.NetworkID = dm.network.Network.ID 83 | } 84 | 85 | resource, err := dm.pool.RunWithOptions(runOpts, func(config *docker.HostConfig) { 86 | config.RestartPolicy = docker.RestartPolicy{ 87 | Name: "no", 88 | } 89 | }) 90 | if err != nil { 91 | return err 92 | } 93 | 94 | if err := resource.Expire(120); err != nil { 95 | return err 96 | } 97 | 98 | waitCtx, cancel := context.WithTimeout(context.Background(), waitContainerTimeout) 99 | defer cancel() 100 | 101 | // Ensure the command completed successfully. 102 | status, err := dm.pool.Client.WaitContainerWithContext(resource.Container.ID, waitCtx) 103 | if err != nil { 104 | return err 105 | } 106 | 107 | if status != 0 { 108 | stream := new(bytes.Buffer) 109 | 110 | if err = dm.pool.Client.Logs(docker.LogsOptions{ 111 | Context: waitCtx, 112 | OutputStream: stream, 113 | ErrorStream: stream, 114 | Stdout: true, 115 | Stderr: true, 116 | Container: resource.Container.ID, 117 | }); err != nil { 118 | return err 119 | } 120 | 121 | return fmt.Errorf("got non-zero exit code %s", stream.String()) 122 | } 123 | 124 | if err := dm.pool.Purge(resource); err != nil { 125 | return err 126 | } 127 | 128 | return nil 129 | } 130 | -------------------------------------------------------------------------------- /testing/dockertestx/nginx.go: -------------------------------------------------------------------------------- 1 | package dockertestx 2 | 3 | import ( 4 | "bytes" 5 | _ "embed" 6 | "fmt" 7 | "io/fs" 8 | "net/http" 9 | "os" 10 | "path" 11 | "text/template" 12 | "time" 13 | 14 | "github.com/google/uuid" 15 | "github.com/ory/dockertest/v3" 16 | "github.com/ory/dockertest/v3/docker" 17 | ) 18 | 19 | const ( 20 | nginxDefaultHealthEndpoint = "/healthz" 21 | nginxDefaultExposedPort = "8080" 22 | nginxDefaultVersionTag = "1.23" 23 | ) 24 | 25 | var ( 26 | //go:embed configs/nginx/cortex_nginx.conf 27 | NginxCortexConfig string 28 | ) 29 | 30 | type dockerNginxOption func(dc *dockerNginx) 31 | 32 | // NginxWithHealthEndpoint is an option to assign health endpoint 33 | func NginxWithHealthEndpoint(healthEndpoint string) dockerNginxOption { 34 | return func(dc *dockerNginx) { 35 | dc.healthEndpoint = healthEndpoint 36 | } 37 | } 38 | 39 | // NginxWithDockertestNetwork is an option to assign docker network 40 | func NginxWithDockertestNetwork(network *dockertest.Network) dockerNginxOption { 41 | return func(dc *dockerNginx) { 42 | dc.network = network 43 | } 44 | } 45 | 46 | // NginxWithVersionTag is an option to assign release tag 47 | // of a `nginx` image 48 | func NginxWithVersionTag(versionTag string) dockerNginxOption { 49 | return func(dc *dockerNginx) { 50 | dc.versionTag = versionTag 51 | } 52 | } 53 | 54 | // NginxWithDockerPool is an option to assign docker pool 55 | func NginxWithDockerPool(pool *dockertest.Pool) dockerNginxOption { 56 | return func(dc *dockerNginx) { 57 | dc.pool = pool 58 | } 59 | } 60 | 61 | // NginxWithDockerPool is an option to assign docker pool 62 | func NginxWithExposedPort(port string) dockerNginxOption { 63 | return func(dc *dockerNginx) { 64 | dc.exposedPort = port 65 | } 66 | } 67 | 68 | func NginxWithPresetConfig(presetConfig string) dockerNginxOption { 69 | return func(dc *dockerNginx) { 70 | dc.presetConfig = presetConfig 71 | } 72 | } 73 | 74 | func NginxWithConfigVariables(cv map[string]string) dockerNginxOption { 75 | return func(dc *dockerNginx) { 76 | dc.configVariables = cv 77 | } 78 | } 79 | 80 | type dockerNginx struct { 81 | network *dockertest.Network 82 | pool *dockertest.Pool 83 | exposedPort string 84 | internalHost string 85 | externalHost string 86 | presetConfig string 87 | versionTag string 88 | healthEndpoint string 89 | configVariables map[string]string 90 | dockertestResource *dockertest.Resource 91 | } 92 | 93 | // CreateNginx is a function to create a dockerized nginx 94 | func CreateNginx(opts ...dockerNginxOption) (*dockerNginx, error) { 95 | var ( 96 | err error 97 | dc = &dockerNginx{} 98 | ) 99 | 100 | for _, opt := range opts { 101 | opt(dc) 102 | } 103 | 104 | name := fmt.Sprintf("nginx-%s", uuid.New().String()) 105 | 106 | if dc.pool == nil { 107 | dc.pool, err = dockertest.NewPool("") 108 | if err != nil { 109 | return nil, fmt.Errorf("could not create dockertest pool: %w", err) 110 | } 111 | } 112 | 113 | if dc.versionTag == "" { 114 | dc.versionTag = nginxDefaultVersionTag 115 | } 116 | 117 | if dc.exposedPort == "" { 118 | dc.exposedPort = nginxDefaultExposedPort 119 | } 120 | 121 | if dc.healthEndpoint == "" { 122 | dc.healthEndpoint = nginxDefaultHealthEndpoint 123 | } 124 | 125 | runOpts := &dockertest.RunOptions{ 126 | Name: name, 127 | Repository: "nginx", 128 | Tag: dc.versionTag, 129 | ExposedPorts: []string{fmt.Sprintf("%s/tcp", dc.exposedPort)}, 130 | } 131 | 132 | if dc.network != nil { 133 | runOpts.NetworkID = dc.network.Network.ID 134 | } 135 | 136 | var confString string 137 | switch dc.presetConfig { 138 | case "cortex": 139 | confString = NginxCortexConfig 140 | } 141 | 142 | tmpl := template.New("nginx-config") 143 | parsedTemplate, err := tmpl.Parse(confString) 144 | if err != nil { 145 | return nil, err 146 | } 147 | var generatedConf bytes.Buffer 148 | err = parsedTemplate.Execute(&generatedConf, dc.configVariables) 149 | if err != nil { 150 | // it is unlikely that the code returns error here 151 | return nil, err 152 | } 153 | confString = generatedConf.String() 154 | 155 | pwd, err := os.Getwd() 156 | if err != nil { 157 | return nil, err 158 | } 159 | 160 | var ( 161 | confDestinationFolder = fmt.Sprintf("%s/tmp/dockertest-configs/nginx", pwd) 162 | ) 163 | 164 | foldersPath := []string{confDestinationFolder} 165 | for _, fp := range foldersPath { 166 | if _, err := os.Stat(fp); os.IsNotExist(err) { 167 | if err := os.MkdirAll(fp, 0777); err != nil { 168 | return nil, err 169 | } 170 | } 171 | } 172 | 173 | if err := os.WriteFile(path.Join(confDestinationFolder, "nginx.conf"), []byte(confString), fs.ModePerm); err != nil { 174 | return nil, err 175 | } 176 | 177 | dc.dockertestResource, err = dc.pool.RunWithOptions( 178 | runOpts, 179 | func(config *docker.HostConfig) { 180 | config.RestartPolicy = docker.RestartPolicy{ 181 | Name: "no", 182 | } 183 | config.Mounts = []docker.HostMount{ 184 | { 185 | Target: "/etc/nginx/nginx.conf", 186 | Source: path.Join(confDestinationFolder, "nginx.conf"), 187 | Type: "bind", 188 | }, 189 | } 190 | }, 191 | ) 192 | if err != nil { 193 | return nil, err 194 | } 195 | 196 | externalPort := dc.dockertestResource.GetPort(fmt.Sprintf("%s/tcp", dc.exposedPort)) 197 | dc.internalHost = fmt.Sprintf("%s:%s", name, dc.exposedPort) 198 | dc.externalHost = fmt.Sprintf("localhost:%s", externalPort) 199 | 200 | if err = dc.dockertestResource.Expire(120); err != nil { 201 | return nil, err 202 | } 203 | 204 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet 205 | dc.pool.MaxWait = 60 * time.Second 206 | 207 | if err = dc.pool.Retry(func() error { 208 | httpClient := &http.Client{} 209 | res, err := httpClient.Get(fmt.Sprintf("http://localhost:%s%s", externalPort, dc.healthEndpoint)) 210 | if err != nil { 211 | return err 212 | } 213 | 214 | if res.StatusCode != 200 { 215 | return fmt.Errorf("nginx server return status %d", res.StatusCode) 216 | } 217 | 218 | return nil 219 | }); err != nil { 220 | err = fmt.Errorf("could not connect to docker: %w", err) 221 | return nil, fmt.Errorf("could not connect to docker: %w", err) 222 | } 223 | 224 | return dc, nil 225 | } 226 | 227 | // GetPool returns docker pool 228 | func (dc *dockerNginx) GetPool() *dockertest.Pool { 229 | return dc.pool 230 | } 231 | 232 | // GetResource returns docker resource 233 | func (dc *dockerNginx) GetResource() *dockertest.Resource { 234 | return dc.dockertestResource 235 | } 236 | 237 | // GetInternalHost returns internal hostname and port 238 | // e.g. internal-xxxxxx:8080 239 | func (dc *dockerNginx) GetInternalHost() string { 240 | return dc.internalHost 241 | } 242 | 243 | // GetExternalHost returns localhost and port 244 | // e.g. localhost:51113 245 | func (dc *dockerNginx) GetExternalHost() string { 246 | return dc.externalHost 247 | } 248 | -------------------------------------------------------------------------------- /testing/dockertestx/postgres.go: -------------------------------------------------------------------------------- 1 | package dockertestx 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/google/uuid" 8 | "github.com/jmoiron/sqlx" 9 | "github.com/ory/dockertest/v3" 10 | "github.com/ory/dockertest/v3/docker" 11 | "github.com/raystack/salt/log" 12 | ) 13 | 14 | const ( 15 | defaultPGUname = "test_user" 16 | defaultPGPasswd = "test_pass" 17 | defaultDBname = "test_db" 18 | ) 19 | 20 | type dockerPostgresOption func(dpg *dockerPostgres) 21 | 22 | func PostgresWithLogger(logger log.Logger) dockerPostgresOption { 23 | return func(dpg *dockerPostgres) { 24 | dpg.logger = logger 25 | } 26 | } 27 | 28 | // PostgresWithDockertestNetwork is an option to assign docker network 29 | func PostgresWithDockertestNetwork(network *dockertest.Network) dockerPostgresOption { 30 | return func(dpg *dockerPostgres) { 31 | dpg.network = network 32 | } 33 | } 34 | 35 | // PostgresWithDockertestResourceExpiry is an option to assign docker resource expiry time 36 | func PostgresWithDockertestResourceExpiry(expiryInSeconds uint) dockerPostgresOption { 37 | return func(dpg *dockerPostgres) { 38 | dpg.expiryInSeconds = expiryInSeconds 39 | } 40 | } 41 | 42 | // PostgresWithDetail is an option to assign custom details 43 | // like username, password, and database name 44 | func PostgresWithDetail( 45 | username string, 46 | password string, 47 | dbName string, 48 | ) dockerPostgresOption { 49 | return func(dpg *dockerPostgres) { 50 | dpg.username = username 51 | dpg.password = password 52 | dpg.dbName = dbName 53 | } 54 | } 55 | 56 | // PostgresWithVersionTag is an option to assign release tag 57 | // of a `postgres` image 58 | func PostgresWithVersionTag(versionTag string) dockerPostgresOption { 59 | return func(dpg *dockerPostgres) { 60 | dpg.versionTag = versionTag 61 | } 62 | } 63 | 64 | // PostgresWithDockerPool is an option to assign docker pool 65 | func PostgresWithDockerPool(pool *dockertest.Pool) dockerPostgresOption { 66 | return func(dpg *dockerPostgres) { 67 | dpg.pool = pool 68 | } 69 | } 70 | 71 | type dockerPostgres struct { 72 | logger log.Logger 73 | network *dockertest.Network 74 | pool *dockertest.Pool 75 | username string 76 | password string 77 | dbName string 78 | versionTag string 79 | connStringInternal string 80 | connStringExternal string 81 | expiryInSeconds uint 82 | dockertestResource *dockertest.Resource 83 | } 84 | 85 | // CreatePostgres creates a postgres instance with default configurations 86 | func CreatePostgres(opts ...dockerPostgresOption) (*dockerPostgres, error) { 87 | var ( 88 | err error 89 | dpg = &dockerPostgres{} 90 | ) 91 | 92 | for _, opt := range opts { 93 | opt(dpg) 94 | } 95 | 96 | name := fmt.Sprintf("postgres-%s", uuid.New().String()) 97 | 98 | if dpg.pool == nil { 99 | dpg.pool, err = dockertest.NewPool("") 100 | if err != nil { 101 | return nil, fmt.Errorf("could not create dockertest pool: %w", err) 102 | } 103 | } 104 | 105 | if dpg.username == "" { 106 | dpg.username = defaultPGUname 107 | } 108 | 109 | if dpg.password == "" { 110 | dpg.password = defaultPGPasswd 111 | } 112 | 113 | if dpg.dbName == "" { 114 | dpg.dbName = defaultDBname 115 | } 116 | 117 | if dpg.versionTag == "" { 118 | dpg.versionTag = "12" 119 | } 120 | 121 | if dpg.expiryInSeconds == 0 { 122 | dpg.expiryInSeconds = 120 123 | } 124 | 125 | runOpts := &dockertest.RunOptions{ 126 | Name: name, 127 | Repository: "postgres", 128 | Tag: dpg.versionTag, 129 | Env: []string{ 130 | "POSTGRES_PASSWORD=" + dpg.password, 131 | "POSTGRES_USER=" + dpg.username, 132 | "POSTGRES_DB=" + dpg.dbName, 133 | }, 134 | ExposedPorts: []string{"5432/tcp"}, 135 | } 136 | 137 | if dpg.network != nil { 138 | runOpts.NetworkID = dpg.network.Network.ID 139 | } 140 | 141 | dpg.dockertestResource, err = dpg.pool.RunWithOptions( 142 | runOpts, 143 | func(config *docker.HostConfig) { 144 | config.RestartPolicy = docker.RestartPolicy{ 145 | Name: "no", 146 | } 147 | }, 148 | ) 149 | if err != nil { 150 | return nil, err 151 | } 152 | 153 | pgPort := dpg.dockertestResource.GetPort("5432/tcp") 154 | dpg.connStringInternal = fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", dpg.username, dpg.password, name, "5432", dpg.dbName) 155 | dpg.connStringExternal = fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", dpg.username, dpg.password, "localhost", pgPort, dpg.dbName) 156 | 157 | if err = dpg.dockertestResource.Expire(dpg.expiryInSeconds); err != nil { 158 | return nil, err 159 | } 160 | 161 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet 162 | dpg.pool.MaxWait = 60 * time.Second 163 | 164 | if err = dpg.pool.Retry(func() error { 165 | if _, err := sqlx.Connect("postgres", dpg.connStringExternal); err != nil { 166 | return err 167 | } 168 | return nil 169 | }); err != nil { 170 | err = fmt.Errorf("could not connect to docker: %w", err) 171 | return nil, fmt.Errorf("could not connect to docker: %w", err) 172 | } 173 | 174 | return dpg, nil 175 | } 176 | 177 | // GetInternalConnString returns internal connection string of a postgres instance 178 | func (dpg *dockerPostgres) GetInternalConnString() string { 179 | return dpg.connStringInternal 180 | } 181 | 182 | // GetExternalConnString returns external connection string of a postgres instance 183 | func (dpg *dockerPostgres) GetExternalConnString() string { 184 | return dpg.connStringExternal 185 | } 186 | 187 | // GetPool returns docker pool 188 | func (dpg *dockerPostgres) GetPool() *dockertest.Pool { 189 | return dpg.pool 190 | } 191 | 192 | // GetResource returns docker resource 193 | func (dpg *dockerPostgres) GetResource() *dockertest.Resource { 194 | return dpg.dockertestResource 195 | } 196 | -------------------------------------------------------------------------------- /testing/dockertestx/spicedb.go: -------------------------------------------------------------------------------- 1 | package dockertestx 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | authzedpb "github.com/authzed/authzed-go/proto/authzed/api/v1" 9 | "github.com/authzed/authzed-go/v1" 10 | "github.com/authzed/grpcutil" 11 | "github.com/google/uuid" 12 | "github.com/ory/dockertest/v3" 13 | "github.com/ory/dockertest/v3/docker" 14 | "google.golang.org/grpc" 15 | "google.golang.org/grpc/codes" 16 | "google.golang.org/grpc/credentials/insecure" 17 | "google.golang.org/grpc/status" 18 | ) 19 | 20 | const ( 21 | defaultPreSharedKey = "default-preshared-key" 22 | defaultLogLevel = "debug" 23 | ) 24 | 25 | type dockerSpiceDBOption func(dsp *dockerSpiceDB) 26 | 27 | func SpiceDBWithLogLevel(logLevel string) dockerSpiceDBOption { 28 | return func(dsp *dockerSpiceDB) { 29 | dsp.logLevel = logLevel 30 | } 31 | } 32 | 33 | // SpiceDBWithDockertestNetwork is an option to assign docker network 34 | func SpiceDBWithDockertestNetwork(network *dockertest.Network) dockerSpiceDBOption { 35 | return func(dsp *dockerSpiceDB) { 36 | dsp.network = network 37 | } 38 | } 39 | 40 | // SpiceDBWithVersionTag is an option to assign release tag 41 | // of a `quay.io/authzed/spicedb` image 42 | func SpiceDBWithVersionTag(versionTag string) dockerSpiceDBOption { 43 | return func(dsp *dockerSpiceDB) { 44 | dsp.versionTag = versionTag 45 | } 46 | } 47 | 48 | // SpiceDBWithDockerPool is an option to assign docker pool 49 | func SpiceDBWithDockerPool(pool *dockertest.Pool) dockerSpiceDBOption { 50 | return func(dsp *dockerSpiceDB) { 51 | dsp.pool = pool 52 | } 53 | } 54 | 55 | // SpiceDBWithPreSharedKey is an option to assign pre-shared-key 56 | func SpiceDBWithPreSharedKey(preSharedKey string) dockerSpiceDBOption { 57 | return func(dsp *dockerSpiceDB) { 58 | dsp.preSharedKey = preSharedKey 59 | } 60 | } 61 | 62 | type dockerSpiceDB struct { 63 | network *dockertest.Network 64 | pool *dockertest.Pool 65 | preSharedKey string 66 | versionTag string 67 | logLevel string 68 | externalPort string 69 | dockertestResource *dockertest.Resource 70 | } 71 | 72 | // CreateSpiceDB creates a spicedb instance with postgres backend and default configurations 73 | func CreateSpiceDB(postgresConnectionURL string, opts ...dockerSpiceDBOption) (*dockerSpiceDB, error) { 74 | var ( 75 | err error 76 | dsp = &dockerSpiceDB{} 77 | ) 78 | 79 | for _, opt := range opts { 80 | opt(dsp) 81 | } 82 | 83 | name := fmt.Sprintf("spicedb-%s", uuid.New().String()) 84 | 85 | if dsp.pool == nil { 86 | dsp.pool, err = dockertest.NewPool("") 87 | if err != nil { 88 | return nil, fmt.Errorf("could not create dockertest pool: %w", err) 89 | } 90 | } 91 | 92 | if dsp.preSharedKey == "" { 93 | dsp.preSharedKey = defaultPreSharedKey 94 | } 95 | 96 | if dsp.logLevel == "" { 97 | dsp.logLevel = defaultLogLevel 98 | } 99 | 100 | if dsp.versionTag == "" { 101 | dsp.versionTag = "v1.0.0" 102 | } 103 | 104 | runOpts := &dockertest.RunOptions{ 105 | Name: name, 106 | Repository: "quay.io/authzed/spicedb", 107 | Tag: dsp.versionTag, 108 | Cmd: []string{"spicedb", "serve", "--log-level", dsp.logLevel, "--grpc-preshared-key", dsp.preSharedKey, "--grpc-no-tls", "--datastore-engine", "postgres", "--datastore-conn-uri", postgresConnectionURL}, 109 | ExposedPorts: []string{"50051/tcp"}, 110 | } 111 | 112 | if dsp.network != nil { 113 | runOpts.NetworkID = dsp.network.Network.ID 114 | } 115 | 116 | dsp.dockertestResource, err = dsp.pool.RunWithOptions( 117 | runOpts, 118 | func(config *docker.HostConfig) { 119 | config.RestartPolicy = docker.RestartPolicy{ 120 | Name: "no", 121 | } 122 | }, 123 | ) 124 | if err != nil { 125 | return nil, err 126 | } 127 | 128 | dsp.externalPort = dsp.dockertestResource.GetPort("50051/tcp") 129 | 130 | if err = dsp.dockertestResource.Expire(120); err != nil { 131 | return nil, err 132 | } 133 | 134 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet 135 | dsp.pool.MaxWait = 60 * time.Second 136 | 137 | if err = dsp.pool.Retry(func() error { 138 | client, err := authzed.NewClient( 139 | fmt.Sprintf("localhost:%s", dsp.externalPort), 140 | grpc.WithTransportCredentials(insecure.NewCredentials()), 141 | grpcutil.WithInsecureBearerToken(dsp.preSharedKey), 142 | ) 143 | if err != nil { 144 | return err 145 | } 146 | _, err = client.ReadSchema(context.Background(), &authzedpb.ReadSchemaRequest{}) 147 | grpCStatus := status.Convert(err) 148 | if grpCStatus.Code() == codes.Unavailable { 149 | return err 150 | } 151 | return nil 152 | }); err != nil { 153 | err = fmt.Errorf("could not connect to docker: %w", err) 154 | return nil, fmt.Errorf("could not connect to docker: %w", err) 155 | } 156 | 157 | return dsp, nil 158 | } 159 | 160 | // GetExternalPort returns exposed port of the spicedb instance 161 | func (dsp *dockerSpiceDB) GetExternalPort() string { 162 | return dsp.externalPort 163 | } 164 | 165 | // GetPreSharedKey returns pre-shared-key used in the spicedb instance 166 | func (dsp *dockerSpiceDB) GetPreSharedKey() string { 167 | return dsp.preSharedKey 168 | } 169 | 170 | // GetPool returns docker pool 171 | func (dsp *dockerSpiceDB) GetPool() *dockertest.Pool { 172 | return dsp.pool 173 | } 174 | 175 | // GetResource returns docker resource 176 | func (dsp *dockerSpiceDB) GetResource() *dockertest.Resource { 177 | return dsp.dockertestResource 178 | } 179 | -------------------------------------------------------------------------------- /testing/dockertestx/spicedb_migrate.go: -------------------------------------------------------------------------------- 1 | package dockertestx 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | 8 | "github.com/ory/dockertest/v3" 9 | "github.com/ory/dockertest/v3/docker" 10 | ) 11 | 12 | type dockerMigrateSpiceDBOption func(dmm *dockerMigrateSpiceDB) 13 | 14 | // MigrateSpiceDBWithDockertestNetwork is an option to assign docker network 15 | func MigrateSpiceDBWithDockertestNetwork(network *dockertest.Network) dockerMigrateSpiceDBOption { 16 | return func(dm *dockerMigrateSpiceDB) { 17 | dm.network = network 18 | } 19 | } 20 | 21 | // MigrateSpiceDBWithVersionTag is an option to assign release tag 22 | // of a `quay.io/authzed/spicedb` image 23 | func MigrateSpiceDBWithVersionTag(versionTag string) dockerMigrateSpiceDBOption { 24 | return func(dm *dockerMigrateSpiceDB) { 25 | dm.versionTag = versionTag 26 | } 27 | } 28 | 29 | // MigrateSpiceDBWithDockerPool is an option to assign docker pool 30 | func MigrateSpiceDBWithDockerPool(pool *dockertest.Pool) dockerMigrateSpiceDBOption { 31 | return func(dm *dockerMigrateSpiceDB) { 32 | dm.pool = pool 33 | } 34 | } 35 | 36 | type dockerMigrateSpiceDB struct { 37 | network *dockertest.Network 38 | pool *dockertest.Pool 39 | versionTag string 40 | } 41 | 42 | // MigrateSpiceDB migrates spicedb with postgres backend 43 | func MigrateSpiceDB(postgresConnectionURL string, opts ...dockerMigrateMinioOption) error { 44 | var ( 45 | err error 46 | dm = &dockerMigrateMinio{} 47 | ) 48 | 49 | for _, opt := range opts { 50 | opt(dm) 51 | } 52 | 53 | if dm.pool == nil { 54 | dm.pool, err = dockertest.NewPool("") 55 | if err != nil { 56 | return fmt.Errorf("could not create dockertest pool: %w", err) 57 | } 58 | } 59 | 60 | if dm.versionTag == "" { 61 | dm.versionTag = "v1.0.0" 62 | } 63 | 64 | runOpts := &dockertest.RunOptions{ 65 | Repository: "quay.io/authzed/spicedb", 66 | Tag: dm.versionTag, 67 | Cmd: []string{"spicedb", "migrate", "head", "--datastore-engine", "postgres", "--datastore-conn-uri", postgresConnectionURL}, 68 | } 69 | 70 | if dm.network != nil { 71 | runOpts.NetworkID = dm.network.Network.ID 72 | } 73 | 74 | resource, err := dm.pool.RunWithOptions(runOpts, func(config *docker.HostConfig) { 75 | config.RestartPolicy = docker.RestartPolicy{ 76 | Name: "no", 77 | } 78 | }) 79 | if err != nil { 80 | return err 81 | } 82 | 83 | if err := resource.Expire(120); err != nil { 84 | return err 85 | } 86 | 87 | waitCtx, cancel := context.WithTimeout(context.Background(), waitContainerTimeout) 88 | defer cancel() 89 | 90 | // Ensure the command completed successfully. 91 | status, err := dm.pool.Client.WaitContainerWithContext(resource.Container.ID, waitCtx) 92 | if err != nil { 93 | return err 94 | } 95 | 96 | if status != 0 { 97 | stream := new(bytes.Buffer) 98 | 99 | if err = dm.pool.Client.Logs(docker.LogsOptions{ 100 | Context: waitCtx, 101 | OutputStream: stream, 102 | ErrorStream: stream, 103 | Stdout: true, 104 | Stderr: true, 105 | Container: resource.Container.ID, 106 | }); err != nil { 107 | return err 108 | } 109 | 110 | return fmt.Errorf("got non-zero exit code %s", stream.String()) 111 | } 112 | 113 | if err := dm.pool.Purge(resource); err != nil { 114 | return err 115 | } 116 | 117 | return nil 118 | } 119 | -------------------------------------------------------------------------------- /utils/error_status.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "github.com/pkg/errors" 5 | "google.golang.org/grpc/codes" 6 | "google.golang.org/grpc/status" 7 | ) 8 | 9 | var codeToStr = map[codes.Code]string{ 10 | codes.OK: `"OK"`, 11 | codes.Canceled: `"CANCELED"`, 12 | codes.Unknown: `"UNKNOWN"`, 13 | codes.InvalidArgument: `"INVALID_ARGUMENT"`, 14 | codes.DeadlineExceeded: `"DEADLINE_EXCEEDED"`, 15 | codes.NotFound: `"NOT_FOUND"`, 16 | codes.AlreadyExists: `"ALREADY_EXISTS"`, 17 | codes.PermissionDenied: `"PERMISSION_DENIED"`, 18 | codes.ResourceExhausted: `"RESOURCE_EXHAUSTED"`, 19 | codes.FailedPrecondition: `"FAILED_PRECONDITION"`, 20 | codes.Aborted: `"ABORTED"`, 21 | codes.OutOfRange: `"OUT_OF_RANGE"`, 22 | codes.Unimplemented: `"UNIMPLEMENTED"`, 23 | codes.Internal: `"INTERNAL"`, 24 | codes.Unavailable: `"UNAVAILABLE"`, 25 | codes.DataLoss: `"DATA_LOSS"`, 26 | codes.Unauthenticated: `"UNAUTHENTICATED"`, 27 | } 28 | 29 | func StatusCode(err error) codes.Code { 30 | if err == nil { 31 | return codes.OK 32 | } 33 | var se interface { 34 | GRPCStatus() *status.Status 35 | } 36 | if errors.As(err, &se) { 37 | return se.GRPCStatus().Code() 38 | } 39 | return codes.Unknown 40 | } 41 | func StatusText(err error) string { 42 | return codeToStr[StatusCode(err)] 43 | } 44 | -------------------------------------------------------------------------------- /utils/error_status_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/pkg/errors" 8 | "github.com/stretchr/testify/assert" 9 | "google.golang.org/grpc/codes" 10 | "google.golang.org/grpc/status" 11 | ) 12 | 13 | func TestStatusCode(t *testing.T) { 14 | cases := []struct { 15 | name string 16 | err error 17 | expected codes.Code 18 | }{ 19 | { 20 | name: "with status.Error", 21 | err: status.Error(codes.NotFound, "Somebody that I used to know"), 22 | expected: codes.NotFound, 23 | }, 24 | { 25 | name: "with wrapped status.Error", 26 | err: fmt.Errorf("%w", status.Error(codes.Unavailable, "I shot the sheriff")), 27 | expected: codes.Unavailable, 28 | }, 29 | { 30 | name: "with std lib error", 31 | err: errors.New("Runnin' down a dream"), 32 | expected: codes.Unknown, 33 | }, 34 | { 35 | name: "with nil error", 36 | err: nil, 37 | expected: codes.OK, 38 | }, 39 | } 40 | for _, tc := range cases { 41 | t.Run(tc.name, func(t *testing.T) { 42 | assert.Equal(t, tc.expected, StatusCode(tc.err)) 43 | }) 44 | } 45 | } 46 | --------------------------------------------------------------------------------