├── .github
├── SECURITY.md
└── workflows
│ ├── lint.yaml
│ └── test.yaml
├── .gitignore
├── .golangci.yml
├── LICENSE
├── Makefile
├── README.md
├── auth
├── audit
│ ├── audit.go
│ ├── audit_test.go
│ ├── mocks
│ │ └── repository.go
│ ├── model.go
│ └── repositories
│ │ ├── dockertest_test.go
│ │ ├── postgres.go
│ │ └── postgres_test.go
└── oidc
│ ├── _example
│ └── main.go
│ ├── cobra.go
│ ├── redirect.html
│ ├── source_gsa.go
│ ├── source_oidc.go
│ └── utils.go
├── cli
├── commander
│ ├── codex.go
│ ├── completion.go
│ ├── hooks.go
│ ├── layout.go
│ ├── manager.go
│ ├── reference.go
│ └── topics.go
├── printer
│ ├── colors.go
│ ├── markdown.go
│ ├── progress.go
│ ├── spinner.go
│ ├── structured.go
│ ├── table.go
│ └── text.go
├── prompter
│ └── prompt.go
├── releaser
│ └── release.go
└── terminator
│ ├── brew.go
│ ├── browser.go
│ ├── pager.go
│ └── term.go
├── config
├── config.go
├── config_test.go
├── doc.go
└── helpers.go
├── db
├── config.go
├── db.go
├── db_test.go
├── migrate.go
├── migrate_test.go
└── migrations
│ ├── 1481574547_create_users_table.down.sql
│ └── 1481574547_create_users_table.up.sql
├── go.mod
├── go.sum
├── log
├── logger.go
├── logrus.go
├── logrus_test.go
├── noop.go
├── zap.go
└── zap_test.go
├── rql
├── README.md
├── parser.go
└── parser_test.go
├── server
├── mux
│ ├── README.md
│ ├── mux.go
│ ├── option.go
│ └── serve_target.go
└── spa
│ ├── doc.go
│ ├── handler.go
│ └── router.go
├── telemetry
├── opentelemetry.go
├── otelgrpc
│ ├── otelgrpc.go
│ └── otelgrpc_test.go
├── otelhhtpclient
│ ├── annotations.go
│ ├── http_transport.go
│ └── http_transport_test.go
└── telemetry.go
├── testing
└── dockertestx
│ ├── README.md
│ ├── configs
│ ├── cortex
│ │ └── single_process_cortex.yaml
│ └── nginx
│ │ └── cortex_nginx.conf
│ ├── cortex.go
│ ├── dockertestx.go
│ ├── minio.go
│ ├── minio_migrate.go
│ ├── nginx.go
│ ├── postgres.go
│ ├── spicedb.go
│ └── spicedb_migrate.go
└── utils
├── error_status.go
└── error_status_test.go
/.github/SECURITY.md:
--------------------------------------------------------------------------------
1 | Raystack takes the security of our software products and services seriously.
2 |
3 | If you believe you have found a security vulnerability in this project, you can report it to us directly using [private vulnerability reporting][].
4 |
5 | - Include a description of your investigation of this project's codebase and why you believe an exploit is possible.
6 | - POCs and links to code are greatly encouraged.
7 | - Such reports are not eligible for a bounty reward.
8 |
9 | **Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.**
10 |
11 | Thanks for helping make GitHub safe for everyone.
12 |
13 | [private vulnerability reporting]: https://github.com/raystack/meteor/security/advisories
--------------------------------------------------------------------------------
/.github/workflows/lint.yaml:
--------------------------------------------------------------------------------
1 | name: Lint
2 | on:
3 | push:
4 | paths:
5 | - "**.go"
6 | - go.mod
7 | - go.sum
8 | pull_request:
9 | paths:
10 | - "**.go"
11 | - go.mod
12 | - go.sum
13 |
14 | jobs:
15 | checks:
16 | runs-on: ubuntu-latest
17 | steps:
18 | - name: Checkout
19 | uses: actions/checkout@v4
20 | with:
21 | fetch-depth: 0
22 | - name: Setup Go
23 | uses: actions/setup-go@v5
24 | with:
25 | go-version-file: "go.mod"
26 | - name: Run linter
27 | uses: golangci/golangci-lint-action@v6
28 | with:
29 | version: v1.60
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout
14 | uses: actions/checkout@v4
15 | with:
16 | fetch-depth: 0
17 | - name: Setup Go
18 | uses: actions/setup-go@v5
19 | with:
20 | go-version-file: "go.mod"
21 | - name: Run unit tests
22 | run: make test
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Binaries for programs and plugins
2 | *.exe
3 | *.exe~
4 | *.dll
5 | *.so
6 | *.dylib
7 |
8 | # Test binary, built with `go test -cfg`
9 | *.test
10 |
11 | # Output of the go coverage tool, specifically when used with LiteIDE
12 | *.out
13 |
14 | # Dependency directories (remove the comment below to include it)
15 | vendor/
16 |
17 | .idea
18 | .vscode
19 | expt/
20 | temp.env
21 | .DS_Store
22 | temp.env
23 |
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | output:
2 | formats:
3 | - format: line-number
4 | linters:
5 | enable-all: false
6 | disable-all: true
7 | enable:
8 | - govet
9 | - goimports
10 | - thelper
11 | - tparallel
12 | - unconvert
13 | - wastedassign
14 | - revive
15 | - unused
16 | - gofmt
17 | - whitespace
18 | - misspell
19 | linters-settings:
20 | revive:
21 | ignore-generated-header: true
22 | severity: warning
23 | severity:
24 | default-severity: error
25 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: check fmt lint test vet help
2 | .DEFAULT_GOAL := help
3 |
4 | check: test lint ## Run tests and linters
5 |
6 | test: ## Run tests
7 | go test ./... -race
8 |
9 | lint: ## Run linter
10 | golangci-lint run
11 |
12 | help:
13 | @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # salt
2 |
3 | [](https://godoc.org/github.com/raystack/salt)
4 | 
5 | [](https://goreportcard.com/report/github.com/raystack/salt)
6 |
7 | Salt is a Golang utility library offering a variety of packages to simplify and enhance application development. It provides modular and reusable components for common tasks, including configuration management, CLI utilities, authentication, logging, and more.
8 |
9 | ## Installation
10 |
11 | To use, run the following command:
12 |
13 | ```
14 | go get github.com/raystack/salt
15 | ```
16 |
17 | ## Pacakages
18 |
19 | ### Configuration and Environment
20 | - **`config`**
21 | Utilities for managing application configurations using environment variables, files, or defaults.
22 |
23 | ### CLI Utilities
24 | - **`cli/cmdx`**
25 | Command execution and management tools.
26 |
27 | - **`cli/printer`**
28 | Utilities for formatting and printing output to the terminal.
29 |
30 | - **`cli/prompt`**
31 | Interactive CLI prompts for user input.
32 |
33 | - **`cli/terminal`**
34 | Terminal utilities for colors, cursor management, and formatting.
35 |
36 | - **`cli/version`**
37 | Utilities for displaying and managing CLI tool versions.
38 |
39 | ### Authentication and Security
40 | - **`auth/oidc`**
41 | Helpers for integrating OpenID Connect authentication flows.
42 |
43 | - **`auth/audit`**
44 | Auditing tools for tracking security events and compliance.
45 |
46 | ### Server and Infrastructure
47 | - **`server`**
48 | Utilities for setting up and managing HTTP or RPC servers.
49 |
50 | - **`db`**
51 | Helpers for database connections, migrations, and query execution.
52 |
53 | - **`telemetry`**
54 | Observability tools for capturing application metrics and traces.
55 |
56 | ### Development and Testing
57 | - **`dockertestx`**
58 | Tools for creating and managing Docker-based testing environments.
59 |
60 | ### Utilities
61 | - **`log`**
62 | Simplified logging utilities for structured and unstructured log messages.
63 |
64 | - **`utils`**
65 | General-purpose utility functions for common programming tasks.
66 |
--------------------------------------------------------------------------------
/auth/audit/audit.go:
--------------------------------------------------------------------------------
1 | //go:generate mockery --name=repository --exported
2 |
3 | package audit
4 |
5 | import (
6 | "context"
7 | "errors"
8 | "fmt"
9 | "time"
10 | )
11 |
12 | var (
13 | TimeNow = time.Now
14 |
15 | ErrInvalidMetadata = errors.New("failed to cast existing metadata to map[string]interface{} type")
16 | )
17 |
18 | type actorContextKey struct{}
19 | type metadataContextKey struct{}
20 |
21 | func WithActor(ctx context.Context, actor string) context.Context {
22 | return context.WithValue(ctx, actorContextKey{}, actor)
23 | }
24 |
25 | func WithMetadata(ctx context.Context, md map[string]interface{}) (context.Context, error) {
26 | existingMetadata := ctx.Value(metadataContextKey{})
27 | if existingMetadata == nil {
28 | return context.WithValue(ctx, metadataContextKey{}, md), nil
29 | }
30 |
31 | // append new metadata
32 | mapMd, ok := existingMetadata.(map[string]interface{})
33 | if !ok {
34 | return nil, ErrInvalidMetadata
35 | }
36 | for k, v := range md {
37 | mapMd[k] = v
38 | }
39 |
40 | return context.WithValue(ctx, metadataContextKey{}, mapMd), nil
41 | }
42 |
43 | type repository interface {
44 | Init(context.Context) error
45 | Insert(context.Context, *Log) error
46 | }
47 |
48 | type AuditOption func(*Service)
49 |
50 | func WithRepository(r repository) AuditOption {
51 | return func(s *Service) {
52 | s.repository = r
53 | }
54 | }
55 |
56 | func WithMetadataExtractor(fn func(context.Context) map[string]interface{}) AuditOption {
57 | return func(s *Service) {
58 | s.withMetadata = func(ctx context.Context) (context.Context, error) {
59 | md := fn(ctx)
60 | return WithMetadata(ctx, md)
61 | }
62 | }
63 | }
64 |
65 | func WithActorExtractor(fn func(context.Context) (string, error)) AuditOption {
66 | return func(s *Service) {
67 | s.actorExtractor = fn
68 | }
69 | }
70 |
71 | func defaultActorExtractor(ctx context.Context) (string, error) {
72 | if actor, ok := ctx.Value(actorContextKey{}).(string); ok {
73 | return actor, nil
74 | }
75 | return "", nil
76 | }
77 |
78 | type Service struct {
79 | repository repository
80 | actorExtractor func(context.Context) (string, error)
81 | withMetadata func(context.Context) (context.Context, error)
82 | }
83 |
84 | func New(opts ...AuditOption) *Service {
85 | svc := &Service{
86 | actorExtractor: defaultActorExtractor,
87 | }
88 | for _, o := range opts {
89 | o(svc)
90 | }
91 |
92 | return svc
93 | }
94 |
95 | func (s *Service) Log(ctx context.Context, action string, data interface{}) error {
96 | if s.withMetadata != nil {
97 | var err error
98 | if ctx, err = s.withMetadata(ctx); err != nil {
99 | return err
100 | }
101 | }
102 |
103 | l := &Log{
104 | Timestamp: TimeNow(),
105 | Action: action,
106 | Data: data,
107 | }
108 |
109 | if md, ok := ctx.Value(metadataContextKey{}).(map[string]interface{}); ok {
110 | l.Metadata = md
111 | }
112 |
113 | if s.actorExtractor != nil {
114 | actor, err := s.actorExtractor(ctx)
115 | if err != nil {
116 | return fmt.Errorf("extracting actor: %w", err)
117 | }
118 | l.Actor = actor
119 | }
120 |
121 | return s.repository.Insert(ctx, l)
122 | }
123 |
--------------------------------------------------------------------------------
/auth/audit/audit_test.go:
--------------------------------------------------------------------------------
1 | package audit_test
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "testing"
7 | "time"
8 |
9 | "github.com/raystack/salt/auth/audit"
10 | "github.com/raystack/salt/auth/audit/mocks"
11 |
12 | "github.com/stretchr/testify/mock"
13 | "github.com/stretchr/testify/suite"
14 | )
15 |
16 | type AuditTestSuite struct {
17 | suite.Suite
18 |
19 | now time.Time
20 |
21 | mockRepository *mocks.Repository
22 | service *audit.Service
23 | }
24 |
25 | func (s *AuditTestSuite) setupTest() {
26 | s.mockRepository = new(mocks.Repository)
27 | s.service = audit.New(
28 | audit.WithMetadataExtractor(func(context.Context) map[string]interface{} {
29 | return map[string]interface{}{
30 | "trace_id": "test-trace-id",
31 | "app_name": "guardian_test",
32 | "app_version": 1,
33 | }
34 | }),
35 | audit.WithRepository(s.mockRepository),
36 | )
37 |
38 | s.now = time.Now()
39 | audit.TimeNow = func() time.Time {
40 | return s.now
41 | }
42 | }
43 |
44 | func TestAudit(t *testing.T) {
45 | suite.Run(t, new(AuditTestSuite))
46 | }
47 |
48 | func (s *AuditTestSuite) TestLog() {
49 | s.Run("should insert to repository", func() {
50 | s.setupTest()
51 |
52 | s.mockRepository.On("Insert", mock.Anything, &audit.Log{
53 | Timestamp: s.now,
54 | Action: "action",
55 | Actor: "user@example.com",
56 | Data: map[string]interface{}{"foo": "bar"},
57 | Metadata: map[string]interface{}{
58 | "trace_id": "test-trace-id",
59 | "app_name": "guardian_test",
60 | "app_version": 1,
61 | },
62 | }).Return(nil)
63 |
64 | ctx := context.Background()
65 | ctx = audit.WithActor(ctx, "user@example.com")
66 | err := s.service.Log(ctx, "action", map[string]interface{}{"foo": "bar"})
67 | s.NoError(err)
68 | })
69 |
70 | s.Run("actor extractor", func() {
71 | s.Run("should use actor extractor if option given", func() {
72 | expectedActor := "test-actor"
73 | s.service = audit.New(
74 | audit.WithActorExtractor(func(ctx context.Context) (string, error) {
75 | return expectedActor, nil
76 | }),
77 | audit.WithRepository(s.mockRepository),
78 | )
79 |
80 | s.mockRepository.On("Insert", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
81 | log := args.Get(1).(*audit.Log)
82 | s.Equal(expectedActor, log.Actor)
83 | }).Return(nil).Once()
84 |
85 | err := s.service.Log(context.Background(), "", nil)
86 | s.NoError(err)
87 | })
88 |
89 | s.Run("should return error if extractor returns error", func() {
90 | expectedError := errors.New("test error")
91 | s.service = audit.New(
92 | audit.WithActorExtractor(func(ctx context.Context) (string, error) {
93 | return "", expectedError
94 | }),
95 | )
96 |
97 | err := s.service.Log(context.Background(), "", nil)
98 | s.ErrorIs(err, expectedError)
99 | })
100 | })
101 |
102 | s.Run("metadata", func() {
103 | s.Run("should pass empty trace id if extractor not found", func() {
104 | s.service = audit.New(
105 | audit.WithMetadataExtractor(func(ctx context.Context) map[string]interface{} {
106 | return map[string]interface{}{
107 | "app_name": "guardian_test",
108 | "app_version": 1,
109 | }
110 | }),
111 | audit.WithRepository(s.mockRepository),
112 | )
113 |
114 | s.mockRepository.On("Insert", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
115 | l := args.Get(1).(*audit.Log)
116 | s.IsType(map[string]interface{}{}, l.Metadata)
117 |
118 | md := l.Metadata.(map[string]interface{})
119 | s.Empty(md["trace_id"])
120 | s.NotEmpty(md["app_name"])
121 | s.NotEmpty(md["app_version"])
122 | }).Return(nil).Once()
123 |
124 | err := s.service.Log(context.Background(), "", nil)
125 | s.NoError(err)
126 | })
127 |
128 | s.Run("should append new metadata to existing one", func() {
129 | s.service = audit.New(
130 | audit.WithMetadataExtractor(func(ctx context.Context) map[string]interface{} {
131 | return map[string]interface{}{
132 | "existing": "foobar",
133 | }
134 | }),
135 | audit.WithRepository(s.mockRepository),
136 | )
137 |
138 | expectedMetadata := map[string]interface{}{
139 | "existing": "foobar",
140 | "new": "foobar",
141 | }
142 | s.mockRepository.On("Insert", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
143 | log := args.Get(1).(*audit.Log)
144 | s.Equal(expectedMetadata, log.Metadata)
145 | }).Return(nil).Once()
146 |
147 | ctx, err := audit.WithMetadata(context.Background(), map[string]interface{}{
148 | "new": "foobar",
149 | })
150 | s.Require().NoError(err)
151 |
152 | err = s.service.Log(ctx, "", nil)
153 | s.NoError(err)
154 | })
155 | })
156 |
157 | s.Run("should return error if repository.Insert fails", func() {
158 | s.setupTest()
159 |
160 | expectedError := errors.New("test error")
161 | s.mockRepository.On("Insert", mock.Anything, mock.Anything).Return(expectedError)
162 |
163 | err := s.service.Log(context.Background(), "", nil)
164 | s.ErrorIs(err, expectedError)
165 | })
166 | }
167 |
--------------------------------------------------------------------------------
/auth/audit/mocks/repository.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.10.0. DO NOT EDIT.
2 |
3 | package mocks
4 |
5 | import (
6 | context "context"
7 | "github.com/raystack/salt/auth/audit"
8 |
9 | mock "github.com/stretchr/testify/mock"
10 | )
11 |
12 | // Repository is an autogenerated mock type for the repository type
13 | type Repository struct {
14 | mock.Mock
15 | }
16 |
17 | // Init provides a mock function with given fields: _a0
18 | func (_m *Repository) Init(_a0 context.Context) error {
19 | ret := _m.Called(_a0)
20 |
21 | var r0 error
22 | if rf, ok := ret.Get(0).(func(context.Context) error); ok {
23 | r0 = rf(_a0)
24 | } else {
25 | r0 = ret.Error(0)
26 | }
27 |
28 | return r0
29 | }
30 |
31 | // Insert provides a mock function with given fields: _a0, _a1
32 | func (_m *Repository) Insert(_a0 context.Context, _a1 *audit.Log) error {
33 | ret := _m.Called(_a0, _a1)
34 |
35 | var r0 error
36 | if rf, ok := ret.Get(0).(func(context.Context, *audit.Log) error); ok {
37 | r0 = rf(_a0, _a1)
38 | } else {
39 | r0 = ret.Error(0)
40 | }
41 |
42 | return r0
43 | }
44 |
--------------------------------------------------------------------------------
/auth/audit/model.go:
--------------------------------------------------------------------------------
1 | package audit
2 |
3 | import "time"
4 |
5 | type Log struct {
6 | Timestamp time.Time `json:"timestamp"`
7 | Action string `json:"action"`
8 | Actor string `json:"actor"`
9 | Data interface{} `json:"data"`
10 | Metadata interface{} `json:"metadata"`
11 | }
12 |
--------------------------------------------------------------------------------
/auth/audit/repositories/dockertest_test.go:
--------------------------------------------------------------------------------
1 | package repositories_test
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "fmt"
7 | "time"
8 |
9 | "github.com/raystack/salt/auth/audit/repositories"
10 |
11 | _ "github.com/lib/pq"
12 | "github.com/ory/dockertest/v3"
13 | "github.com/ory/dockertest/v3/docker"
14 | "github.com/raystack/salt/log"
15 | )
16 |
17 | func newTestRepository(logger log.Logger) (*repositories.PostgresRepository, *dockertest.Pool, *dockertest.Resource, error) {
18 | host := "localhost"
19 | port := "5433"
20 | user := "test_user"
21 | password := "test_pass"
22 | dbName := "test_db"
23 | sslMode := "disable"
24 |
25 | opts := &dockertest.RunOptions{
26 | Repository: "postgres",
27 | Tag: "13",
28 | Env: []string{
29 | "POSTGRES_PASSWORD=" + password,
30 | "POSTGRES_USER=" + user,
31 | "POSTGRES_DB=" + dbName,
32 | },
33 | PortBindings: map[docker.Port][]docker.PortBinding{
34 | "5432": {
35 | {HostIP: "0.0.0.0", HostPort: port},
36 | },
37 | },
38 | }
39 |
40 | // uses a sensible default on windows (tcp/http) and linux/osx (socket)
41 | pool, err := dockertest.NewPool("")
42 | if err != nil {
43 | return nil, nil, nil, fmt.Errorf("could not create dockertest pool: %w", err)
44 | }
45 |
46 | resource, err := pool.RunWithOptions(opts, func(config *docker.HostConfig) {
47 | config.AutoRemove = true
48 | config.RestartPolicy = docker.RestartPolicy{Name: "no"}
49 | })
50 | if err != nil {
51 | return nil, nil, nil, fmt.Errorf("could not start resource: %w", err)
52 | }
53 |
54 | port = resource.GetPort("5432/tcp")
55 |
56 | // attach terminal logger to container if exists
57 | // for debugging purpose
58 | if logger.Level() == "debug" {
59 | logWaiter, err := pool.Client.AttachToContainerNonBlocking(docker.AttachToContainerOptions{
60 | Container: resource.Container.ID,
61 | OutputStream: logger.Writer(),
62 | ErrorStream: logger.Writer(),
63 | Stderr: true,
64 | Stdout: true,
65 | Stream: true,
66 | })
67 | if err != nil {
68 | logger.Fatal("could not connect to postgres container log output", "error", err)
69 | }
70 | defer func() {
71 | if err = logWaiter.Close(); err != nil {
72 | logger.Fatal("could not close container log", "error", err)
73 | }
74 |
75 | if err = logWaiter.Wait(); err != nil {
76 | logger.Fatal("could not wait for container log to close", "error", err)
77 | }
78 | }()
79 | }
80 |
81 | // Tell docker to hard kill the container in 120 seconds
82 | if err := resource.Expire(120); err != nil {
83 | return nil, nil, nil, err
84 | }
85 |
86 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet
87 | pool.MaxWait = 60 * time.Second
88 |
89 | var repo *repositories.PostgresRepository
90 | time.Sleep(5 * time.Second)
91 | if err := pool.Retry(func() error {
92 | db, err := sql.Open("postgres", fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", host, port, user, password, dbName, sslMode))
93 | if err != nil {
94 | return err
95 | }
96 | repo = repositories.NewPostgresRepository(db)
97 |
98 | return db.Ping()
99 | }); err != nil {
100 | return nil, nil, nil, fmt.Errorf("could not connect to docker: %w", err)
101 | }
102 |
103 | if err := setup(repo); err != nil {
104 | logger.Fatal("failed to setup and migrate DB", "error", err)
105 | }
106 | return repo, pool, resource, nil
107 | }
108 |
109 | func setup(repo *repositories.PostgresRepository) error {
110 | var queries = []string{
111 | "DROP SCHEMA public CASCADE",
112 | "CREATE SCHEMA public",
113 | }
114 | for _, query := range queries {
115 | repo.DB().Exec(query)
116 | }
117 |
118 | if err := repo.Init(context.Background()); err != nil {
119 | return err
120 | }
121 |
122 | return nil
123 | }
124 |
125 | func purgeTestDocker(pool *dockertest.Pool, resource *dockertest.Resource) error {
126 | if err := pool.Purge(resource); err != nil {
127 | return fmt.Errorf("could not purge resource: %w", err)
128 | }
129 | return nil
130 | }
131 |
--------------------------------------------------------------------------------
/auth/audit/repositories/postgres.go:
--------------------------------------------------------------------------------
1 | package repositories
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "encoding/json"
7 | "fmt"
8 | "time"
9 |
10 | "github.com/raystack/salt/auth/audit"
11 |
12 | "github.com/jmoiron/sqlx/types"
13 | )
14 |
15 | type AuditModel struct {
16 | Timestamp time.Time `db:"timestamp"`
17 | Action string `db:"action"`
18 | Actor string `db:"actor"`
19 | Data types.NullJSONText `db:"data"`
20 | Metadata types.NullJSONText `db:"metadata"`
21 | }
22 |
23 | type PostgresRepository struct {
24 | db *sql.DB
25 | }
26 |
27 | func NewPostgresRepository(db *sql.DB) *PostgresRepository {
28 | return &PostgresRepository{db}
29 | }
30 |
31 | func (r *PostgresRepository) DB() *sql.DB {
32 | return r.db
33 | }
34 |
35 | func (r *PostgresRepository) Init(ctx context.Context) error {
36 | sql := `
37 | CREATE TABLE IF NOT EXISTS audit_logs (
38 | timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
39 | action TEXT NOT NULL,
40 | actor TEXT NOT NULL,
41 | data JSONB NOT NULL,
42 | metadata JSONB NOT NULL
43 | );
44 |
45 | CREATE INDEX IF NOT EXISTS audit_logs_timestamp_idx ON audit_logs (timestamp);
46 | CREATE INDEX IF NOT EXISTS audit_logs_action_idx ON audit_logs (action);
47 | CREATE INDEX IF NOT EXISTS audit_logs_actor_idx ON audit_logs (actor);
48 | `
49 | if _, err := r.db.ExecContext(ctx, sql); err != nil {
50 | return fmt.Errorf("migrating audit model to postgres db: %w", err)
51 | }
52 | return nil
53 | }
54 |
55 | func (r *PostgresRepository) Insert(ctx context.Context, l *audit.Log) error {
56 | m := &AuditModel{
57 | Timestamp: l.Timestamp,
58 | Action: l.Action,
59 | Actor: l.Actor,
60 | }
61 |
62 | if l.Data != nil {
63 | data, err := json.Marshal(l.Data)
64 | if err != nil {
65 | return fmt.Errorf("marshalling data: %w", err)
66 | }
67 | m.Data = types.NullJSONText{JSONText: data, Valid: true}
68 | }
69 |
70 | if l.Metadata != nil {
71 | metadata, err := json.Marshal(l.Metadata)
72 | if err != nil {
73 | return fmt.Errorf("marshalling metadata: %w", err)
74 | }
75 | m.Metadata = types.NullJSONText{JSONText: metadata, Valid: true}
76 | }
77 |
78 | if _, err := r.db.ExecContext(ctx, "INSERT INTO audit_logs (timestamp, action, actor, data, metadata) VALUES ($1, $2, $3, $4, $5)", m.Timestamp, m.Action, m.Actor, m.Data, m.Metadata); err != nil {
79 | return fmt.Errorf("inserting to db: %w", err)
80 | }
81 |
82 | return nil
83 | }
84 |
--------------------------------------------------------------------------------
/auth/audit/repositories/postgres_test.go:
--------------------------------------------------------------------------------
1 | package repositories_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | "github.com/raystack/salt/auth/audit"
9 | "github.com/raystack/salt/auth/audit/repositories"
10 |
11 | "github.com/google/go-cmp/cmp"
12 | "github.com/google/go-cmp/cmp/cmpopts"
13 | "github.com/jmoiron/sqlx/types"
14 | "github.com/raystack/salt/log"
15 | "github.com/stretchr/testify/suite"
16 | )
17 |
18 | type PostgresRepositoryTestSuite struct {
19 | suite.Suite
20 |
21 | repository *repositories.PostgresRepository
22 | }
23 |
24 | func TestPostgresRepository(t *testing.T) {
25 | suite.Run(t, new(PostgresRepositoryTestSuite))
26 | }
27 |
28 | func (s *PostgresRepositoryTestSuite) SetupSuite() {
29 | var err error
30 | repository, pool, dockerResource, err := newTestRepository(log.NewLogrus())
31 | if err != nil {
32 | s.T().Fatal(err)
33 | }
34 | s.repository = repository
35 |
36 | s.T().Cleanup(func() {
37 | if err := s.repository.DB().Close(); err != nil {
38 | s.T().Fatal(err)
39 | }
40 | if err := purgeTestDocker(pool, dockerResource); err != nil {
41 | s.T().Fatal(err)
42 | }
43 | })
44 | }
45 |
46 | func (s *PostgresRepositoryTestSuite) TestInsert() {
47 | s.Run("should insert record to db", func() {
48 | l := &audit.Log{
49 | Timestamp: time.Now(),
50 | Action: "test-action",
51 | Actor: "user@example.com",
52 | Data: types.NullJSONText{
53 | JSONText: []byte(`{"test": "data"}`),
54 | Valid: true,
55 | },
56 | Metadata: types.NullJSONText{
57 | JSONText: []byte(`{"test": "metadata"}`),
58 | Valid: true,
59 | },
60 | }
61 |
62 | err := s.repository.Insert(context.Background(), l)
63 | s.Require().NoError(err)
64 |
65 | rows, err := s.repository.DB().Query("SELECT * FROM audit_logs")
66 | var actualResult repositories.AuditModel
67 | for rows.Next() {
68 | err := rows.Scan(&actualResult.Timestamp, &actualResult.Action, &actualResult.Actor, &actualResult.Data, &actualResult.Metadata)
69 | s.Require().NoError(err)
70 | }
71 |
72 | s.NoError(err)
73 | s.NotNil(actualResult)
74 | if diff := cmp.Diff(l.Timestamp, actualResult.Timestamp, cmpopts.EquateApproxTime(time.Microsecond)); diff != "" {
75 | s.T().Errorf("result not match, diff: %v", diff)
76 | }
77 | s.Equal(l.Action, actualResult.Action)
78 | s.Equal(l.Actor, actualResult.Actor)
79 | s.Equal(l.Data, actualResult.Data)
80 | s.Equal(l.Metadata, actualResult.Metadata)
81 | })
82 |
83 | s.Run("should return error if data marshalling returns error", func() {
84 | l := &audit.Log{
85 | Data: make(chan int),
86 | }
87 |
88 | err := s.repository.Insert(context.Background(), l)
89 | s.EqualError(err, "marshalling data: json: unsupported type: chan int")
90 | })
91 |
92 | s.Run("should return error if metadata marshalling returns error", func() {
93 | l := &audit.Log{
94 | Metadata: map[string]interface{}{
95 | "foo": make(chan int),
96 | },
97 | }
98 |
99 | err := s.repository.Insert(context.Background(), l)
100 | s.EqualError(err, "marshalling metadata: json: unsupported type: chan int")
101 | })
102 | }
103 |
--------------------------------------------------------------------------------
/auth/oidc/_example/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "encoding/json"
5 | "github.com/raystack/salt/auth/oidc"
6 | "log"
7 | "os"
8 | "strings"
9 |
10 | "golang.org/x/oauth2"
11 | "golang.org/x/oauth2/google"
12 | )
13 |
14 | func main() {
15 | cfg := &oauth2.Config{
16 | ClientID: os.Getenv("CLIENT_ID"),
17 | ClientSecret: os.Getenv("CLIENT_SECRET"),
18 | Endpoint: google.Endpoint,
19 | RedirectURL: "http://localhost:5454",
20 | Scopes: strings.Split(os.Getenv("OIDC_SCOPES"), ","),
21 | }
22 | aud := os.Getenv("OIDC_AUDIENCE")
23 | keyFile := os.Getenv("GOOGLE_SERVICE_ACCOUNT")
24 |
25 | onTokenOrErr := func(t *oauth2.Token, err error) {
26 | if err != nil {
27 | log.Fatalf("oidc login failed: %v", err)
28 | }
29 |
30 | _ = json.NewEncoder(os.Stdout).Encode(map[string]interface{}{
31 | "token_type": t.TokenType,
32 | "access_token": t.AccessToken,
33 | "expiry": t.Expiry,
34 | "refresh_token": t.RefreshToken,
35 | "id_token": t.Extra("id_token"),
36 | })
37 | }
38 |
39 | _ = oidc.LoginCmd(cfg, aud, keyFile, onTokenOrErr).Execute()
40 | }
41 |
--------------------------------------------------------------------------------
/auth/oidc/cobra.go:
--------------------------------------------------------------------------------
1 | package oidc
2 |
3 | import (
4 | "context"
5 | "os/signal"
6 | "syscall"
7 |
8 | "github.com/spf13/cobra"
9 | "golang.org/x/oauth2"
10 | )
11 |
12 | func LoginCmd(cfg *oauth2.Config, aud, keyFilePath string, onTokenOrErr func(t *oauth2.Token, err error)) *cobra.Command {
13 | cmd := &cobra.Command{
14 | Use: "login",
15 | Short: "Login with your Google account.",
16 | Run: func(cmd *cobra.Command, args []string) {
17 | ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
18 | defer cancel()
19 |
20 | var ts oauth2.TokenSource
21 | if keyFilePath != "" {
22 | var err error
23 | ts, err = NewGoogleServiceAccountTokenSource(ctx, keyFilePath, aud)
24 | if err != nil {
25 | onTokenOrErr(nil, err)
26 | return
27 | }
28 | } else {
29 | ts = NewTokenSource(ctx, cfg, aud)
30 | }
31 | onTokenOrErr(ts.Token())
32 | },
33 | }
34 |
35 | return cmd
36 | }
37 |
--------------------------------------------------------------------------------
/auth/oidc/redirect.html:
--------------------------------------------------------------------------------
1 |
2 |
3 | Login
4 |
19 |
20 |
21 |
22 |
✅ Done
23 |
It is safe to close this window now.
24 |
Go back to the console to continue.
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/auth/oidc/source_gsa.go:
--------------------------------------------------------------------------------
1 | package oidc
2 |
3 | import (
4 | "context"
5 |
6 | "golang.org/x/oauth2"
7 | "google.golang.org/api/idtoken"
8 | )
9 |
10 | func NewGoogleServiceAccountTokenSource(ctx context.Context, keyFile, aud string) (oauth2.TokenSource, error) {
11 | return idtoken.NewTokenSource(ctx, aud, idtoken.WithCredentialsFile(keyFile))
12 | }
13 |
--------------------------------------------------------------------------------
/auth/oidc/source_oidc.go:
--------------------------------------------------------------------------------
1 | package oidc
2 |
3 | import (
4 | "context"
5 | "crypto/sha256"
6 | "errors"
7 | "fmt"
8 |
9 | "golang.org/x/oauth2"
10 | )
11 |
12 | const (
13 | // Values from OpenID Connect.
14 | scopeOpenID = "openid"
15 | audienceKey = "audience"
16 |
17 | // Values used in PKCE implementation.
18 | // Refer https://www.rfc-editor.org/rfc/rfc7636
19 | pkceS256 = "S256"
20 | codeVerifierLen = 32
21 | codeChallengeKey = "code_challenge"
22 | codeVerifierKey = "code_verifier"
23 | codeChallengeMethodKey = "code_challenge_method"
24 | )
25 |
26 | func NewTokenSource(ctx context.Context, conf *oauth2.Config, audience string) oauth2.TokenSource {
27 | conf.Scopes = append(conf.Scopes, scopeOpenID)
28 | return &authHandlerSource{
29 | ctx: ctx,
30 | config: conf,
31 | audience: audience,
32 | }
33 | }
34 |
35 | type authHandlerSource struct {
36 | ctx context.Context
37 | config *oauth2.Config
38 | audience string
39 | }
40 |
41 | func (source *authHandlerSource) Token() (*oauth2.Token, error) {
42 | stateBytes, err := randomBytes(10)
43 | if err != nil {
44 | return nil, err
45 | }
46 | actualState := string(stateBytes)
47 |
48 | codeVerifier, codeChallenge, challengeMethod, err := newPKCEParams()
49 | if err != nil {
50 | return nil, err
51 | }
52 |
53 | // Step 1. Send user to authorization page for obtaining consent.
54 | url := source.config.AuthCodeURL(actualState,
55 | oauth2.SetAuthURLParam(audienceKey, source.audience),
56 | oauth2.SetAuthURLParam(codeChallengeKey, codeChallenge),
57 | oauth2.SetAuthURLParam(codeChallengeMethodKey, challengeMethod),
58 | )
59 |
60 | code, receivedState, err := browserAuthzHandler(source.ctx, source.config.RedirectURL, url)
61 | if err != nil {
62 | return nil, err
63 | } else if receivedState != actualState {
64 | return nil, errors.New("state received in redirection does not match")
65 | }
66 |
67 | // Step 2. Exchange code-grant for tokens (access_token, refresh_token, id_token).
68 | tok, err := source.config.Exchange(source.ctx, code,
69 | oauth2.SetAuthURLParam(audienceKey, source.audience),
70 | oauth2.SetAuthURLParam(codeVerifierKey, codeVerifier),
71 | )
72 | if err != nil {
73 | return nil, err
74 | }
75 |
76 | idToken, ok := tok.Extra("id_token").(string)
77 | if !ok {
78 | return nil, errors.New("id_token not found in token response")
79 | }
80 | tok.AccessToken = idToken
81 |
82 | return tok, nil
83 | }
84 |
85 | // newPKCEParams generates parameters for 'Proof Key for Code Exchange'.
86 | // Refer https://www.rfc-editor.org/rfc/rfc7636#section-4.2
87 | func newPKCEParams() (verifier, challenge, method string, err error) {
88 | // generate 'verifier' string.
89 | verifierBytes, err := randomBytes(codeVerifierLen)
90 | if err != nil {
91 | return "", "", "", fmt.Errorf("failed to generate random bytes: %v", err)
92 | }
93 | verifier = encode(verifierBytes)
94 |
95 | // generate S256 challenge.
96 | h := sha256.New()
97 | h.Write([]byte(verifier))
98 | challenge = encode(h.Sum(nil))
99 |
100 | return verifier, challenge, pkceS256, nil
101 | }
102 |
--------------------------------------------------------------------------------
/auth/oidc/utils.go:
--------------------------------------------------------------------------------
1 | package oidc
2 |
3 | import (
4 | "context"
5 | "crypto/rand"
6 | _ "embed" // for embedded html
7 | "encoding/base64"
8 | "errors"
9 | "fmt"
10 | "io"
11 | "net/http"
12 | "net/url"
13 | "os/exec"
14 | "runtime"
15 | "strings"
16 | )
17 |
18 | const (
19 | codeParam = "code"
20 | stateParam = "state"
21 |
22 | errParam = "error"
23 | errDescParam = "error_description"
24 | )
25 |
26 | //go:embed redirect.html
27 | var callbackResponsePage string
28 |
29 | func encode(msg []byte) string {
30 | encoded := base64.StdEncoding.EncodeToString(msg)
31 | encoded = strings.Replace(encoded, "+", "-", -1)
32 | encoded = strings.Replace(encoded, "/", "_", -1)
33 | encoded = strings.Replace(encoded, "=", "", -1)
34 | return encoded
35 | }
36 |
37 | // https://tools.ietf.org/html/rfc7636#section-4.1)
38 | func randomBytes(length int) ([]byte, error) {
39 | const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
40 | const csLen = byte(len(charset))
41 | output := make([]byte, 0, length)
42 | for {
43 | buf := make([]byte, length)
44 | if _, err := io.ReadFull(rand.Reader, buf); err != nil {
45 | return nil, fmt.Errorf("failed to read random bytes: %v", err)
46 | }
47 | for _, b := range buf {
48 | // Avoid bias by using a value range that's a multiple of 62
49 | if b < (csLen * 4) {
50 | output = append(output, charset[b%csLen])
51 |
52 | if len(output) == length {
53 | return output, nil
54 | }
55 | }
56 | }
57 | }
58 | }
59 |
60 | func browserAuthzHandler(ctx context.Context, redirectURL, authCodeURL string) (code string, state string, err error) {
61 | if err := openURL(authCodeURL); err != nil {
62 | return "", "", err
63 | }
64 |
65 | u, err := url.Parse(redirectURL)
66 | if err != nil {
67 | return "", "", err
68 | }
69 |
70 | code, state, err = waitForCallback(ctx, fmt.Sprintf(":%s", u.Port()))
71 | if err != nil {
72 | return "", "", err
73 | }
74 | return code, state, nil
75 | }
76 |
77 | func waitForCallback(ctx context.Context, addr string) (code, state string, err error) {
78 | var cb struct {
79 | code string
80 | state string
81 | err error
82 | }
83 |
84 | stopCh := make(chan struct{})
85 | srv := &http.Server{
86 | Addr: addr,
87 | Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
88 | cb.code, cb.state, cb.err = parseCallbackRequest(r)
89 |
90 | w.WriteHeader(http.StatusOK)
91 | w.Header().Set("content-type", "text/html")
92 | _, _ = w.Write([]byte(callbackResponsePage))
93 |
94 | // try to flush to ensure the page is shown to user before we close
95 | // the server.
96 | if fl, ok := w.(http.Flusher); ok {
97 | fl.Flush()
98 | }
99 |
100 | close(stopCh)
101 | }),
102 | }
103 |
104 | go func() {
105 | select {
106 | case <-stopCh:
107 | _ = srv.Close()
108 |
109 | case <-ctx.Done():
110 | cb.err = ctx.Err()
111 | _ = srv.Close()
112 | }
113 | }()
114 |
115 | if serveErr := srv.ListenAndServe(); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
116 | return "", "", serveErr
117 | }
118 | return cb.code, cb.state, cb.err
119 | }
120 |
121 | func parseCallbackRequest(r *http.Request) (code string, state string, err error) {
122 | if err = r.ParseForm(); err != nil {
123 | return "", "", err
124 | }
125 |
126 | state = r.Form.Get(stateParam)
127 | if state == "" {
128 | return "", "", errors.New("missing state parameter")
129 | }
130 |
131 | if errorCode := r.Form.Get(errParam); errorCode != "" {
132 | // Got error from provider. Passing through.
133 | return "", "", fmt.Errorf("%s: %s", errorCode, r.Form.Get(errDescParam))
134 | }
135 |
136 | code = r.Form.Get(codeParam)
137 | if code == "" {
138 | return "", "", errors.New("missing code parameter")
139 | }
140 |
141 | return code, state, nil
142 | }
143 |
144 | // openURL opens the specified URL in the default application registered for
145 | // the URL scheme.
146 | func openURL(url string) error {
147 | var cmd string
148 | var args []string
149 |
150 | switch runtime.GOOS {
151 | case "windows":
152 | cmd = "cmd"
153 | args = []string{"/c", "start"}
154 | // If we don't escape &, cmd will ignore everything after the first &.
155 | url = strings.Replace(url, "&", "^&", -1)
156 |
157 | case "darwin":
158 | cmd = "open"
159 |
160 | default: // "linux", "freebsd", "openbsd", "netbsd"
161 | cmd = "xdg-open"
162 | }
163 |
164 | args = append(args, url)
165 | return exec.Command(cmd, args...).Start()
166 | }
167 |
--------------------------------------------------------------------------------
/cli/commander/codex.go:
--------------------------------------------------------------------------------
1 | package commander
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "path/filepath"
7 |
8 | "github.com/spf13/cobra"
9 | "github.com/spf13/cobra/doc"
10 | )
11 |
12 | // addMarkdownCommand integrates a hidden `markdown` command into the root command.
13 | // This command generates a Markdown documentation tree for all commands in the hierarchy.
14 | func (m *Manager) addMarkdownCommand(outputPath string) {
15 | markdownCmd := &cobra.Command{
16 | Use: "markdown",
17 | Short: "Generate Markdown documentation for all commands",
18 | Hidden: true,
19 | Annotations: map[string]string{
20 | "group": "help",
21 | },
22 | RunE: func(cmd *cobra.Command, args []string) error {
23 | return m.generateMarkdownTree(outputPath, m.RootCmd)
24 | },
25 | }
26 |
27 | m.RootCmd.AddCommand(markdownCmd)
28 | }
29 |
30 | // generateMarkdownTree generates a Markdown documentation tree for the given command hierarchy.
31 | //
32 | // Parameters:
33 | // - rootOutputPath: The root directory where the Markdown files will be generated.
34 | // - cmd: The root Cobra command whose hierarchy will be documented.
35 | //
36 | // Returns:
37 | // - An error if any part of the process (file creation, directory creation) fails.
38 | func (m *Manager) generateMarkdownTree(rootOutputPath string, cmd *cobra.Command) error {
39 | dirFilePath := filepath.Join(rootOutputPath, cmd.Name())
40 |
41 | // Handle subcommands by creating a directory and iterating through subcommands.
42 | if len(cmd.Commands()) > 0 {
43 | if err := ensureDir(dirFilePath); err != nil {
44 | return fmt.Errorf("failed to create directory for command %q: %w", cmd.Name(), err)
45 | }
46 | for _, subCmd := range cmd.Commands() {
47 | if err := m.generateMarkdownTree(dirFilePath, subCmd); err != nil {
48 | return err
49 | }
50 | }
51 | } else {
52 | // Generate a Markdown file for leaf commands.
53 | outFilePath := filepath.Join(rootOutputPath, cmd.Name()+".md")
54 |
55 | f, err := os.Create(outFilePath)
56 | if err != nil {
57 | return err
58 | }
59 | defer f.Close()
60 |
61 | // Generate Markdown with a custom link handler.
62 | return doc.GenMarkdownCustom(cmd, f, func(s string) string {
63 | return filepath.Join(dirFilePath, s)
64 | })
65 | }
66 |
67 | return nil
68 | }
69 |
70 | // ensureDir ensures that the given directory exists, creating it if necessary.
71 | func ensureDir(path string) error {
72 | if _, err := os.Stat(path); os.IsNotExist(err) {
73 | if err := os.MkdirAll(path, os.ModePerm); err != nil {
74 | return err
75 | }
76 | }
77 | return nil
78 | }
79 |
--------------------------------------------------------------------------------
/cli/commander/completion.go:
--------------------------------------------------------------------------------
1 | package commander
2 |
3 | import (
4 | "os"
5 |
6 | "github.com/MakeNowJust/heredoc"
7 | "github.com/spf13/cobra"
8 | )
9 |
10 | // addCompletionCommand adds a `completion` command to the CLI.
11 | // The `completion` command generates shell completion scripts
12 | // for Bash, Zsh, Fish, and PowerShell.
13 | // Usage:
14 | //
15 | // $ mycli completion bash
16 | // $ mycli completion zsh
17 | func (m *Manager) addCompletionCommand() {
18 | summary := m.generateCompletionSummary(m.RootCmd.Use)
19 |
20 | completionCmd := &cobra.Command{
21 | Use: "completion [bash|zsh|fish|powershell]",
22 | Short: "Generate shell completion scripts",
23 | Long: summary,
24 | DisableFlagsInUseLine: true,
25 | ValidArgs: []string{"bash", "zsh", "fish", "powershell"},
26 | Args: cobra.ExactValidArgs(1),
27 | Run: func(cmd *cobra.Command, args []string) {
28 | switch args[0] {
29 | case "bash":
30 | cmd.Root().GenBashCompletion(os.Stdout)
31 | case "zsh":
32 | cmd.Root().GenZshCompletion(os.Stdout)
33 | case "fish":
34 | cmd.Root().GenFishCompletion(os.Stdout, true)
35 | case "powershell":
36 | cmd.Root().GenPowerShellCompletionWithDesc(os.Stdout)
37 | }
38 | },
39 | }
40 |
41 | m.RootCmd.AddCommand(completionCmd)
42 | }
43 |
44 | // generateCompletionSummary creates the long description for the `completion` command.
45 | func (m *Manager) generateCompletionSummary(exec string) string {
46 | var execs []interface{}
47 | for i := 0; i < 12; i++ {
48 | execs = append(execs, exec)
49 | }
50 | return heredoc.Docf(`To load completions:
51 | `+"```"+`
52 | Bash:
53 |
54 | $ source <(%s completion bash)
55 |
56 | # To load completions for each session, execute once:
57 | # Linux:
58 | $ %s completion bash > /etc/bash_completion.d/%s
59 | # macOS:
60 | $ %s completion bash > /usr/local/etc/bash_completion.d/%s
61 |
62 | Zsh:
63 |
64 | # If shell completion is not already enabled in your environment,
65 | # you will need to enable it. You can execute the following once:
66 |
67 | $ echo "autoload -U compinit; compinit" >> ~/.zshrc
68 |
69 | # To load completions for each session, execute once:
70 | $ %s completion zsh > "${fpath[1]}/_yourprogram"
71 |
72 | # You will need to start a new shell for this setup to take effect.
73 |
74 | Fish:
75 |
76 | $ %s completion fish | source
77 |
78 | # To load completions for each session, execute once:
79 | $ %s completion fish > ~/.config/fish/completions/%s.fish
80 |
81 | PowerShell:
82 |
83 | PS> %s completion powershell | Out-String | Invoke-Expression
84 |
85 | # To load completions for every new session, run:
86 | PS> %s completion powershell > %s.ps1
87 | # and source this file from your PowerShell profile.
88 | `+"```"+`
89 | `, execs...)
90 | }
91 |
--------------------------------------------------------------------------------
/cli/commander/hooks.go:
--------------------------------------------------------------------------------
1 | package commander
2 |
3 | // addClientHooks applies all configured hooks to commands annotated with `client:true`.
4 | func (m *Manager) addClientHooks() {
5 | for _, cmd := range m.RootCmd.Commands() {
6 | for _, hook := range m.Hooks {
7 | if cmd.Annotations["client"] == "true" {
8 | hook.Behavior(cmd)
9 | }
10 | }
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/cli/commander/manager.go:
--------------------------------------------------------------------------------
1 | package commander
2 |
3 | import (
4 | "strings"
5 |
6 | "github.com/spf13/cobra"
7 | )
8 |
9 | // Manager manages and configures features for a CLI tool.
10 | type Manager struct {
11 | RootCmd *cobra.Command
12 | Help bool // Enable custom help.
13 | Reference bool // Enable reference command.
14 | Completion bool // Enable shell completion.
15 | Config bool // Enable configuration management.
16 | Docs bool // Enable markdown documentation
17 | Hooks []HookBehavior // Hook behaviors to apply to commands
18 | Topics []HelpTopic // Help topics with their details.
19 | }
20 |
21 | // HelpTopic defines a single help topic with its details.
22 | type HelpTopic struct {
23 | Name string
24 | Short string
25 | Long string
26 | Example string
27 | }
28 |
29 | // HookBehavior defines a specific behavior applied to commands.
30 | type HookBehavior struct {
31 | Name string // Name of the hook (e.g., "setup", "auth").
32 | Behavior func(cmd *cobra.Command) // Function to apply to commands.
33 | }
34 |
35 | // New creates a new CLI Manager using the provided root command and optional configurations.
36 | //
37 | // Parameters:
38 | // - rootCmd: The root Cobra command for the CLI.
39 | // - options: Functional options for configuring the Manager.
40 | //
41 | // Example:
42 | //
43 | // rootCmd := &cobra.Command{Use: "mycli"}
44 | // manager := cmdx.NewCommander(rootCmd, cmdx.WithTopics(...), cmdx.WithHooks(...))
45 | func New(rootCmd *cobra.Command, options ...func(*Manager)) *Manager {
46 | // Create Manager with defaults
47 | manager := &Manager{
48 | RootCmd: rootCmd,
49 | Help: true, // Default enabled
50 | Reference: true, // Default enabled
51 | Completion: true, // Default enabled
52 | Docs: false, // Default disabled
53 | Topics: []HelpTopic{},
54 | Hooks: []HookBehavior{},
55 | }
56 |
57 | // Apply functional options
58 | for _, opt := range options {
59 | opt(manager)
60 | }
61 |
62 | return manager
63 | }
64 |
65 | // Init sets up the CLI features based on the Manager's configuration.
66 | // It enables or disables features like custom help, reference documentation,
67 | // shell completion, help topics, and client hooks based on the Manager's settings.
68 | func (m *Manager) Init() {
69 | if m.Help {
70 | m.setCustomHelp()
71 | }
72 | if m.Reference {
73 | m.addReferenceCommand()
74 | }
75 | if m.Completion {
76 | m.addCompletionCommand()
77 | }
78 | if m.Docs {
79 | m.addMarkdownCommand("./docs")
80 | }
81 | if len(m.Topics) > 0 {
82 | m.addHelpTopics()
83 | }
84 |
85 | if len(m.Hooks) > 0 {
86 | m.addClientHooks()
87 | }
88 | }
89 |
90 | // WithTopics sets the help topics for the Manager.
91 | func WithTopics(topics []HelpTopic) func(*Manager) {
92 | return func(m *Manager) {
93 | m.Topics = topics
94 | }
95 | }
96 |
97 | // WithHooks sets the hook behaviors for the Manager.
98 | func WithHooks(hooks []HookBehavior) func(*Manager) {
99 | return func(m *Manager) {
100 | m.Hooks = hooks
101 | }
102 | }
103 |
104 | // IsCommandErr checks if the given error is related to a Cobra command error.
105 | // This is useful for distinguishing between user errors (e.g., incorrect commands or flags)
106 | // and program errors, allowing the application to display appropriate messages.
107 | func IsCommandErr(err error) bool {
108 | if err == nil {
109 | return false
110 | }
111 |
112 | // Known Cobra command error keywords
113 | cmdErrorKeywords := []string{
114 | "unknown command",
115 | "unknown flag",
116 | "unknown shorthand flag",
117 | }
118 |
119 | errMessage := err.Error()
120 | for _, keyword := range cmdErrorKeywords {
121 | if strings.Contains(errMessage, keyword) {
122 | return true
123 | }
124 | }
125 | return false
126 | }
127 |
--------------------------------------------------------------------------------
/cli/commander/reference.go:
--------------------------------------------------------------------------------
1 | package commander
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io"
7 | "strings"
8 |
9 | "github.com/raystack/salt/cli/printer"
10 |
11 | "github.com/spf13/cobra"
12 | )
13 |
14 | // addReferenceCommand adds a `reference` command to the CLI.
15 | // The `reference` command generates markdown documentation for all commands
16 | // in the CLI command tree.
17 | func (m *Manager) addReferenceCommand() {
18 | var isPlain bool
19 | refCmd := &cobra.Command{
20 | Use: "reference",
21 | Short: "Comprehensive reference of all commands",
22 | Long: m.generateReferenceMarkdown(),
23 | Run: m.runReferenceCommand(&isPlain),
24 | Annotations: map[string]string{
25 | "group": "help",
26 | },
27 | }
28 | refCmd.SetHelpFunc(m.runReferenceCommand(&isPlain))
29 | refCmd.Flags().BoolVarP(&isPlain, "plain", "p", true, "output in plain markdown (without ANSI color)")
30 |
31 | m.RootCmd.AddCommand(refCmd)
32 | }
33 |
34 | // runReferenceCommand handles the output generation for the `reference` command.
35 | // It renders the documentation either as plain markdown or with ANSI color.
36 | func (m *Manager) runReferenceCommand(isPlain *bool) func(cmd *cobra.Command, args []string) {
37 | return func(cmd *cobra.Command, args []string) {
38 | var (
39 | output string
40 | err error
41 | )
42 |
43 | if *isPlain {
44 | output = cmd.Long
45 | } else {
46 | output, err = printer.Markdown(cmd.Long)
47 | if err != nil {
48 | fmt.Println("Error generating markdown:", err)
49 | return
50 | }
51 | }
52 |
53 | fmt.Print(output)
54 | }
55 | }
56 |
57 | // generateReferenceMarkdown generates a complete markdown representation
58 | // of the command tree for the `reference` command.
59 | func (m *Manager) generateReferenceMarkdown() string {
60 | buf := bytes.NewBufferString(fmt.Sprintf("# %s reference\n\n", m.RootCmd.Name()))
61 | for _, c := range m.RootCmd.Commands() {
62 | if c.Hidden {
63 | continue
64 | }
65 | m.generateCommandReference(buf, c, 2)
66 | }
67 | return buf.String()
68 | }
69 |
70 | // generateCommandReference recursively generates markdown for a given command
71 | // and its subcommands.
72 | func (m *Manager) generateCommandReference(w io.Writer, cmd *cobra.Command, depth int) {
73 | // Name + Description
74 | fmt.Fprintf(w, "%s `%s`\n\n", strings.Repeat("#", depth), cmd.UseLine())
75 | fmt.Fprintf(w, "%s\n\n", cmd.Short)
76 |
77 | // Flags
78 | if flagUsages := cmd.Flags().FlagUsages(); flagUsages != "" {
79 | fmt.Fprintf(w, "```\n%s```\n\n", dedent(flagUsages))
80 | }
81 |
82 | // Subcommands
83 | for _, c := range cmd.Commands() {
84 | if c.Hidden {
85 | continue
86 | }
87 | m.generateCommandReference(w, c, depth+1)
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/cli/commander/topics.go:
--------------------------------------------------------------------------------
1 | package commander
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/spf13/cobra"
7 | )
8 |
9 | // addHelpTopics adds all configured help topics to the CLI.
10 | //
11 | // Help topics provide detailed information about specific subjects,
12 | // such as environment variables or configuration.
13 | func (m *Manager) addHelpTopics() {
14 | for _, topic := range m.Topics {
15 | m.addHelpTopicCommand(topic)
16 | }
17 | }
18 |
19 | // addHelpTopicCommand adds a single help topic command to the CLI.
20 | func (m *Manager) addHelpTopicCommand(topic HelpTopic) {
21 | helpCmd := &cobra.Command{
22 | Use: topic.Name,
23 | Short: topic.Short,
24 | Long: topic.Long,
25 | Example: topic.Example,
26 | Hidden: false,
27 | Annotations: map[string]string{
28 | "group": "help",
29 | },
30 | }
31 |
32 | helpCmd.SetHelpFunc(helpTopicHelpFunc)
33 | helpCmd.SetUsageFunc(helpTopicUsageFunc)
34 |
35 | m.RootCmd.AddCommand(helpCmd)
36 | }
37 |
38 | // helpTopicHelpFunc customizes the help message for a help topic command.
39 | func helpTopicHelpFunc(cmd *cobra.Command, args []string) {
40 | fmt.Fprintln(cmd.OutOrStdout(), cmd.Long)
41 | if cmd.Example != "" {
42 | fmt.Fprintln(cmd.OutOrStdout(), "\nEXAMPLES")
43 | fmt.Fprintln(cmd.OutOrStdout(), indent(cmd.Example, " "))
44 | }
45 | }
46 |
47 | // helpTopicUsageFunc customizes the usage message for a help topic command.
48 | func helpTopicUsageFunc(cmd *cobra.Command) error {
49 | fmt.Fprintf(cmd.OutOrStdout(), "Usage: %s help %s\n", cmd.Root().Name(), cmd.Use)
50 | return nil
51 | }
52 |
--------------------------------------------------------------------------------
/cli/printer/colors.go:
--------------------------------------------------------------------------------
1 | package printer
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/muesli/termenv"
7 | )
8 |
9 | var tp = termenv.EnvColorProfile()
10 |
11 | // Theme defines a collection of colors for terminal outputs.
12 | type Theme struct {
13 | Green termenv.Color
14 | Yellow termenv.Color
15 | Cyan termenv.Color
16 | Red termenv.Color
17 | Grey termenv.Color
18 | Blue termenv.Color
19 | Magenta termenv.Color
20 | }
21 |
22 | var themes = map[string]Theme{
23 | "light": {
24 | Green: tp.Color("#005F00"),
25 | Yellow: tp.Color("#FFAF00"),
26 | Cyan: tp.Color("#0087FF"),
27 | Red: tp.Color("#D70000"),
28 | Grey: tp.Color("#303030"),
29 | Blue: tp.Color("#000087"),
30 | Magenta: tp.Color("#AF00FF"),
31 | },
32 | "dark": {
33 | Green: tp.Color("#A8CC8C"),
34 | Yellow: tp.Color("#DBAB79"),
35 | Cyan: tp.Color("#66C2CD"),
36 | Red: tp.Color("#E88388"),
37 | Grey: tp.Color("#B9BFCA"),
38 | Blue: tp.Color("#71BEF2"),
39 | Magenta: tp.Color("#D290E4"),
40 | },
41 | }
42 |
43 | // NewTheme initializes a Theme based on the terminal background (light or dark).
44 | func NewTheme() Theme {
45 | if !termenv.HasDarkBackground() {
46 | return themes["light"]
47 | }
48 | return themes["dark"]
49 | }
50 |
51 | var theme = NewTheme()
52 |
53 | // formatColorize applies the given color to the formatted text.
54 | func formatColorize(color termenv.Color, t string, args ...interface{}) string {
55 | return colorize(color, fmt.Sprintf(t, args...))
56 | }
57 |
58 | func Green(t ...string) string {
59 | return colorize(theme.Green, t...)
60 | }
61 |
62 | func Greenf(t string, args ...interface{}) string {
63 | return formatColorize(theme.Green, t, args...)
64 | }
65 |
66 | func Yellow(t ...string) string {
67 | return colorize(theme.Yellow, t...)
68 | }
69 |
70 | func Yellowf(t string, args ...interface{}) string {
71 | return formatColorize(theme.Yellow, t, args...)
72 | }
73 |
74 | func Cyan(t ...string) string {
75 | return colorize(theme.Cyan, t...)
76 | }
77 |
78 | func Cyanf(t string, args ...interface{}) string {
79 | return formatColorize(theme.Cyan, t, args...)
80 | }
81 |
82 | func Red(t ...string) string {
83 | return colorize(theme.Red, t...)
84 | }
85 |
86 | func Redf(t string, args ...interface{}) string {
87 | return formatColorize(theme.Red, t, args...)
88 | }
89 |
90 | func Grey(t ...string) string {
91 | return colorize(theme.Grey, t...)
92 | }
93 |
94 | func Greyf(t string, args ...interface{}) string {
95 | return formatColorize(theme.Grey, t, args...)
96 | }
97 |
98 | func Blue(t ...string) string {
99 | return colorize(theme.Blue, t...)
100 | }
101 |
102 | func Bluef(t string, args ...interface{}) string {
103 | return formatColorize(theme.Blue, t, args...)
104 | }
105 |
106 | func Magenta(t ...string) string {
107 | return colorize(theme.Magenta, t...)
108 | }
109 |
110 | func Magentaf(t string, args ...interface{}) string {
111 | return formatColorize(theme.Magenta, t, args...)
112 | }
113 |
114 | func Icon(name string) string {
115 | icons := map[string]string{"failure": "✘", "success": "✔", "info": "ℹ", "warning": "⚠"}
116 | if icon, exists := icons[name]; exists {
117 | return icon
118 | }
119 | return ""
120 | }
121 |
122 | // colorize applies the given color to the text.
123 | func colorize(color termenv.Color, t ...string) string {
124 | return termenv.String(t...).Foreground(color).String()
125 | }
126 |
--------------------------------------------------------------------------------
/cli/printer/markdown.go:
--------------------------------------------------------------------------------
1 | package printer
2 |
3 | import (
4 | "strings"
5 |
6 | "github.com/charmbracelet/glamour"
7 | )
8 |
9 | // RenderOpts is a type alias for a slice of glamour.TermRendererOption,
10 | // representing the rendering options for the markdown renderer.
11 | type RenderOpts []glamour.TermRendererOption
12 |
13 | // This ensures the rendered markdown has no extra indentation or margins, providing a compact view.
14 | func withoutIndentation() glamour.TermRendererOption {
15 | overrides := []byte(`
16 | {
17 | "document": {
18 | "margin": 0
19 | },
20 | "code_block": {
21 | "margin": 0
22 | }
23 | }`)
24 |
25 | return glamour.WithStylesFromJSONBytes(overrides)
26 | }
27 |
28 | // withoutWrap ensures the rendered markdown does not wrap lines, useful for wide terminals.
29 | func withoutWrap() glamour.TermRendererOption {
30 | return glamour.WithWordWrap(0)
31 | }
32 |
33 | // render applies the given rendering options to the provided markdown text.
34 | func render(text string, opts RenderOpts) (string, error) {
35 | // Ensure input text uses consistent line endings.
36 | text = strings.ReplaceAll(text, "\r\n", "\n")
37 |
38 | tr, err := glamour.NewTermRenderer(opts...)
39 | if err != nil {
40 | return "", err
41 | }
42 |
43 | return tr.Render(text)
44 | }
45 |
46 | // Markdown renders the given markdown text with default options.
47 | func Markdown(text string) (string, error) {
48 | opts := RenderOpts{
49 | glamour.WithAutoStyle(), // Automatically determine styling based on terminal settings.
50 | glamour.WithEmoji(), // Enable emoji rendering.
51 | withoutIndentation(), // Disable indentation for a compact view.
52 | withoutWrap(), // Disable word wrapping.
53 | }
54 |
55 | return render(text, opts)
56 | }
57 |
58 | // MarkdownWithWrap renders the given markdown text with a specified word wrapping width.
59 | func MarkdownWithWrap(text string, wrap int) (string, error) {
60 | opts := RenderOpts{
61 | glamour.WithAutoStyle(), // Automatically determine styling based on terminal settings.
62 | glamour.WithEmoji(), // Enable emoji rendering.
63 | glamour.WithWordWrap(wrap), // Enable word wrapping with the specified width.
64 | withoutIndentation(), // Disable indentation for a compact view.
65 | }
66 |
67 | return render(text, opts)
68 | }
69 |
--------------------------------------------------------------------------------
/cli/printer/progress.go:
--------------------------------------------------------------------------------
1 | package printer
2 |
3 | import (
4 | "github.com/schollz/progressbar/v3"
5 | )
6 |
7 | // Progress creates and returns a progress bar for tracking the progress of an operation.
8 | //
9 | // Parameters:
10 | // - max: The maximum value of the progress bar, indicating 100% completion.
11 | // - description: A brief description of the task associated with the progress bar.
12 |
13 | // Example Usage:
14 | //
15 | // bar := printer.Progress(100, "Downloading files")
16 | // for i := 0; i < 100; i++ {
17 | // bar.Add(1) // Increment progress by 1.
18 | // }
19 | func Progress(max int, description string) *progressbar.ProgressBar {
20 | bar := progressbar.NewOptions(
21 | max,
22 | progressbar.OptionEnableColorCodes(true),
23 | progressbar.OptionSetDescription(description),
24 | progressbar.OptionShowCount(),
25 | )
26 | return bar
27 | }
28 |
--------------------------------------------------------------------------------
/cli/printer/spinner.go:
--------------------------------------------------------------------------------
1 | package printer
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/briandowns/spinner"
7 | "github.com/raystack/salt/cli/terminator"
8 | )
9 |
10 | // Indicator represents a terminal spinner used for indicating progress or ongoing operations.
11 | type Indicator struct {
12 | spinner *spinner.Spinner // The spinner instance.
13 | }
14 |
15 | // Stop halts the spinner animation.
16 | //
17 | // This method ensures the spinner is stopped gracefully. If the spinner is nil (e.g., when the
18 | // terminal does not support TTY), the method does nothing.
19 | //
20 | // Example Usage:
21 | //
22 | // indicator := printer.Spin("Loading")
23 | // // Perform some operation...
24 | // indicator.Stop()
25 | func (s *Indicator) Stop() {
26 | if s.spinner == nil {
27 | return
28 | }
29 | s.spinner.Stop()
30 | }
31 |
32 | // Spin creates and starts a terminal spinner to indicate an ongoing operation.
33 | //
34 | // The spinner uses a predefined character set and updates at a fixed interval. It automatically
35 | // disables itself if the terminal does not support TTY.
36 | //
37 | // Parameters:
38 | // - label: A string to prefix the spinner (e.g., "Loading").
39 | //
40 | // Returns:
41 | // - An *Indicator instance that manages the spinner lifecycle.
42 | //
43 | // Example Usage:
44 | //
45 | // indicator := printer.Spin("Processing data")
46 | // // Perform some long-running operation...
47 | // indicator.Stop()
48 | func Spin(label string) *Indicator {
49 | // Predefined spinner character set (dots style).
50 | set := spinner.CharSets[11]
51 |
52 | // Check if the terminal supports TTY; if not, return a no-op Indicator.
53 | if !terminator.IsTTY() {
54 | return &Indicator{}
55 | }
56 |
57 | // Create a new spinner instance with a 120ms update interval and cyan color.
58 | s := spinner.New(set, 120*time.Millisecond, spinner.WithColor("fgCyan"))
59 |
60 | // Add a label prefix if provided.
61 | if label != "" {
62 | s.Prefix = label + " "
63 | }
64 |
65 | // Start the spinner animation.
66 | s.Start()
67 |
68 | // Return the Indicator wrapping the spinner instance.
69 | return &Indicator{s}
70 | }
71 |
--------------------------------------------------------------------------------
/cli/printer/structured.go:
--------------------------------------------------------------------------------
1 | package printer
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 |
7 | "gopkg.in/yaml.v3"
8 | )
9 |
10 | // YAML prints the given data in YAML format.
11 | func YAML(data interface{}) error {
12 | return File(data, "yaml")
13 | }
14 |
15 | // JSON prints the given data in JSON format.
16 | func JSON(data interface{}) error {
17 | return File(data, "json")
18 | }
19 |
20 | // PrettyJSON prints the given data in pretty-printed JSON format.
21 | func PrettyJSON(data interface{}) error {
22 | return File(data, "prettyjson")
23 | }
24 |
25 | // File marshals and prints the given data in the specified format.
26 | func File(data interface{}, format string) (err error) {
27 | var output []byte
28 | switch format {
29 | case "yaml":
30 | output, err = yaml.Marshal(data)
31 | case "json":
32 | output, err = json.Marshal(data)
33 | case "prettyjson":
34 | output, err = json.MarshalIndent(data, "", "\t")
35 | default:
36 | return fmt.Errorf("unknown format: %v", format)
37 | }
38 |
39 | if err != nil {
40 | return err
41 | }
42 |
43 | fmt.Println(string(output))
44 | return nil
45 | }
46 |
--------------------------------------------------------------------------------
/cli/printer/table.go:
--------------------------------------------------------------------------------
1 | package printer
2 |
3 | import (
4 | "io"
5 |
6 | "github.com/olekukonko/tablewriter"
7 | )
8 |
9 | // Table renders a terminal-friendly table to the provided writer.
10 | //
11 | // Create a table with customized formatting and styles,
12 | // suitable for displaying data in CLI applications.
13 | //
14 | // Parameters:
15 | // - target: The `io.Writer` where the table will be written (e.g., os.Stdout).
16 | // - rows: A 2D slice of strings representing the table rows and columns.
17 | // Each inner slice represents a single row, with its elements as column values.
18 | //
19 | // Example Usage:
20 | //
21 | // rows := [][]string{
22 | // {"ID", "Name", "Age"},
23 | // {"1", "Alice", "30"},
24 | // {"2", "Bob", "25"},
25 | // }
26 | // printer.Table(os.Stdout, rows)
27 | //
28 | // Behavior:
29 | // - Disables text wrapping for better terminal rendering.
30 | // - Aligns headers and rows to the left.
31 | // - Removes borders and separators for a clean look.
32 | // - Formats the table using tab padding for better alignment in terminals.
33 | func Table(target io.Writer, rows [][]string) {
34 | table := tablewriter.NewWriter(target)
35 | table.SetAutoWrapText(false)
36 | table.SetAutoFormatHeaders(true)
37 | table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
38 | table.SetAlignment(tablewriter.ALIGN_LEFT)
39 | table.SetCenterSeparator("")
40 | table.SetColumnSeparator("")
41 | table.SetRowSeparator("")
42 | table.SetHeaderLine(false)
43 | table.SetBorder(false)
44 | table.SetTablePadding("\t")
45 | table.SetNoWhiteSpace(true)
46 | table.AppendBulk(rows)
47 | table.Render()
48 | }
49 |
--------------------------------------------------------------------------------
/cli/printer/text.go:
--------------------------------------------------------------------------------
1 | package printer
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/muesli/termenv"
7 | )
8 |
9 | // Success prints the given message(s) in green to indicate success.
10 | func Success(t ...string) {
11 | printWithColor(Green, t...)
12 | }
13 |
14 | // Successln prints the given message(s) in green with a newline.
15 | func Successln(t ...string) {
16 | printWithColorln(Green, t...)
17 | }
18 |
19 | // Successf formats and prints the success message in green.
20 | func Successf(t string, args ...interface{}) {
21 | printWithColorf(Greenf, t, args...)
22 | }
23 |
24 | // Warning prints the given message(s) in yellow to indicate a warning.
25 | func Warning(t ...string) {
26 | printWithColor(Yellow, t...)
27 | }
28 |
29 | // Warningln prints the given message(s) in yellow with a newline.
30 | func Warningln(t ...string) {
31 | printWithColorln(Yellow, t...)
32 | }
33 |
34 | // Warningf formats and prints the warning message in yellow.
35 | func Warningf(t string, args ...interface{}) {
36 | printWithColorf(Yellowf, t, args...)
37 | }
38 |
39 | // Error prints the given message(s) in red to indicate an error.
40 | func Error(t ...string) {
41 | printWithColor(Red, t...)
42 | }
43 |
44 | // Errorln prints the given message(s) in red with a newline.
45 | func Errorln(t ...string) {
46 | printWithColorln(Red, t...)
47 | }
48 |
49 | // Errorf formats and prints the error message in red.
50 | func Errorf(t string, args ...interface{}) {
51 | printWithColorf(Redf, t, args...)
52 | }
53 |
54 | // Info prints the given message(s) in cyan to indicate informational messages.
55 | func Info(t ...string) {
56 | printWithColor(Cyan, t...)
57 | }
58 |
59 | // Infoln prints the given message(s) in cyan with a newline.
60 | func Infoln(t ...string) {
61 | printWithColorln(Cyan, t...)
62 | }
63 |
64 | // Infof formats and prints the informational message in cyan.
65 | func Infof(t string, args ...interface{}) {
66 | printWithColorf(Cyanf, t, args...)
67 | }
68 |
69 | // Bold prints the given message(s) in bold style.
70 | func Bold(t ...string) string {
71 | return termenv.String(t...).Bold().String()
72 | }
73 |
74 | // Boldln prints the given message(s) in bold style with a newline.
75 | func Boldln(t ...string) {
76 | fmt.Println(Bold(t...))
77 | }
78 |
79 | // Boldf formats and prints the message in bold style.
80 | func Boldf(t string, args ...interface{}) string {
81 | return Bold(fmt.Sprintf(t, args...))
82 | }
83 |
84 | // Italic prints the given message(s) in italic style.
85 | func Italic(t ...string) string {
86 | return termenv.String(t...).Italic().String()
87 | }
88 |
89 | // Italicln prints the given message(s) in italic style with a newline.
90 | func Italicln(t ...string) {
91 | fmt.Println(Italic(t...))
92 | }
93 |
94 | // Italicf formats and prints the message in italic style.
95 | func Italicf(t string, args ...interface{}) string {
96 | return Italic(fmt.Sprintf(t, args...))
97 | }
98 |
99 | // Space prints a single space to the output.
100 | func Space() {
101 | fmt.Print(" ")
102 | }
103 |
104 | // printWithColor prints the given message(s) with the specified color function.
105 | func printWithColor(colorFunc func(...string) string, t ...string) {
106 | fmt.Print(colorFunc(t...))
107 | }
108 |
109 | // printWithColorln prints the given message(s) with the specified color function and a newline.
110 | func printWithColorln(colorFunc func(...string) string, t ...string) {
111 | fmt.Println(colorFunc(t...))
112 | }
113 |
114 | // printWithColorf formats and prints the message with the specified color function.
115 | func printWithColorf(colorFunc func(string, ...interface{}) string, t string, args ...interface{}) {
116 | fmt.Print(colorFunc(t, args...))
117 | }
118 |
--------------------------------------------------------------------------------
/cli/prompter/prompt.go:
--------------------------------------------------------------------------------
1 | package prompter
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/AlecAivazis/survey/v2"
7 | )
8 |
9 | // Prompter defines an interface for user input interactions.
10 | type Prompter interface {
11 | Select(message, defaultValue string, options []string) (int, error)
12 | MultiSelect(message, defaultValue string, options []string) ([]int, error)
13 | Input(message, defaultValue string) (string, error)
14 | Confirm(message string, defaultValue bool) (bool, error)
15 | }
16 |
17 | // New creates and returns a new Prompter instance.
18 | func New() Prompter {
19 | return &surveyPrompter{}
20 | }
21 |
22 | type surveyPrompter struct {
23 | }
24 |
25 | // ask is a helper function to prompt the user and capture the response.
26 | func (p *surveyPrompter) ask(q survey.Prompt, response interface{}) error {
27 | err := survey.AskOne(q, response)
28 | if err != nil {
29 | return fmt.Errorf("prompt error: %w", err)
30 | }
31 | return nil
32 | }
33 |
34 | // Select prompts the user to select one option from a list.
35 | //
36 | // Parameters:
37 | // - message: The prompt message to display.
38 | // - defaultValue: The default selected value.
39 | // - options: The list of options to display.
40 | //
41 | // Returns:
42 | // - The index of the selected option.
43 | // - An error, if any.
44 | func (p *surveyPrompter) Select(message, defaultValue string, options []string) (int, error) {
45 | var result int
46 | err := p.ask(&survey.Select{
47 | Message: message,
48 | Default: defaultValue,
49 | Options: options,
50 | PageSize: 20,
51 | }, &result)
52 | return result, err
53 | }
54 |
55 | // MultiSelect prompts the user to select multiple options from a list.
56 | //
57 | // Parameters:
58 | // - message: The prompt message to display.
59 | // - defaultValue: The default selected values.
60 | // - options: The list of options to display.
61 | //
62 | // Returns:
63 | // - A slice of indices representing the selected options.
64 | // - An error, if any.
65 | func (p *surveyPrompter) MultiSelect(message, defaultValue string, options []string) ([]int, error) {
66 | var result []int
67 | err := p.ask(&survey.MultiSelect{
68 | Message: message,
69 | Default: defaultValue,
70 | Options: options,
71 | PageSize: 20,
72 | }, &result)
73 | return result, err
74 | }
75 |
76 | // Input prompts the user for a text input.
77 | //
78 | // Parameters:
79 | // - message: The prompt message to display.
80 | // - defaultValue: The default input value.
81 | //
82 | // Returns:
83 | // - The user's input as a string.
84 | // - An error, if any.
85 | func (p *surveyPrompter) Input(message, defaultValue string) (string, error) {
86 | var result string
87 | err := p.ask(&survey.Input{
88 | Message: message,
89 | Default: defaultValue,
90 | }, &result)
91 | return result, err
92 | }
93 |
94 | // Confirm prompts the user for a yes/no confirmation.
95 | //
96 | // Parameters:
97 | // - message: The prompt message to display.
98 | // - defaultValue: The default confirmation value.
99 | //
100 | // Returns:
101 | // - A boolean indicating the user's choice.
102 | // - An error, if any.
103 | func (p *surveyPrompter) Confirm(message string, defaultValue bool) (bool, error) {
104 | var result bool
105 | err := p.ask(&survey.Confirm{
106 | Message: message,
107 | Default: defaultValue,
108 | }, &result)
109 | return result, err
110 | }
111 |
--------------------------------------------------------------------------------
/cli/releaser/release.go:
--------------------------------------------------------------------------------
1 | package releaser
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "io"
7 | "net/http"
8 | "time"
9 |
10 | "github.com/hashicorp/go-version"
11 | "github.com/pkg/errors"
12 | )
13 |
14 | var (
15 | // Timeout sets the HTTP client timeout for fetching release info.
16 | Timeout = time.Second * 1
17 |
18 | // APIFormat is the GitHub API URL template to fetch the latest release of a repository.
19 | APIFormat = "https://api.github.com/repos/%s/releases/latest"
20 | )
21 |
22 | // Info holds information about a software release.
23 | type Info struct {
24 | Version string // Version of the release
25 | TarURL string // Tarball URL of the release
26 | }
27 |
28 | // FetchInfo fetches details related to the latest release from the provided URL.
29 | //
30 | // Parameters:
31 | // - releaseURL: The URL to fetch the latest release information from.
32 | // Example: "https://api.github.com/repos/raystack/optimus/releases/latest"
33 | //
34 | // Returns:
35 | // - An *Info struct containing the release and tarball URL.
36 | // - An error if the HTTP request or response parsing fails.
37 | func FetchInfo(url string) (*Info, error) {
38 | httpClient := http.Client{Timeout: Timeout}
39 | req, err := http.NewRequest(http.MethodGet, url, nil)
40 | if err != nil {
41 | return nil, errors.Wrap(err, "failed to create HTTP request")
42 | }
43 | req.Header.Set("User-Agent", "raystack/salt")
44 |
45 | resp, err := httpClient.Do(req)
46 | if err != nil {
47 | return nil, errors.Wrapf(err, "failed to fetch release information from URL: %s", url)
48 | }
49 | defer func() {
50 | if resp.Body != nil {
51 | resp.Body.Close()
52 | }
53 | }()
54 | if resp.StatusCode != http.StatusOK {
55 | return nil, fmt.Errorf("unexpected status code %d from URL: %s", resp.StatusCode, url)
56 | }
57 |
58 | body, err := io.ReadAll(resp.Body)
59 | if err != nil {
60 | return nil, errors.Wrap(err, "failed to read response body")
61 | }
62 |
63 | var data struct {
64 | TagName string `json:"tag_name"`
65 | Tarball string `json:"tarball_url"`
66 | }
67 | if err = json.Unmarshal(body, &data); err != nil {
68 | return nil, errors.Wrapf(err, "failed to parse JSON response")
69 | }
70 |
71 | return &Info{
72 | Version: data.TagName,
73 | TarURL: data.Tarball,
74 | }, nil
75 | }
76 |
77 | // CompareVersions compares the current release with the latest release.
78 | //
79 | // Parameters:
80 | // - currVersion: The current release string.
81 | // - latestVersion: The latest release string.
82 | //
83 | // Returns:
84 | // - true if the current release is greater than or equal to the latest release.
85 | // - An error if release parsing fails.
86 | func CompareVersions(current, latest string) (bool, error) {
87 | currentVersion, err := version.NewVersion(current)
88 | if err != nil {
89 | return false, errors.Wrap(err, "invalid current version format")
90 | }
91 |
92 | latestVersion, err := version.NewVersion(latest)
93 | if err != nil {
94 | return false, errors.Wrap(err, "invalid latest version format")
95 | }
96 |
97 | return currentVersion.GreaterThanOrEqual(latestVersion), nil
98 | }
99 |
100 | // CheckForUpdate generates a message indicating if an update is available.
101 | //
102 | // Parameters:
103 | // - currentVersion: The current version string (e.g., "v1.0.0").
104 | // - repo: The GitHub repository in the format "owner/repo".
105 | //
106 | // Returns:
107 | // - A string containing the update message if a newer version is available.
108 | // - An empty string if the current version is up-to-date or if an error occurs.
109 | func CheckForUpdate(currentVersion, repo string) string {
110 | releaseURL := fmt.Sprintf(APIFormat, repo)
111 | info, err := FetchInfo(releaseURL)
112 | if err != nil {
113 | return ""
114 | }
115 |
116 | isLatest, err := CompareVersions(currentVersion, info.Version)
117 | if err != nil || isLatest {
118 | return ""
119 | }
120 |
121 | return fmt.Sprintf("A new release (%s) is available. consider updating to latest version.", info.Version)
122 | }
123 |
--------------------------------------------------------------------------------
/cli/terminator/brew.go:
--------------------------------------------------------------------------------
1 | package terminator
2 |
3 | import (
4 | "os/exec"
5 | "path/filepath"
6 | "strings"
7 | )
8 |
9 | // IsUnderHomebrew checks if a given binary path is managed under the Homebrew path.
10 | // This function is useful to verify if a binary is installed via Homebrew
11 | // by comparing its location to the Homebrew binary directory.
12 | func IsUnderHomebrew(path string) bool {
13 | if path == "" {
14 | return false
15 | }
16 |
17 | brewExe, err := exec.LookPath("brew")
18 | if err != nil {
19 | return false
20 | }
21 |
22 | brewPrefixBytes, err := exec.Command(brewExe, "--prefix").Output()
23 | if err != nil {
24 | return false
25 | }
26 |
27 | brewBinPrefix := filepath.Join(strings.TrimSpace(string(brewPrefixBytes)), "bin") + string(filepath.Separator)
28 | return strings.HasPrefix(path, brewBinPrefix)
29 | }
30 |
31 | // HasHomebrew checks if Homebrew is installed on the user's system.
32 | // This function determines the presence of Homebrew by looking for the "brew"
33 | // executable in the system's PATH. It is useful to ensure Homebrew dependencies
34 | // can be managed before executing related commands.
35 | func HasHomebrew() bool {
36 | _, err := exec.LookPath("brew")
37 | return err == nil
38 | }
39 |
--------------------------------------------------------------------------------
/cli/terminator/browser.go:
--------------------------------------------------------------------------------
1 | package terminator
2 |
3 | import (
4 | "os"
5 | "os/exec"
6 | "strings"
7 | )
8 |
9 | // OpenBrowser opens the default web browser at the specified URL.
10 | //
11 | // Parameters:
12 | // - goos: The operating system name (e.g., "darwin", "windows", or "linux").
13 | // - url: The URL to open in the web browser.
14 | //
15 | // Returns:
16 | // - An *exec.Cmd configured to open the URL. Note that you must call `cmd.Run()`
17 | // or `cmd.Start()` on the returned command to execute it.
18 | //
19 | // Panics:
20 | // - This function will panic if called without a TTY (e.g., not running in a terminal).
21 | func OpenBrowser(goos, url string) *exec.Cmd {
22 | if !IsTTY() {
23 | panic("OpenBrowser called without a TTY")
24 | }
25 |
26 | exe := "open"
27 | var args []string
28 |
29 | switch goos {
30 | case "darwin":
31 | // macOS: Use the "open" command to open the URL.
32 | args = append(args, url)
33 | case "windows":
34 | // Windows: Use "cmd /c start" to open the URL.
35 | exe, _ = exec.LookPath("cmd")
36 | replacer := strings.NewReplacer("&", "^&")
37 | args = append(args, "/c", "start", replacer.Replace(url))
38 | default:
39 | // Linux: Use "xdg-open" or fallback to "wslview" for WSL environments.
40 | exe = linuxExe()
41 | args = append(args, url)
42 | }
43 |
44 | // Create the command to open the browser and set stderr for error reporting.
45 | cmd := exec.Command(exe, args...)
46 | cmd.Stderr = os.Stderr
47 | return cmd
48 | }
49 |
50 | // linuxExe determines the appropriate command to open a web browser on Linux.
51 | func linuxExe() string {
52 | exe := "xdg-open"
53 |
54 | _, err := exec.LookPath(exe)
55 | if err != nil {
56 | _, err := exec.LookPath("wslview")
57 | if err == nil {
58 | exe = "wslview"
59 | }
60 | }
61 |
62 | return exe
63 | }
64 |
--------------------------------------------------------------------------------
/cli/terminator/pager.go:
--------------------------------------------------------------------------------
1 | package terminator
2 |
3 | import (
4 | "errors"
5 | "io"
6 | "os"
7 | "os/exec"
8 | "strings"
9 | "syscall"
10 |
11 | "github.com/cli/safeexec"
12 | "github.com/google/shlex"
13 | )
14 |
15 | // Pager manages a pager process for displaying output in a paginated format.
16 | //
17 | // It supports configuring the pager command, starting the pager process,
18 | // and ensuring proper cleanup when the pager is no longer needed.
19 | type Pager struct {
20 | Out io.Writer // The writer to send output to the pager.
21 | ErrOut io.Writer // The writer to send error output to.
22 | pagerCommand string // The command to run the pager (e.g., "less", "more").
23 | pagerProcess *os.Process // The running pager process, if any.
24 | }
25 |
26 | // NewPager creates a new Pager instance with default settings.
27 | //
28 | // If the "PAGER" environment variable is not set, the default command is "more".
29 | func NewPager() *Pager {
30 | pagerCmd := os.Getenv("PAGER")
31 | if pagerCmd == "" {
32 | pagerCmd = "more"
33 | }
34 |
35 | return &Pager{
36 | pagerCommand: pagerCmd,
37 | Out: os.Stdout,
38 | ErrOut: os.Stderr,
39 | }
40 | }
41 |
42 | // Set updates the pager command used to display output.
43 | //
44 | // Parameters:
45 | // - cmd: The pager command (e.g., "less", "more").
46 | func (p *Pager) Set(cmd string) {
47 | p.pagerCommand = cmd
48 | }
49 |
50 | // Get returns the current pager command.
51 | //
52 | // Returns:
53 | // - The pager command as a string.
54 | func (p *Pager) Get() string {
55 | return p.pagerCommand
56 | }
57 |
58 | // Start begins the pager process to display output.
59 | //
60 | // If the pager command is "cat" or empty, it does nothing.
61 | // The function also sets environment variables to optimize the behavior of
62 | // certain pagers, like "less" and "lv".
63 | //
64 | // Returns:
65 | // - An error if the pager command fails to start or if arguments cannot be parsed.
66 | func (p *Pager) Start() error {
67 | if p.pagerCommand == "" || p.pagerCommand == "cat" {
68 | return nil
69 | }
70 |
71 | pagerArgs, err := shlex.Split(p.pagerCommand)
72 | if err != nil {
73 | return err
74 | }
75 |
76 | // Prepare the environment variables for the pager process.
77 | pagerEnv := os.Environ()
78 | for i := len(pagerEnv) - 1; i >= 0; i-- {
79 | if strings.HasPrefix(pagerEnv[i], "PAGER=") {
80 | pagerEnv = append(pagerEnv[0:i], pagerEnv[i+1:]...)
81 | }
82 | }
83 | if _, ok := os.LookupEnv("LESS"); !ok {
84 | pagerEnv = append(pagerEnv, "LESS=FRX")
85 | }
86 | if _, ok := os.LookupEnv("LV"); !ok {
87 | pagerEnv = append(pagerEnv, "LV=-c")
88 | }
89 |
90 | // Locate the pager executable using safeexec for added security.
91 | pagerExe, err := safeexec.LookPath(pagerArgs[0])
92 | if err != nil {
93 | return err
94 | }
95 |
96 | pagerCmd := exec.Command(pagerExe, pagerArgs[1:]...)
97 | pagerCmd.Env = pagerEnv
98 | pagerCmd.Stdout = p.Out
99 | pagerCmd.Stderr = p.ErrOut
100 | pagedOut, err := pagerCmd.StdinPipe()
101 | if err != nil {
102 | return err
103 | }
104 | p.Out = &pagerWriter{pagedOut}
105 |
106 | // Start the pager process.
107 | err = pagerCmd.Start()
108 | if err != nil {
109 | return err
110 | }
111 | p.pagerProcess = pagerCmd.Process
112 | return nil
113 | }
114 |
115 | // Stop terminates the running pager process and cleans up resources.
116 | func (p *Pager) Stop() {
117 | if p.pagerProcess == nil {
118 | return
119 | }
120 |
121 | // Close the output writer and wait for the process to exit.
122 | _ = p.Out.(io.WriteCloser).Close()
123 | _, _ = p.pagerProcess.Wait()
124 | p.pagerProcess = nil
125 | }
126 |
127 | // pagerWriter is a custom writer that wraps WriteCloser and handles EPIPE errors.
128 | //
129 | // If a write fails due to a closed pipe, it returns an ErrClosedPagerPipe error.
130 | type pagerWriter struct {
131 | io.WriteCloser
132 | }
133 |
134 | // Write writes data to the underlying WriteCloser and handles EPIPE errors.
135 | //
136 | // Parameters:
137 | // - d: The data to write.
138 | //
139 | // Returns:
140 | // - The number of bytes written and an error if the write fails.
141 | func (w *pagerWriter) Write(d []byte) (int, error) {
142 | n, err := w.WriteCloser.Write(d)
143 | if err != nil && (errors.Is(err, io.ErrClosedPipe) || isEpipeError(err)) {
144 | return n, &ErrClosedPagerPipe{err}
145 | }
146 | return n, err
147 | }
148 |
149 | // isEpipeError checks if an error is a broken pipe (EPIPE) error.
150 | //
151 | // Returns:
152 | // - A boolean indicating whether the error is an EPIPE error.
153 | func isEpipeError(err error) bool {
154 | return errors.Is(err, syscall.EPIPE)
155 | }
156 |
157 | // ErrClosedPagerPipe is an error type returned when writing to a closed pager pipe.
158 | type ErrClosedPagerPipe struct {
159 | error
160 | }
161 |
--------------------------------------------------------------------------------
/cli/terminator/term.go:
--------------------------------------------------------------------------------
1 | package terminator
2 |
3 | import (
4 | "os"
5 |
6 | "github.com/mattn/go-isatty"
7 | "github.com/muesli/termenv"
8 | )
9 |
10 | // IsTTY checks if the current output is a TTY (teletypewriter) or a Cygwin terminal.
11 | // This function is useful for determining if the program is running in a terminal
12 | // environment, which is important for features like colored output or interactive prompts.
13 | func IsTTY() bool {
14 | return isatty.IsTerminal(os.Stdout.Fd()) || isatty.IsCygwinTerminal(os.Stdout.Fd())
15 | }
16 |
17 | // IsColorDisabled checks if color output is disabled based on the environment settings.
18 | // This function uses the `termenv` library to determine if the NO_COLOR environment
19 | // variable is set, which is a common way to disable colored output.
20 | func IsColorDisabled() bool {
21 | return termenv.EnvNoColor()
22 | }
23 |
24 | // IsCI checks if the code is running in a Continuous Integration (CI) environment.
25 | // This function checks for common environment variables used by popular CI systems
26 | // like GitHub Actions, Travis CI, CircleCI, Jenkins, TeamCity, and others.
27 | func IsCI() bool {
28 | return os.Getenv("CI") != "" || // GitHub Actions, Travis CI, CircleCI, Cirrus CI, GitLab CI, AppVeyor, CodeShip, dsari
29 | os.Getenv("BUILD_NUMBER") != "" || // Jenkins, TeamCity
30 | os.Getenv("RUN_ID") != "" // TaskCluster, dsari
31 | }
32 |
--------------------------------------------------------------------------------
/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "fmt"
7 | "os"
8 | "path/filepath"
9 | "reflect"
10 | "strings"
11 |
12 | "github.com/go-playground/validator"
13 | "github.com/mcuadros/go-defaults"
14 | "github.com/spf13/pflag"
15 | "github.com/spf13/viper"
16 | "gopkg.in/yaml.v3"
17 | )
18 |
19 | // Loader is responsible for managing configuration
20 | type Loader struct {
21 | v *viper.Viper
22 | flags *pflag.FlagSet
23 | }
24 |
25 | // Option defines a functional option for configuring the Loader.
26 | type Option func(c *Loader)
27 |
28 | // NewLoader creates a new Loader instance with the provided options.
29 | // It initializes Viper with defaults for YAML configuration files and environment variable handling.
30 | //
31 | // Example:
32 | //
33 | // loader := config.NewLoader(
34 | // config.WithFile("./config.yaml"),
35 | // config.WithEnvPrefix("MYAPP"),
36 | // )
37 | func NewLoader(options ...Option) *Loader {
38 | v := viper.New()
39 |
40 | v.SetConfigName("config")
41 | v.SetConfigType("yaml")
42 | v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
43 | v.AutomaticEnv()
44 |
45 | loader := &Loader{v: v}
46 | for _, opt := range options {
47 | opt(loader)
48 | }
49 | return loader
50 | }
51 |
52 | // WithFile specifies the configuration file to use.
53 | func WithFile(configFilePath string) Option {
54 | return func(l *Loader) {
55 | l.v.SetConfigFile(configFilePath)
56 | }
57 | }
58 |
59 | // WithEnvPrefix specifies a prefix for ENV variables.
60 | func WithEnvPrefix(prefix string) Option {
61 | return func(l *Loader) {
62 | l.v.SetEnvPrefix(prefix)
63 | }
64 | }
65 |
66 | // WithFlags specifies a command-line flag set to bind dynamically based on `cmdx` tags.
67 | func WithFlags(flags *pflag.FlagSet) Option {
68 | return func(l *Loader) {
69 | l.flags = flags
70 | }
71 | }
72 |
73 | // WithAppConfig sets up application-specific configuration file handling.
74 | func WithAppConfig(app string) Option {
75 | return func(l *Loader) {
76 | filePath, err := getConfigFilePath(app)
77 | if err != nil {
78 | panic(fmt.Errorf("failed to determine config file path: %w", err))
79 | }
80 | l.v.SetConfigFile(filePath)
81 | }
82 | }
83 |
84 | // Load reads the configuration from the file, environment variables, and command-line flags,
85 | // and merges them into the provided configuration struct. It validates the configuration
86 | // using struct tags.
87 | //
88 | // The priority order is:
89 | // 1. Command-line flags
90 | // 2. Environment variables
91 | // 3. Configuration file
92 | // 4. Default values
93 | func (l *Loader) Load(config interface{}) error {
94 | if err := validateStructPtr(config); err != nil {
95 | return err
96 | }
97 |
98 | // Apply default values before reading configuration
99 | defaults.SetDefaults(config)
100 |
101 | // Bind flags dynamically using reflection on `cmdx` tags if a flag set is provided
102 | if l.flags != nil {
103 | if err := bindFlags(l.v, l.flags, reflect.TypeOf(config).Elem(), ""); err != nil {
104 | return fmt.Errorf("failed to bind flags: %w", err)
105 | }
106 | }
107 |
108 | // Bind environment variables for all keys in the config
109 | keys, err := extractFlattenedKeys(config)
110 | if err != nil {
111 | return fmt.Errorf("failed to extract config keys: %w", err)
112 | }
113 | for _, key := range keys {
114 | if err := l.v.BindEnv(key); err != nil {
115 | return fmt.Errorf("failed to bind environment variable for key %q: %w", key, err)
116 | }
117 | }
118 |
119 | // Attempt to read the configuration file
120 | if err := l.v.ReadInConfig(); err != nil {
121 | var configFileNotFoundError viper.ConfigFileNotFoundError
122 | if errors.As(err, &configFileNotFoundError) {
123 | fmt.Println("Warning: Config file not found. Falling back to defaults and environment variables.")
124 | }
125 | }
126 |
127 | // Unmarshal the merged configuration into the provided struct
128 | if err := l.v.Unmarshal(config); err != nil {
129 | return fmt.Errorf("failed to unmarshal config: %w", err)
130 | }
131 |
132 | // Validate the resulting configuration
133 | if err := validator.New().Struct(config); err != nil {
134 | return fmt.Errorf("invalid configuration: %w", err)
135 | }
136 |
137 | return nil
138 | }
139 |
140 | // Init initializes the configuration file with default values.
141 | func (l *Loader) Init(config interface{}) error {
142 | defaults.SetDefaults(config)
143 |
144 | path := l.v.ConfigFileUsed()
145 | if fileExists(path) {
146 | return errors.New("configuration file already exists")
147 | }
148 |
149 | data, err := yaml.Marshal(config)
150 | if err != nil {
151 | return fmt.Errorf("failed to marshal configuration: %w", err)
152 | }
153 |
154 | if err := ensureDir(filepath.Dir(path)); err != nil {
155 | return fmt.Errorf("failed to create directory: %w", err)
156 | }
157 |
158 | if err := os.WriteFile(path, data, 0644); err != nil {
159 | return fmt.Errorf("failed to write configuration file: %w", err)
160 | }
161 | return nil
162 | }
163 |
164 | // Get retrieves a configuration value by key.
165 | func (l *Loader) Get(key string) interface{} {
166 | return l.v.Get(key)
167 | }
168 |
169 | // Set updates a configuration value in memory (not persisted to file).
170 | func (l *Loader) Set(key string, value interface{}) {
171 | l.v.Set(key, value)
172 | }
173 |
174 | // Save writes the current configuration to the file specified during initialization.
175 | func (l *Loader) Save() error {
176 | configFile := l.v.ConfigFileUsed()
177 | if configFile == "" {
178 | return errors.New("no configuration file specified for saving")
179 | }
180 |
181 | settings := l.v.AllSettings()
182 | content, err := yaml.Marshal(settings)
183 | if err != nil {
184 | return fmt.Errorf("failed to marshal configuration: %w", err)
185 | }
186 |
187 | if err := os.WriteFile(configFile, content, 0644); err != nil {
188 | return fmt.Errorf("failed to write configuration to file: %w", err)
189 | }
190 | return nil
191 | }
192 |
193 | // View returns the current configuration as a formatted JSON string.
194 | func (l *Loader) View() (string, error) {
195 | settings := l.v.AllSettings()
196 | data, err := json.MarshalIndent(settings, "", " ")
197 | if err != nil {
198 | return "", fmt.Errorf("failed to format configuration as JSON: %w", err)
199 | }
200 | return string(data), nil
201 | }
202 |
--------------------------------------------------------------------------------
/config/doc.go:
--------------------------------------------------------------------------------
1 | /*
2 | Package config provides a flexible and extensible configuration management solution for Go applications.
3 |
4 | It integrates configuration files, environment variables, command-line flags, and default values to populate
5 | and validate user-defined structs.
6 |
7 | Configuration Precedence:
8 | The `Loader` merges configuration values from multiple sources in the following order of precedence (highest to lowest):
9 | 1. Command-line flags: Defined using `pflag.FlagSet` and dynamically bound via `cmdx` tags.
10 | 2. Environment variables: Dynamically bound to configuration keys, optionally prefixed using `WithEnvPrefix`.
11 | 3. Configuration file: YAML configuration files specified via `WithFile`.
12 | 4. Default values: Struct fields annotated with `default` tags are populated if no other source provides a value.
13 |
14 | Defaults:
15 | Default values are specified using the `default` struct tag. Fields annotated with `default` are populated
16 | before any other source (flags, environment variables, or files).
17 |
18 | Example:
19 |
20 | type Config struct {
21 | ServerPort int `mapstructure:"server.port" default:"8080"`
22 | LogLevel string `mapstructure:"log.level" default:"info"`
23 | }
24 |
25 | In the absence of higher-priority sources, `ServerPort` will default to `8080` and `LogLevel` to `info`.
26 |
27 | Validation:
28 | Validation is performed using the `go-playground/validator` package. Fields annotated with `validate` tags
29 | are validated after merging all configuration sources.
30 |
31 | Example:
32 |
33 | type Config struct {
34 | ServerPort int `mapstructure:"server.port" validate:"required,min=1"`
35 | LogLevel string `mapstructure:"log.level" validate:"required,oneof=debug info warn error"`
36 | }
37 |
38 | If validation fails, the `Load` method returns a detailed error indicating the invalid fields.
39 |
40 | Annotations:
41 | Configuration structs use the following struct tags to define behavior:
42 | - `mapstructure`: Maps YAML or environment variables to struct fields.
43 | - `default`: Provides fallback values for fields when no source overrides them.
44 | - `validate`: Ensures the final configuration meets application-specific requirements.
45 |
46 | Example:
47 |
48 | type Config struct {
49 | Server struct {
50 | Port int `mapstructure:"server.port" default:"8080" validate:"required,min=1"`
51 | Host string `mapstructure:"server.host" default:"localhost" validate:"required"`
52 | } `mapstructure:"server"`
53 |
54 | LogLevel string `mapstructure:"log.level" default:"info" validate:"required,oneof=debug info warn error"`
55 | }
56 |
57 | The `Loader` will merge all sources, apply defaults, and validate the result in a single call to `Load`.
58 |
59 | Features:
60 | - Merges configurations from multiple sources: flags, environment variables, files, and defaults.
61 | - Supports nested structs with dynamic field mapping using `cmdx` tags.
62 | - Validates fields with constraints defined in `validate` tags.
63 | - Saves and views the final configuration in YAML or JSON formats.
64 |
65 | Example Usage:
66 |
67 | type Config struct {
68 | ServerPort int `mapstructure:"server.port" cmdx:"server.port" default:"8080" validate:"required,min=1"`
69 | LogLevel string `mapstructure:"log.level" cmdx:"log.level" default:"info" validate:"required,oneof=debug info warn error"`
70 | }
71 |
72 | func main() {
73 | flags := pflag.NewFlagSet("example", pflag.ExitOnError)
74 | flags.Int("server.port", 8080, "Server port")
75 | flags.String("log.level", "info", "Log level")
76 |
77 | loader := config.NewLoader(
78 | config.WithFile("./config.yaml"),
79 | config.WithEnvPrefix("MYAPP"),
80 | config.WithFlags(flags),
81 | )
82 |
83 | flags.Parse(os.Args[1:])
84 |
85 | cfg := &Config{}
86 | if err := loader.Load(cfg); err != nil {
87 | log.Fatalf("Failed to load configuration: %v", err)
88 | }
89 |
90 | fmt.Printf("Configuration: %+v\n", cfg)
91 | }
92 | */
93 | package config
94 |
--------------------------------------------------------------------------------
/config/helpers.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "os"
7 | "path/filepath"
8 | "reflect"
9 | "runtime"
10 |
11 | "github.com/jeremywohl/flatten"
12 | "github.com/mitchellh/mapstructure"
13 | "github.com/spf13/pflag"
14 | "github.com/spf13/viper"
15 | )
16 |
17 | // bindFlags dynamically binds flags to configuration fields based on `cmdx` tags.
18 | func bindFlags(v *viper.Viper, flagSet *pflag.FlagSet, structType reflect.Type, parentKey string) error {
19 | for i := 0; i < structType.NumField(); i++ {
20 | field := structType.Field(i)
21 | tag := field.Tag.Get("cmdx")
22 | if tag == "" {
23 | continue
24 | }
25 |
26 | if parentKey != "" {
27 | tag = parentKey + "." + tag
28 | }
29 |
30 | if field.Type.Kind() == reflect.Struct {
31 | // Recurse into nested structs
32 | if err := bindFlags(v, flagSet, field.Type, tag); err != nil {
33 | return err
34 | }
35 | } else {
36 | flag := flagSet.Lookup(tag)
37 | if flag == nil {
38 | return fmt.Errorf("missing flag for tag: %s", tag)
39 | }
40 | if err := v.BindPFlag(tag, flag); err != nil {
41 | return fmt.Errorf("failed to bind flag for tag: %s, error: %w", tag, err)
42 | }
43 | }
44 | }
45 | return nil
46 | }
47 |
48 | // validateStructPtr ensures the provided value is a pointer to a struct.
49 | func validateStructPtr(value interface{}) error {
50 | val := reflect.ValueOf(value)
51 | if val.Kind() != reflect.Ptr || val.Elem().Kind() != reflect.Struct {
52 | return errors.New("load requires a pointer to a struct")
53 | }
54 | return nil
55 | }
56 |
57 | // extractFlattenedKeys retrieves all keys from the struct in a flattened format.
58 | func extractFlattenedKeys(config interface{}) ([]string, error) {
59 | var structMap map[string]interface{}
60 | if err := mapstructure.Decode(config, &structMap); err != nil {
61 | return nil, err
62 | }
63 | flatMap, err := flatten.Flatten(structMap, "", flatten.DotStyle)
64 | if err != nil {
65 | return nil, err
66 | }
67 | keys := make([]string, 0, len(flatMap))
68 | for k := range flatMap {
69 | keys = append(keys, k)
70 | }
71 | return keys, nil
72 | }
73 |
74 | // Utilities for app-specific configuration paths
75 | func getConfigFilePath(app string) (string, error) {
76 | dirPath := getConfigDir("raystack")
77 | if err := ensureDir(dirPath); err != nil {
78 | return "", err
79 | }
80 | return filepath.Join(dirPath, app+".yml"), nil
81 | }
82 |
83 | func getConfigDir(root string) string {
84 | switch {
85 | case envSet("RAYSTACK_CONFIG_DIR"):
86 | return filepath.Join(os.Getenv("RAYSTACK_CONFIG_DIR"), root)
87 | case envSet("XDG_CONFIG_HOME"):
88 | return filepath.Join(os.Getenv("XDG_CONFIG_HOME"), root)
89 | case runtime.GOOS == "windows" && envSet("APPDATA"):
90 | return filepath.Join(os.Getenv("APPDATA"), root)
91 | default:
92 | home, _ := os.UserHomeDir()
93 | return filepath.Join(home, ".config", root)
94 | }
95 | }
96 |
97 | func ensureDir(dir string) error {
98 | if err := os.MkdirAll(dir, 0755); err != nil {
99 | return fmt.Errorf("failed to create directory: %w", err)
100 | }
101 | return nil
102 | }
103 |
104 | func fileExists(filename string) bool {
105 | _, err := os.Stat(filename)
106 | return err == nil
107 | }
108 |
109 | func envSet(key string) bool {
110 | return os.Getenv(key) != ""
111 | }
112 |
--------------------------------------------------------------------------------
/db/config.go:
--------------------------------------------------------------------------------
1 | package db
2 |
3 | import (
4 | "time"
5 | )
6 |
7 | type Config struct {
8 | Driver string `yaml:"driver" mapstructure:"driver"`
9 | URL string `yaml:"url" mapstructure:"url"`
10 | MaxIdleConns int `yaml:"max_idle_conns" mapstructure:"max_idle_conns" default:"10"`
11 | MaxOpenConns int `yaml:"max_open_conns" mapstructure:"max_open_conns" default:"10"`
12 | ConnMaxLifeTime time.Duration `yaml:"conn_max_life_time" mapstructure:"conn_max_life_time" default:"10ms"`
13 | MaxQueryTimeout time.Duration `yaml:"max_query_timeout" mapstructure:"max_query_timeout" default:"100ms"`
14 | }
15 |
--------------------------------------------------------------------------------
/db/db.go:
--------------------------------------------------------------------------------
1 | package db
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "fmt"
7 | "net/url"
8 | "time"
9 |
10 | "github.com/jmoiron/sqlx"
11 | "github.com/pkg/errors"
12 | )
13 |
14 | type Client struct {
15 | *sqlx.DB
16 | queryTimeOut time.Duration
17 | cfg Config
18 | host string
19 | }
20 |
21 | // NewClient creates a new sqlx database client
22 | func New(cfg Config) (*Client, error) {
23 | dbURL, err := url.Parse(cfg.URL)
24 | if err != nil {
25 | return nil, err
26 | }
27 | host := dbURL.Host
28 |
29 | db, err := sqlx.Connect(cfg.Driver, cfg.URL)
30 | if err != nil {
31 | return nil, err
32 | }
33 |
34 | db.SetMaxIdleConns(cfg.MaxIdleConns)
35 | db.SetMaxOpenConns(cfg.MaxOpenConns)
36 | db.SetConnMaxLifetime(cfg.ConnMaxLifeTime)
37 |
38 | return &Client{DB: db, queryTimeOut: cfg.MaxQueryTimeout, cfg: cfg, host: host}, err
39 | }
40 |
41 | func (c Client) WithTimeout(ctx context.Context, op func(ctx context.Context) error) (err error) {
42 | ctxWithTimeout, cancel := context.WithTimeout(ctx, c.queryTimeOut)
43 | defer cancel()
44 |
45 | return op(ctxWithTimeout)
46 | }
47 |
48 | func (c Client) WithTxn(ctx context.Context, txnOptions sql.TxOptions, txFunc func(*sqlx.Tx) error) (err error) {
49 | txn, err := c.BeginTxx(ctx, &txnOptions)
50 | if err != nil {
51 | return err
52 | }
53 |
54 | defer func() {
55 | if p := recover(); p != nil {
56 | switch p := p.(type) {
57 | case error:
58 | err = p
59 | default:
60 | err = errors.Errorf("%s", p)
61 | }
62 | err = txn.Rollback()
63 | panic(p)
64 | } else if err != nil {
65 | if rlbErr := txn.Rollback(); rlbErr != nil {
66 | err = fmt.Errorf("rollback error: %s while executing: %w", rlbErr, err)
67 | } else {
68 | err = fmt.Errorf("rollback: %w", err)
69 | }
70 | } else {
71 | err = txn.Commit()
72 | }
73 | }()
74 |
75 | err = txFunc(txn)
76 | return err
77 | }
78 |
79 | // ConnectionURL fetch the database connection url
80 | func (c *Client) ConnectionURL() string {
81 | return c.cfg.URL
82 | }
83 |
84 | // Host fetch the database host information
85 | func (c *Client) Host() string {
86 | return c.host
87 | }
88 |
89 | // Close closes the database connection
90 | func (c *Client) Close() error {
91 | return c.DB.Close()
92 | }
93 |
--------------------------------------------------------------------------------
/db/db_test.go:
--------------------------------------------------------------------------------
1 | package db_test
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "fmt"
7 | "log"
8 | "os"
9 | "testing"
10 | "time"
11 |
12 | "github.com/jmoiron/sqlx"
13 | "github.com/ory/dockertest/v3"
14 | "github.com/ory/dockertest/v3/docker"
15 | "github.com/raystack/salt/db"
16 | "github.com/stretchr/testify/assert"
17 | )
18 |
19 | const (
20 | dialect = "postgres"
21 | user = "root"
22 | password = "pass"
23 | database = "postgres"
24 | host = "localhost"
25 | port = "5432"
26 | dsn = "postgres://%s:%s@localhost:%s/%s?sslmode=disable"
27 | )
28 |
29 | var (
30 | createTableQuery = "CREATE TABLE IF NOT EXISTS users (id VARCHAR(36) PRIMARY KEY, name VARCHAR(50))"
31 | dropTableQuery = "DROP TABLE IF EXISTS users"
32 | checkTableQuery = "SELECT EXISTS(SELECT * FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'users');"
33 | )
34 |
35 | var client *db.Client
36 |
37 | func TestMain(m *testing.M) {
38 | pool, err := dockertest.NewPool("")
39 | if err != nil {
40 | log.Fatalf("Could not connect to docker: %s", err)
41 | }
42 |
43 | opts := dockertest.RunOptions{
44 | Repository: "postgres",
45 | Tag: "14",
46 | Env: []string{
47 | "POSTGRES_USER=" + user,
48 | "POSTGRES_PASSWORD=" + password,
49 | "POSTGRES_DB=" + database,
50 | },
51 | ExposedPorts: []string{"5432"},
52 | PortBindings: map[docker.Port][]docker.PortBinding{
53 | "5432": {
54 | {HostIP: "0.0.0.0", HostPort: port},
55 | },
56 | },
57 | }
58 |
59 | resource, err := pool.RunWithOptions(&opts, func(config *docker.HostConfig) {
60 | config.AutoRemove = true
61 | config.RestartPolicy = docker.RestartPolicy{Name: "no"}
62 | })
63 | if err != nil {
64 | log.Fatalf("Could not start resource: %s", err.Error())
65 | }
66 |
67 | fmt.Println(resource.GetPort("5432/tcp"))
68 |
69 | if err := resource.Expire(120); err != nil {
70 | log.Fatalf("Could not expire resource: %s", err.Error())
71 | }
72 |
73 | pool.MaxWait = 60 * time.Second
74 |
75 | dsn := fmt.Sprintf(dsn, user, password, port, database)
76 | var (
77 | pgConfig = db.Config{
78 | Driver: "postgres",
79 | URL: dsn,
80 | }
81 | )
82 |
83 | if err = pool.Retry(func() error {
84 | client, err = db.New(pgConfig)
85 | return err
86 | }); err != nil {
87 | log.Fatalf("Could not connect to docker: %s", err.Error())
88 | }
89 |
90 | defer func() {
91 | client.Close()
92 | }()
93 |
94 | code := m.Run()
95 |
96 | if err := pool.Purge(resource); err != nil {
97 | log.Fatalf("Could not purge resource: %s", err)
98 | }
99 |
100 | os.Exit(code)
101 | }
102 |
103 | func TestWithTxn(t *testing.T) {
104 | if _, err := client.Exec(dropTableQuery); err != nil {
105 | log.Fatalf("Could not cleanup: %s", err)
106 | }
107 | err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error {
108 | if _, err := tx.Exec(createTableQuery); err != nil {
109 | return err
110 | }
111 | if _, err := tx.Exec(dropTableQuery); err != nil {
112 | return err
113 | }
114 |
115 | return nil
116 | })
117 | assert.NoError(t, err)
118 |
119 | // Table should be dropped
120 | var tableExist bool
121 | result := client.QueryRow(checkTableQuery)
122 | result.Scan(&tableExist)
123 | assert.Equal(t, false, tableExist)
124 | }
125 |
126 | func TestWithTxnCommit(t *testing.T) {
127 | if _, err := client.Exec(dropTableQuery); err != nil {
128 | log.Fatalf("Could not cleanup: %s", err)
129 | }
130 | query2 := "SELECT 1"
131 |
132 | err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error {
133 | if _, err := tx.Exec(createTableQuery); err != nil {
134 | return err
135 | }
136 | if _, err := tx.Exec(query2); err != nil {
137 | return err
138 | }
139 |
140 | return nil
141 | })
142 | // WithTx should not return an error
143 | assert.NoError(t, err)
144 |
145 | // User table should exist
146 | var tableExist bool
147 | result := client.QueryRow(checkTableQuery)
148 | result.Scan(&tableExist)
149 | assert.Equal(t, true, tableExist)
150 | }
151 |
152 | func TestWithTxnRollback(t *testing.T) {
153 | if _, err := client.Exec(dropTableQuery); err != nil {
154 | log.Fatalf("Could not cleanup: %s", err)
155 | }
156 | query2 := "WRONG QUERY"
157 |
158 | err := client.WithTxn(context.Background(), sql.TxOptions{}, func(tx *sqlx.Tx) error {
159 | if _, err := tx.Exec(createTableQuery); err != nil {
160 | return err
161 | }
162 | if _, err := tx.Exec(query2); err != nil {
163 | return err
164 | }
165 |
166 | return nil
167 | })
168 | // WithTx should return an error
169 | assert.Error(t, err)
170 |
171 | // Table should not be created
172 | var tableExist bool
173 | result := client.QueryRow(checkTableQuery)
174 | result.Scan(&tableExist)
175 | assert.Equal(t, false, tableExist)
176 | }
177 |
--------------------------------------------------------------------------------
/db/migrate.go:
--------------------------------------------------------------------------------
1 | package db
2 |
3 | import (
4 | "fmt"
5 | "io/fs"
6 |
7 | "github.com/golang-migrate/migrate/v4"
8 | _ "github.com/golang-migrate/migrate/v4/database"
9 | _ "github.com/golang-migrate/migrate/v4/database/mysql"
10 | _ "github.com/golang-migrate/migrate/v4/database/postgres"
11 | _ "github.com/golang-migrate/migrate/v4/source/file"
12 | "github.com/golang-migrate/migrate/v4/source/iofs"
13 | )
14 |
15 | func RunMigrations(config Config, embeddedMigrations fs.FS, resourcePath string) error {
16 | m, err := getMigrationInstance(config, embeddedMigrations, resourcePath)
17 | if err != nil {
18 | return err
19 | }
20 |
21 | err = m.Up()
22 | if err == migrate.ErrNoChange || err == nil {
23 | return nil
24 | }
25 |
26 | return err
27 | }
28 |
29 | func RunRollback(config Config, embeddedMigrations fs.FS, resourcePath string) error {
30 | m, err := getMigrationInstance(config, embeddedMigrations, resourcePath)
31 | if err != nil {
32 | return err
33 | }
34 |
35 | err = m.Steps(-1)
36 | if err == migrate.ErrNoChange || err == nil {
37 | return nil
38 | }
39 |
40 | return err
41 | }
42 |
43 | func getMigrationInstance(config Config, embeddedMigrations fs.FS, resourcePath string) (*migrate.Migrate, error) {
44 | src, err := iofs.New(embeddedMigrations, resourcePath)
45 | if err != nil {
46 | return nil, fmt.Errorf("db migrator: %v", err)
47 | }
48 | return migrate.NewWithSourceInstance("iofs", src, config.URL)
49 | }
50 |
--------------------------------------------------------------------------------
/db/migrate_test.go:
--------------------------------------------------------------------------------
1 | package db_test
2 |
3 | import (
4 | "embed"
5 | "fmt"
6 | "log"
7 | "testing"
8 |
9 | "github.com/raystack/salt/db"
10 | "github.com/stretchr/testify/assert"
11 | )
12 |
13 | //go:embed migrations/*.sql
14 | var migrationFs embed.FS
15 |
16 | func TestRunMigrations(t *testing.T) {
17 | if _, err := client.Exec(dropTableQuery); err != nil {
18 | log.Fatalf("Could not cleanup: %s", err)
19 | }
20 |
21 | dsn := fmt.Sprintf(dsn, user, password, port, database)
22 | var (
23 | pgConfig = db.Config{
24 | Driver: "postgres",
25 | URL: dsn,
26 | }
27 | )
28 |
29 | err := db.RunMigrations(pgConfig, migrationFs, "migrations")
30 | assert.NoError(t, err)
31 |
32 | // User table should exist
33 | var tableExist bool
34 | result := client.QueryRow(checkTableQuery)
35 | result.Scan(&tableExist)
36 | assert.Equal(t, true, tableExist)
37 | }
38 |
39 | func TestRunRollback(t *testing.T) {
40 | if _, err := client.Exec(dropTableQuery); err != nil {
41 | log.Fatalf("Could not cleanup: %s", err)
42 | }
43 |
44 | dsn := fmt.Sprintf(dsn, user, password, port, database)
45 | var (
46 | pgConfig = db.Config{
47 | Driver: "postgres",
48 | URL: dsn,
49 | }
50 | )
51 |
52 | err := db.RunRollback(pgConfig, migrationFs, "migrations")
53 | assert.NoError(t, err)
54 |
55 | // User table should not exist
56 | var tableExist bool
57 | result := client.QueryRow(checkTableQuery)
58 | result.Scan(&tableExist)
59 | assert.Equal(t, false, tableExist)
60 | }
61 |
--------------------------------------------------------------------------------
/db/migrations/1481574547_create_users_table.down.sql:
--------------------------------------------------------------------------------
1 | DROP TABLE IF EXISTS users
--------------------------------------------------------------------------------
/db/migrations/1481574547_create_users_table.up.sql:
--------------------------------------------------------------------------------
1 | CREATE TABLE IF NOT EXISTS users (id VARCHAR(36) PRIMARY KEY, name VARCHAR(50))
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/raystack/salt
2 |
3 | go 1.22
4 |
5 | require (
6 | github.com/AlecAivazis/survey/v2 v2.3.6
7 | github.com/MakeNowJust/heredoc v1.0.0
8 | github.com/NYTimes/gziphandler v1.1.1
9 | github.com/authzed/authzed-go v0.7.0
10 | github.com/authzed/grpcutil v0.0.0-20230908193239-4286bb1d6403
11 | github.com/briandowns/spinner v1.18.0
12 | github.com/charmbracelet/glamour v0.3.0
13 | github.com/cli/safeexec v1.0.0
14 | github.com/go-playground/validator v9.31.0+incompatible
15 | github.com/golang-migrate/migrate/v4 v4.16.0
16 | github.com/google/go-cmp v0.6.0
17 | github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
18 | github.com/google/uuid v1.6.0
19 | github.com/hashicorp/go-version v1.3.0
20 | github.com/jeremywohl/flatten v1.0.1
21 | github.com/jmoiron/sqlx v1.3.5
22 | github.com/lib/pq v1.10.4
23 | github.com/mattn/go-isatty v0.0.19
24 | github.com/mcuadros/go-defaults v1.2.0
25 | github.com/mitchellh/mapstructure v1.5.0
26 | github.com/muesli/termenv v0.11.1-0.20220212125758-44cd13922739
27 | github.com/oklog/run v1.1.0
28 | github.com/olekukonko/tablewriter v0.0.5
29 | github.com/ory/dockertest/v3 v3.9.1
30 | github.com/pkg/errors v0.9.1
31 | github.com/schollz/progressbar/v3 v3.8.5
32 | github.com/sirupsen/logrus v1.9.2
33 | github.com/spf13/cobra v1.8.1
34 | github.com/spf13/pflag v1.0.5
35 | github.com/spf13/viper v1.19.0
36 | github.com/stretchr/testify v1.9.0
37 | go.opentelemetry.io/contrib/instrumentation/host v0.56.0
38 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.52.0
39 | go.opentelemetry.io/contrib/instrumentation/runtime v0.56.0
40 | go.opentelemetry.io/contrib/samplers/probability/consistent v0.25.0
41 | go.opentelemetry.io/otel v1.31.0
42 | go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0
43 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0
44 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0
45 | go.opentelemetry.io/otel/metric v1.31.0
46 | go.opentelemetry.io/otel/sdk v1.31.0
47 | go.opentelemetry.io/otel/sdk/metric v1.31.0
48 | go.uber.org/zap v1.21.0
49 | golang.org/x/oauth2 v0.22.0
50 | golang.org/x/text v0.19.0
51 | google.golang.org/api v0.171.0
52 | google.golang.org/grpc v1.67.1
53 | google.golang.org/protobuf v1.35.1
54 | gopkg.in/yaml.v3 v3.0.1
55 | )
56 |
57 | require (
58 | cloud.google.com/go/compute/metadata v0.5.0 // indirect
59 | github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
60 | github.com/Microsoft/go-winio v0.6.1 // indirect
61 | github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
62 | github.com/alecthomas/chroma v0.8.2 // indirect
63 | github.com/alecthomas/repr v0.2.0 // indirect
64 | github.com/aymerick/douceur v0.2.0 // indirect
65 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect
66 | github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d // indirect
67 | github.com/containerd/continuity v0.3.0 // indirect
68 | github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
69 | github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect
70 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
71 | github.com/dlclark/regexp2 v1.2.0 // indirect
72 | github.com/docker/cli v20.10.14+incompatible // indirect
73 | github.com/docker/docker v20.10.24+incompatible // indirect
74 | github.com/docker/go-connections v0.4.0 // indirect
75 | github.com/docker/go-units v0.5.0 // indirect
76 | github.com/ebitengine/purego v0.8.0 // indirect
77 | github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect
78 | github.com/fatih/color v1.15.0 // indirect
79 | github.com/felixge/httpsnoop v1.0.4 // indirect
80 | github.com/fsnotify/fsnotify v1.7.0 // indirect
81 | github.com/go-logr/logr v1.4.2 // indirect
82 | github.com/go-logr/stdr v1.2.2 // indirect
83 | github.com/go-ole/go-ole v1.3.0 // indirect
84 | github.com/go-playground/locales v0.14.1 // indirect
85 | github.com/go-playground/universal-translator v0.18.1 // indirect
86 | github.com/go-sql-driver/mysql v1.6.0 // indirect
87 | github.com/gogo/protobuf v1.3.2 // indirect
88 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
89 | github.com/golang/protobuf v1.5.4 // indirect
90 | github.com/google/s2a-go v0.1.7 // indirect
91 | github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
92 | github.com/gorilla/css v1.0.0 // indirect
93 | github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect
94 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect
95 | github.com/hashicorp/errwrap v1.1.0 // indirect
96 | github.com/hashicorp/go-multierror v1.1.1 // indirect
97 | github.com/hashicorp/hcl v1.0.0 // indirect
98 | github.com/imdario/mergo v0.3.12 // indirect
99 | github.com/inconshreveable/mousetrap v1.1.0 // indirect
100 | github.com/jzelinskie/stringz v0.0.0-20210414224931-d6a8ce844a70 // indirect
101 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
102 | github.com/leodido/go-urn v1.4.0 // indirect
103 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
104 | github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect
105 | github.com/magiconair/properties v1.8.7 // indirect
106 | github.com/mattn/go-colorable v0.1.13 // indirect
107 | github.com/mattn/go-runewidth v0.0.13 // indirect
108 | github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect
109 | github.com/microcosm-cc/bluemonday v1.0.6 // indirect
110 | github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
111 | github.com/moby/term v0.5.0 // indirect
112 | github.com/muesli/reflow v0.3.0 // indirect
113 | github.com/opencontainers/go-digest v1.0.0 // indirect
114 | github.com/opencontainers/image-spec v1.0.2 // indirect
115 | github.com/opencontainers/runc v1.1.2 // indirect
116 | github.com/pelletier/go-toml/v2 v2.2.2 // indirect
117 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
118 | github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
119 | github.com/rivo/uniseg v0.2.0 // indirect
120 | github.com/russross/blackfriday/v2 v2.1.0 // indirect
121 | github.com/sagikazarmark/locafero v0.4.0 // indirect
122 | github.com/sagikazarmark/slog-shim v0.1.0 // indirect
123 | github.com/shirou/gopsutil/v4 v4.24.9 // indirect
124 | github.com/sourcegraph/conc v0.3.0 // indirect
125 | github.com/spf13/afero v1.11.0 // indirect
126 | github.com/spf13/cast v1.6.0 // indirect
127 | github.com/stretchr/objx v0.5.2 // indirect
128 | github.com/subosito/gotenv v1.6.0 // indirect
129 | github.com/tklauser/go-sysconf v0.3.14 // indirect
130 | github.com/tklauser/numcpus v0.9.0 // indirect
131 | github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect
132 | github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
133 | github.com/xeipuuv/gojsonschema v1.2.0 // indirect
134 | github.com/yuin/goldmark v1.4.13 // indirect
135 | github.com/yuin/goldmark-emoji v1.0.1 // indirect
136 | github.com/yusufpapurcu/wmi v1.2.4 // indirect
137 | go.opencensus.io v0.24.0 // indirect
138 | go.opentelemetry.io/otel/trace v1.31.0 // indirect
139 | go.opentelemetry.io/proto/otlp v1.3.1 // indirect
140 | go.uber.org/atomic v1.10.0 // indirect
141 | go.uber.org/multierr v1.9.0 // indirect
142 | golang.org/x/crypto v0.28.0 // indirect
143 | golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
144 | golang.org/x/mod v0.18.0 // indirect
145 | golang.org/x/net v0.30.0 // indirect
146 | golang.org/x/sync v0.8.0 // indirect
147 | golang.org/x/sys v0.26.0 // indirect
148 | golang.org/x/term v0.25.0 // indirect
149 | golang.org/x/tools v0.22.0 // indirect
150 | google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect
151 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 // indirect
152 | gopkg.in/go-playground/assert.v1 v1.2.1 // indirect
153 | gopkg.in/ini.v1 v1.67.0 // indirect
154 | gopkg.in/yaml.v2 v2.4.0 // indirect
155 | gotest.tools/v3 v3.5.1 // indirect
156 | )
157 |
--------------------------------------------------------------------------------
/log/logger.go:
--------------------------------------------------------------------------------
1 | package log
2 |
3 | import (
4 | "io"
5 | )
6 |
7 | // Option modifies the logger behavior
8 | type Option func(interface{})
9 |
10 | // Logger is a convenient interface to use provided loggers
11 | // either use it as it is or implement your own interface where
12 | // the logging implementations are used
13 | // Each log method must take first string as message and then one or
14 | // more key,value arguments.
15 | // For example:
16 | //
17 | // timeTaken := time.Duration(time.Second * 1)
18 | // l.Debug("processed request", "time taken", timeTaken)
19 | //
20 | // here key should always be a `string` and value could be of any type as
21 | // long as it is printable.
22 | //
23 | // l.Info("processed request", "time taken", timeTaken, "started at", startedAt)
24 | type Logger interface {
25 |
26 | // Debug level message with alternating key/value pairs
27 | // key should be string, value could be anything printable
28 | Debug(msg string, args ...interface{})
29 |
30 | // Info level message with alternating key/value pairs
31 | // key should be string, value could be anything printable
32 | Info(msg string, args ...interface{})
33 |
34 | // Warn level message with alternating key/value pairs
35 | // key should be string, value could be anything printable
36 | Warn(msg string, args ...interface{})
37 |
38 | // Error level message with alternating key/value pairs
39 | // key should be string, value could be anything printable
40 | Error(msg string, args ...interface{})
41 |
42 | // Fatal level message with alternating key/value pairs
43 | // key should be string, value could be anything printable
44 | Fatal(msg string, args ...interface{})
45 |
46 | // Level returns priority level for which this logger will filter logs
47 | Level() string
48 |
49 | // Writer used to print logs
50 | Writer() io.Writer
51 | }
52 |
--------------------------------------------------------------------------------
/log/logrus.go:
--------------------------------------------------------------------------------
1 | package log
2 |
3 | import (
4 | "io"
5 |
6 | "github.com/sirupsen/logrus"
7 | )
8 |
9 | type Logrus struct {
10 | log *logrus.Logger
11 | }
12 |
13 | func (l Logrus) getFields(args ...interface{}) map[string]interface{} {
14 | fieldMap := map[string]interface{}{}
15 | if len(args) > 1 && len(args)%2 == 0 {
16 | for i := 1; i < len(args); i += 2 {
17 | fieldMap[args[i-1].(string)] = args[i]
18 | }
19 | }
20 | return fieldMap
21 | }
22 |
23 | func (l *Logrus) Info(msg string, args ...interface{}) {
24 | l.log.WithFields(l.getFields(args...)).Info(msg)
25 | }
26 |
27 | func (l *Logrus) Debug(msg string, args ...interface{}) {
28 | l.log.WithFields(l.getFields(args...)).Debug(msg)
29 | }
30 |
31 | func (l *Logrus) Warn(msg string, args ...interface{}) {
32 | l.log.WithFields(l.getFields(args...)).Warn(msg)
33 | }
34 |
35 | func (l *Logrus) Error(msg string, args ...interface{}) {
36 | l.log.WithFields(l.getFields(args...)).Error(msg)
37 | }
38 |
39 | func (l *Logrus) Fatal(msg string, args ...interface{}) {
40 | l.log.WithFields(l.getFields(args...)).Fatal(msg)
41 | }
42 |
43 | func (l *Logrus) Level() string {
44 | return l.log.Level.String()
45 | }
46 |
47 | func (l *Logrus) Writer() io.Writer {
48 | return l.log.Writer()
49 | }
50 |
51 | func (l *Logrus) Entry(args ...interface{}) *logrus.Entry {
52 | return l.log.WithFields(l.getFields(args...))
53 | }
54 |
55 | func LogrusWithLevel(level string) Option {
56 | return func(logger interface{}) {
57 | logLevel, err := logrus.ParseLevel(level)
58 | if err != nil {
59 | panic(err)
60 | }
61 | logger.(*Logrus).log.SetLevel(logLevel)
62 | }
63 | }
64 |
65 | func LogrusWithWriter(writer io.Writer) Option {
66 | return func(logger interface{}) {
67 | logger.(*Logrus).log.SetOutput(writer)
68 | }
69 | }
70 |
71 | // LogrusWithFormatter can be used to change default formatting
72 | // by implementing logrus.Formatter
73 | // For example:
74 | //
75 | // type PlainFormatter struct{}
76 | // func (p *PlainFormatter) Format(entry *logrus.Entry) ([]byte, error) {
77 | // return []byte(entry.Message), nil
78 | // }
79 | // l := log.NewLogrus(log.LogrusWithFormatter(&PlainFormatter{}))
80 | func LogrusWithFormatter(f logrus.Formatter) Option {
81 | return func(logger interface{}) {
82 | logger.(*Logrus).log.SetFormatter(f)
83 | }
84 | }
85 |
86 | // NewLogrus returns a logrus logger instance with info level as default log level
87 | func NewLogrus(opts ...Option) *Logrus {
88 | logger := &Logrus{
89 | log: logrus.New(),
90 | }
91 | logger.log.Level = logrus.InfoLevel
92 | for _, opt := range opts {
93 | opt(logger)
94 | }
95 | return logger
96 | }
97 |
--------------------------------------------------------------------------------
/log/logrus_test.go:
--------------------------------------------------------------------------------
1 | package log_test
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "fmt"
7 | "testing"
8 |
9 | "github.com/sirupsen/logrus"
10 |
11 | "github.com/raystack/salt/log"
12 |
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | func TestLogrus(t *testing.T) {
17 | t.Run("should parse info messages at debug level correctly", func(t *testing.T) {
18 | var b bytes.Buffer
19 | foo := bufio.NewWriter(&b)
20 |
21 | logger := log.NewLogrus(log.LogrusWithLevel("debug"), log.LogrusWithWriter(foo), log.LogrusWithFormatter(&logrus.TextFormatter{
22 | DisableTimestamp: true,
23 | }))
24 | logger.Info("hello world")
25 | foo.Flush()
26 |
27 | assert.Equal(t, "level=info msg=\"hello world\"\n", b.String())
28 | })
29 | t.Run("should not parse debug messages at info level correctly", func(t *testing.T) {
30 | var b bytes.Buffer
31 | foo := bufio.NewWriter(&b)
32 |
33 | logger := log.NewLogrus(log.LogrusWithLevel("info"), log.LogrusWithWriter(foo), log.LogrusWithFormatter(&logrus.TextFormatter{
34 | DisableTimestamp: true,
35 | }))
36 | logger.Debug("hello world")
37 | foo.Flush()
38 |
39 | assert.Equal(t, "", b.String())
40 | })
41 | t.Run("should parse field maps correctly", func(t *testing.T) {
42 | var b bytes.Buffer
43 | foo := bufio.NewWriter(&b)
44 |
45 | logger := log.NewLogrus(log.LogrusWithLevel("debug"), log.LogrusWithWriter(foo), log.LogrusWithFormatter(&logrus.TextFormatter{
46 | DisableTimestamp: true,
47 | }))
48 | logger.Debug("current values", "day", 11, "month", "aug")
49 | foo.Flush()
50 |
51 | assert.Equal(t, "level=debug msg=\"current values\" day=11 month=aug\n", b.String())
52 | })
53 | t.Run("should handle errors correctly", func(t *testing.T) {
54 | var b bytes.Buffer
55 | foo := bufio.NewWriter(&b)
56 |
57 | logger := log.NewLogrus(log.LogrusWithLevel("info"), log.LogrusWithWriter(foo), log.LogrusWithFormatter(&logrus.TextFormatter{
58 | DisableTimestamp: true,
59 | }))
60 | var err = fmt.Errorf("request failed")
61 | logger.Error(err.Error(), "hello", "world")
62 | foo.Flush()
63 | assert.Equal(t, "level=error msg=\"request failed\" hello=world\n", b.String())
64 | })
65 | t.Run("should ignore params if malformed", func(t *testing.T) {
66 | var b bytes.Buffer
67 | foo := bufio.NewWriter(&b)
68 |
69 | logger := log.NewLogrus(log.LogrusWithLevel("info"), log.LogrusWithWriter(foo), log.LogrusWithFormatter(&logrus.TextFormatter{
70 | DisableTimestamp: true,
71 | }))
72 | var err = fmt.Errorf("request failed")
73 | logger.Error(err.Error(), "hello", "world", "!")
74 | foo.Flush()
75 | assert.Equal(t, "level=error msg=\"request failed\"\n", b.String())
76 | })
77 | }
78 |
--------------------------------------------------------------------------------
/log/noop.go:
--------------------------------------------------------------------------------
1 | package log
2 |
3 | import (
4 | "io"
5 | "io/ioutil"
6 | )
7 |
8 | type Noop struct{}
9 |
10 | func (n *Noop) Info(msg string, args ...interface{}) {}
11 | func (n *Noop) Debug(msg string, args ...interface{}) {}
12 | func (n *Noop) Warn(msg string, args ...interface{}) {}
13 | func (n *Noop) Error(msg string, args ...interface{}) {}
14 | func (n *Noop) Fatal(msg string, args ...interface{}) {}
15 |
16 | func (n *Noop) Level() string {
17 | return "unsupported"
18 | }
19 | func (n *Noop) Writer() io.Writer {
20 | return ioutil.Discard
21 | }
22 |
23 | // NewNoop returns a no operation logger, useful in tests
24 | func NewNoop(opts ...Option) *Noop {
25 | return &Noop{}
26 | }
27 |
--------------------------------------------------------------------------------
/log/zap.go:
--------------------------------------------------------------------------------
1 | package log
2 |
3 | import (
4 | "context"
5 | "io"
6 |
7 | "go.uber.org/zap"
8 | )
9 |
10 | type Zap struct {
11 | log *zap.SugaredLogger
12 | conf zap.Config
13 | }
14 |
15 | type ctxKey string
16 |
17 | var loggerCtxKey = ctxKey("zapLoggerCtxKey")
18 |
19 | func (z Zap) Debug(msg string, args ...interface{}) {
20 | z.log.With(args...).Debug(msg)
21 | }
22 |
23 | func (z Zap) Info(msg string, args ...interface{}) {
24 | z.log.With(args...).Info(msg)
25 | }
26 |
27 | func (z Zap) Warn(msg string, args ...interface{}) {
28 | z.log.With(args...).Warn(msg, args)
29 | }
30 |
31 | func (z Zap) Error(msg string, args ...interface{}) {
32 | z.log.With(args...).Error(msg, args)
33 | }
34 |
35 | func (z Zap) Fatal(msg string, args ...interface{}) {
36 | z.log.With(args...).Fatal(msg, args)
37 | }
38 |
39 | func (z Zap) Level() string {
40 | return z.conf.Level.String()
41 | }
42 |
43 | func (z Zap) Writer() io.Writer {
44 | panic("not supported")
45 | }
46 |
47 | func ZapWithConfig(conf zap.Config, opts ...zap.Option) Option {
48 | return func(z interface{}) {
49 | z.(*Zap).conf = conf
50 | prodLogger, err := z.(*Zap).conf.Build(opts...)
51 | if err != nil {
52 | panic(err)
53 | }
54 | z.(*Zap).log = prodLogger.Sugar()
55 | }
56 | }
57 |
58 | // GetInternalZapLogger Gets internal SugaredLogger instance
59 | func (z Zap) GetInternalZapLogger() *zap.SugaredLogger {
60 | return z.log
61 | }
62 |
63 | // NewContext will add Zap inside context
64 | func (z Zap) NewContext(ctx context.Context) context.Context {
65 | return context.WithValue(ctx, loggerCtxKey, z)
66 | }
67 |
68 | // ZapContextWithFields will add Zap Fields to logger in Context
69 | func ZapContextWithFields(ctx context.Context, fields ...zap.Field) context.Context {
70 | return context.WithValue(ctx, loggerCtxKey, Zap{
71 | // Error when not Desugaring when adding fields: github.com/ipfs/go-log/issues/85
72 | log: ZapFromContext(ctx).GetInternalZapLogger().Desugar().With(fields...).Sugar(),
73 | conf: ZapFromContext(ctx).conf,
74 | })
75 | }
76 |
77 | // ZapFromContext will help in fetching back zap logger from context
78 | func ZapFromContext(ctx context.Context) Zap {
79 | if ctxLogger, ok := ctx.Value(loggerCtxKey).(Zap); ok {
80 | return ctxLogger
81 | }
82 |
83 | return Zap{}
84 | }
85 |
86 | func ZapWithNoop() Option {
87 | return func(z interface{}) {
88 | z.(*Zap).log = zap.NewNop().Sugar()
89 | z.(*Zap).conf = zap.Config{}
90 | }
91 | }
92 |
93 | // NewZap returns a zap logger instance with info level as default log level
94 | func NewZap(opts ...Option) *Zap {
95 | defaultConfig := zap.NewProductionConfig()
96 | defaultConfig.Level.SetLevel(zap.InfoLevel)
97 | logger, err := defaultConfig.Build()
98 | if err != nil {
99 | panic(err)
100 | }
101 |
102 | zapper := &Zap{
103 | log: logger.Sugar(),
104 | conf: defaultConfig,
105 | }
106 | for _, opt := range opts {
107 | opt(zapper)
108 | }
109 | return zapper
110 | }
111 |
--------------------------------------------------------------------------------
/log/zap_test.go:
--------------------------------------------------------------------------------
1 | package log_test
2 |
3 | import (
4 | "bufio"
5 | "bytes"
6 | "context"
7 | "crypto/rand"
8 | "fmt"
9 | "io"
10 | "net/url"
11 | "testing"
12 | "time"
13 |
14 | "github.com/stretchr/testify/assert"
15 |
16 | "go.uber.org/zap"
17 |
18 | "github.com/raystack/salt/log"
19 | )
20 |
21 | type zapBufWriter struct {
22 | io.Writer
23 | }
24 |
25 | func (cw zapBufWriter) Close() error {
26 | return nil
27 | }
28 | func (cw zapBufWriter) Sync() error {
29 | return nil
30 | }
31 |
32 | type zapClock struct {
33 | t time.Time
34 | }
35 |
36 | func (m zapClock) Now() time.Time {
37 | return m.t
38 | }
39 |
40 | func (m zapClock) NewTicker(duration time.Duration) *time.Ticker {
41 | return time.NewTicker(duration)
42 | }
43 |
44 | func buildBufferedZapOption(writer io.Writer, t time.Time, bufWriterKey string) log.Option {
45 | config := zap.NewDevelopmentConfig()
46 | config.DisableCaller = true
47 | // register mock writer
48 | _ = zap.RegisterSink(bufWriterKey, func(u *url.URL) (zap.Sink, error) {
49 | return zapBufWriter{writer}, nil
50 | })
51 | // build a valid custom path
52 | customPath := fmt.Sprintf("%s:", bufWriterKey)
53 | config.OutputPaths = []string{customPath}
54 |
55 | return log.ZapWithConfig(config, zap.WithClock(&zapClock{
56 | t: t,
57 | }))
58 | }
59 |
60 | func TestZap(t *testing.T) {
61 | mockedTime := time.Date(2021, 6, 10, 11, 55, 0, 0, time.UTC)
62 |
63 | t.Run("should successfully print at info level", func(t *testing.T) {
64 | var b bytes.Buffer
65 | bWriter := bufio.NewWriter(&b)
66 |
67 | zapper := log.NewZap(buildBufferedZapOption(bWriter, mockedTime, randomString(10)))
68 | zapper.Info("hello", "wor", "ld")
69 | bWriter.Flush()
70 |
71 | assert.Equal(t, mockedTime.Format("2006-01-02T15:04:05.000Z0700")+"\tINFO\thello\t{\"wor\": \"ld\"}\n", b.String())
72 | })
73 |
74 | t.Run("should successfully print log from context", func(t *testing.T) {
75 | var b bytes.Buffer
76 | bWriter := bufio.NewWriter(&b)
77 |
78 | zapper := log.NewZap(buildBufferedZapOption(bWriter, mockedTime, randomString(10)))
79 | ctx := zapper.NewContext(context.Background())
80 | contextualLog := log.ZapFromContext(ctx)
81 | contextualLog.Info("hello", "wor", "ld")
82 | bWriter.Flush()
83 |
84 | assert.Equal(t, mockedTime.Format("2006-01-02T15:04:05.000Z0700")+"\tINFO\thello\t{\"wor\": \"ld\"}\n", b.String())
85 | })
86 |
87 | t.Run("should successfully print log from context with fields", func(t *testing.T) {
88 | var b bytes.Buffer
89 | bWriter := bufio.NewWriter(&b)
90 |
91 | zapper := log.NewZap(buildBufferedZapOption(bWriter, mockedTime, randomString(10)))
92 | ctx := zapper.NewContext(context.Background())
93 | ctx = log.ZapContextWithFields(ctx, zap.Int("one", 1))
94 | ctx = log.ZapContextWithFields(ctx, zap.String("two", "two"))
95 | log.ZapFromContext(ctx).Info("hello", "wor", "ld")
96 | bWriter.Flush()
97 |
98 | assert.Equal(t, mockedTime.Format("2006-01-02T15:04:05.000Z0700")+"\tINFO\thello\t{\"one\": 1, \"two\": \"two\", \"wor\": \"ld\"}\n", b.String())
99 | })
100 | }
101 |
102 | func randomString(n int) string {
103 | const alphabets = "abcdefghijklmnopqrstuvwxyz"
104 | var alphaBytes = make([]byte, n)
105 | rand.Read(alphaBytes)
106 | for i, b := range alphaBytes {
107 | alphaBytes[i] = alphabets[b%byte(len(alphabets))]
108 | }
109 | return string(alphaBytes)
110 | }
111 |
--------------------------------------------------------------------------------
/rql/README.md:
--------------------------------------------------------------------------------
1 | # RQL (Rest Query Language)
2 |
3 | A library to parse support advanced REST API query parameters like (filter, pagination, sort, group, search) and logical operators on the keys (like eq, neq, like, gt, lt etc)
4 |
5 | It takes a Golang struct and a json string as input and returns a Golang object that can be used to prepare SQL Statements (using raw sql or ORM Query builders).
6 |
7 | ### Usage
8 |
9 | Frontend should send the parameters and operator like this schema to the backend service on some route with `POST` HTTP Method
10 |
11 | ```json
12 | {
13 | "filters": [
14 | { "name": "id", "operator": "neq", "value": 20 },
15 | { "name": "title", "operator": "neq", "value": "nasa" },
16 | { "name": "enabled", "operator": "eq", "value": false },
17 | {
18 | "name": "created_at",
19 | "operator": "gte",
20 | "value": "2025-02-05T11:25:37.957Z"
21 | },
22 | { "name": "title", "operator": "like", "value": "xyz" }
23 | ],
24 | "group_by": ["plan_name"],
25 | "offset": 20,
26 | "limit": 50,
27 | "search": "abcd",
28 | "sort": [
29 | { "name": "title", "order": "desc" },
30 | { "name": "created_at", "order": "asc" }
31 | ]
32 | }
33 | ```
34 |
35 | The `rql` library can be used to parse this json, validate it and returns a Struct containing all the info to generate the operations and values for SQL.
36 |
37 | The validation happens via stuct tags defined on your model. Example:
38 |
39 | ```golang
40 | type Organization struct {
41 | Id int `rql:"name=id,type=number,min=10,max=200"`
42 | BillingPlanName string `rql:"name=plan_name,type=string"`
43 | CreatedAt time.Time `rql:"name=created_at,type=datetime"`
44 | MemberCount int `rql:"name=member_count,type=number"`
45 | Title string `rql:"name=title,type=string"`
46 | Enabled bool `rql:"name=enabled,type=bool"`
47 | }
48 |
49 | ```
50 |
51 | **Supported data types:**
52 |
53 | 1. number
54 | 2. string
55 | 3. datetime
56 | 4. bool
57 |
58 | Check `main.go` for more info on usage.
59 |
60 | Using this struct, a SQL query can be generated. Here is an example using `goqu` SQL Builder
61 |
62 | ```go
63 | //init the library's "Query" object with input json bytes
64 | userInput := &parser.Query{}
65 |
66 | //assuming jsonBytes is defined earlier
67 | err = json.Unmarshal(jsonBytes, userInput)
68 | if err != nil {
69 | panic(fmt.Sprintf("failed to unmarshal query string to parser query struct, err:%s", err.Error()))
70 | }
71 |
72 | //validate the json input
73 | err = parser.ValidateQuery(userInput, Organization{})
74 | if err != nil {
75 | panic(err)
76 | }
77 |
78 | //userInput object can be utilized to prepare SQL statement
79 | query := goqu.From("organizations")
80 |
81 | fuzzySearchColumns := []string{"id", "billing_plan_name", "title"}
82 |
83 | for _, filter_item := range userInput.Filters {
84 | query = query.Where(goqu.Ex{
85 | filter_item.Name: goqu.Op{filter_item.Operator: filter_item.Value},
86 | })
87 | }
88 |
89 | listOfExpressions := make([]goqu.Expression, 0)
90 |
91 | if userInput.Search != "" {
92 | for _, col := range fuzzySearchColumns {
93 | listOfExpressions = append(listOfExpressions, goqu.Ex{
94 | col: goqu.Op{"LIKE": userInput.Search},
95 | })
96 | }
97 | }
98 |
99 | query = query.Where(goqu.Or(listOfExpressions...))
100 |
101 | query = query.Offset(uint(userInput.Offset))
102 | for _, sort_item := range userInput.Sort {
103 | switch sort_item.Order {
104 | case "asc":
105 | query = query.OrderAppend(goqu.C(sort_item.Name).Asc())
106 | case "desc":
107 | query = query.OrderAppend(goqu.C(sort_item.Name).Desc())
108 | default:
109 | }
110 | }
111 | query = query.Limit(uint(userInput.Limit))
112 | sql, _, _ := query.ToSQL()
113 | fmt.Println(sql)
114 |
115 |
116 | ```
117 |
118 | giving output as
119 |
120 | ```sql
121 | SELECT * FROM "organizations" WHERE (("id" != 20) AND ("title" != 'nasa') AND ("enabled" IS FALSE) AND ("createdAt" >= '2025-02-05T11:25:37.957Z') AND ("title" LIKE 'xyz') AND (("id" LIKE 'abcd') OR ("billing_plan_name" LIKE 'abcd') OR ("title" LIKE 'abcd'))) ORDER BY "title" DESC, "createdAt" ASC LIMIT 50 OFFSET 20
122 | ```
123 |
124 | ### Improvements
125 |
126 | 1. The operators need to mapped with SQL operators like (`eq` should be converted to `=` etc). Right now we are relying on GoQU to do that, but we can make it SQL ORL lib agnostic.
127 |
128 | 2. Support validation on the range or values of the data. Like `min`, `max` on number etc.
129 |
--------------------------------------------------------------------------------
/rql/parser.go:
--------------------------------------------------------------------------------
1 | package rql
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "slices"
7 | "strings"
8 | "time"
9 | )
10 |
11 | var validNumberOperations = []string{"eq", "neq", "gt", "lt", "gte", "lte"}
12 | var validStringOperations = []string{"eq", "neq", "like", "in", "notin", "notlike", "empty", "notempty"}
13 | var validBoolOperations = []string{"eq", "neq"}
14 | var validDatetimeOperations = []string{"eq", "neq", "gt", "lt", "gte", "lte"}
15 |
16 | const TAG = "rql"
17 | const DATATYPE_NUMBER = "number"
18 | const DATATYPE_DATETIME = "datetime"
19 | const DATATYPE_STRING = "string"
20 | const DATATYPE_BOOL = "bool"
21 | const SORT_ORDER_ASC = "asc"
22 | const SORT_ORDER_DESC = "desc"
23 |
24 | var validSortOrder = []string{SORT_ORDER_ASC, SORT_ORDER_DESC}
25 |
26 | type Query struct {
27 | Filters []Filter `json:"filters"`
28 | GroupBy []string `json:"group_by"`
29 | Offset int `json:"offset"`
30 | Limit int `json:"limit"`
31 | Search string `json:"search"`
32 | Sort []Sort `json:"sort"`
33 | }
34 |
35 | type Filter struct {
36 | Name string `json:"name"`
37 | Operator string `json:"operator"`
38 | dataType string
39 | Value any `json:"value"`
40 | }
41 |
42 | type Sort struct {
43 | Name string `json:"name"`
44 | Order string `json:"order"`
45 | }
46 |
47 | func ValidateQuery(q *Query, checkStruct interface{}) error {
48 | val := reflect.ValueOf(checkStruct)
49 |
50 | // validate filters
51 | for _, filterItem := range q.Filters {
52 | //validate filter key name
53 | filterIdx := searchKeyInsideStruct(filterItem.Name, val)
54 | if filterIdx < 0 {
55 | return fmt.Errorf("'%s' is not a valid filter key", filterItem.Name)
56 | }
57 | structKeyTag := val.Type().Field(filterIdx).Tag.Get(TAG)
58 |
59 | // validate filter key data type
60 | allowedDataType := getDataTypeOfField(structKeyTag)
61 | filterItem.dataType = allowedDataType
62 | switch allowedDataType {
63 | case DATATYPE_NUMBER:
64 | err := validateNumberType(filterItem)
65 | if err != nil {
66 | return err
67 | }
68 | case DATATYPE_BOOL:
69 | err := validateBoolType(filterItem)
70 | if err != nil {
71 | return err
72 | }
73 | case DATATYPE_DATETIME:
74 | err := validateDatetimeType(filterItem)
75 | if err != nil {
76 | return err
77 | }
78 | case DATATYPE_STRING:
79 | err := validateStringType(filterItem)
80 | if err != nil {
81 | return err
82 | }
83 | default:
84 | return fmt.Errorf("type '%s' is not recognized", allowedDataType)
85 | }
86 |
87 | if !isValidOperator(filterItem) {
88 | return fmt.Errorf("value '%s' for key '%s' is valid string", filterItem.Operator, filterItem.Name)
89 | }
90 | }
91 |
92 | err := validateGroupByKeys(q, val)
93 | if err != nil {
94 | return err
95 | }
96 | return validateSortKey(q, val)
97 | }
98 |
99 | func validateNumberType(filterItem Filter) error {
100 | // check if the type is any of Golang numeric types
101 | // if not, return error
102 | switch filterItem.Value.(type) {
103 | case uint8, uint16, uint32, uint64, int8, int16, int32, int64, float32, float64, int, uint:
104 | return nil
105 | default:
106 | return fmt.Errorf("value %v for key '%s' is not int type", filterItem.Value, filterItem.Name)
107 | }
108 | }
109 |
110 | func validateDatetimeType(filterItem Filter) error {
111 | // cast the value to datetime
112 | // if failed, return error
113 | castedVal, ok := filterItem.Value.(string)
114 | if !ok {
115 | return fmt.Errorf("value %s for key '%s' is not a valid ISO datetime string", filterItem.Value, filterItem.Name)
116 | }
117 | _, err := time.Parse(time.RFC3339, castedVal)
118 | if err != nil {
119 | return fmt.Errorf("value %s for key '%s' is not a valid ISO datetime string", filterItem.Value, filterItem.Name)
120 | }
121 | return nil
122 | }
123 |
124 | func validateBoolType(filterItem Filter) error {
125 | // cast the value to bool
126 | // if failed, return error
127 | _, ok := filterItem.Value.(bool)
128 | if !ok {
129 | return fmt.Errorf("value %v for key '%s' is not bool type", filterItem.Value, filterItem.Name)
130 | }
131 | return nil
132 | }
133 |
134 | func validateStringType(filterItem Filter) error {
135 | // cast the value to string
136 | // if failed, return error
137 | _, ok := filterItem.Value.(string)
138 | if !ok {
139 | return fmt.Errorf("value %s for key '%s' is valid string type", filterItem.Value, filterItem.Name)
140 | }
141 | return nil
142 | }
143 |
144 | func searchKeyInsideStruct(keyName string, val reflect.Value) int {
145 | normalizedKey := strings.ToLower(keyName)
146 |
147 | for i := 0; i < val.NumField(); i++ {
148 | field := val.Type().Field(i)
149 |
150 | // Check field name
151 | if strings.ToLower(field.Name) == normalizedKey {
152 | return i
153 | }
154 |
155 | // Check rql tag
156 | if tag, ok := field.Tag.Lookup("rql"); ok {
157 | // Parse the tag string
158 | tagParts := strings.Split(tag, ",")
159 | for _, part := range tagParts {
160 | if strings.HasPrefix(part, "name=") {
161 | tagName := strings.TrimPrefix(part, "name=")
162 | if strings.ToLower(tagName) == normalizedKey {
163 | return i
164 | }
165 | }
166 | }
167 | }
168 | }
169 |
170 | return -1
171 | }
172 |
173 | // parse the tag schema which is of the format
174 | // type=int,min=10,max=200
175 | // to extract type else fallback to string
176 | func getDataTypeOfField(tagString string) string {
177 | res := DATATYPE_STRING
178 | splitted := strings.Split(tagString, ",")
179 | for _, item := range splitted {
180 | kvSplitted := strings.Split(item, "=")
181 | if len(kvSplitted) == 2 {
182 | if kvSplitted[0] == "type" {
183 | return kvSplitted[1]
184 | }
185 | }
186 | }
187 | //fallback to string if type not found in tag value
188 | return res
189 | }
190 |
191 | func GetDataTypeOfField(fieldName string, checkStruct interface{}) (string, error) {
192 | val := reflect.ValueOf(checkStruct)
193 | filterIdx := searchKeyInsideStruct(fieldName, val)
194 | if filterIdx < 0 {
195 | return "", fmt.Errorf("'%s' is not a valid field", fieldName)
196 | }
197 | structKeyTag := val.Type().Field(filterIdx).Tag.Get(TAG)
198 | dataType := getDataTypeOfField(structKeyTag)
199 | if !slices.Contains([]string{DATATYPE_STRING, DATATYPE_BOOL, DATATYPE_NUMBER, DATATYPE_DATETIME}, dataType) {
200 | return "", fmt.Errorf("invalid datatype '%s' is for field %s", dataType, fieldName)
201 | }
202 | return dataType, nil
203 | }
204 |
205 | func isValidOperator(filterItem Filter) bool {
206 | switch filterItem.dataType {
207 | case DATATYPE_NUMBER:
208 | return slices.Contains(validNumberOperations, filterItem.Operator)
209 | case DATATYPE_DATETIME:
210 | return slices.Contains(validDatetimeOperations, filterItem.Operator)
211 | case DATATYPE_STRING:
212 | return slices.Contains(validStringOperations, filterItem.Operator)
213 | case DATATYPE_BOOL:
214 | return slices.Contains(validBoolOperations, filterItem.Operator)
215 | default:
216 | return false
217 | }
218 | }
219 |
220 | func validateSortKey(q *Query, val reflect.Value) error {
221 | for _, item := range q.Sort {
222 | filterIdx := searchKeyInsideStruct(item.Name, val)
223 | if filterIdx < 0 {
224 | return fmt.Errorf("'%s' is not a valid sort key", item.Name)
225 | }
226 | if !slices.Contains(validSortOrder, item.Order) {
227 | return fmt.Errorf("'%s' is not a valid sort key", item.Name)
228 | }
229 | }
230 | return nil
231 | }
232 |
233 | func validateGroupByKeys(q *Query, val reflect.Value) error {
234 | for _, item := range q.GroupBy {
235 | filterIdx := searchKeyInsideStruct(item, val)
236 | if filterIdx < 0 {
237 | return fmt.Errorf("'%s' is not a valid sort key", item)
238 | }
239 | }
240 | return nil
241 | }
242 |
--------------------------------------------------------------------------------
/rql/parser_test.go:
--------------------------------------------------------------------------------
1 | package rql
2 |
3 | import (
4 | "strings"
5 | "testing"
6 | "time"
7 | )
8 |
9 | func TestValidateQuery(t *testing.T) {
10 | type TestStruct struct {
11 | ID int32 `rql:"name=id,type=number"`
12 | Name string `rql:"name=name,type=string"`
13 | IsActive bool `rql:"name=is_active,type=bool"`
14 | CreatedAt time.Time `rql:"name=created_at,type=datetime"`
15 | }
16 |
17 | tests := []struct {
18 | name string
19 | query Query
20 | checkStruct TestStruct
21 | expectErr bool
22 | }{
23 | {
24 | name: "Valid filters and sort",
25 | query: Query{
26 | Filters: []Filter{
27 | {Name: "ID", Operator: "eq", Value: 123},
28 | {Name: "Name", Operator: "like", Value: "test"},
29 | {Name: "is_active", Operator: "eq", Value: true},
30 | {Name: "created_at", Operator: "eq", Value: "2021-09-15T15:53:00Z"},
31 | },
32 | Sort: []Sort{
33 | {Name: "ID", Order: "asc"},
34 | },
35 | },
36 | checkStruct: TestStruct{},
37 | expectErr: false,
38 | },
39 | {
40 | name: "Invalid filter key",
41 | query: Query{
42 | Filters: []Filter{
43 | {Name: "NonExistentKey", Operator: "eq", Value: "test"},
44 | },
45 | },
46 | checkStruct: TestStruct{},
47 | expectErr: true,
48 | },
49 | {
50 | name: "Invalid filter operator",
51 | query: Query{
52 | Filters: []Filter{
53 | {Name: "ID", Operator: "invalid", Value: 123},
54 | },
55 | },
56 | checkStruct: TestStruct{},
57 | expectErr: true,
58 | },
59 | {
60 | name: "Invalid filter value type",
61 | query: Query{
62 | Filters: []Filter{
63 | {Name: "ID", Operator: "eq", Value: "invalid"},
64 | },
65 | },
66 | checkStruct: TestStruct{},
67 | expectErr: true,
68 | },
69 | {
70 | name: "Invalid sort key",
71 | query: Query{
72 | Sort: []Sort{
73 | {Name: "NonExistentKey", Order: "asc"},
74 | },
75 | },
76 | checkStruct: TestStruct{},
77 | expectErr: true,
78 | },
79 | }
80 |
81 | for _, tt := range tests {
82 | t.Run(tt.name, func(t *testing.T) {
83 | err := ValidateQuery(&tt.query, tt.checkStruct)
84 | if (err != nil) != tt.expectErr {
85 | t.Errorf("ValidateQuery() error = %v, expectErr %v", err, tt.expectErr)
86 | }
87 | })
88 | }
89 | }
90 |
91 | func TestGetDataTypeOfField(t *testing.T) {
92 | type TestStruct struct {
93 | StringField string `rql:"name=string_field,type=string"`
94 | NumberField int `rql:"name=number_field,type=number"`
95 | BoolField bool `rql:"name=bool_field,type=bool"`
96 | DateTimeField time.Time `rql:"name=datetime_field,type=datetime"`
97 | InvalidField string `rql:"name=invalid_field,type=invalid"`
98 | NoTypeField string `rql:"name=no_type_field"` // No type specified
99 | NoTagField string // No tag at all
100 | }
101 |
102 | tests := []struct {
103 | name string
104 | fieldName string
105 | expectedType string
106 | expectedError bool
107 | errorContains string
108 | }{
109 | {
110 | name: "String field by struct name",
111 | fieldName: "StringField",
112 | expectedType: "string",
113 | expectedError: false,
114 | },
115 | {
116 | name: "String field by tag name",
117 | fieldName: "string_field",
118 | expectedType: "string",
119 | expectedError: false,
120 | },
121 | {
122 | name: "Number field by struct name",
123 | fieldName: "NumberField",
124 | expectedType: "number",
125 | expectedError: false,
126 | },
127 | {
128 | name: "Number field by tag name",
129 | fieldName: "number_field",
130 | expectedType: "number",
131 | expectedError: false,
132 | },
133 | {
134 | name: "Bool field by struct name",
135 | fieldName: "BoolField",
136 | expectedType: "bool",
137 | expectedError: false,
138 | },
139 | {
140 | name: "DateTime field by struct name",
141 | fieldName: "DateTimeField",
142 | expectedType: "datetime",
143 | expectedError: false,
144 | },
145 | {
146 | name: "Invalid field name",
147 | fieldName: "NonExistentField",
148 | expectedType: "",
149 | expectedError: true,
150 | errorContains: "is not a valid field",
151 | },
152 | {
153 | name: "No type specified in tag",
154 | fieldName: "NoTypeField",
155 | expectedType: "string", // Should default to string
156 | expectedError: false,
157 | },
158 | {
159 | name: "No tag field",
160 | fieldName: "NoTagField",
161 | expectedType: "string", // Should default to string
162 | expectedError: false,
163 | },
164 | }
165 |
166 | testStruct := TestStruct{}
167 |
168 | for _, tt := range tests {
169 | t.Run(tt.name, func(t *testing.T) {
170 | dataType, err := GetDataTypeOfField(tt.fieldName, testStruct)
171 |
172 | // Check error cases
173 | if tt.expectedError {
174 | if err == nil {
175 | t.Errorf("Expected error but got none")
176 | return
177 | }
178 | if !strings.Contains(err.Error(), tt.errorContains) {
179 | t.Errorf("Expected error containing '%s', got '%s'", tt.errorContains, err.Error())
180 | }
181 | return
182 | }
183 |
184 | // Check success cases
185 | if err != nil {
186 | t.Errorf("Unexpected error: %v", err)
187 | return
188 | }
189 |
190 | if dataType != tt.expectedType {
191 | t.Errorf("Expected type '%s', got '%s'", tt.expectedType, dataType)
192 | }
193 | })
194 | }
195 | }
196 |
--------------------------------------------------------------------------------
/server/mux/README.md:
--------------------------------------------------------------------------------
1 | # Mux
2 |
3 | `mux` package provides helpers for starting multiple servers. HTTP and gRPC
4 | servers are supported currently.
5 |
6 | ## Usage
7 |
8 | ```go
9 | package main
10 |
11 | import (
12 | "context"
13 | "log"
14 | "net/http"
15 | "os/signal"
16 | "syscall"
17 | "time"
18 |
19 | "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
20 | "github.com/raystack/salt/common"
21 | "github.com/raystack/salt/mux"
22 | "google.golang.org/grpc"
23 | "google.golang.org/grpc/reflection"
24 | )
25 |
26 | func main() {
27 | ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
28 | defer cancel()
29 |
30 | grpcServer := grpc.NewServer()
31 |
32 | reflection.Register(grpcServer)
33 |
34 | grpcGateway := runtime.NewServeMux()
35 |
36 | httpMux := http.NewServeMux()
37 | httpMux.Handle("/api/", http.StripPrefix("/api", grpcGateway))
38 |
39 | log.Fatalf("server exited: %v", mux.Serve(
40 | ctx,
41 | mux.WithHTTPTarget(":8080", &http.Server{
42 | Handler: httpMux,
43 | ReadTimeout: 120 * time.Second,
44 | WriteTimeout: 120 * time.Second,
45 | MaxHeaderBytes: 1 << 20,
46 | }),
47 | mux.WithGRPCTarget(":8081", grpcServer),
48 | mux.WithGracePeriod(5*time.Second),
49 | ))
50 | }
51 |
52 | type SlowCommonService struct {
53 | *common.CommonService
54 | }
55 |
56 | func (s SlowCommonService) GetVersion(ctx context.Context, req *commonv1.GetVersionRequest) (*commonv1.GetVersionResponse, error) {
57 | for i := 0; i < 5; i++ {
58 | log.Printf("dooing stuff")
59 | time.Sleep(1 * time.Second)
60 | }
61 | return s.CommonService.GetVersion(ctx, req)
62 | }
63 | ```
64 |
--------------------------------------------------------------------------------
/server/mux/mux.go:
--------------------------------------------------------------------------------
1 | package mux
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "log"
8 | "net"
9 | "time"
10 |
11 | "github.com/oklog/run"
12 | )
13 |
14 | const (
15 | defaultGracePeriod = 10 * time.Second
16 | )
17 |
18 | // Serve starts TCP listeners and serves the registered protocol servers of the
19 | // given serveTarget(s) and blocks until the servers exit. Context can be
20 | // cancelled to perform graceful shutdown.
21 | func Serve(ctx context.Context, opts ...Option) error {
22 | mux := muxServer{gracePeriod: defaultGracePeriod}
23 | for _, opt := range opts {
24 | if err := opt(&mux); err != nil {
25 | return err
26 | }
27 | }
28 |
29 | if len(mux.targets) == 0 {
30 | return errors.New("mux serve: at least one serve target must be set")
31 | }
32 |
33 | return mux.Serve(ctx)
34 | }
35 |
36 | type muxServer struct {
37 | targets []serveTarget
38 | gracePeriod time.Duration
39 | }
40 |
41 | func (mux *muxServer) Serve(ctx context.Context) error {
42 | var g run.Group
43 | for _, t := range mux.targets {
44 | l, err := net.Listen("tcp", t.Address())
45 | if err != nil {
46 | return fmt.Errorf("mux serve: %w", err)
47 | }
48 |
49 | t := t // redeclare to avoid referring to updated value inside closures.
50 | g.Add(func() error {
51 | err := t.Serve(l)
52 | if err != nil {
53 | log.Print("[ERROR] Serve:", err)
54 | }
55 | return err
56 | }, func(error) {
57 | ctx, cancel := context.WithTimeout(context.Background(), mux.gracePeriod)
58 | defer cancel()
59 |
60 | if err := t.Shutdown(ctx); err != nil {
61 | log.Print("[ERROR] Shutdown server gracefully:", err)
62 | }
63 | })
64 | }
65 |
66 | g.Add(func() error {
67 | <-ctx.Done()
68 | return ctx.Err()
69 | }, func(error) {
70 | })
71 |
72 | return g.Run()
73 | }
74 |
--------------------------------------------------------------------------------
/server/mux/option.go:
--------------------------------------------------------------------------------
1 | package mux
2 |
3 | import (
4 | "net/http"
5 | "time"
6 |
7 | "google.golang.org/grpc"
8 | )
9 |
10 | // Option values can be used with Serve() for customisation.
11 | type Option func(m *muxServer) error
12 |
13 | func WithHTTPTarget(addr string, srv *http.Server) Option {
14 | srv.Addr = addr
15 | return func(m *muxServer) error {
16 | m.targets = append(m.targets, httpServeTarget{Server: srv})
17 | return nil
18 | }
19 | }
20 |
21 | func WithGRPCTarget(addr string, srv *grpc.Server) Option {
22 | return func(m *muxServer) error {
23 | m.targets = append(m.targets, gRPCServeTarget{
24 | Addr: addr,
25 | Server: srv,
26 | })
27 | return nil
28 | }
29 | }
30 |
31 | // WithGracePeriod sets the wait duration for graceful shutdown.
32 | func WithGracePeriod(d time.Duration) Option {
33 | return func(m *muxServer) error {
34 | if d <= 0 {
35 | d = defaultGracePeriod
36 | }
37 | m.gracePeriod = d
38 | return nil
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/server/mux/serve_target.go:
--------------------------------------------------------------------------------
1 | package mux
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net"
7 | "net/http"
8 |
9 | "google.golang.org/grpc"
10 | )
11 |
12 | type serveTarget interface {
13 | Address() string
14 | Serve(l net.Listener) error
15 | Shutdown(ctx context.Context) error
16 | }
17 |
18 | type httpServeTarget struct {
19 | *http.Server
20 | }
21 |
22 | func (h httpServeTarget) Address() string { return h.Addr }
23 |
24 | func (h httpServeTarget) Serve(l net.Listener) error {
25 | if err := h.Server.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) {
26 | return err
27 | }
28 | return nil
29 | }
30 |
31 | type gRPCServeTarget struct {
32 | Addr string
33 | *grpc.Server
34 | }
35 |
36 | func (g gRPCServeTarget) Address() string { return g.Addr }
37 |
38 | func (g gRPCServeTarget) Shutdown(ctx context.Context) error {
39 | signal := make(chan struct{})
40 | go func() {
41 | defer close(signal)
42 |
43 | g.GracefulStop()
44 | }()
45 |
46 | select {
47 | case <-ctx.Done():
48 | g.Stop()
49 | return errors.New("graceful stop failed")
50 |
51 | case <-signal:
52 | }
53 |
54 | return nil
55 | }
56 |
--------------------------------------------------------------------------------
/server/spa/doc.go:
--------------------------------------------------------------------------------
1 | /*
2 | Package spa provides a simple and efficient HTTP handler for serving
3 | Single Page Applications (SPAs).
4 |
5 | The handler serves static files from an embedded file system and falls
6 | back to serving an index file for client-side routing. Optionally, it
7 | supports gzip compression for optimizing responses.
8 |
9 | Features:
10 | - Serves static assets from an embedded file system.
11 | - Fallback to an index file for client-side routing.
12 | - Optional gzip compression for supported clients.
13 |
14 | Usage:
15 |
16 | To use this package, embed your SPA's build assets into your binary using
17 | the `embed` package. Then, create an SPA handler using the `Handler` function
18 | and register it with an HTTP server.
19 |
20 | Example:
21 |
22 | package main
23 |
24 | import (
25 | "embed"
26 | "log"
27 | "net/http"
28 |
29 | "yourmodule/spa"
30 | )
31 |
32 | //go:embed build/*
33 | var build embed.FS
34 |
35 | func main() {
36 | handler, err := spa.Handler(build, "build", "index.html", true)
37 | if err != nil {
38 | log.Fatalf("Failed to initialize SPA handler: %v", err)
39 | }
40 |
41 | log.Println("Serving SPA on http://localhost:8080")
42 | http.ListenAndServe(":8080", handler)
43 | }
44 | */
45 | package spa
46 |
--------------------------------------------------------------------------------
/server/spa/handler.go:
--------------------------------------------------------------------------------
1 | package spa
2 |
3 | import (
4 | "embed"
5 | "errors"
6 | "fmt"
7 | "io/fs"
8 | "net/http"
9 |
10 | "github.com/NYTimes/gziphandler"
11 | )
12 |
13 | // Handler returns an HTTP handler for serving a Single Page Application (SPA).
14 | //
15 | // The handler serves static files from the specified directory in the embedded
16 | // file system and falls back to serving the index file if a requested file is not found.
17 | // This is useful for client-side routing in SPAs.
18 | //
19 | // Parameters:
20 | // - build: An embedded file system containing the build assets.
21 | // - dir: The directory within the embedded file system where the static files are located.
22 | // - index: The name of the index file (usually "index.html").
23 | // - gzip: If true, the response body will be compressed using gzip for clients that support it.
24 | //
25 | // Returns:
26 | // - An http.Handler that serves the SPA and optional gzip compression.
27 | // - An error if the file system or index file cannot be initialized.
28 | func Handler(build embed.FS, dir string, index string, gzip bool) (http.Handler, error) {
29 | fsys, err := fs.Sub(build, dir)
30 | if err != nil {
31 | return nil, fmt.Errorf("couldn't create sub filesystem: %w", err)
32 | }
33 |
34 | if _, err = fsys.Open(index); err != nil {
35 | if errors.Is(err, fs.ErrNotExist) {
36 | return nil, fmt.Errorf("ui is enabled but no index.html found: %w", err)
37 | } else {
38 | return nil, fmt.Errorf("ui assets error: %w", err)
39 | }
40 | }
41 | router := &router{index: index, fs: http.FS(fsys)}
42 |
43 | hlr := http.FileServer(router)
44 |
45 | if !gzip {
46 | return hlr, nil
47 | }
48 | return gziphandler.GzipHandler(hlr), nil
49 | }
50 |
--------------------------------------------------------------------------------
/server/spa/router.go:
--------------------------------------------------------------------------------
1 | package spa
2 |
3 | import (
4 | "errors"
5 | "io/fs"
6 | "net/http"
7 | )
8 |
9 | // router is the http filesystem which only serves files
10 | // and prevent the directory traversal.
11 | type router struct {
12 | index string
13 | fs http.FileSystem
14 | }
15 |
16 | // Open inspects the URL path to locate a file within the static dir.
17 | // If a file is found, it will be served. If not, the file located at
18 | // the index path on the SPA handler will be served.
19 | func (r *router) Open(name string) (http.File, error) {
20 | file, err := r.fs.Open(name)
21 |
22 | if err == nil {
23 | return file, nil
24 | }
25 | // Serve index if file does not exist.
26 | if errors.Is(err, fs.ErrNotExist) {
27 | file, err := r.fs.Open(r.index)
28 | return file, err
29 | }
30 |
31 | return nil, err
32 | }
33 |
--------------------------------------------------------------------------------
/telemetry/opentelemetry.go:
--------------------------------------------------------------------------------
1 | package telemetry
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "time"
7 |
8 | "github.com/raystack/salt/log"
9 | "go.opentelemetry.io/contrib/instrumentation/host"
10 | "go.opentelemetry.io/contrib/instrumentation/runtime"
11 | "go.opentelemetry.io/contrib/samplers/probability/consistent"
12 | "go.opentelemetry.io/otel"
13 | "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
14 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace"
15 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
16 | "go.opentelemetry.io/otel/propagation"
17 | sdkmetric "go.opentelemetry.io/otel/sdk/metric"
18 | "go.opentelemetry.io/otel/sdk/resource"
19 | sdktrace "go.opentelemetry.io/otel/sdk/trace"
20 | semconv "go.opentelemetry.io/otel/semconv/v1.20.0"
21 | "google.golang.org/grpc/encoding/gzip"
22 | )
23 |
24 | type OpenTelemetryConfig struct {
25 | Enabled bool `yaml:"enabled" mapstructure:"enabled" default:"false"`
26 | CollectorAddr string `yaml:"collector_addr" mapstructure:"collector_addr" default:"localhost:4317"`
27 | PeriodicReadInterval time.Duration `yaml:"periodic_read_interval" mapstructure:"periodic_read_interval" default:"1s"`
28 | TraceSampleProbability float64 `yaml:"trace_sample_probability" mapstructure:"trace_sample_probability" default:"1"`
29 | VerboseResourceLabelsEnabled bool `yaml:"verbose_resource_labels_enabled" mapstructure:"verbose_resource_labels_enabled" default:"false"`
30 | }
31 |
32 | func initOTLP(ctx context.Context, cfg Config, logger log.Logger) (func(), error) {
33 | if !cfg.OpenTelemetry.Enabled {
34 | logger.Info("OpenTelemetry monitoring is disabled.")
35 | return noOp, nil
36 | }
37 | resourceOptions := []resource.Option{
38 | resource.WithFromEnv(),
39 | resource.WithAttributes(
40 | semconv.ServiceName(cfg.AppName),
41 | semconv.ServiceVersion(cfg.AppVersion),
42 | ),
43 | }
44 | if cfg.OpenTelemetry.VerboseResourceLabelsEnabled {
45 | resourceOptions = append(resourceOptions,
46 | resource.WithTelemetrySDK(),
47 | resource.WithOS(),
48 | resource.WithHost(),
49 | resource.WithProcess(),
50 | resource.WithProcessRuntimeName(),
51 | resource.WithProcessRuntimeVersion(),
52 | )
53 | }
54 | res, err := resource.New(ctx, resourceOptions...)
55 | if err != nil {
56 | return nil, fmt.Errorf("create resource: %w", err)
57 | }
58 | shutdownMetric, err := initGlobalMetrics(ctx, res, cfg.OpenTelemetry, logger)
59 | if err != nil {
60 | return nil, err
61 | }
62 | shutdownTracer, err := initGlobalTracer(ctx, res, cfg.OpenTelemetry, logger)
63 | if err != nil {
64 | shutdownMetric()
65 | return nil, err
66 | }
67 | shutdownProviders := func() {
68 | shutdownTracer()
69 | shutdownMetric()
70 | }
71 | if err := host.Start(); err != nil {
72 | shutdownProviders()
73 | return nil, err
74 | }
75 | if err := runtime.Start(); err != nil {
76 | shutdownProviders()
77 | return nil, err
78 | }
79 | return shutdownProviders, nil
80 | }
81 | func initGlobalMetrics(ctx context.Context, res *resource.Resource, cfg OpenTelemetryConfig, logger log.Logger) (func(), error) {
82 | exporter, err := otlpmetricgrpc.New(ctx,
83 | otlpmetricgrpc.WithEndpoint(cfg.CollectorAddr),
84 | otlpmetricgrpc.WithCompressor(gzip.Name),
85 | otlpmetricgrpc.WithInsecure(),
86 | )
87 | if err != nil {
88 | return nil, fmt.Errorf("create metric exporter: %w", err)
89 | }
90 | reader := sdkmetric.NewPeriodicReader(exporter, sdkmetric.WithInterval(cfg.PeriodicReadInterval))
91 | provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader), sdkmetric.WithResource(res))
92 | otel.SetMeterProvider(provider)
93 | return func() {
94 | shutdownCtx, cancel := context.WithTimeout(context.Background(), gracePeriod)
95 | defer cancel()
96 | if err := provider.Shutdown(shutdownCtx); err != nil {
97 | logger.Error("otlp metric-provider failed to shutdown", "err", err)
98 | }
99 | }, nil
100 | }
101 | func initGlobalTracer(ctx context.Context, res *resource.Resource, cfg OpenTelemetryConfig, logger log.Logger) (func(), error) {
102 | exporter, err := otlptrace.New(ctx, otlptracegrpc.NewClient(
103 | otlptracegrpc.WithEndpoint(cfg.CollectorAddr),
104 | otlptracegrpc.WithInsecure(),
105 | otlptracegrpc.WithCompressor(gzip.Name),
106 | ))
107 | if err != nil {
108 | return nil, fmt.Errorf("create trace exporter: %w", err)
109 | }
110 | tracerProvider := sdktrace.NewTracerProvider(
111 | sdktrace.WithSampler(consistent.ProbabilityBased(cfg.TraceSampleProbability)),
112 | sdktrace.WithResource(res),
113 | sdktrace.WithSpanProcessor(sdktrace.NewBatchSpanProcessor(exporter)),
114 | )
115 | otel.SetTracerProvider(tracerProvider)
116 | otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
117 | propagation.TraceContext{}, propagation.Baggage{},
118 | ))
119 | return func() {
120 | shutdownCtx, cancel := context.WithTimeout(context.Background(), gracePeriod)
121 | defer cancel()
122 | if err := tracerProvider.Shutdown(shutdownCtx); err != nil {
123 | logger.Error("otlp trace-provider failed to shutdown", "err", err)
124 | }
125 | }, nil
126 | }
127 | func noOp() {}
128 |
--------------------------------------------------------------------------------
/telemetry/otelgrpc/otelgrpc.go:
--------------------------------------------------------------------------------
1 | package otelgrpc
2 |
3 | import (
4 | "context"
5 | "net"
6 | "strings"
7 | "time"
8 |
9 | "github.com/raystack/salt/utils"
10 | "go.opentelemetry.io/otel"
11 | "go.opentelemetry.io/otel/attribute"
12 | "go.opentelemetry.io/otel/metric"
13 | semconv "go.opentelemetry.io/otel/semconv/v1.20.0"
14 | "google.golang.org/grpc"
15 | "google.golang.org/grpc/peer"
16 | "google.golang.org/protobuf/proto"
17 | )
18 |
19 | type UnaryParams struct {
20 | Start time.Time
21 | Method string
22 | Req any
23 | Res any
24 | Err error
25 | }
26 | type Meter struct {
27 | duration metric.Int64Histogram
28 | requestSize metric.Int64Histogram
29 | responseSize metric.Int64Histogram
30 | attributes []attribute.KeyValue
31 | }
32 | type MeterOpts struct {
33 | meterName string `default:"github.com/raystack/salt/telemetry/otelgrpc"`
34 | }
35 | type Option func(*MeterOpts)
36 |
37 | func WithMeterName(meterName string) Option {
38 | return func(opts *MeterOpts) {
39 | opts.meterName = meterName
40 | }
41 | }
42 | func NewMeter(hostName string, opts ...Option) Meter {
43 | meterOpts := &MeterOpts{}
44 | for _, opt := range opts {
45 | opt(meterOpts)
46 | }
47 | meter := otel.Meter(meterOpts.meterName)
48 | duration, err := meter.Int64Histogram("rpc.client.duration", metric.WithUnit("ms"))
49 | handleOtelErr(err)
50 | requestSize, err := meter.Int64Histogram("rpc.client.request.size", metric.WithUnit("By"))
51 | handleOtelErr(err)
52 | responseSize, err := meter.Int64Histogram("rpc.client.response.size", metric.WithUnit("By"))
53 | handleOtelErr(err)
54 | addr, port := ExtractAddress(hostName)
55 | return Meter{
56 | duration: duration,
57 | requestSize: requestSize,
58 | responseSize: responseSize,
59 | attributes: []attribute.KeyValue{
60 | semconv.RPCSystemGRPC,
61 | attribute.String("network.transport", "tcp"),
62 | attribute.String("server.address", addr),
63 | attribute.String("server.port", port),
64 | },
65 | }
66 | }
67 | func GetProtoSize(p any) int {
68 | if p == nil {
69 | return 0
70 | }
71 | if pm, ok := p.(proto.Message); ok {
72 | return proto.Size(pm)
73 | }
74 | return 0
75 | }
76 | func (m *Meter) RecordUnary(ctx context.Context, p UnaryParams) {
77 | reqSize := GetProtoSize(p.Req)
78 | resSize := GetProtoSize(p.Res)
79 | attrs := make([]attribute.KeyValue, len(m.attributes))
80 | copy(attrs, m.attributes)
81 | attrs = append(attrs, attribute.String("rpc.grpc.status_text", utils.StatusText(p.Err)))
82 | attrs = append(attrs, attribute.String("network.type", netTypeFromCtx(ctx)))
83 | attrs = append(attrs, ParseFullMethod(p.Method)...)
84 | m.duration.Record(ctx,
85 | time.Since(p.Start).Milliseconds(),
86 | metric.WithAttributes(attrs...))
87 | m.requestSize.Record(ctx,
88 | int64(reqSize),
89 | metric.WithAttributes(attrs...))
90 | m.responseSize.Record(ctx,
91 | int64(resSize),
92 | metric.WithAttributes(attrs...))
93 | }
94 | func (m *Meter) RecordStream(ctx context.Context, start time.Time, method string, err error) {
95 | attrs := make([]attribute.KeyValue, len(m.attributes))
96 | copy(attrs, m.attributes)
97 | attrs = append(attrs, attribute.String("rpc.grpc.status_text", utils.StatusText(err)))
98 | attrs = append(attrs, attribute.String("network.type", netTypeFromCtx(ctx)))
99 | attrs = append(attrs, ParseFullMethod(method)...)
100 | m.duration.Record(ctx,
101 | time.Since(start).Milliseconds(),
102 | metric.WithAttributes(attrs...))
103 | }
104 | func (m *Meter) UnaryClientInterceptor() grpc.UnaryClientInterceptor {
105 | return func(ctx context.Context,
106 | method string,
107 | req, reply interface{},
108 | cc *grpc.ClientConn,
109 | invoker grpc.UnaryInvoker,
110 | opts ...grpc.CallOption,
111 | ) (err error) {
112 | defer func(start time.Time) {
113 | m.RecordUnary(ctx, UnaryParams{
114 | Start: start,
115 | Req: req,
116 | Res: reply,
117 | Err: err,
118 | })
119 | }(time.Now())
120 | return invoker(ctx, method, req, reply, cc, opts...)
121 | }
122 | }
123 | func (m *Meter) StreamClientInterceptor() grpc.StreamClientInterceptor {
124 | return func(ctx context.Context,
125 | desc *grpc.StreamDesc,
126 | cc *grpc.ClientConn,
127 | method string,
128 | streamer grpc.Streamer,
129 | opts ...grpc.CallOption,
130 | ) (s grpc.ClientStream, err error) {
131 | defer func(start time.Time) {
132 | m.RecordStream(ctx, start, method, err)
133 | }(time.Now())
134 | return streamer(ctx, desc, cc, method, opts...)
135 | }
136 | }
137 | func (m *Meter) GetAttributes() []attribute.KeyValue {
138 | return m.attributes
139 | }
140 | func ParseFullMethod(fullMethod string) []attribute.KeyValue {
141 | name := strings.TrimLeft(fullMethod, "/")
142 | service, method, found := strings.Cut(name, "/")
143 | if !found {
144 | return nil
145 | }
146 | var attrs []attribute.KeyValue
147 | if service != "" {
148 | attrs = append(attrs, semconv.RPCService(service))
149 | }
150 | if method != "" {
151 | attrs = append(attrs, semconv.RPCMethod(method))
152 | }
153 | return attrs
154 | }
155 | func handleOtelErr(err error) {
156 | if err != nil {
157 | otel.Handle(err)
158 | }
159 | }
160 | func ExtractAddress(addr string) (host, port string) {
161 | host, port, err := net.SplitHostPort(addr)
162 | if err != nil {
163 | return addr, "80"
164 | }
165 | return host, port
166 | }
167 | func netTypeFromCtx(ctx context.Context) (ipType string) {
168 | ipType = "unknown"
169 | p, ok := peer.FromContext(ctx)
170 | if !ok {
171 | return ipType
172 | }
173 | clientIP, _, err := net.SplitHostPort(p.Addr.String())
174 | if err != nil {
175 | return ipType
176 | }
177 | ip := net.ParseIP(clientIP)
178 | if ip.To4() != nil {
179 | ipType = "ipv4"
180 | } else if ip.To16() != nil {
181 | ipType = "ipv6"
182 | }
183 | return ipType
184 | }
185 |
--------------------------------------------------------------------------------
/telemetry/otelgrpc/otelgrpc_test.go:
--------------------------------------------------------------------------------
1 | package otelgrpc_test
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 |
7 | "github.com/raystack/salt/telemetry/otelgrpc"
8 | "github.com/stretchr/testify/assert"
9 | "go.opentelemetry.io/otel/attribute"
10 | semconv "go.opentelemetry.io/otel/semconv/v1.20.0"
11 | )
12 |
13 | func Test_parseFullMethod(t *testing.T) {
14 | type args struct {
15 | fullMethod string
16 | }
17 | tests := []struct {
18 | name string
19 | args args
20 | want []attribute.KeyValue
21 | }{
22 | {name: "should parse correct method", args: args{
23 | fullMethod: "/test.service.name/MethodNameV1",
24 | }, want: []attribute.KeyValue{
25 | semconv.RPCService("test.service.name"),
26 | semconv.RPCMethod("MethodNameV1"),
27 | }},
28 | {name: "should return empty attributes on incorrect method", args: args{
29 | fullMethod: "incorrectMethod",
30 | }, want: nil},
31 | }
32 | for _, tt := range tests {
33 | t.Run(tt.name, func(t *testing.T) {
34 | if got := otelgrpc.ParseFullMethod(tt.args.fullMethod); !reflect.DeepEqual(got, tt.want) {
35 | t.Errorf("parseFullMethod() = %v, want %v", got, tt.want)
36 | }
37 | })
38 | }
39 | }
40 |
41 | func TestExtractAddress(t *testing.T) {
42 | gotHost, gotPort := otelgrpc.ExtractAddress("localhost:1001")
43 | assert.Equal(t, "localhost", gotHost)
44 | assert.Equal(t, "1001", gotPort)
45 | gotHost, gotPort = otelgrpc.ExtractAddress("localhost")
46 | assert.Equal(t, "localhost", gotHost)
47 | assert.Equal(t, "80", gotPort)
48 | gotHost, gotPort = otelgrpc.ExtractAddress("some.address.golabs.io:15010")
49 | assert.Equal(t, "some.address.golabs.io", gotHost)
50 | assert.Equal(t, "15010", gotPort)
51 | }
52 |
--------------------------------------------------------------------------------
/telemetry/otelhhtpclient/annotations.go:
--------------------------------------------------------------------------------
1 | package otelhttpclient
2 |
3 | import (
4 | "context"
5 | "net/http"
6 |
7 | "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
8 | "go.opentelemetry.io/otel/attribute"
9 | )
10 |
11 | type labelerContextKeyType int
12 |
13 | const lablelerContextKey labelerContextKeyType = 0
14 |
15 | // AnnotateRequest adds telemetry related annotations to request context and returns.
16 | // The request context on the returned request should be retained.
17 | // Ensure `route` is a route template and not actual URL to prevent high cardinality
18 | // on the metrics.
19 | func AnnotateRequest(req *http.Request, route string) *http.Request {
20 | ctx := req.Context()
21 | l := &otelhttp.Labeler{}
22 | l.Add(attribute.String(attributeHTTPRoute, route))
23 | return req.WithContext(context.WithValue(ctx, lablelerContextKey, l))
24 | }
25 |
26 | // LabelerFromContext returns the labeler annotation from the context if exists.
27 | func LabelerFromContext(ctx context.Context) (*otelhttp.Labeler, bool) {
28 | l, ok := ctx.Value(lablelerContextKey).(*otelhttp.Labeler)
29 | if !ok {
30 | l = &otelhttp.Labeler{}
31 | }
32 | return l, ok
33 | }
34 |
--------------------------------------------------------------------------------
/telemetry/otelhhtpclient/http_transport.go:
--------------------------------------------------------------------------------
1 | package otelhttpclient
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "net/http"
7 | "time"
8 |
9 | "go.opentelemetry.io/otel"
10 | "go.opentelemetry.io/otel/attribute"
11 | "go.opentelemetry.io/otel/metric"
12 | )
13 |
14 | // Refer OpenTelemetry Semantic Conventions for HTTP Client.
15 | // https://github.com/open-telemetry/semantic-conventions/blob/main/docs/http/http-metrics.md#http-client
16 | const (
17 | metricClientDuration = "http.client.duration"
18 | metricClientRequestSize = "http.client.request.size"
19 | metricClientResponseSize = "http.client.response.size"
20 | attributeNetProtoName = "network.protocol.name"
21 | attributeNetProtoVersion = "network.protocol.release"
22 | attributeServerPort = "server.port"
23 | attributeServerAddress = "server.address"
24 | attributeHTTPRoute = "http.route"
25 | attributeRequestMethod = "http.request.method"
26 | attributeResponseStatusCode = "http.response.status_code"
27 | )
28 |
29 | type httpTransport struct {
30 | roundTripper http.RoundTripper
31 | metricClientDuration metric.Float64Histogram
32 | metricClientRequestSize metric.Int64Counter
33 | metricClientResponseSize metric.Int64Counter
34 | }
35 |
36 | func NewHTTPTransport(baseTransport http.RoundTripper) http.RoundTripper {
37 | if _, ok := baseTransport.(*httpTransport); ok {
38 | return baseTransport
39 | }
40 | if baseTransport == nil {
41 | baseTransport = http.DefaultTransport
42 | }
43 | icl := &httpTransport{roundTripper: baseTransport}
44 | icl.createMeasures(otel.Meter("github.com/raystack/salt/telemetry/otehttpclient"))
45 | return icl
46 | }
47 | func (tr *httpTransport) RoundTrip(req *http.Request) (*http.Response, error) {
48 | ctx := req.Context()
49 | startAt := time.Now()
50 | labeler, _ := LabelerFromContext(req.Context())
51 | var bw bodyWrapper
52 | if req.Body != nil && req.Body != http.NoBody {
53 | bw.ReadCloser = req.Body
54 | req.Body = &bw
55 | }
56 | port := req.URL.Port()
57 | if port == "" {
58 | port = "80"
59 | if req.URL.Scheme == "https" {
60 | port = "443"
61 | }
62 | }
63 | attribs := append(labeler.Get(),
64 | attribute.String(attributeNetProtoName, "http"),
65 | attribute.String(attributeRequestMethod, req.Method),
66 | attribute.String(attributeServerAddress, req.URL.Hostname()),
67 | attribute.String(attributeServerPort, port),
68 | )
69 | resp, err := tr.roundTripper.RoundTrip(req)
70 | if err != nil {
71 | attribs = append(attribs,
72 | attribute.Int(attributeResponseStatusCode, 0),
73 | attribute.String(attributeNetProtoVersion, fmt.Sprintf("%d.%d", req.ProtoMajor, req.ProtoMinor)),
74 | )
75 | } else {
76 | attribs = append(attribs,
77 | attribute.Int(attributeResponseStatusCode, resp.StatusCode),
78 | attribute.String(attributeNetProtoVersion, fmt.Sprintf("%d.%d", resp.ProtoMajor, resp.ProtoMinor)),
79 | )
80 | }
81 | elapsedTime := float64(time.Since(startAt)) / float64(time.Millisecond)
82 | withAttribs := metric.WithAttributes(attribs...)
83 | tr.metricClientDuration.Record(ctx, elapsedTime, withAttribs)
84 | tr.metricClientRequestSize.Add(ctx, int64(bw.read), withAttribs)
85 | if resp != nil {
86 | tr.metricClientResponseSize.Add(ctx, resp.ContentLength, withAttribs)
87 | }
88 | return resp, err
89 | }
90 | func (tr *httpTransport) createMeasures(meter metric.Meter) {
91 | var err error
92 | tr.metricClientRequestSize, err = meter.Int64Counter(metricClientRequestSize)
93 | handleErr(err)
94 | tr.metricClientResponseSize, err = meter.Int64Counter(metricClientResponseSize)
95 | handleErr(err)
96 | tr.metricClientDuration, err = meter.Float64Histogram(metricClientDuration)
97 | handleErr(err)
98 | }
99 | func handleErr(err error) {
100 | if err != nil {
101 | otel.Handle(err)
102 | }
103 | }
104 |
105 | // bodyWrapper wraps a http.Request.Body (an io.ReadCloser) to track the number
106 | // of bytes read and the last error.
107 | type bodyWrapper struct {
108 | io.ReadCloser
109 | read int
110 | err error
111 | }
112 |
113 | func (w *bodyWrapper) Read(b []byte) (int, error) {
114 | n, err := w.ReadCloser.Read(b)
115 | w.read += n
116 | w.err = err
117 | return n, err
118 | }
119 | func (w *bodyWrapper) Close() error {
120 | return w.ReadCloser.Close()
121 | }
122 |
--------------------------------------------------------------------------------
/telemetry/otelhhtpclient/http_transport_test.go:
--------------------------------------------------------------------------------
1 | package otelhttpclient_test
2 |
3 | import (
4 | "testing"
5 |
6 | otelhttpclient "github.com/raystack/salt/telemetry/otelhhtpclient"
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestNewHTTPTransport(t *testing.T) {
11 | tr := otelhttpclient.NewHTTPTransport(nil)
12 | assert.NotNil(t, tr)
13 | }
14 |
--------------------------------------------------------------------------------
/telemetry/telemetry.go:
--------------------------------------------------------------------------------
1 | package telemetry
2 |
3 | import (
4 | "context"
5 | "time"
6 |
7 | "github.com/raystack/salt/log"
8 | )
9 |
10 | const gracePeriod = 5 * time.Second
11 |
12 | type Config struct {
13 | AppVersion string
14 | AppName string `yaml:"app_name" mapstructure:"app_name" default:"service"`
15 | OpenTelemetry OpenTelemetryConfig `yaml:"open_telemetry" mapstructure:"open_telemetry"`
16 | }
17 |
18 | func Init(ctx context.Context, cfg Config, logger log.Logger) (cleanUp func(), err error) {
19 | shutdown, err := initOTLP(ctx, cfg, logger)
20 | if err != nil {
21 | return noOp, err
22 | }
23 | return shutdown, nil
24 | }
25 |
--------------------------------------------------------------------------------
/testing/dockertestx/README.md:
--------------------------------------------------------------------------------
1 | # dockertestx
2 |
3 | This package is an abstraction of several dockerized data storages using `ory/dockertest` to bootstrap a specific dockerized instance.
4 |
5 | Example postgres
6 |
7 | ```go
8 | // create postgres instance
9 | pgDocker, err := dockertest.CreatePostgres(
10 | dockertest.PostgresWithDetail(
11 | pgUser, pgPass, pgDBName,
12 | ),
13 | )
14 |
15 | // get connection string
16 | connString := pgDocker.GetExternalConnString()
17 |
18 | // purge docker
19 | if err := pgDocker.GetPool().Purge(pgDocker.GetResouce()); err != nil {
20 | return fmt.Errorf("could not purge resource: %w", err)
21 | }
22 | ```
23 |
24 | Example spice db
25 |
26 | - bootsrap spice db with postgres and wire them internally via network bridge
27 |
28 | ```go
29 | // create custom pool
30 | pool, err := dockertest.NewPool("")
31 | if err != nil {
32 | return nil, err
33 | }
34 |
35 | // create a bridge network for testing
36 | network, err = pool.Client.CreateNetwork(docker.CreateNetworkOptions{
37 | Name: fmt.Sprintf("bridge-%s", uuid.New().String()),
38 | })
39 | if err != nil {
40 | return nil, err
41 | }
42 |
43 |
44 | // create postgres instance
45 | pgDocker, err := dockertest.CreatePostgres(
46 | dockertest.PostgresWithDockerPool(pool),
47 | dockertest.PostgresWithDockertestNetwork(network),
48 | dockertest.PostgresWithDetail(
49 | pgUser, pgPass, pgDBName,
50 | ),
51 | )
52 |
53 | // get connection string
54 | connString := pgDocker.GetInternalConnString()
55 |
56 | // create spice db instance
57 | spiceDocker, err := dockertest.CreateSpiceDB(connString,
58 | dockertest.SpiceDBWithDockerPool(pool),
59 | dockertest.SpiceDBWithDockertestNetwork(network),
60 | )
61 |
62 | if err := dockertest.MigrateSpiceDB(connString,
63 | dockertest.MigrateSpiceDBWithDockerPool(pool),
64 | dockertest.MigrateSpiceDBWithDockertestNetwork(network),
65 | ); err != nil {
66 | return err
67 | }
68 |
69 | // purge docker resources
70 | if err := pool.Purge(spiceDocker.GetResouce()); err != nil {
71 | return fmt.Errorf("could not purge resource: %w", err)
72 | }
73 | if err := pool.Purge(pgDocker.GetResouce()); err != nil {
74 | return fmt.Errorf("could not purge resource: %w", err)
75 | }
76 | ```
77 |
--------------------------------------------------------------------------------
/testing/dockertestx/configs/cortex/single_process_cortex.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for running Cortex in single-process mode.
2 | # This configuration should not be used in production.
3 | # It is only for getting started and development.
4 |
5 | # Disable the requirement that every request to Cortex has a
6 | # X-Scope-OrgID header. `fake` will be substituted in instead.
7 | auth_enabled: false
8 |
9 | server:
10 | http_listen_port: 9009
11 |
12 | # Configure the server to allow messages up to 100MB.
13 | grpc_server_max_recv_msg_size: 104857600
14 | grpc_server_max_send_msg_size: 104857600
15 | grpc_server_max_concurrent_streams: 1000
16 |
17 | distributor:
18 | shard_by_all_labels: true
19 | pool:
20 | health_check_ingesters: true
21 |
22 | ingester_client:
23 | grpc_client_config:
24 | # Configure the client to allow messages up to 100MB.
25 | max_recv_msg_size: 104857600
26 | max_send_msg_size: 104857600
27 | grpc_compression: gzip
28 |
29 | ingester:
30 | # We want our ingesters to flush chunks at the same time to optimise
31 | # deduplication opportunities.
32 | spread_flushes: true
33 | chunk_age_jitter: 0
34 |
35 | walconfig:
36 | wal_enabled: true
37 | recover_from_wal: true
38 | wal_dir: /tmp/cortex/wal
39 |
40 | lifecycler:
41 | # The address to advertise for this ingester. Will be autodiscovered by
42 | # looking up address on eth0 or en0; can be specified if this fails.
43 | # address: 127.0.0.1
44 |
45 | # We want to start immediately and flush on shutdown.
46 | join_after: 0
47 | min_ready_duration: 0s
48 | final_sleep: 0s
49 | num_tokens: 512
50 | tokens_file_path: /tmp/cortex/wal/tokens
51 |
52 | # Use an in memory ring store, so we don't need to launch a Consul.
53 | ring:
54 | kvstore:
55 | store: inmemory
56 | replication_factor: 1
57 |
58 | # Use local storage - BoltDB for the index, and the filesystem
59 | # for the chunks.
60 | schema:
61 | configs:
62 | - from: 2019-07-29
63 | store: boltdb
64 | object_store: filesystem
65 | schema: v10
66 | index:
67 | prefix: index_
68 | period: 1w
69 |
70 | storage:
71 | boltdb:
72 | directory: /tmp/cortex/index
73 |
74 | filesystem:
75 | directory: /tmp/cortex/chunks
76 |
77 | delete_store:
78 | store: boltdb
79 |
80 | purger:
81 | object_store_type: filesystem
82 |
83 | frontend_worker:
84 | # Configure the frontend worker in the querier to match worker count
85 | # to max_concurrent on the queriers.
86 | match_max_concurrent: true
87 |
88 | # Configure the ruler to scan the /tmp/cortex/rules directory for prometheus
89 | # rules: https://prometheus.io/docs/prometheus/latest/configuration/recording_rules/#recording-rules
90 | ruler:
91 | enable_api: true
92 | enable_sharding: false
93 | # alertmanager_url: http://cortex-am:9009/api/prom/alertmanager/
94 | rule_path: /tmp/cortex/rules
95 | storage:
96 | type: s3
97 | s3:
98 | # endpoint: http://minio1:9000
99 | bucketnames: cortex
100 | secret_access_key: minio123
101 | access_key_id: minio
102 | s3forcepathstyle: true
103 |
104 | alertmanager:
105 | enable_api: true
106 | sharding_enabled: false
107 | data_dir: data/
108 | external_url: /api/prom/alertmanager
109 | storage:
110 | type: s3
111 | s3:
112 | # endpoint: http://minio1:9000
113 | bucketnames: cortex
114 | secret_access_key: minio123
115 | access_key_id: minio
116 | s3forcepathstyle: true
117 |
118 | alertmanager_storage:
119 | backend: local
120 | local:
121 | path: tmp/cortex/alertmanager
122 |
--------------------------------------------------------------------------------
/testing/dockertestx/configs/nginx/cortex_nginx.conf:
--------------------------------------------------------------------------------
1 | worker_processes 1;
2 | error_log /dev/stderr;
3 | pid /tmp/nginx.pid;
4 | worker_rlimit_nofile 8192;
5 |
6 | events {
7 | worker_connections 1024;
8 | }
9 |
10 |
11 | http {
12 | client_max_body_size 5M;
13 | default_type application/octet-stream;
14 | log_format main '$remote_addr - $remote_user [$time_local] $status '
15 | '"$request" $body_bytes_sent "$http_referer" '
16 | '"$http_user_agent" "$http_x_forwarded_for" $http_x_scope_orgid';
17 | access_log /dev/stderr main;
18 | sendfile on;
19 | tcp_nopush on;
20 | resolver 127.0.0.11 ipv6=off;
21 |
22 | server {
23 | listen {{.ExposedPort}};
24 | proxy_connect_timeout 300s;
25 | proxy_send_timeout 300s;
26 | proxy_read_timeout 300s;
27 | proxy_http_version 1.1;
28 |
29 | location = /healthz {
30 | return 200 'alive';
31 | }
32 |
33 | # Distributor Config
34 | location = /ring {
35 | proxy_pass http://{{.RulerHost}}$request_uri;
36 | }
37 |
38 | location = /all_user_stats {
39 | proxy_pass http://{{.RulerHost}}$request_uri;
40 | }
41 |
42 | location = /api/prom/push {
43 | proxy_pass http://{{.RulerHost}}$request_uri;
44 | }
45 |
46 | ## New Remote write API. Ref: https://cortexmetrics.io/docs/api/#remote-write
47 | location = /api/v1/push {
48 | proxy_pass http://{{.RulerHost}}$request_uri;
49 | }
50 |
51 |
52 | # Alertmanager Config
53 | location ~ /api/prom/alertmanager/.* {
54 | proxy_pass http://{{.AlertManagerHost}}$request_uri;
55 | }
56 |
57 | location ~ /api/v1/alerts {
58 | proxy_pass http://{{.AlertManagerHost}}$request_uri;
59 | }
60 |
61 | location ~ /multitenant_alertmanager/status {
62 | proxy_pass http://{{.AlertManagerHost}}$request_uri;
63 | }
64 |
65 | # Ruler Config
66 | location ~ /api/v1/rules {
67 | proxy_pass http://{{.RulerHost}}$request_uri;
68 | }
69 |
70 | location ~ /ruler/ring {
71 | proxy_pass http://{{.RulerHost}}$request_uri;
72 | }
73 |
74 | # Config Config
75 | location ~ /api/prom/configs/.* {
76 | proxy_pass http://{{.RulerHost}}$request_uri;
77 | }
78 |
79 | # Query Config
80 | location ~ /api/prom/.* {
81 | proxy_pass http://{{.RulerHost}}$request_uri;
82 | }
83 |
84 | ## New Query frontend APIs as per https://cortexmetrics.io/docs/api/#querier--query-frontend
85 | location ~ ^/prometheus/api/v1/(read|metadata|labels|series|query_range|query) {
86 | proxy_pass http://{{.RulerHost}}$request_uri;
87 | }
88 |
89 | location ~ /prometheus/api/v1/label/.* {
90 | proxy_pass http://{{.RulerHost}}$request_uri;
91 | }
92 | }
93 | }
--------------------------------------------------------------------------------
/testing/dockertestx/cortex.go:
--------------------------------------------------------------------------------
1 | package dockertestx
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "os"
7 | "path"
8 | "runtime"
9 | "time"
10 |
11 | "github.com/google/uuid"
12 | "github.com/ory/dockertest/v3"
13 | "github.com/ory/dockertest/v3/docker"
14 | )
15 |
16 | type dockerCortexOption func(dc *dockerCortex)
17 |
18 | // CortexWithDockertestNetwork is an option to assign docker network
19 | func CortexWithDockertestNetwork(network *dockertest.Network) dockerCortexOption {
20 | return func(dc *dockerCortex) {
21 | dc.network = network
22 | }
23 | }
24 |
25 | // CortexWithDockertestNetwork is an option to assign release tag
26 | // of a `quay.io/cortexproject/cortex` image
27 | func CortexWithVersionTag(versionTag string) dockerCortexOption {
28 | return func(dc *dockerCortex) {
29 | dc.versionTag = versionTag
30 | }
31 | }
32 |
33 | // CortexWithDockerPool is an option to assign docker pool
34 | func CortexWithDockerPool(pool *dockertest.Pool) dockerCortexOption {
35 | return func(dc *dockerCortex) {
36 | dc.pool = pool
37 | }
38 | }
39 |
40 | // CortexWithModule is an option to assign cortex module name
41 | // e.g. all, alertmanager, querier, etc
42 | func CortexWithModule(moduleName string) dockerCortexOption {
43 | return func(dc *dockerCortex) {
44 | dc.moduleName = moduleName
45 | }
46 | }
47 |
48 | // CortexWithAlertmanagerURL is an option to assign alertmanager url
49 | func CortexWithAlertmanagerURL(amURL string) dockerCortexOption {
50 | return func(dc *dockerCortex) {
51 | dc.alertManagerURL = amURL
52 | }
53 | }
54 |
55 | // CortexWithS3Endpoint is an option to assign external s3/minio storage
56 | func CortexWithS3Endpoint(s3URL string) dockerCortexOption {
57 | return func(dc *dockerCortex) {
58 | dc.s3URL = s3URL
59 | }
60 | }
61 |
62 | type dockerCortex struct {
63 | network *dockertest.Network
64 | pool *dockertest.Pool
65 | moduleName string
66 | alertManagerURL string
67 | s3URL string
68 | internalHost string
69 | externalHost string
70 | versionTag string
71 | dockertestResource *dockertest.Resource
72 | }
73 |
74 | // CreateCortex is a function to create a dockerized single-process cortex with
75 | // s3/minio as the backend storage
76 | func CreateCortex(opts ...dockerCortexOption) (*dockerCortex, error) {
77 | var (
78 | err error
79 | dc = &dockerCortex{}
80 | )
81 |
82 | for _, opt := range opts {
83 | opt(dc)
84 | }
85 |
86 | name := fmt.Sprintf("cortex-%s", uuid.New().String())
87 |
88 | if dc.pool == nil {
89 | dc.pool, err = dockertest.NewPool("")
90 | if err != nil {
91 | return nil, fmt.Errorf("could not create dockertest pool: %w", err)
92 | }
93 | }
94 |
95 | if dc.versionTag == "" {
96 | dc.versionTag = "master-63703f5"
97 | }
98 |
99 | if dc.moduleName == "" {
100 | dc.moduleName = "all"
101 | }
102 |
103 | runOpts := &dockertest.RunOptions{
104 | Name: name,
105 | Repository: "quay.io/cortexproject/cortex",
106 | Tag: dc.versionTag,
107 | Env: []string{
108 | "minio_host=siren_nginx_1",
109 | },
110 | Cmd: []string{
111 | fmt.Sprintf("-target=%s", dc.moduleName),
112 | "-config.file=/etc/single-process-config.yaml",
113 | fmt.Sprintf("-ruler.storage.s3.endpoint=%s", dc.s3URL),
114 | fmt.Sprintf("-ruler.alertmanager-url=%s", dc.alertManagerURL),
115 | fmt.Sprintf("-alertmanager.storage.s3.endpoint=%s", dc.s3URL),
116 | },
117 | ExposedPorts: []string{"9009/tcp"},
118 | ExtraHosts: []string{
119 | "cortex.siren_nginx_1:127.0.0.1",
120 | },
121 | }
122 |
123 | if dc.network != nil {
124 | runOpts.NetworkID = dc.network.Network.ID
125 | }
126 |
127 | pwd, err := os.Getwd()
128 | if err != nil {
129 | return nil, err
130 | }
131 |
132 | var (
133 | rulesFolder = fmt.Sprintf("%s/tmp/dockertest-configs/cortex/rules", pwd)
134 | alertManagerFolder = fmt.Sprintf("%s/tmp/dockertest-configs/cortex/alertmanager", pwd)
135 | )
136 |
137 | foldersPath := []string{rulesFolder, alertManagerFolder}
138 | for _, fp := range foldersPath {
139 | if _, err := os.Stat(fp); os.IsNotExist(err) {
140 | if err := os.MkdirAll(fp, 0777); err != nil {
141 | return nil, err
142 | }
143 | }
144 | }
145 |
146 | _, thisFileName, _, ok := runtime.Caller(0)
147 | if !ok {
148 | return nil, err
149 | }
150 | thisFileFolder := path.Dir(thisFileName)
151 |
152 | dc.dockertestResource, err = dc.pool.RunWithOptions(
153 | runOpts,
154 | func(config *docker.HostConfig) {
155 | config.RestartPolicy = docker.RestartPolicy{
156 | Name: "no",
157 | }
158 | config.Mounts = []docker.HostMount{
159 | {
160 | Target: "/etc/single-process-config.yaml",
161 | Source: fmt.Sprintf("%s/configs/cortex/single_process_cortex.yaml", thisFileFolder),
162 | Type: "bind",
163 | },
164 | {
165 | Target: "/tmp/cortex/rules",
166 | Source: rulesFolder,
167 | Type: "bind",
168 | },
169 | {
170 | Target: "/tmp/cortex/alertmanager",
171 | Source: alertManagerFolder,
172 | Type: "bind",
173 | },
174 | }
175 | },
176 | )
177 | if err != nil {
178 | return nil, err
179 | }
180 |
181 | externalPort := dc.dockertestResource.GetPort("9009/tcp")
182 | dc.internalHost = fmt.Sprintf("%s:9009", name)
183 | dc.externalHost = fmt.Sprintf("localhost:%s", externalPort)
184 |
185 | if err = dc.dockertestResource.Expire(120); err != nil {
186 | return nil, err
187 | }
188 |
189 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet
190 | dc.pool.MaxWait = 60 * time.Second
191 |
192 | if err = dc.pool.Retry(func() error {
193 | httpClient := &http.Client{}
194 | res, err := httpClient.Get(fmt.Sprintf("http://localhost:%s/config", externalPort))
195 | if err != nil {
196 | return err
197 | }
198 |
199 | if res.StatusCode != 200 {
200 | return fmt.Errorf("cortex server return status %d", res.StatusCode)
201 | }
202 |
203 | return nil
204 | }); err != nil {
205 | err = fmt.Errorf("could not connect to docker: %w", err)
206 | return nil, fmt.Errorf("could not connect to docker: %w", err)
207 | }
208 |
209 | return dc, nil
210 | }
211 |
212 | // GetInternalHost returns internal hostname and port
213 | // e.g. internal-xxxxxx:8080
214 | func (dc *dockerCortex) GetInternalHost() string {
215 | return dc.internalHost
216 | }
217 |
218 | // GetExternalHost returns localhost and port
219 | // e.g. localhost:51113
220 | func (dc *dockerCortex) GetExternalHost() string {
221 | return dc.externalHost
222 | }
223 |
224 | // GetPool returns docker pool
225 | func (dc *dockerCortex) GetPool() *dockertest.Pool {
226 | return dc.pool
227 | }
228 |
229 | // GetResource returns docker resource
230 | func (dc *dockerCortex) GetResource() *dockertest.Resource {
231 | return dc.dockertestResource
232 | }
233 |
--------------------------------------------------------------------------------
/testing/dockertestx/dockertestx.go:
--------------------------------------------------------------------------------
1 | package dockertestx
2 |
3 | import "runtime"
4 |
5 | func DockerHostAddress() string {
6 | var dockerHostInternal = "host-gateway" // linux by default
7 | if runtime.GOOS == "darwin" {
8 | dockerHostInternal = "host.docker.internal"
9 | }
10 | return dockerHostInternal
11 | }
12 |
--------------------------------------------------------------------------------
/testing/dockertestx/minio.go:
--------------------------------------------------------------------------------
1 | package dockertestx
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "time"
7 |
8 | "github.com/google/uuid"
9 | "github.com/ory/dockertest/v3"
10 | "github.com/ory/dockertest/v3/docker"
11 | )
12 |
13 | const (
14 | defaultMinioRootUser = "minio"
15 | defaultMinioRootPassword = "minio123"
16 | defaultMinioDomain = "localhost"
17 | )
18 |
19 | type dockerMinioOption func(dm *dockerMinio)
20 |
21 | // MinioWithDockertestNetwork is an option to assign docker network
22 | func MinioWithDockertestNetwork(network *dockertest.Network) dockerMinioOption {
23 | return func(dm *dockerMinio) {
24 | dm.network = network
25 | }
26 | }
27 |
28 | // MinioWithVersionTag is an option to assign release tag
29 | // of a `quay.io/minio/minio` image
30 | func MinioWithVersionTag(versionTag string) dockerMinioOption {
31 | return func(dm *dockerMinio) {
32 | dm.versionTag = versionTag
33 | }
34 | }
35 |
36 | // MinioWithDockerPool is an option to assign docker pool
37 | func MinioWithDockerPool(pool *dockertest.Pool) dockerMinioOption {
38 | return func(dm *dockerMinio) {
39 | dm.pool = pool
40 | }
41 | }
42 |
43 | type dockerMinio struct {
44 | network *dockertest.Network
45 | pool *dockertest.Pool
46 | rootUser string
47 | rootPassword string
48 | domain string
49 | versionTag string
50 | internalHost string
51 | externalHost string
52 | externalConsoleHost string
53 | dockertestResource *dockertest.Resource
54 | }
55 |
56 | // CreateMinio creates a minio instance with default configurations
57 | func CreateMinio(opts ...dockerMinioOption) (*dockerMinio, error) {
58 | var (
59 | err error
60 | dm = &dockerMinio{}
61 | )
62 |
63 | for _, opt := range opts {
64 | opt(dm)
65 | }
66 |
67 | name := fmt.Sprintf("minio-%s", uuid.New().String())
68 |
69 | if dm.pool == nil {
70 | dm.pool, err = dockertest.NewPool("")
71 | if err != nil {
72 | return nil, fmt.Errorf("could not create dockertest pool: %w", err)
73 | }
74 | }
75 |
76 | if dm.rootUser == "" {
77 | dm.rootUser = defaultMinioRootUser
78 | }
79 |
80 | if dm.rootPassword == "" {
81 | dm.rootPassword = defaultMinioRootPassword
82 | }
83 |
84 | if dm.domain == "" {
85 | dm.domain = defaultMinioDomain
86 | }
87 |
88 | if dm.versionTag == "" {
89 | dm.versionTag = "RELEASE.2022-09-07T22-25-02Z"
90 | }
91 |
92 | runOpts := &dockertest.RunOptions{
93 | Name: name,
94 | Repository: "quay.io/minio/minio",
95 | Tag: dm.versionTag,
96 | Env: []string{
97 | "MINIO_ROOT_USER=" + dm.rootUser,
98 | "MINIO_ROOT_PASSWORD=" + dm.rootPassword,
99 | "MINIO_DOMAIN=" + dm.domain,
100 | },
101 | Cmd: []string{"server", "/data1", "--console-address", ":9001"},
102 | ExposedPorts: []string{"9000/tcp", "9001/tcp"},
103 | }
104 |
105 | if dm.network != nil {
106 | runOpts.NetworkID = dm.network.Network.ID
107 | }
108 |
109 | dm.dockertestResource, err = dm.pool.RunWithOptions(
110 | runOpts,
111 | func(config *docker.HostConfig) {
112 | config.RestartPolicy = docker.RestartPolicy{
113 | Name: "no",
114 | }
115 | },
116 | )
117 | if err != nil {
118 | return nil, err
119 | }
120 |
121 | minioPort := dm.dockertestResource.GetPort("9000/tcp")
122 | minioConsolePort := dm.dockertestResource.GetPort("9001/tcp")
123 |
124 | dm.internalHost = fmt.Sprintf("%s:%s", name, "9000")
125 | dm.externalHost = fmt.Sprintf("%s:%s", "localhost", minioPort)
126 | dm.externalConsoleHost = fmt.Sprintf("%s:%s", "localhost", minioConsolePort)
127 |
128 | if err = dm.dockertestResource.Expire(120); err != nil {
129 | return nil, err
130 | }
131 |
132 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet
133 | dm.pool.MaxWait = 60 * time.Second
134 |
135 | if err = dm.pool.Retry(func() error {
136 | httpClient := &http.Client{}
137 | res, err := httpClient.Get(fmt.Sprintf("http://localhost:%s/minio/health/live", minioPort))
138 | if err != nil {
139 | return err
140 | }
141 |
142 | if res.StatusCode != 200 {
143 | return fmt.Errorf("minio server return status %d", res.StatusCode)
144 | }
145 |
146 | return nil
147 | }); err != nil {
148 | err = fmt.Errorf("could not connect to docker: %w", err)
149 | return nil, fmt.Errorf("could not connect to docker: %w", err)
150 | }
151 |
152 | return dm, nil
153 | }
154 |
155 | func (dm *dockerMinio) GetInternalHost() string {
156 | return dm.internalHost
157 | }
158 |
159 | func (dm *dockerMinio) GetExternalHost() string {
160 | return dm.externalHost
161 | }
162 |
163 | func (dm *dockerMinio) GetExternalConsoleHost() string {
164 | return dm.externalConsoleHost
165 | }
166 |
167 | // GetPool returns docker pool
168 | func (dm *dockerMinio) GetPool() *dockertest.Pool {
169 | return dm.pool
170 | }
171 |
172 | // GetResource returns docker resource
173 | func (dm *dockerMinio) GetResource() *dockertest.Resource {
174 | return dm.dockertestResource
175 | }
176 |
--------------------------------------------------------------------------------
/testing/dockertestx/minio_migrate.go:
--------------------------------------------------------------------------------
1 | package dockertestx
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 | "time"
8 |
9 | "github.com/ory/dockertest/v3"
10 | "github.com/ory/dockertest/v3/docker"
11 | )
12 |
13 | const waitContainerTimeout = 60 * time.Second
14 |
15 | type dockerMigrateMinioOption func(dmm *dockerMigrateMinio)
16 |
17 | // MigrateMinioWithDockertestNetwork is an option to assign docker network
18 | func MigrateMinioWithDockertestNetwork(network *dockertest.Network) dockerMigrateMinioOption {
19 | return func(dm *dockerMigrateMinio) {
20 | dm.network = network
21 | }
22 | }
23 |
24 | // MigrateMinioWithVersionTag is an option to assign release tag
25 | // of a `minio/mc` image
26 | func MigrateMinioWithVersionTag(versionTag string) dockerMigrateMinioOption {
27 | return func(dm *dockerMigrateMinio) {
28 | dm.versionTag = versionTag
29 | }
30 | }
31 |
32 | // MigrateMinioWithDockerPool is an option to assign docker pool
33 | func MigrateMinioWithDockerPool(pool *dockertest.Pool) dockerMigrateMinioOption {
34 | return func(dm *dockerMigrateMinio) {
35 | dm.pool = pool
36 | }
37 | }
38 |
39 | type dockerMigrateMinio struct {
40 | network *dockertest.Network
41 | pool *dockertest.Pool
42 | versionTag string
43 | }
44 |
45 | // MigrateMinio does migration of a `bucketName` to a minio located in `minioHost`
46 | func MigrateMinio(minioHost string, bucketName string, opts ...dockerMigrateMinioOption) error {
47 | var (
48 | err error
49 | dm = &dockerMigrateMinio{}
50 | )
51 |
52 | for _, opt := range opts {
53 | opt(dm)
54 | }
55 |
56 | if dm.pool == nil {
57 | dm.pool, err = dockertest.NewPool("")
58 | if err != nil {
59 | return fmt.Errorf("could not create dockertest pool: %w", err)
60 | }
61 | }
62 |
63 | if dm.versionTag == "" {
64 | dm.versionTag = "RELEASE.2022-08-28T20-08-11Z"
65 | }
66 |
67 | runOpts := &dockertest.RunOptions{
68 | Repository: "minio/mc",
69 | Tag: dm.versionTag,
70 | Entrypoint: []string{
71 | "bin/sh",
72 | "-c",
73 | fmt.Sprintf(`
74 | /usr/bin/mc alias set myminio http://%s minio minio123;
75 | /usr/bin/mc rm -r --force %s;
76 | /usr/bin/mc mb myminio/%s;
77 | `, minioHost, bucketName, bucketName),
78 | },
79 | }
80 |
81 | if dm.network != nil {
82 | runOpts.NetworkID = dm.network.Network.ID
83 | }
84 |
85 | resource, err := dm.pool.RunWithOptions(runOpts, func(config *docker.HostConfig) {
86 | config.RestartPolicy = docker.RestartPolicy{
87 | Name: "no",
88 | }
89 | })
90 | if err != nil {
91 | return err
92 | }
93 |
94 | if err := resource.Expire(120); err != nil {
95 | return err
96 | }
97 |
98 | waitCtx, cancel := context.WithTimeout(context.Background(), waitContainerTimeout)
99 | defer cancel()
100 |
101 | // Ensure the command completed successfully.
102 | status, err := dm.pool.Client.WaitContainerWithContext(resource.Container.ID, waitCtx)
103 | if err != nil {
104 | return err
105 | }
106 |
107 | if status != 0 {
108 | stream := new(bytes.Buffer)
109 |
110 | if err = dm.pool.Client.Logs(docker.LogsOptions{
111 | Context: waitCtx,
112 | OutputStream: stream,
113 | ErrorStream: stream,
114 | Stdout: true,
115 | Stderr: true,
116 | Container: resource.Container.ID,
117 | }); err != nil {
118 | return err
119 | }
120 |
121 | return fmt.Errorf("got non-zero exit code %s", stream.String())
122 | }
123 |
124 | if err := dm.pool.Purge(resource); err != nil {
125 | return err
126 | }
127 |
128 | return nil
129 | }
130 |
--------------------------------------------------------------------------------
/testing/dockertestx/nginx.go:
--------------------------------------------------------------------------------
1 | package dockertestx
2 |
3 | import (
4 | "bytes"
5 | _ "embed"
6 | "fmt"
7 | "io/fs"
8 | "net/http"
9 | "os"
10 | "path"
11 | "text/template"
12 | "time"
13 |
14 | "github.com/google/uuid"
15 | "github.com/ory/dockertest/v3"
16 | "github.com/ory/dockertest/v3/docker"
17 | )
18 |
19 | const (
20 | nginxDefaultHealthEndpoint = "/healthz"
21 | nginxDefaultExposedPort = "8080"
22 | nginxDefaultVersionTag = "1.23"
23 | )
24 |
25 | var (
26 | //go:embed configs/nginx/cortex_nginx.conf
27 | NginxCortexConfig string
28 | )
29 |
30 | type dockerNginxOption func(dc *dockerNginx)
31 |
32 | // NginxWithHealthEndpoint is an option to assign health endpoint
33 | func NginxWithHealthEndpoint(healthEndpoint string) dockerNginxOption {
34 | return func(dc *dockerNginx) {
35 | dc.healthEndpoint = healthEndpoint
36 | }
37 | }
38 |
39 | // NginxWithDockertestNetwork is an option to assign docker network
40 | func NginxWithDockertestNetwork(network *dockertest.Network) dockerNginxOption {
41 | return func(dc *dockerNginx) {
42 | dc.network = network
43 | }
44 | }
45 |
46 | // NginxWithVersionTag is an option to assign release tag
47 | // of a `nginx` image
48 | func NginxWithVersionTag(versionTag string) dockerNginxOption {
49 | return func(dc *dockerNginx) {
50 | dc.versionTag = versionTag
51 | }
52 | }
53 |
54 | // NginxWithDockerPool is an option to assign docker pool
55 | func NginxWithDockerPool(pool *dockertest.Pool) dockerNginxOption {
56 | return func(dc *dockerNginx) {
57 | dc.pool = pool
58 | }
59 | }
60 |
61 | // NginxWithDockerPool is an option to assign docker pool
62 | func NginxWithExposedPort(port string) dockerNginxOption {
63 | return func(dc *dockerNginx) {
64 | dc.exposedPort = port
65 | }
66 | }
67 |
68 | func NginxWithPresetConfig(presetConfig string) dockerNginxOption {
69 | return func(dc *dockerNginx) {
70 | dc.presetConfig = presetConfig
71 | }
72 | }
73 |
74 | func NginxWithConfigVariables(cv map[string]string) dockerNginxOption {
75 | return func(dc *dockerNginx) {
76 | dc.configVariables = cv
77 | }
78 | }
79 |
80 | type dockerNginx struct {
81 | network *dockertest.Network
82 | pool *dockertest.Pool
83 | exposedPort string
84 | internalHost string
85 | externalHost string
86 | presetConfig string
87 | versionTag string
88 | healthEndpoint string
89 | configVariables map[string]string
90 | dockertestResource *dockertest.Resource
91 | }
92 |
93 | // CreateNginx is a function to create a dockerized nginx
94 | func CreateNginx(opts ...dockerNginxOption) (*dockerNginx, error) {
95 | var (
96 | err error
97 | dc = &dockerNginx{}
98 | )
99 |
100 | for _, opt := range opts {
101 | opt(dc)
102 | }
103 |
104 | name := fmt.Sprintf("nginx-%s", uuid.New().String())
105 |
106 | if dc.pool == nil {
107 | dc.pool, err = dockertest.NewPool("")
108 | if err != nil {
109 | return nil, fmt.Errorf("could not create dockertest pool: %w", err)
110 | }
111 | }
112 |
113 | if dc.versionTag == "" {
114 | dc.versionTag = nginxDefaultVersionTag
115 | }
116 |
117 | if dc.exposedPort == "" {
118 | dc.exposedPort = nginxDefaultExposedPort
119 | }
120 |
121 | if dc.healthEndpoint == "" {
122 | dc.healthEndpoint = nginxDefaultHealthEndpoint
123 | }
124 |
125 | runOpts := &dockertest.RunOptions{
126 | Name: name,
127 | Repository: "nginx",
128 | Tag: dc.versionTag,
129 | ExposedPorts: []string{fmt.Sprintf("%s/tcp", dc.exposedPort)},
130 | }
131 |
132 | if dc.network != nil {
133 | runOpts.NetworkID = dc.network.Network.ID
134 | }
135 |
136 | var confString string
137 | switch dc.presetConfig {
138 | case "cortex":
139 | confString = NginxCortexConfig
140 | }
141 |
142 | tmpl := template.New("nginx-config")
143 | parsedTemplate, err := tmpl.Parse(confString)
144 | if err != nil {
145 | return nil, err
146 | }
147 | var generatedConf bytes.Buffer
148 | err = parsedTemplate.Execute(&generatedConf, dc.configVariables)
149 | if err != nil {
150 | // it is unlikely that the code returns error here
151 | return nil, err
152 | }
153 | confString = generatedConf.String()
154 |
155 | pwd, err := os.Getwd()
156 | if err != nil {
157 | return nil, err
158 | }
159 |
160 | var (
161 | confDestinationFolder = fmt.Sprintf("%s/tmp/dockertest-configs/nginx", pwd)
162 | )
163 |
164 | foldersPath := []string{confDestinationFolder}
165 | for _, fp := range foldersPath {
166 | if _, err := os.Stat(fp); os.IsNotExist(err) {
167 | if err := os.MkdirAll(fp, 0777); err != nil {
168 | return nil, err
169 | }
170 | }
171 | }
172 |
173 | if err := os.WriteFile(path.Join(confDestinationFolder, "nginx.conf"), []byte(confString), fs.ModePerm); err != nil {
174 | return nil, err
175 | }
176 |
177 | dc.dockertestResource, err = dc.pool.RunWithOptions(
178 | runOpts,
179 | func(config *docker.HostConfig) {
180 | config.RestartPolicy = docker.RestartPolicy{
181 | Name: "no",
182 | }
183 | config.Mounts = []docker.HostMount{
184 | {
185 | Target: "/etc/nginx/nginx.conf",
186 | Source: path.Join(confDestinationFolder, "nginx.conf"),
187 | Type: "bind",
188 | },
189 | }
190 | },
191 | )
192 | if err != nil {
193 | return nil, err
194 | }
195 |
196 | externalPort := dc.dockertestResource.GetPort(fmt.Sprintf("%s/tcp", dc.exposedPort))
197 | dc.internalHost = fmt.Sprintf("%s:%s", name, dc.exposedPort)
198 | dc.externalHost = fmt.Sprintf("localhost:%s", externalPort)
199 |
200 | if err = dc.dockertestResource.Expire(120); err != nil {
201 | return nil, err
202 | }
203 |
204 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet
205 | dc.pool.MaxWait = 60 * time.Second
206 |
207 | if err = dc.pool.Retry(func() error {
208 | httpClient := &http.Client{}
209 | res, err := httpClient.Get(fmt.Sprintf("http://localhost:%s%s", externalPort, dc.healthEndpoint))
210 | if err != nil {
211 | return err
212 | }
213 |
214 | if res.StatusCode != 200 {
215 | return fmt.Errorf("nginx server return status %d", res.StatusCode)
216 | }
217 |
218 | return nil
219 | }); err != nil {
220 | err = fmt.Errorf("could not connect to docker: %w", err)
221 | return nil, fmt.Errorf("could not connect to docker: %w", err)
222 | }
223 |
224 | return dc, nil
225 | }
226 |
227 | // GetPool returns docker pool
228 | func (dc *dockerNginx) GetPool() *dockertest.Pool {
229 | return dc.pool
230 | }
231 |
232 | // GetResource returns docker resource
233 | func (dc *dockerNginx) GetResource() *dockertest.Resource {
234 | return dc.dockertestResource
235 | }
236 |
237 | // GetInternalHost returns internal hostname and port
238 | // e.g. internal-xxxxxx:8080
239 | func (dc *dockerNginx) GetInternalHost() string {
240 | return dc.internalHost
241 | }
242 |
243 | // GetExternalHost returns localhost and port
244 | // e.g. localhost:51113
245 | func (dc *dockerNginx) GetExternalHost() string {
246 | return dc.externalHost
247 | }
248 |
--------------------------------------------------------------------------------
/testing/dockertestx/postgres.go:
--------------------------------------------------------------------------------
1 | package dockertestx
2 |
3 | import (
4 | "fmt"
5 | "time"
6 |
7 | "github.com/google/uuid"
8 | "github.com/jmoiron/sqlx"
9 | "github.com/ory/dockertest/v3"
10 | "github.com/ory/dockertest/v3/docker"
11 | "github.com/raystack/salt/log"
12 | )
13 |
14 | const (
15 | defaultPGUname = "test_user"
16 | defaultPGPasswd = "test_pass"
17 | defaultDBname = "test_db"
18 | )
19 |
20 | type dockerPostgresOption func(dpg *dockerPostgres)
21 |
22 | func PostgresWithLogger(logger log.Logger) dockerPostgresOption {
23 | return func(dpg *dockerPostgres) {
24 | dpg.logger = logger
25 | }
26 | }
27 |
28 | // PostgresWithDockertestNetwork is an option to assign docker network
29 | func PostgresWithDockertestNetwork(network *dockertest.Network) dockerPostgresOption {
30 | return func(dpg *dockerPostgres) {
31 | dpg.network = network
32 | }
33 | }
34 |
35 | // PostgresWithDockertestResourceExpiry is an option to assign docker resource expiry time
36 | func PostgresWithDockertestResourceExpiry(expiryInSeconds uint) dockerPostgresOption {
37 | return func(dpg *dockerPostgres) {
38 | dpg.expiryInSeconds = expiryInSeconds
39 | }
40 | }
41 |
42 | // PostgresWithDetail is an option to assign custom details
43 | // like username, password, and database name
44 | func PostgresWithDetail(
45 | username string,
46 | password string,
47 | dbName string,
48 | ) dockerPostgresOption {
49 | return func(dpg *dockerPostgres) {
50 | dpg.username = username
51 | dpg.password = password
52 | dpg.dbName = dbName
53 | }
54 | }
55 |
56 | // PostgresWithVersionTag is an option to assign release tag
57 | // of a `postgres` image
58 | func PostgresWithVersionTag(versionTag string) dockerPostgresOption {
59 | return func(dpg *dockerPostgres) {
60 | dpg.versionTag = versionTag
61 | }
62 | }
63 |
64 | // PostgresWithDockerPool is an option to assign docker pool
65 | func PostgresWithDockerPool(pool *dockertest.Pool) dockerPostgresOption {
66 | return func(dpg *dockerPostgres) {
67 | dpg.pool = pool
68 | }
69 | }
70 |
71 | type dockerPostgres struct {
72 | logger log.Logger
73 | network *dockertest.Network
74 | pool *dockertest.Pool
75 | username string
76 | password string
77 | dbName string
78 | versionTag string
79 | connStringInternal string
80 | connStringExternal string
81 | expiryInSeconds uint
82 | dockertestResource *dockertest.Resource
83 | }
84 |
85 | // CreatePostgres creates a postgres instance with default configurations
86 | func CreatePostgres(opts ...dockerPostgresOption) (*dockerPostgres, error) {
87 | var (
88 | err error
89 | dpg = &dockerPostgres{}
90 | )
91 |
92 | for _, opt := range opts {
93 | opt(dpg)
94 | }
95 |
96 | name := fmt.Sprintf("postgres-%s", uuid.New().String())
97 |
98 | if dpg.pool == nil {
99 | dpg.pool, err = dockertest.NewPool("")
100 | if err != nil {
101 | return nil, fmt.Errorf("could not create dockertest pool: %w", err)
102 | }
103 | }
104 |
105 | if dpg.username == "" {
106 | dpg.username = defaultPGUname
107 | }
108 |
109 | if dpg.password == "" {
110 | dpg.password = defaultPGPasswd
111 | }
112 |
113 | if dpg.dbName == "" {
114 | dpg.dbName = defaultDBname
115 | }
116 |
117 | if dpg.versionTag == "" {
118 | dpg.versionTag = "12"
119 | }
120 |
121 | if dpg.expiryInSeconds == 0 {
122 | dpg.expiryInSeconds = 120
123 | }
124 |
125 | runOpts := &dockertest.RunOptions{
126 | Name: name,
127 | Repository: "postgres",
128 | Tag: dpg.versionTag,
129 | Env: []string{
130 | "POSTGRES_PASSWORD=" + dpg.password,
131 | "POSTGRES_USER=" + dpg.username,
132 | "POSTGRES_DB=" + dpg.dbName,
133 | },
134 | ExposedPorts: []string{"5432/tcp"},
135 | }
136 |
137 | if dpg.network != nil {
138 | runOpts.NetworkID = dpg.network.Network.ID
139 | }
140 |
141 | dpg.dockertestResource, err = dpg.pool.RunWithOptions(
142 | runOpts,
143 | func(config *docker.HostConfig) {
144 | config.RestartPolicy = docker.RestartPolicy{
145 | Name: "no",
146 | }
147 | },
148 | )
149 | if err != nil {
150 | return nil, err
151 | }
152 |
153 | pgPort := dpg.dockertestResource.GetPort("5432/tcp")
154 | dpg.connStringInternal = fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", dpg.username, dpg.password, name, "5432", dpg.dbName)
155 | dpg.connStringExternal = fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", dpg.username, dpg.password, "localhost", pgPort, dpg.dbName)
156 |
157 | if err = dpg.dockertestResource.Expire(dpg.expiryInSeconds); err != nil {
158 | return nil, err
159 | }
160 |
161 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet
162 | dpg.pool.MaxWait = 60 * time.Second
163 |
164 | if err = dpg.pool.Retry(func() error {
165 | if _, err := sqlx.Connect("postgres", dpg.connStringExternal); err != nil {
166 | return err
167 | }
168 | return nil
169 | }); err != nil {
170 | err = fmt.Errorf("could not connect to docker: %w", err)
171 | return nil, fmt.Errorf("could not connect to docker: %w", err)
172 | }
173 |
174 | return dpg, nil
175 | }
176 |
177 | // GetInternalConnString returns internal connection string of a postgres instance
178 | func (dpg *dockerPostgres) GetInternalConnString() string {
179 | return dpg.connStringInternal
180 | }
181 |
182 | // GetExternalConnString returns external connection string of a postgres instance
183 | func (dpg *dockerPostgres) GetExternalConnString() string {
184 | return dpg.connStringExternal
185 | }
186 |
187 | // GetPool returns docker pool
188 | func (dpg *dockerPostgres) GetPool() *dockertest.Pool {
189 | return dpg.pool
190 | }
191 |
192 | // GetResource returns docker resource
193 | func (dpg *dockerPostgres) GetResource() *dockertest.Resource {
194 | return dpg.dockertestResource
195 | }
196 |
--------------------------------------------------------------------------------
/testing/dockertestx/spicedb.go:
--------------------------------------------------------------------------------
1 | package dockertestx
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "time"
7 |
8 | authzedpb "github.com/authzed/authzed-go/proto/authzed/api/v1"
9 | "github.com/authzed/authzed-go/v1"
10 | "github.com/authzed/grpcutil"
11 | "github.com/google/uuid"
12 | "github.com/ory/dockertest/v3"
13 | "github.com/ory/dockertest/v3/docker"
14 | "google.golang.org/grpc"
15 | "google.golang.org/grpc/codes"
16 | "google.golang.org/grpc/credentials/insecure"
17 | "google.golang.org/grpc/status"
18 | )
19 |
20 | const (
21 | defaultPreSharedKey = "default-preshared-key"
22 | defaultLogLevel = "debug"
23 | )
24 |
25 | type dockerSpiceDBOption func(dsp *dockerSpiceDB)
26 |
27 | func SpiceDBWithLogLevel(logLevel string) dockerSpiceDBOption {
28 | return func(dsp *dockerSpiceDB) {
29 | dsp.logLevel = logLevel
30 | }
31 | }
32 |
33 | // SpiceDBWithDockertestNetwork is an option to assign docker network
34 | func SpiceDBWithDockertestNetwork(network *dockertest.Network) dockerSpiceDBOption {
35 | return func(dsp *dockerSpiceDB) {
36 | dsp.network = network
37 | }
38 | }
39 |
40 | // SpiceDBWithVersionTag is an option to assign release tag
41 | // of a `quay.io/authzed/spicedb` image
42 | func SpiceDBWithVersionTag(versionTag string) dockerSpiceDBOption {
43 | return func(dsp *dockerSpiceDB) {
44 | dsp.versionTag = versionTag
45 | }
46 | }
47 |
48 | // SpiceDBWithDockerPool is an option to assign docker pool
49 | func SpiceDBWithDockerPool(pool *dockertest.Pool) dockerSpiceDBOption {
50 | return func(dsp *dockerSpiceDB) {
51 | dsp.pool = pool
52 | }
53 | }
54 |
55 | // SpiceDBWithPreSharedKey is an option to assign pre-shared-key
56 | func SpiceDBWithPreSharedKey(preSharedKey string) dockerSpiceDBOption {
57 | return func(dsp *dockerSpiceDB) {
58 | dsp.preSharedKey = preSharedKey
59 | }
60 | }
61 |
62 | type dockerSpiceDB struct {
63 | network *dockertest.Network
64 | pool *dockertest.Pool
65 | preSharedKey string
66 | versionTag string
67 | logLevel string
68 | externalPort string
69 | dockertestResource *dockertest.Resource
70 | }
71 |
72 | // CreateSpiceDB creates a spicedb instance with postgres backend and default configurations
73 | func CreateSpiceDB(postgresConnectionURL string, opts ...dockerSpiceDBOption) (*dockerSpiceDB, error) {
74 | var (
75 | err error
76 | dsp = &dockerSpiceDB{}
77 | )
78 |
79 | for _, opt := range opts {
80 | opt(dsp)
81 | }
82 |
83 | name := fmt.Sprintf("spicedb-%s", uuid.New().String())
84 |
85 | if dsp.pool == nil {
86 | dsp.pool, err = dockertest.NewPool("")
87 | if err != nil {
88 | return nil, fmt.Errorf("could not create dockertest pool: %w", err)
89 | }
90 | }
91 |
92 | if dsp.preSharedKey == "" {
93 | dsp.preSharedKey = defaultPreSharedKey
94 | }
95 |
96 | if dsp.logLevel == "" {
97 | dsp.logLevel = defaultLogLevel
98 | }
99 |
100 | if dsp.versionTag == "" {
101 | dsp.versionTag = "v1.0.0"
102 | }
103 |
104 | runOpts := &dockertest.RunOptions{
105 | Name: name,
106 | Repository: "quay.io/authzed/spicedb",
107 | Tag: dsp.versionTag,
108 | Cmd: []string{"spicedb", "serve", "--log-level", dsp.logLevel, "--grpc-preshared-key", dsp.preSharedKey, "--grpc-no-tls", "--datastore-engine", "postgres", "--datastore-conn-uri", postgresConnectionURL},
109 | ExposedPorts: []string{"50051/tcp"},
110 | }
111 |
112 | if dsp.network != nil {
113 | runOpts.NetworkID = dsp.network.Network.ID
114 | }
115 |
116 | dsp.dockertestResource, err = dsp.pool.RunWithOptions(
117 | runOpts,
118 | func(config *docker.HostConfig) {
119 | config.RestartPolicy = docker.RestartPolicy{
120 | Name: "no",
121 | }
122 | },
123 | )
124 | if err != nil {
125 | return nil, err
126 | }
127 |
128 | dsp.externalPort = dsp.dockertestResource.GetPort("50051/tcp")
129 |
130 | if err = dsp.dockertestResource.Expire(120); err != nil {
131 | return nil, err
132 | }
133 |
134 | // exponential backoff-retry, because the application in the container might not be ready to accept connections yet
135 | dsp.pool.MaxWait = 60 * time.Second
136 |
137 | if err = dsp.pool.Retry(func() error {
138 | client, err := authzed.NewClient(
139 | fmt.Sprintf("localhost:%s", dsp.externalPort),
140 | grpc.WithTransportCredentials(insecure.NewCredentials()),
141 | grpcutil.WithInsecureBearerToken(dsp.preSharedKey),
142 | )
143 | if err != nil {
144 | return err
145 | }
146 | _, err = client.ReadSchema(context.Background(), &authzedpb.ReadSchemaRequest{})
147 | grpCStatus := status.Convert(err)
148 | if grpCStatus.Code() == codes.Unavailable {
149 | return err
150 | }
151 | return nil
152 | }); err != nil {
153 | err = fmt.Errorf("could not connect to docker: %w", err)
154 | return nil, fmt.Errorf("could not connect to docker: %w", err)
155 | }
156 |
157 | return dsp, nil
158 | }
159 |
160 | // GetExternalPort returns exposed port of the spicedb instance
161 | func (dsp *dockerSpiceDB) GetExternalPort() string {
162 | return dsp.externalPort
163 | }
164 |
165 | // GetPreSharedKey returns pre-shared-key used in the spicedb instance
166 | func (dsp *dockerSpiceDB) GetPreSharedKey() string {
167 | return dsp.preSharedKey
168 | }
169 |
170 | // GetPool returns docker pool
171 | func (dsp *dockerSpiceDB) GetPool() *dockertest.Pool {
172 | return dsp.pool
173 | }
174 |
175 | // GetResource returns docker resource
176 | func (dsp *dockerSpiceDB) GetResource() *dockertest.Resource {
177 | return dsp.dockertestResource
178 | }
179 |
--------------------------------------------------------------------------------
/testing/dockertestx/spicedb_migrate.go:
--------------------------------------------------------------------------------
1 | package dockertestx
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "fmt"
7 |
8 | "github.com/ory/dockertest/v3"
9 | "github.com/ory/dockertest/v3/docker"
10 | )
11 |
12 | type dockerMigrateSpiceDBOption func(dmm *dockerMigrateSpiceDB)
13 |
14 | // MigrateSpiceDBWithDockertestNetwork is an option to assign docker network
15 | func MigrateSpiceDBWithDockertestNetwork(network *dockertest.Network) dockerMigrateSpiceDBOption {
16 | return func(dm *dockerMigrateSpiceDB) {
17 | dm.network = network
18 | }
19 | }
20 |
21 | // MigrateSpiceDBWithVersionTag is an option to assign release tag
22 | // of a `quay.io/authzed/spicedb` image
23 | func MigrateSpiceDBWithVersionTag(versionTag string) dockerMigrateSpiceDBOption {
24 | return func(dm *dockerMigrateSpiceDB) {
25 | dm.versionTag = versionTag
26 | }
27 | }
28 |
29 | // MigrateSpiceDBWithDockerPool is an option to assign docker pool
30 | func MigrateSpiceDBWithDockerPool(pool *dockertest.Pool) dockerMigrateSpiceDBOption {
31 | return func(dm *dockerMigrateSpiceDB) {
32 | dm.pool = pool
33 | }
34 | }
35 |
36 | type dockerMigrateSpiceDB struct {
37 | network *dockertest.Network
38 | pool *dockertest.Pool
39 | versionTag string
40 | }
41 |
42 | // MigrateSpiceDB migrates spicedb with postgres backend
43 | func MigrateSpiceDB(postgresConnectionURL string, opts ...dockerMigrateMinioOption) error {
44 | var (
45 | err error
46 | dm = &dockerMigrateMinio{}
47 | )
48 |
49 | for _, opt := range opts {
50 | opt(dm)
51 | }
52 |
53 | if dm.pool == nil {
54 | dm.pool, err = dockertest.NewPool("")
55 | if err != nil {
56 | return fmt.Errorf("could not create dockertest pool: %w", err)
57 | }
58 | }
59 |
60 | if dm.versionTag == "" {
61 | dm.versionTag = "v1.0.0"
62 | }
63 |
64 | runOpts := &dockertest.RunOptions{
65 | Repository: "quay.io/authzed/spicedb",
66 | Tag: dm.versionTag,
67 | Cmd: []string{"spicedb", "migrate", "head", "--datastore-engine", "postgres", "--datastore-conn-uri", postgresConnectionURL},
68 | }
69 |
70 | if dm.network != nil {
71 | runOpts.NetworkID = dm.network.Network.ID
72 | }
73 |
74 | resource, err := dm.pool.RunWithOptions(runOpts, func(config *docker.HostConfig) {
75 | config.RestartPolicy = docker.RestartPolicy{
76 | Name: "no",
77 | }
78 | })
79 | if err != nil {
80 | return err
81 | }
82 |
83 | if err := resource.Expire(120); err != nil {
84 | return err
85 | }
86 |
87 | waitCtx, cancel := context.WithTimeout(context.Background(), waitContainerTimeout)
88 | defer cancel()
89 |
90 | // Ensure the command completed successfully.
91 | status, err := dm.pool.Client.WaitContainerWithContext(resource.Container.ID, waitCtx)
92 | if err != nil {
93 | return err
94 | }
95 |
96 | if status != 0 {
97 | stream := new(bytes.Buffer)
98 |
99 | if err = dm.pool.Client.Logs(docker.LogsOptions{
100 | Context: waitCtx,
101 | OutputStream: stream,
102 | ErrorStream: stream,
103 | Stdout: true,
104 | Stderr: true,
105 | Container: resource.Container.ID,
106 | }); err != nil {
107 | return err
108 | }
109 |
110 | return fmt.Errorf("got non-zero exit code %s", stream.String())
111 | }
112 |
113 | if err := dm.pool.Purge(resource); err != nil {
114 | return err
115 | }
116 |
117 | return nil
118 | }
119 |
--------------------------------------------------------------------------------
/utils/error_status.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "github.com/pkg/errors"
5 | "google.golang.org/grpc/codes"
6 | "google.golang.org/grpc/status"
7 | )
8 |
9 | var codeToStr = map[codes.Code]string{
10 | codes.OK: `"OK"`,
11 | codes.Canceled: `"CANCELED"`,
12 | codes.Unknown: `"UNKNOWN"`,
13 | codes.InvalidArgument: `"INVALID_ARGUMENT"`,
14 | codes.DeadlineExceeded: `"DEADLINE_EXCEEDED"`,
15 | codes.NotFound: `"NOT_FOUND"`,
16 | codes.AlreadyExists: `"ALREADY_EXISTS"`,
17 | codes.PermissionDenied: `"PERMISSION_DENIED"`,
18 | codes.ResourceExhausted: `"RESOURCE_EXHAUSTED"`,
19 | codes.FailedPrecondition: `"FAILED_PRECONDITION"`,
20 | codes.Aborted: `"ABORTED"`,
21 | codes.OutOfRange: `"OUT_OF_RANGE"`,
22 | codes.Unimplemented: `"UNIMPLEMENTED"`,
23 | codes.Internal: `"INTERNAL"`,
24 | codes.Unavailable: `"UNAVAILABLE"`,
25 | codes.DataLoss: `"DATA_LOSS"`,
26 | codes.Unauthenticated: `"UNAUTHENTICATED"`,
27 | }
28 |
29 | func StatusCode(err error) codes.Code {
30 | if err == nil {
31 | return codes.OK
32 | }
33 | var se interface {
34 | GRPCStatus() *status.Status
35 | }
36 | if errors.As(err, &se) {
37 | return se.GRPCStatus().Code()
38 | }
39 | return codes.Unknown
40 | }
41 | func StatusText(err error) string {
42 | return codeToStr[StatusCode(err)]
43 | }
44 |
--------------------------------------------------------------------------------
/utils/error_status_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "fmt"
5 | "testing"
6 |
7 | "github.com/pkg/errors"
8 | "github.com/stretchr/testify/assert"
9 | "google.golang.org/grpc/codes"
10 | "google.golang.org/grpc/status"
11 | )
12 |
13 | func TestStatusCode(t *testing.T) {
14 | cases := []struct {
15 | name string
16 | err error
17 | expected codes.Code
18 | }{
19 | {
20 | name: "with status.Error",
21 | err: status.Error(codes.NotFound, "Somebody that I used to know"),
22 | expected: codes.NotFound,
23 | },
24 | {
25 | name: "with wrapped status.Error",
26 | err: fmt.Errorf("%w", status.Error(codes.Unavailable, "I shot the sheriff")),
27 | expected: codes.Unavailable,
28 | },
29 | {
30 | name: "with std lib error",
31 | err: errors.New("Runnin' down a dream"),
32 | expected: codes.Unknown,
33 | },
34 | {
35 | name: "with nil error",
36 | err: nil,
37 | expected: codes.OK,
38 | },
39 | }
40 | for _, tc := range cases {
41 | t.Run(tc.name, func(t *testing.T) {
42 | assert.Equal(t, tc.expected, StatusCode(tc.err))
43 | })
44 | }
45 | }
46 |
--------------------------------------------------------------------------------