├── .gitignore ├── .golangci.yaml ├── LICENSE ├── Makefile ├── README.md ├── app.go ├── cconfig ├── cconfigtest │ ├── doc.go │ └── util.go ├── doc.go ├── loader.go ├── loader_test.go └── tree.go ├── cerrors ├── error.go ├── error_test.go ├── utils.go └── utils_test.go ├── chttp ├── chttptest │ ├── content_type.go │ ├── doc.go │ ├── handler.go │ ├── html_reader_writer.go │ ├── json_reader_writer.go │ ├── route.go │ └── src │ │ ├── layouts │ │ └── main.html │ │ ├── pages │ │ ├── index.html │ │ └── not-found.html │ │ └── partials │ │ └── test-component.html ├── config.go ├── doc.go ├── error.html ├── fs.go ├── handler.go ├── handler_test.go ├── html_reader_writer.go ├── html_reader_writer_test.go ├── html_renderer.go ├── html_router.go ├── http.go ├── json_reader_writer.go ├── json_reader_writer_test.go ├── middleware.go ├── panic_logger_mw.go ├── panic_logger_mw_test.go ├── request_id_mw.go ├── request_logger_mw.go ├── request_logger_mw_test.go ├── route_ctx_mw.go ├── route_ctx_mw_test.go ├── server.go ├── server_test.go └── wire.go ├── clifecycle ├── doc.go ├── lifecycle.go └── logger.go ├── clogger ├── config.go ├── doc.go ├── json_redactor.go ├── json_redactor_test.go ├── level.go ├── level_test.go ├── logger.go ├── logger_test.go ├── noop.go ├── noop_test.go ├── recorder.go ├── recorder_test.go ├── util.go └── zap.go ├── cmetrics ├── metrics.go ├── models.go ├── noop.go └── wire.go ├── csql ├── config.go ├── ctx.go ├── ctx_test.go ├── db.go ├── db_test.go ├── doc.go ├── migrations_test.sql ├── migrator.go ├── migrator_test.go ├── qb │ ├── query_builder.go │ └── query_builder_test.go ├── querier.go ├── querier_test.go ├── tx_middleware.go ├── tx_middleware_test.go └── wire.go ├── doc.go ├── flags.go ├── go.mod ├── go.sum ├── wire.go └── wire_gen.go /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | vendor/ 4 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | issues: 2 | exclude-use-default: false 3 | fix: true 4 | 5 | run: 6 | build-tags: 7 | - wireinject 8 | 9 | linters: 10 | disable-all: true 11 | enable: 12 | - bodyclose 13 | - depguard 14 | - dogsled 15 | - dupl 16 | - errcheck 17 | - exportloopref 18 | - exhaustive 19 | - funlen 20 | - gochecknoinits 21 | - goconst 22 | - gocritic 23 | - gocyclo 24 | - gofmt 25 | - goimports 26 | - gomnd 27 | - goprintffuncname 28 | - gosec 29 | - gosimple 30 | - govet 31 | - ineffassign 32 | - misspell 33 | - nakedret 34 | - noctx 35 | - nolintlint 36 | - revive 37 | - rowserrcheck 38 | - staticcheck 39 | - stylecheck 40 | - typecheck 41 | - unconvert 42 | - whitespace 43 | - paralleltest 44 | - tparallel 45 | - testpackage 46 | - thelper 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tushar Soni 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GO=GO111MODULE=on go 2 | WIRE=wire 3 | COVERAGE_FILE="/tmp/copper_coverage.out" 4 | 5 | .PHONY: all 6 | all: lint generate test 7 | 8 | .PHONY: cover 9 | cover: test 10 | $(GO) tool cover -html=$(COVERAGE_FILE) 11 | 12 | .PHONY: test 13 | test: 14 | $(GO) test -coverprofile=$(COVERAGE_FILE) ./... 15 | 16 | .PHONY: lint 17 | lint: tidy 18 | golangci-lint run 19 | 20 | .PHONY: tidy 21 | tidy: 22 | $(GO) mod tidy 23 | 24 | .PHONY: generate 25 | generate: 26 | $(WIRE) . 27 | 28 | .PHONY: release 29 | release: 30 | git tag -a $(version) -m "Release $(version)" 31 | git push origin $(version) 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | Copper logo 4 | 5 |

6 | 7 |

8 | 9 | Go Report Card 10 | 11 | 12 | Go Doc 13 | 14 |

15 | 16 |
17 | 18 | # Copper 19 | 20 |

21 | Copper is a Go toolkit complete with everything you need to build web apps. It focuses on developer productivity and makes building web apps in Go more fun with less boilerplate and out-of-the-box support for common needs. 22 |

23 | 24 | #### 🚀 Fullstack Toolkit 25 |

Copper provides a toolkit complete with everything you need to build web apps quickly.

26 | 27 | 28 | #### 📦 One Binary 29 |

Build frontend apps along with your backend and ship everything in a single binary.

30 | 31 | 32 | #### 📝 Server-side HTML 33 |

Copper includes utilities that help build web apps with server rendered HTML pages.

34 | 35 | #### 💡 Auto Restarts 36 |

Copper detects changes and automatically restarts server to save time.

37 | 38 | #### 🏗 Scaffolding 39 |

Skip boilerplate and scaffold code for your packages, database queries and routes.

40 | 41 | #### 🔋 Batteries Included 42 |

Includes CLI, lint, dev server, config management, and more!

43 | 44 | #### 🔩 First-party packages 45 |

Includes packages for authentication, pub/sub, queues, emails, and websockets.

46 | 47 | 48 |
49 | 50 | ## Intro Video (Hacker News Clone) 51 | [![](https://user-images.githubusercontent.com/2974009/175425653-da11ba79-d9ec-4e82-a2e0-5f5515d3417e.png)](https://vimeo.com/723537998) 52 | 53 |
54 | 55 | ## Getting Started 56 | 57 | Head over to the documentation to get started - [https://docs.gocopper.dev/](https://docs.gocopper.dev/getting-started) 58 | 59 | Join us on Discord here - https://discord.gg/fT2AEZyM6A 60 | 61 |
62 | 63 | ## License 64 | MIT 65 | -------------------------------------------------------------------------------- /app.go: -------------------------------------------------------------------------------- 1 | package copper 2 | 3 | import ( 4 | "log" 5 | "os" 6 | "os/signal" 7 | "syscall" 8 | 9 | "github.com/gocopper/copper/cconfig" 10 | "github.com/gocopper/copper/clifecycle" 11 | "github.com/gocopper/copper/clogger" 12 | ) 13 | 14 | // Runner provides an interface that can be run by a Copper app using the Run or Start funcs. 15 | // This interface is implemented by various packages within Copper including chttp.Server. 16 | type Runner interface { 17 | Run() error 18 | } 19 | 20 | // New provides a convenience wrapper around InitApp that logs and exits if there is an error. 21 | func New() *App { 22 | app, err := InitApp() 23 | if err != nil { 24 | log.Fatal(err) 25 | } 26 | 27 | return app 28 | } 29 | 30 | // NewApp creates a new Copper app and returns it along with the app's lifecycle manager, 31 | // config, and the logger. 32 | func NewApp(lifecycle *clifecycle.Lifecycle, config cconfig.Loader, logger clogger.Logger) *App { 33 | return &App{ 34 | Lifecycle: lifecycle, 35 | Config: config, 36 | Logger: logger, 37 | } 38 | } 39 | 40 | // App defines a Copper app container that can run provided code in its managed lifecycle. 41 | // It provides functionality to read config in multiple environments as defined by 42 | // command-line flags. 43 | type App struct { 44 | Lifecycle *clifecycle.Lifecycle 45 | Config cconfig.Loader 46 | Logger clogger.Logger 47 | } 48 | 49 | // Run runs the provided funcs. Once all of the functions complete their run, 50 | // the lifecycle's stop funcs are also called. If any of the fns return an error, 51 | // the app exits with an exit code 1. 52 | // Run should be used when none of the fn are long-running. For long-running funcs like 53 | // an HTTP server, use Start. 54 | func (a *App) Run(fns ...Runner) { 55 | for i := range fns { 56 | err := fns[i].Run() 57 | if err != nil { 58 | a.Logger.Error("Failed to run", err) 59 | a.Lifecycle.Stop(a.Logger) 60 | os.Exit(1) 61 | } 62 | } 63 | 64 | a.Lifecycle.Stop(a.Logger) 65 | } 66 | 67 | // Start runs the provided fns and then waits on the OS's INT and TERM signals from the 68 | // user to exit. Once the signal is received, the lifecycle's stop funcs are 69 | // called. 70 | // If any of the fns fail to run and returns an error, the app exits with exit code 71 | // 1. 72 | func (a *App) Start(fns ...Runner) { 73 | for i := range fns { 74 | err := fns[i].Run() 75 | if err != nil { 76 | a.Logger.Error("Failed to run", err) 77 | os.Exit(1) 78 | } 79 | } 80 | 81 | osInt := make(chan os.Signal, 1) 82 | 83 | signal.Notify(osInt, syscall.SIGINT, syscall.SIGTERM) 84 | 85 | <-osInt 86 | 87 | a.Lifecycle.Stop(a.Logger) 88 | } 89 | -------------------------------------------------------------------------------- /cconfig/cconfigtest/doc.go: -------------------------------------------------------------------------------- 1 | // Package cconfigtest provides helper methods to test the cconfig package 2 | package cconfigtest 3 | -------------------------------------------------------------------------------- /cconfig/cconfigtest/util.go: -------------------------------------------------------------------------------- 1 | package cconfigtest 2 | 3 | import ( 4 | "os" 5 | "path" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // SetupDirWithConfigs creates a temp directory that can store config files. 12 | // The directory is cleaned up after test run. 13 | func SetupDirWithConfigs(t *testing.T, configs map[string]string) string { 14 | t.Helper() 15 | 16 | dir, err := os.MkdirTemp("", "") 17 | assert.NoError(t, err) 18 | 19 | for fp, data := range configs { 20 | err = os.WriteFile(path.Join(dir, fp), []byte(data), os.ModePerm) 21 | assert.NoError(t, err) 22 | } 23 | 24 | t.Cleanup(func() { 25 | assert.NoError(t, os.RemoveAll(dir)) 26 | }) 27 | 28 | return dir 29 | } 30 | -------------------------------------------------------------------------------- /cconfig/doc.go: -------------------------------------------------------------------------------- 1 | // Package cconfig provides utilities to read app config files. See New and Loader for more documentation and example. 2 | package cconfig 3 | -------------------------------------------------------------------------------- /cconfig/loader.go: -------------------------------------------------------------------------------- 1 | package cconfig 2 | 3 | import ( 4 | "github.com/gocopper/copper/cerrors" 5 | "github.com/pelletier/go-toml" 6 | ) 7 | 8 | type ( 9 | // Path defines the path to the config file. 10 | Path string 11 | // Overrides defines a ';' separated string of config overrides in TOML format 12 | Overrides string 13 | ) 14 | 15 | // Loader provides methods to load config files into structs. 16 | type Loader interface { 17 | // Load reads the config values defined under the given key (TOML table) and sets them into the dest struct. 18 | // For example: 19 | // # prod.toml 20 | // key1 = "val1" 21 | // 22 | // # config.go 23 | // type MyConfig struct { 24 | // Key1 string `toml:"key1"` 25 | // } 26 | // 27 | // func LoadMyConfig(loader cconfig.Loader) (MyConfig, error) { 28 | // var config MyConfig 29 | // 30 | // err := loader.Load("my_config", &config) 31 | // if err != nil { 32 | // return MyConfig{}, cerrors.New(err, "failed to load configs for my_config", nil) 33 | // } 34 | // 35 | // return config, nil 36 | // } 37 | Load(key string, dest interface{}) error 38 | } 39 | 40 | // New provides an implementation of Loader that reads a config file at the given file path. It supports extending the 41 | // config file at the given path by using an 'extends' key. For example, the config file may be extended like so: 42 | // 43 | // # base.toml 44 | // key1 = "val1" 45 | // 46 | // # prod.toml 47 | // extends = "base.toml 48 | // key2 = "val2" 49 | // 50 | // If New is called with the path to prod.toml, it loads both key2 (from prod.toml) and key1 (from base.toml) since 51 | // prod.toml extends base.toml. 52 | // 53 | // The extends key can support multiple files like so: 54 | // extends = ["base.toml", "secrets.toml"] 55 | // 56 | // If a config key is present in multiple files, New returns an error. For example, if prod.toml sets a value for 'key1' 57 | // that has already been set in base.toml, an error will be returned. To enable key overrides see NewWithKeyOverrides. 58 | func New(fp Path, ov Overrides) (Loader, error) { 59 | return newLoader(string(fp), string(ov), true) 60 | } 61 | 62 | // NewWithKeyOverrides works exactly the same way as New except it supports key overrides. For example, this is a valid 63 | // config: 64 | // # base.toml 65 | // key1 = "val1" 66 | // 67 | // # prod.toml 68 | // extends = "base.toml 69 | // key1 = "val2" 70 | // 71 | // If prod.toml is loaded, key1 will be set to "val2" since it has been overridden in prod.toml. 72 | func NewWithKeyOverrides(fp Path, overrides Overrides) (Loader, error) { 73 | return newLoader(string(fp), string(overrides), false) 74 | } 75 | 76 | func newLoader(fp, overrides string, disableKeyOverrides bool) (*loader, error) { 77 | tree, err := loadTree(fp, overrides, disableKeyOverrides) 78 | if err != nil { 79 | return nil, cerrors.New(err, "failed to load config tree", map[string]interface{}{ 80 | "path": fp, 81 | }) 82 | } 83 | 84 | return &loader{ 85 | tree: tree, 86 | }, nil 87 | } 88 | 89 | type loader struct { 90 | tree *toml.Tree 91 | } 92 | 93 | func (l *loader) Load(key string, dest interface{}) error { 94 | if !l.tree.Has(key) { 95 | return nil 96 | } 97 | 98 | keyTree, ok := l.tree.Get(key).(*toml.Tree) 99 | if !ok { 100 | return cerrors.New(nil, "invalid key type", map[string]interface{}{ 101 | "key": key, 102 | }) 103 | } 104 | 105 | err := keyTree.Unmarshal(dest) 106 | if err != nil { 107 | return cerrors.New(err, "failed to unmarshal config into dest", map[string]interface{}{ 108 | "key": key, 109 | }) 110 | } 111 | 112 | return nil 113 | } 114 | -------------------------------------------------------------------------------- /cconfig/loader_test.go: -------------------------------------------------------------------------------- 1 | package cconfig_test 2 | 3 | import ( 4 | "path" 5 | "testing" 6 | 7 | "github.com/gocopper/copper/cconfig" 8 | "github.com/gocopper/copper/cconfig/cconfigtest" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestLoader_Load(t *testing.T) { 13 | t.Parallel() 14 | 15 | var ( 16 | dir = cconfigtest.SetupDirWithConfigs(t, map[string]string{ 17 | "base.toml": ` 18 | [group1] 19 | key1 = "val1-base" 20 | key2 = "val2-base" 21 | `, 22 | "test.toml": ` 23 | extends = "base.toml" 24 | 25 | [group1] 26 | key1 = "val1-test" 27 | `, 28 | }) 29 | fp = cconfig.Path(path.Join(dir, "test.toml")) 30 | ov = cconfig.Overrides("group1.key2=\"val2-override\"") 31 | ) 32 | 33 | t.Run("load with extends and key overrides", func(t *testing.T) { 34 | t.Parallel() 35 | 36 | var testConfig struct { 37 | Key1 string `toml:"key1"` 38 | Key2 string `toml:"key2"` 39 | } 40 | 41 | configs, err := cconfig.NewWithKeyOverrides(fp, ov) 42 | assert.NoError(t, err) 43 | 44 | err = configs.Load("group1", &testConfig) 45 | assert.NoError(t, err) 46 | 47 | assert.Equal(t, "val1-test", testConfig.Key1) 48 | assert.Equal(t, "val2-override", testConfig.Key2) 49 | }) 50 | 51 | t.Run("error with key overrides disabled", func(t *testing.T) { 52 | t.Parallel() 53 | 54 | _, err := cconfig.New(fp, ov) 55 | 56 | assert.NotNil(t, err) 57 | assert.Contains(t, err.Error(), "key is being overridden when key overrides are disabled") 58 | }) 59 | } 60 | -------------------------------------------------------------------------------- /cconfig/tree.go: -------------------------------------------------------------------------------- 1 | package cconfig 2 | 3 | import ( 4 | "html/template" 5 | "os" 6 | "path/filepath" 7 | "reflect" 8 | "strings" 9 | 10 | "github.com/gocopper/copper/cerrors" 11 | "github.com/pelletier/go-toml" 12 | ) 13 | 14 | //nolint:funlen 15 | func loadTree(fp, overrides string, disableKeyOverrides bool) (*toml.Tree, error) { 16 | tmpl, err := template.ParseFiles(fp) 17 | if err != nil { 18 | return nil, cerrors.New(err, "failed to parse config file as template", map[string]interface{}{ 19 | "path": fp, 20 | }) 21 | } 22 | 23 | envVars := make(map[string]string) 24 | for _, e := range os.Environ() { 25 | pair := strings.Split(e, "=") 26 | envVars[pair[0]] = pair[1] 27 | } 28 | 29 | var tomlOut strings.Builder 30 | err = tmpl.Execute(&tomlOut, map[string]interface{}{ 31 | "EnvVars": envVars, 32 | }) 33 | if err != nil { 34 | return nil, cerrors.New(err, "failed to execute config file template", map[string]interface{}{ 35 | "path": fp, 36 | }) 37 | } 38 | 39 | tree, err := toml.LoadBytes([]byte(tomlOut.String())) 40 | if err != nil { 41 | return nil, cerrors.New(err, "failed to load config file", map[string]interface{}{ 42 | "path": fp, 43 | }) 44 | } 45 | 46 | // If the TOML tree does not have a top-level 'extends' key, we can return the tree as-is 47 | if !tree.Has("extends") { 48 | return tree, nil 49 | } 50 | 51 | parentFilePaths := make([]string, 0) 52 | 53 | // The extends key can be a string or a list of strings representing the config file paths that need to be loaded 54 | switch extends := tree.Get("extends").(type) { 55 | case string: 56 | parentFilePaths = append(parentFilePaths, extends) 57 | 58 | // If extends is set to a list, verify each value is a valid string, and add it to parentFilePaths 59 | case []interface{}: 60 | for i := range extends { 61 | parentFilePath, ok := extends[i].(string) 62 | if !ok { 63 | return nil, cerrors.New(nil, "extends can only contain strings", map[string]interface{}{ 64 | "path": fp, 65 | "value": extends[i], 66 | }) 67 | } 68 | 69 | parentFilePaths = append(parentFilePaths, parentFilePath) 70 | } 71 | default: 72 | return nil, cerrors.New(nil, "'extends' must be string or []string", map[string]interface{}{ 73 | "path": fp, 74 | "type": reflect.TypeOf(extends).String(), 75 | }) 76 | } 77 | 78 | // Load each parentFilePath in-order 79 | for _, parentFP := range parentFilePaths { 80 | parentFilePath := filepath.Join(filepath.Dir(fp), parentFP) 81 | 82 | // Load the parent tree at the given path defined by the extends key. Note that this is a recursive call 83 | // that will load all ancestors. 84 | parentTree, err := loadTree(parentFilePath, "", disableKeyOverrides) 85 | if err != nil { 86 | return nil, cerrors.New(err, "failed to load parent tree", map[string]interface{}{ 87 | "parentPath": parentFilePath, 88 | }) 89 | } 90 | 91 | // Once the parent tree and its ancestors are loaded, we need to merge it with our current tree 92 | tree, err = mergeTrees(parentTree, tree, disableKeyOverrides) 93 | if err != nil { 94 | return nil, cerrors.New(err, "failed to merge with parent tree", map[string]interface{}{ 95 | "parentPath": parentFilePath, 96 | }) 97 | } 98 | } 99 | 100 | // Apply overrides 101 | for _, ov := range strings.Split(overrides, ";") { 102 | t, err := toml.Load(ov) 103 | if err != nil { 104 | return nil, cerrors.New(err, "failed to parse override as TOML", map[string]interface{}{ 105 | "override": ov, 106 | }) 107 | } 108 | 109 | tree, err = mergeTrees(tree, t, disableKeyOverrides) 110 | if err != nil { 111 | return nil, cerrors.New(err, "failed to merge tree with overrides", map[string]interface{}{ 112 | "override": ov, 113 | }) 114 | } 115 | } 116 | return tree, nil 117 | } 118 | 119 | //nolint:funlen 120 | func mergeTrees(base, override *toml.Tree, disableKeyOverrides bool) (*toml.Tree, error) { 121 | // For each key in the override tree, we need to apply it to the base tree 122 | for _, key := range override.Keys() { 123 | switch keyVal := override.Get(key).(type) { 124 | // If the value at the given key is a TOML tree (aka a table according to the spec), we need to merge it with 125 | // the base table. 126 | // For example, if the base tree contains: 127 | // [group1] 128 | // key1="val1" 129 | // and the override tree contains: 130 | // [group1] 131 | // key2="val"2 132 | // We need to load it in such a way where group1 contains both key1 and key2 133 | case *toml.Tree: 134 | // If the base does not contain the key, we can set the entire value from the override table as-is 135 | if !base.Has(key) { 136 | base.Set(key, keyVal) 137 | continue 138 | } 139 | 140 | // Verify that the value type in the base and override trees are the same. For example, this is invalid: 141 | // # base.toml 142 | // group1 = "I am a string" 143 | // 144 | // # prod.toml 145 | // extends = "base.toml" 146 | // [group1] # I am a table! 147 | // key1="val"1 148 | // 149 | // The above configuration is invalid because group1 is a table in prod.toml but a string in base.toml. As 150 | // a result, they cannot be merged. 151 | baseTree, ok := base.Get(key).(*toml.Tree) 152 | if !ok { 153 | return nil, cerrors.New(nil, "base and override key types don't match", map[string]interface{}{ 154 | "key": key, 155 | }) 156 | } 157 | 158 | // Now that we have two trees, we can merge them recursively 159 | mergedTree, err := mergeTrees(baseTree, keyVal, disableKeyOverrides) 160 | if err != nil { 161 | return nil, cerrors.New(err, "failed to merge tree for key", map[string]interface{}{ 162 | "key": key, 163 | }) 164 | } 165 | 166 | base.Set(key, mergedTree) 167 | 168 | // This handles all non-table keys. The key, as found in the override tree, is set on the base tree. If the base 169 | // tree already has a value for the key, it is only overridden if disableKeyOverrides is false. If a key is 170 | // being overridden with disableKeyOverrides=true, an error is returned. 171 | default: 172 | if base.Has(key) && disableKeyOverrides { 173 | return nil, cerrors.New(nil, "key is being overridden when key overrides are disabled", map[string]interface{}{ 174 | "key": key, 175 | }) 176 | } 177 | 178 | base.Set(key, keyVal) 179 | } 180 | } 181 | 182 | return base, nil 183 | } 184 | -------------------------------------------------------------------------------- /cerrors/error.go: -------------------------------------------------------------------------------- 1 | // Package cerrors provides a custom error type that can hold more context than the stdlib error package. 2 | package cerrors 3 | 4 | import ( 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | "sort" 9 | "strings" 10 | ) 11 | 12 | // Error can wrap an error with additional context such as structured tags. 13 | type Error struct { 14 | Message string 15 | Tags map[string]interface{} 16 | Cause error 17 | } 18 | 19 | // New creates an error by (optionally) wrapping an existing error and 20 | // annotating the error with structured tags. 21 | func New(cause error, msg string, tags map[string]interface{}) error { 22 | return Error{ 23 | Message: msg, 24 | Tags: tags, 25 | Cause: cause, 26 | } 27 | } 28 | 29 | // WithTags annotates an existing error with structured tags. 30 | func WithTags(err error, tags map[string]interface{}) error { 31 | cerr, ok := err.(Error) //nolint:errorlint 32 | if !ok { 33 | return Error{ 34 | Message: err.Error(), 35 | Tags: tags, 36 | Cause: errors.Unwrap(err), 37 | } 38 | } 39 | 40 | return Error{ 41 | Message: cerr.Message, 42 | Tags: tags, 43 | Cause: cerr.Cause, 44 | } 45 | } 46 | 47 | // Unwrap returns the underlying cause of an error (if any). 48 | func (e Error) Unwrap() error { 49 | return e.Cause 50 | } 51 | 52 | // Error returns a human-friendly string that contains the 53 | // entire error chain along with all of the tags on each 54 | // error. 55 | func (e Error) Error() string { 56 | var ( 57 | err strings.Builder 58 | tags []string 59 | ) 60 | 61 | err.WriteString(e.Message) 62 | 63 | for tag, val := range e.Tags { 64 | reflectVal := reflect.ValueOf(val) 65 | if reflectVal.Kind() == reflect.Ptr && !reflectVal.IsNil() { 66 | tags = append(tags, fmt.Sprintf("%s=%+v", tag, reflectVal.Elem())) 67 | } else { 68 | tags = append(tags, fmt.Sprintf("%s=%+v", tag, val)) 69 | } 70 | } 71 | 72 | sort.Strings(tags) 73 | 74 | if len(tags) > 0 { 75 | err.WriteString(" where ") 76 | err.WriteString(strings.Join(tags, ",")) 77 | } 78 | 79 | if e.Cause != nil { 80 | err.WriteString(" because\n") 81 | err.WriteString("> ") 82 | err.WriteString(e.Cause.Error()) 83 | } 84 | 85 | return err.String() 86 | } 87 | -------------------------------------------------------------------------------- /cerrors/error_test.go: -------------------------------------------------------------------------------- 1 | package cerrors_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/gocopper/copper/cerrors" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestNew(t *testing.T) { 12 | t.Parallel() 13 | 14 | err := cerrors.New(nil, "test-err", nil) 15 | 16 | cErr, ok := err.(cerrors.Error) //nolint:errorlint 17 | 18 | assert.True(t, ok) 19 | assert.Equal(t, "test-err", cErr.Message) 20 | assert.Nil(t, cErr.Cause) 21 | assert.Nil(t, cErr.Tags) 22 | } 23 | 24 | func TestWithTags_StdErr(t *testing.T) { 25 | t.Parallel() 26 | 27 | err := cerrors.WithTags(errors.New("test-err"), map[string]interface{}{ //nolint:goerr113 28 | "key": "val", 29 | }) 30 | 31 | cErr, ok := err.(cerrors.Error) //nolint:errorlint 32 | 33 | assert.True(t, ok) 34 | assert.Equal(t, "test-err", cErr.Message) 35 | assert.Contains(t, cErr.Tags, "key") 36 | assert.Equal(t, cErr.Tags["key"], "val") 37 | assert.Nil(t, cErr.Cause) 38 | } 39 | 40 | func TestWithTags_CErr(t *testing.T) { 41 | t.Parallel() 42 | 43 | err := cerrors.WithTags(cerrors.New(nil, "test-cerr", nil), map[string]interface{}{ 44 | "key": "val", 45 | }) 46 | 47 | cErr, ok := err.(cerrors.Error) //nolint:errorlint 48 | 49 | assert.True(t, ok) 50 | assert.Equal(t, "test-cerr", cErr.Message) 51 | assert.Contains(t, cErr.Tags, "key") 52 | assert.Equal(t, cErr.Tags["key"], "val") 53 | assert.Nil(t, cErr.Cause) 54 | } 55 | 56 | func TestError_Unwrap(t *testing.T) { 57 | t.Parallel() 58 | 59 | err := cerrors.New(errors.New("cause-err"), "test-err", nil) //nolint:goerr113 60 | cause := errors.Unwrap(err) 61 | 62 | assert.NotNil(t, cause) 63 | assert.EqualError(t, cause, "cause-err") 64 | } 65 | 66 | func TestError_Unwrap_NoCause(t *testing.T) { 67 | t.Parallel() 68 | 69 | err := cerrors.New(nil, "test-err", nil) 70 | 71 | assert.Nil(t, errors.Unwrap(err)) 72 | } 73 | 74 | func TestError_Error(t *testing.T) { 75 | t.Parallel() 76 | 77 | err := cerrors.New(nil, "test-err", nil) 78 | 79 | assert.NotNil(t, err) 80 | assert.Equal(t, "test-err", err.Error()) 81 | } 82 | 83 | func TestError_Error_Cause(t *testing.T) { 84 | t.Parallel() 85 | 86 | err := cerrors.New(errors.New("cause-err"), "test-err", nil) //nolint:goerr113 87 | 88 | expectedErr := `test-err because 89 | > cause-err` 90 | 91 | assert.NotNil(t, err) 92 | assert.Equal(t, expectedErr, err.Error()) 93 | } 94 | 95 | func TestError_Error_Tags(t *testing.T) { 96 | t.Parallel() 97 | 98 | err := cerrors.New(nil, "test-err", map[string]interface{}{ 99 | "key": "val", 100 | }) 101 | 102 | assert.NotNil(t, err) 103 | assert.Equal(t, "test-err where key=val", err.Error()) 104 | } 105 | 106 | func TestError_Error_PtrTags(t *testing.T) { 107 | t.Parallel() 108 | 109 | val := "val" 110 | err := cerrors.New(nil, "test-err", map[string]interface{}{ 111 | "key": &val, 112 | }) 113 | 114 | assert.NotNil(t, err) 115 | assert.Equal(t, "test-err where key=val", err.Error()) 116 | } 117 | -------------------------------------------------------------------------------- /cerrors/utils.go: -------------------------------------------------------------------------------- 1 | package cerrors 2 | 3 | import "errors" 4 | 5 | func Tags(err error) map[string]interface{} { 6 | var cerr Error 7 | 8 | if !errors.As(err, &cerr) { 9 | return nil 10 | } 11 | 12 | return mergeTags(cerr.Tags, Tags(cerr.Cause)) 13 | } 14 | 15 | func WithoutTags(err error) error { 16 | var cerr Error 17 | 18 | if !errors.As(err, &cerr) { 19 | return err 20 | } 21 | 22 | cerr.Tags = nil 23 | cerr.Cause = WithoutTags(cerr.Cause) 24 | 25 | return cerr 26 | } 27 | 28 | func mergeTags(t1, t2 map[string]interface{}) map[string]interface{} { 29 | merged := make(map[string]interface{}) 30 | 31 | for k, v := range t1 { 32 | merged[k] = v 33 | } 34 | 35 | for k, v := range t2 { 36 | merged[k] = v 37 | } 38 | 39 | return merged 40 | } 41 | -------------------------------------------------------------------------------- /cerrors/utils_test.go: -------------------------------------------------------------------------------- 1 | package cerrors_test 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gocopper/copper/cerrors" 6 | "github.com/stretchr/testify/assert" 7 | "testing" 8 | ) 9 | 10 | func TestWithoutTags(t *testing.T) { 11 | err := fmt.Errorf("test-error-0; %w", cerrors.New(cerrors.New(nil, "test-error-2", map[string]interface{}{ 12 | "tag": "val", 13 | }), "test-error-1", map[string]interface{}{ 14 | "tag": "val", 15 | })) 16 | 17 | out := cerrors.WithoutTags(err) 18 | 19 | assert.NotContains(t, out.Error(), "tag") 20 | } 21 | -------------------------------------------------------------------------------- /chttp/chttptest/content_type.go: -------------------------------------------------------------------------------- 1 | package chttptest 2 | 3 | // ContentTypeApplicationJSON defines the application/json content type that can be used when making HTTP requests. 4 | const ContentTypeApplicationJSON = "application/json" 5 | -------------------------------------------------------------------------------- /chttp/chttptest/doc.go: -------------------------------------------------------------------------------- 1 | // Package chttptest provides utility functions that are useful when testing chttp 2 | package chttptest 3 | -------------------------------------------------------------------------------- /chttp/chttptest/handler.go: -------------------------------------------------------------------------------- 1 | package chttptest 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/gocopper/copper/clogger" 10 | 11 | "github.com/gocopper/copper/chttp" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | // PingRoutes creates a handler using chttp.NewHandler, starts a test 16 | // http server, and calls each provided route. It verifies that each 17 | // route's handler is called successfully. 18 | func PingRoutes(t *testing.T, routes []chttp.Route) { 19 | t.Helper() 20 | 21 | for i := range routes { 22 | body := routes[i].Path 23 | routes[i].Handler = func(w http.ResponseWriter, r *http.Request) { 24 | _, err := w.Write([]byte(body)) 25 | assert.NoError(t, err) 26 | } 27 | } 28 | 29 | server := httptest.NewServer(chttp.NewHandler(chttp.NewHandlerParams{ 30 | Routers: []chttp.Router{NewRouter(routes)}, 31 | GlobalMiddlewares: nil, 32 | Logger: clogger.NewNoop(), 33 | })) 34 | defer server.Close() 35 | 36 | for _, route := range routes { 37 | resp, err := http.Get(server.URL + route.Path) //nolint:noctx 38 | assert.NoError(t, err) 39 | 40 | body, err := io.ReadAll(resp.Body) 41 | assert.NoError(t, resp.Body.Close()) 42 | assert.NoError(t, err) 43 | 44 | assert.Equal(t, route.Path, string(body)) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /chttp/chttptest/html_reader_writer.go: -------------------------------------------------------------------------------- 1 | package chttptest 2 | 3 | import ( 4 | "embed" 5 | "testing" 6 | 7 | "github.com/gocopper/copper/chttp" 8 | "github.com/gocopper/copper/clogger" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | // HTMLDir embeds a directory that can be used with chttp.ReaderWriter 13 | // 14 | //go:embed src 15 | var HTMLDir embed.FS 16 | 17 | // NewHTMLReaderWriter creates a *chttp.HTMLReaderWriter suitable for use in tests 18 | func NewHTMLReaderWriter(t *testing.T) *chttp.HTMLReaderWriter { 19 | t.Helper() 20 | 21 | r, err := chttp.NewHTMLRenderer(chttp.NewHTMLRendererParams{ 22 | HTMLDir: HTMLDir, 23 | StaticDir: nil, 24 | Config: chttp.Config{}, 25 | Logger: clogger.NewNoop(), 26 | }) 27 | assert.NoError(t, err) 28 | 29 | return chttp.NewHTMLReaderWriter(r, chttp.Config{}, clogger.NewNoop()) 30 | } 31 | -------------------------------------------------------------------------------- /chttp/chttptest/json_reader_writer.go: -------------------------------------------------------------------------------- 1 | package chttptest 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/gocopper/copper/chttp" 7 | "github.com/gocopper/copper/clogger" 8 | ) 9 | 10 | // NewJSONReaderWriter creates a *chttp.NewJSONReaderWriter suitable for use in tests 11 | func NewJSONReaderWriter(t *testing.T) *chttp.JSONReaderWriter { 12 | t.Helper() 13 | 14 | return chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop()) 15 | } 16 | -------------------------------------------------------------------------------- /chttp/chttptest/route.go: -------------------------------------------------------------------------------- 1 | package chttptest 2 | 3 | import "github.com/gocopper/copper/chttp" 4 | 5 | // ReverseRoutes reverses the provided slice of chttp.Route. 6 | func ReverseRoutes(routes []chttp.Route) []chttp.Route { 7 | for i := 0; i < len(routes)/2; i++ { 8 | j := len(routes) - i - 1 9 | routes[i], routes[j] = routes[j], routes[i] 10 | } 11 | 12 | return routes 13 | } 14 | 15 | // NewRouter returns a router that returns the given routes. 16 | func NewRouter(routes []chttp.Route) chttp.Router { 17 | return &router{routes: routes} 18 | } 19 | 20 | type router struct { 21 | routes []chttp.Route 22 | } 23 | 24 | func (ro *router) Routes() []chttp.Route { 25 | return ro.routes 26 | } 27 | -------------------------------------------------------------------------------- /chttp/chttptest/src/layouts/main.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Test Page 7 | 8 | 9 | {{ template "content" . }} 10 | 11 | 12 | -------------------------------------------------------------------------------- /chttp/chttptest/src/pages/index.html: -------------------------------------------------------------------------------- 1 | {{ define "content" }} 2 | Test Page 3 | {{ end }} 4 | -------------------------------------------------------------------------------- /chttp/chttptest/src/pages/not-found.html: -------------------------------------------------------------------------------- 1 | {{ define "content" }} 2 | The content you are looking for was not found. 3 | {{ end }} 4 | -------------------------------------------------------------------------------- /chttp/chttptest/src/partials/test-component.html: -------------------------------------------------------------------------------- 1 | {{ define "test-component" }} 2 | {{ end }} 3 | -------------------------------------------------------------------------------- /chttp/config.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "github.com/gocopper/copper/cconfig" 5 | "github.com/gocopper/copper/cerrors" 6 | ) 7 | 8 | // LoadConfig loads Config from app's config 9 | func LoadConfig(appConfig cconfig.Loader) (Config, error) { 10 | var config Config 11 | 12 | err := appConfig.Load("chttp", &config) 13 | if err != nil { 14 | return Config{}, cerrors.New(err, "failed to load chttp config", nil) 15 | } 16 | 17 | return config, nil 18 | } 19 | 20 | // Config holds the params needed to configure Server 21 | type Config struct { 22 | Port uint `default:"7501"` 23 | UseLocalHTML bool `toml:"use_local_html"` 24 | RenderHTMLError bool `toml:"render_html_error"` 25 | EnableSinglePageRouting bool `toml:"enable_single_page_routing"` 26 | ReadTimeoutSeconds uint `toml:"read_timeout_seconds" default:"10"` 27 | RedirectURLForUnauthorizedRequests *string `toml:"redirect_url_for_unauthorized_requests"` 28 | BasePath *string `toml:"base_path"` 29 | } 30 | -------------------------------------------------------------------------------- /chttp/doc.go: -------------------------------------------------------------------------------- 1 | // Package chttp helps setup a http server with routing, middlewares, and more. 2 | package chttp 3 | -------------------------------------------------------------------------------- /chttp/error.html: -------------------------------------------------------------------------------- 1 | 28 | 29 |
30 |

Failed to handle request

31 |
> {{ .Error }}
32 | 33 | 38 |
39 | -------------------------------------------------------------------------------- /chttp/fs.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "errors" 5 | "io/fs" 6 | ) 7 | 8 | // EmptyFS is a simple implementation of fs.FS interface that only returns an error. 9 | // This implementation emulates an empty directory. 10 | type EmptyFS struct{} 11 | 12 | // Open returns an error since this fs is empty. 13 | func (fs *EmptyFS) Open(string) (fs.File, error) { 14 | return nil, errors.New("empty fs") 15 | } 16 | -------------------------------------------------------------------------------- /chttp/handler.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "net/http" 5 | "path" 6 | "regexp" 7 | "sort" 8 | "strings" 9 | 10 | "github.com/gocopper/copper/clogger" 11 | 12 | "github.com/gorilla/mux" 13 | ) 14 | 15 | // NewHandlerParams holds the params needed for NewHandler. 16 | type NewHandlerParams struct { 17 | Routers []Router 18 | GlobalMiddlewares []Middleware 19 | Config Config 20 | Logger clogger.Logger 21 | } 22 | 23 | // NewHandler creates a http.Handler with the given routes and middlewares. 24 | // The handler can be used with a http.Server or as an argument to StartServer. 25 | func NewHandler(p NewHandlerParams) http.Handler { 26 | var ( 27 | muxRouter = mux.NewRouter() 28 | muxHandler = http.NewServeMux() 29 | ) 30 | 31 | routes := make([]Route, 0) 32 | for _, router := range p.Routers { 33 | routes = append(routes, router.Routes()...) 34 | } 35 | 36 | sortRoutes(routes) 37 | 38 | for _, route := range routes { 39 | routePath := route.Path 40 | 41 | // If a base path is set in the configuration ensure that the route path 42 | // starts with the base path. If it does not, skip this route. 43 | // Otherwise, remove the base path prefix from the route path to allow for 44 | // proper registration in the router. 45 | if p.Config.BasePath != nil { 46 | if route.RegisterWithBasePath { 47 | routePath = path.Join(*p.Config.BasePath, routePath) 48 | } 49 | 50 | if !strings.HasPrefix(routePath, *p.Config.BasePath) { 51 | continue 52 | } 53 | 54 | routePath = routePath[len(*p.Config.BasePath):] 55 | } 56 | 57 | handler := http.Handler(route.Handler) 58 | 59 | // Register route-level handlers 60 | // Since we are wrapping the handler in middleware functions, the outermost one will run first. 61 | // By applying the middlewares in reverse, we ensure that the first middleware in the list is the outermost one. 62 | for i := len(route.Middlewares) - 1; i >= 0; i-- { 63 | handler = route.Middlewares[i].Handle(handler) 64 | } 65 | 66 | // Register global middlewares 67 | for i := len(p.GlobalMiddlewares) - 1; i >= 0; i-- { 68 | handler = p.GlobalMiddlewares[i].Handle(handler) 69 | } 70 | 71 | handler = setRoutePathInCtxMiddleware(routePath).Handle(handler) 72 | handler = panicLoggerMiddleware(p.Logger).Handle(handler) 73 | 74 | muxRoute := muxRouter.Handle(routePath, handler) 75 | 76 | if len(route.Methods) > 0 { 77 | muxRoute.Methods(route.Methods...) 78 | } 79 | } 80 | 81 | muxHandler.Handle("/", muxRouter) 82 | 83 | return muxHandler 84 | } 85 | 86 | func sortRoutes(routes []Route) { 87 | const matcherPlaceholder = "{{matcher}}" 88 | 89 | re := regexp.MustCompile(`(?U)(\{.*\})`) 90 | 91 | sort.Slice(routes, func(i, j int) bool { 92 | aPath := re.ReplaceAllString(routes[i].Path, matcherPlaceholder) 93 | bPath := re.ReplaceAllString(routes[j].Path, matcherPlaceholder) 94 | 95 | aParts := strings.Split(aPath, "/") 96 | bParts := strings.Split(bPath, "/") 97 | 98 | if aPath == "/" { 99 | aParts = nil 100 | } 101 | 102 | if bPath == "/" { 103 | bParts = nil 104 | } 105 | 106 | if len(aParts) != len(bParts) { 107 | return len(aParts) > len(bParts) 108 | } 109 | 110 | for i, aPart := range aParts { 111 | bPart := bParts[i] 112 | 113 | if aPart == matcherPlaceholder { 114 | return false 115 | } 116 | 117 | if bPart == matcherPlaceholder { 118 | return true 119 | } 120 | } 121 | 122 | return false 123 | }) 124 | } 125 | -------------------------------------------------------------------------------- /chttp/handler_test.go: -------------------------------------------------------------------------------- 1 | package chttp_test 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/gocopper/copper/clogger" 10 | 11 | "github.com/gocopper/copper/chttp" 12 | "github.com/gocopper/copper/chttp/chttptest" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestNewHandler(t *testing.T) { 17 | t.Parallel() 18 | 19 | router := chttptest.NewRouter([]chttp.Route{ 20 | { 21 | Path: "/", 22 | Methods: []string{http.MethodGet}, 23 | Handler: func(w http.ResponseWriter, r *http.Request) { 24 | _, err := w.Write([]byte("success")) 25 | assert.NoError(t, err) 26 | }, 27 | }, 28 | }) 29 | 30 | server := httptest.NewServer(chttp.NewHandler(chttp.NewHandlerParams{ 31 | Routers: []chttp.Router{router}, 32 | GlobalMiddlewares: nil, 33 | Logger: clogger.NewNoop(), 34 | })) 35 | defer server.Close() 36 | 37 | resp, err := http.Get(server.URL) //nolint:noctx 38 | assert.NoError(t, err) 39 | 40 | body, err := io.ReadAll(resp.Body) 41 | assert.NoError(t, resp.Body.Close()) 42 | assert.NoError(t, err) 43 | 44 | assert.Equal(t, "success", string(body)) 45 | } 46 | 47 | func TestNewHandler_GlobalMiddleware(t *testing.T) { 48 | t.Parallel() 49 | 50 | didCallGlobalMiddleware := false 51 | 52 | router := chttptest.NewRouter([]chttp.Route{ 53 | { 54 | Path: "/", 55 | Handler: func(w http.ResponseWriter, r *http.Request) {}, 56 | }, 57 | }) 58 | 59 | server := httptest.NewServer(chttp.NewHandler(chttp.NewHandlerParams{ 60 | Routers: []chttp.Router{router}, 61 | GlobalMiddlewares: []chttp.Middleware{ 62 | chttp.HandleMiddleware(func(next http.Handler) http.Handler { 63 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 64 | didCallGlobalMiddleware = true 65 | next.ServeHTTP(w, r) 66 | }) 67 | }), 68 | }, 69 | Logger: clogger.NewNoop(), 70 | })) 71 | defer server.Close() 72 | 73 | resp, err := http.Get(server.URL) //nolint:noctx 74 | assert.NoError(t, err) 75 | assert.NoError(t, resp.Body.Close()) 76 | 77 | assert.True(t, didCallGlobalMiddleware) 78 | } 79 | 80 | func TestNewHandler_RouteMiddleware(t *testing.T) { 81 | t.Parallel() 82 | 83 | middlewareCalls := make([]string, 0) 84 | 85 | router := chttptest.NewRouter([]chttp.Route{ 86 | { 87 | Path: "/", 88 | Middlewares: []chttp.Middleware{ 89 | chttp.HandleMiddleware(func(next http.Handler) http.Handler { 90 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 91 | middlewareCalls = append(middlewareCalls, "1") 92 | next.ServeHTTP(w, r) 93 | }) 94 | }), 95 | 96 | chttp.HandleMiddleware(func(next http.Handler) http.Handler { 97 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 98 | middlewareCalls = append(middlewareCalls, "2") 99 | next.ServeHTTP(w, r) 100 | }) 101 | }), 102 | }, 103 | Handler: func(w http.ResponseWriter, r *http.Request) {}, 104 | }, 105 | }) 106 | 107 | server := httptest.NewServer(chttp.NewHandler(chttp.NewHandlerParams{ 108 | Routers: []chttp.Router{router}, 109 | GlobalMiddlewares: nil, 110 | Logger: clogger.NewNoop(), 111 | })) 112 | defer server.Close() 113 | 114 | resp, err := http.Get(server.URL) //nolint:noctx 115 | assert.NoError(t, err) 116 | assert.NoError(t, resp.Body.Close()) 117 | 118 | assert.Equal(t, []string{"1", "2"}, middlewareCalls) 119 | } 120 | 121 | func TestNewHandler_RoutePriority_WithPlaceholder(t *testing.T) { 122 | t.Parallel() 123 | 124 | routes := []chttp.Route{ 125 | {Path: "/foo"}, 126 | {Path: "/{id}"}, 127 | } 128 | 129 | chttptest.PingRoutes(t, routes) 130 | chttptest.PingRoutes(t, chttptest.ReverseRoutes(routes)) 131 | } 132 | 133 | func TestNewHandler_RoutePriority_WithIndex(t *testing.T) { 134 | t.Parallel() 135 | 136 | routes := []chttp.Route{ 137 | {Path: "/foo"}, 138 | {Path: "/"}, 139 | } 140 | 141 | chttptest.PingRoutes(t, routes) 142 | chttptest.PingRoutes(t, chttptest.ReverseRoutes(routes)) 143 | } 144 | 145 | func TestNewHandler_RoutePriority_Equal(t *testing.T) { 146 | t.Parallel() 147 | 148 | routes := []chttp.Route{ 149 | {Path: "/foo"}, 150 | {Path: "/bar"}, 151 | } 152 | 153 | chttptest.PingRoutes(t, routes) 154 | chttptest.PingRoutes(t, chttptest.ReverseRoutes(routes)) 155 | } 156 | -------------------------------------------------------------------------------- /chttp/html_reader_writer.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | // Used to embed error.html 5 | _ "embed" 6 | "html/template" 7 | "net/http" 8 | 9 | "github.com/gorilla/mux" 10 | 11 | "github.com/gocopper/copper/cerrors" 12 | "github.com/gocopper/copper/clogger" 13 | ) 14 | 15 | type ( 16 | // WriteHTMLParams holds the params for the WriteHTML function in ReaderWriter 17 | WriteHTMLParams struct { 18 | StatusCode int 19 | Error error 20 | Data interface{} 21 | PageTemplate string 22 | LayoutTemplate string 23 | } 24 | 25 | // WritePartialParams are the parameters for WritePartial 26 | WritePartialParams struct { 27 | Name string 28 | Data interface{} 29 | } 30 | 31 | // HTMLReaderWriter provides functions to read data from HTTP requests and write HTML response bodies 32 | HTMLReaderWriter struct { 33 | html *HTMLRenderer 34 | config Config 35 | logger clogger.Logger 36 | } 37 | ) 38 | 39 | // URLParams returns the route variables for the current request, if any 40 | var URLParams = mux.Vars 41 | 42 | //go:embed error.html 43 | var errorHTML string 44 | 45 | // NewHTMLReaderWriter instantiates a new HTMLReaderWriter with its dependencies 46 | func NewHTMLReaderWriter(html *HTMLRenderer, config Config, logger clogger.Logger) *HTMLReaderWriter { 47 | return &HTMLReaderWriter{ 48 | html: html, 49 | config: config, 50 | logger: logger, 51 | } 52 | } 53 | 54 | // WriteHTMLError handles the given error. In render_error is configured to true, it writes an HTML page with the error. 55 | // Errors are always logged. 56 | func (rw *HTMLReaderWriter) WriteHTMLError(w http.ResponseWriter, r *http.Request, err error) { 57 | rw.WriteHTML(w, r, WriteHTMLParams{ 58 | Error: err, 59 | }) 60 | } 61 | 62 | // WriteHTML writes an HTML response to the provided http.ResponseWriter. Using the given WriteHTMLParams, the HTML 63 | // is generated with a layout, page, and component templates. 64 | func (rw *HTMLReaderWriter) WriteHTML(w http.ResponseWriter, r *http.Request, p WriteHTMLParams) { 65 | if p.StatusCode == 0 && p.Error == nil { 66 | p.StatusCode = http.StatusOK 67 | } 68 | 69 | if p.StatusCode == 0 && p.Error != nil { 70 | p.StatusCode = http.StatusInternalServerError 71 | } 72 | 73 | if p.LayoutTemplate == "" { 74 | p.LayoutTemplate = "main.html" 75 | } 76 | 77 | if p.PageTemplate == "" && p.StatusCode == http.StatusInternalServerError { 78 | p.PageTemplate = "internal-error.html" 79 | } 80 | 81 | if p.PageTemplate == "" && p.StatusCode == http.StatusNotFound { 82 | p.PageTemplate = "not-found.html" 83 | } 84 | 85 | if p.Error != nil { 86 | rw.logger.WithTags(map[string]interface{}{ 87 | "url": r.URL.String(), 88 | }).Error("Failed to handle request", p.Error) 89 | } 90 | 91 | if p.Error != nil && rw.config.RenderHTMLError { 92 | w.Header().Set("content-type", "text/html") 93 | w.WriteHeader(p.StatusCode) 94 | 95 | errorHTMLTmpl := template.Must(template.New("chtml/error.html").Parse(errorHTML)) 96 | 97 | _ = errorHTMLTmpl.Execute(w, map[string]interface{}{ 98 | "Error": p.Error.Error(), 99 | }) 100 | 101 | return 102 | } 103 | 104 | out, err := rw.html.render(r, p.LayoutTemplate, p.PageTemplate, p.Data) 105 | if err != nil { 106 | rw.logger.Error("Failed to render html template", cerrors.WithTags(err, map[string]interface{}{ 107 | "layout": p.LayoutTemplate, 108 | "page": p.PageTemplate, 109 | })) 110 | w.WriteHeader(http.StatusInternalServerError) 111 | return 112 | } 113 | 114 | w.Header().Set("content-type", "text/html") 115 | w.WriteHeader(p.StatusCode) 116 | _, _ = w.Write([]byte(out)) 117 | } 118 | 119 | // WritePartial renders a partial template with the given name and data 120 | func (rw *HTMLReaderWriter) WritePartial(w http.ResponseWriter, r *http.Request, p WritePartialParams) { 121 | out, err := rw.html.partial(r)(p.Name, p.Data) 122 | if err != nil { 123 | rw.WriteHTMLError(w, r, cerrors.New(err, "failed to render partial", map[string]interface{}{ 124 | "name": p.Name, 125 | })) 126 | return 127 | } 128 | 129 | w.Header().Set("content-type", "text/html") 130 | _, _ = w.Write([]byte(out)) 131 | } 132 | 133 | // Unauthorized writes a 401 Unauthorized response to the http.ResponseWriter. If a redirect URL is configured, 134 | // the user is redirected to that URL instead. 135 | func (rw *HTMLReaderWriter) Unauthorized(w http.ResponseWriter, r *http.Request) { 136 | if rw.config.RedirectURLForUnauthorizedRequests != nil { 137 | http.Redirect(w, r, *rw.config.RedirectURLForUnauthorizedRequests, http.StatusSeeOther) 138 | return 139 | } 140 | 141 | w.WriteHeader(http.StatusUnauthorized) 142 | } 143 | -------------------------------------------------------------------------------- /chttp/html_reader_writer_test.go: -------------------------------------------------------------------------------- 1 | package chttp_test 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/gocopper/copper/chttp" 10 | "github.com/gocopper/copper/chttp/chttptest" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestHTMLReaderWriter_WriteHTML(t *testing.T) { 15 | t.Parallel() 16 | 17 | rw := chttptest.NewHTMLReaderWriter(t) 18 | resp := httptest.NewRecorder() 19 | req := httptest.NewRequest(http.MethodGet, "/", nil) 20 | 21 | rw.WriteHTML(resp, req, chttp.WriteHTMLParams{ 22 | StatusCode: http.StatusOK, 23 | Data: map[string]string{"user": "test"}, 24 | PageTemplate: "index.html", 25 | }) 26 | 27 | assert.Equal(t, http.StatusOK, resp.Code) 28 | assert.Equal(t, "text/html", resp.Header().Get("content-type")) 29 | assert.Contains(t, resp.Body.String(), `Test Page`) 30 | } 31 | 32 | func TestHTMLReaderWriter_WriteHTML_NotFound(t *testing.T) { 33 | t.Parallel() 34 | 35 | rw := chttptest.NewHTMLReaderWriter(t) 36 | resp := httptest.NewRecorder() 37 | req := httptest.NewRequest(http.MethodGet, "/", nil) 38 | 39 | rw.WriteHTML(resp, req, chttp.WriteHTMLParams{ 40 | StatusCode: http.StatusNotFound, 41 | Data: map[string]string{"user": "test"}, 42 | }) 43 | 44 | assert.Equal(t, http.StatusNotFound, resp.Code) 45 | assert.Equal(t, "text/html", resp.Header().Get("content-type")) 46 | assert.Contains(t, resp.Body.String(), `not found`) 47 | } 48 | 49 | func TestHTMLReaderWriter_WriteHTML_WithError(t *testing.T) { 50 | t.Parallel() 51 | 52 | rw := chttptest.NewHTMLReaderWriter(t) 53 | resp := httptest.NewRecorder() 54 | req := httptest.NewRequest(http.MethodGet, "/", nil) 55 | 56 | rw.WriteHTML(resp, req, chttp.WriteHTMLParams{ 57 | Error: errors.New("test error"), 58 | Data: map[string]string{"user": "test"}, 59 | }) 60 | 61 | assert.Equal(t, http.StatusInternalServerError, resp.Code) 62 | assert.Equal(t, "text/html", resp.Header().Get("content-type")) 63 | // Since we don't know exactly what the error page looks like, we're not checking content 64 | } 65 | 66 | func TestHTMLReaderWriter_WriteHTMLError(t *testing.T) { 67 | t.Parallel() 68 | 69 | rw := chttptest.NewHTMLReaderWriter(t) 70 | resp := httptest.NewRecorder() 71 | req := httptest.NewRequest(http.MethodGet, "/", nil) 72 | testErr := errors.New("test error") 73 | 74 | rw.WriteHTMLError(resp, req, testErr) 75 | 76 | assert.Equal(t, http.StatusInternalServerError, resp.Code) 77 | assert.Equal(t, "text/html", resp.Header().Get("content-type")) 78 | // Since the error rendering depends on config, we're not checking the exact content 79 | } 80 | 81 | func TestHTMLReaderWriter_WriteHTML_CustomLayout(t *testing.T) { 82 | t.Parallel() 83 | 84 | rw := chttptest.NewHTMLReaderWriter(t) 85 | resp := httptest.NewRecorder() 86 | req := httptest.NewRequest(http.MethodGet, "/", nil) 87 | 88 | rw.WriteHTML(resp, req, chttp.WriteHTMLParams{ 89 | StatusCode: http.StatusOK, 90 | Data: map[string]string{"user": "test"}, 91 | PageTemplate: "index.html", 92 | LayoutTemplate: "main.html", // explicitly set the default layout 93 | }) 94 | 95 | assert.Equal(t, http.StatusOK, resp.Code) 96 | assert.Equal(t, "text/html", resp.Header().Get("content-type")) 97 | assert.Contains(t, resp.Body.String(), `Test Page`) 98 | } 99 | 100 | func TestHTMLReaderWriter_WritePartial(t *testing.T) { 101 | t.Parallel() 102 | 103 | rw := chttptest.NewHTMLReaderWriter(t) 104 | resp := httptest.NewRecorder() 105 | req := httptest.NewRequest(http.MethodGet, "/", nil) 106 | 107 | rw.WritePartial(resp, req, chttp.WritePartialParams{ 108 | Name: "partial.html", 109 | Data: map[string]string{"content": "partial content"}, 110 | }) 111 | 112 | assert.Equal(t, "text/html", resp.Header().Get("content-type")) 113 | // Test would be more specific if we knew what the partial template contained 114 | } 115 | 116 | func TestHTMLReaderWriter_Unauthorized_WithoutRedirect(t *testing.T) { 117 | t.Parallel() 118 | 119 | // Create a reader/writer with default config (no redirect URL) 120 | rw := chttptest.NewHTMLReaderWriter(t) 121 | resp := httptest.NewRecorder() 122 | req := httptest.NewRequest(http.MethodGet, "/protected", nil) 123 | 124 | rw.Unauthorized(resp, req) 125 | 126 | assert.Equal(t, http.StatusUnauthorized, resp.Code) 127 | } 128 | 129 | func TestHTMLReaderWriter_Unauthorized_WithRedirect(t *testing.T) { 130 | t.Parallel() 131 | 132 | // This test would require mocking a config with RedirectURLForUnauthorizedRequests 133 | // Since we can't easily modify the config in the test helper, this is a demonstration 134 | // of how the test would look 135 | 136 | // Assuming we had a way to create a HTMLReaderWriter with a redirect URL in the config: 137 | // redirectURL := "/login" 138 | // rw := createReaderWriterWithRedirectConfig(t, &redirectURL) 139 | 140 | // resp := httptest.NewRecorder() 141 | // req := httptest.NewRequest(http.MethodGet, "/protected", nil) 142 | 143 | // rw.Unauthorized(resp, req) 144 | 145 | // assert.Equal(t, http.StatusSeeOther, resp.Code) 146 | // assert.Equal(t, "/login", resp.Header().Get("Location")) 147 | } 148 | -------------------------------------------------------------------------------- /chttp/html_renderer.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "html/template" 5 | "io/fs" 6 | "net/http" 7 | "os" 8 | "path" 9 | "path/filepath" 10 | "strings" 11 | 12 | "github.com/gocopper/copper/clogger" 13 | 14 | "github.com/Masterminds/sprig/v3" 15 | "github.com/gocopper/copper/cerrors" 16 | ) 17 | 18 | type ( 19 | // HTMLDir is a directory that can be embedded or found on the host system. It should contain sub-directories 20 | // and files to support the WriteHTML function in ReaderWriter. 21 | HTMLDir fs.FS 22 | 23 | // StaticDir represents a directory that holds static resources (JS, CSS, images, etc.) 24 | StaticDir fs.FS 25 | 26 | // HTMLRenderer provides functionality in rendering templatized HTML along with HTML components 27 | HTMLRenderer struct { 28 | htmlDir HTMLDir 29 | staticDir StaticDir 30 | renderFuncs []HTMLRenderFunc 31 | } 32 | 33 | // HTMLRenderFunc can be used to register new template functions 34 | HTMLRenderFunc struct { 35 | // Name for the function that can be invoked in a template 36 | Name string 37 | 38 | // Func should return a function that takes in any number of params and returns either a single return value, 39 | // or two return values of which the second has type error. 40 | Func func(r *http.Request) interface{} 41 | } 42 | 43 | // NewHTMLRendererParams holds the params needed to create HTMLRenderer 44 | NewHTMLRendererParams struct { 45 | HTMLDir HTMLDir 46 | StaticDir StaticDir 47 | RenderFuncs []HTMLRenderFunc 48 | Config Config 49 | Logger clogger.Logger 50 | } 51 | ) 52 | 53 | // NewHTMLRenderer creates a new HTMLRenderer with HTML templates stored in dir and registers the provided HTML 54 | // components 55 | func NewHTMLRenderer(p NewHTMLRendererParams) (*HTMLRenderer, error) { 56 | hr := HTMLRenderer{ 57 | htmlDir: p.HTMLDir, 58 | staticDir: p.StaticDir, 59 | renderFuncs: p.RenderFuncs, 60 | } 61 | 62 | if p.Config.UseLocalHTML { 63 | wd, err := os.Getwd() 64 | if err != nil { 65 | return nil, cerrors.New(err, "failed to get current working directory", nil) 66 | } 67 | 68 | hr.htmlDir = os.DirFS(filepath.Join(wd, "web")) 69 | } 70 | 71 | return &hr, nil 72 | } 73 | 74 | func (r *HTMLRenderer) funcMap(req *http.Request) template.FuncMap { 75 | var funcMap = sprig.FuncMap() 76 | 77 | for i := range r.renderFuncs { 78 | funcMap[r.renderFuncs[i].Name] = r.renderFuncs[i].Func(req) 79 | } 80 | 81 | funcMap["partial"] = r.partial(req) 82 | 83 | return funcMap 84 | } 85 | 86 | func (r *HTMLRenderer) render(req *http.Request, layout, page string, data interface{}) (template.HTML, error) { 87 | var dest strings.Builder 88 | 89 | tmpl, err := template.New(layout). 90 | Funcs(r.funcMap(req)). 91 | ParseFS(r.htmlDir, 92 | path.Join("src", "layouts", layout), 93 | path.Join("src", "pages", page), 94 | ) 95 | if err != nil { 96 | return "", cerrors.New(err, "failed to parse templates in html dir", map[string]interface{}{ 97 | "layout": layout, 98 | "page": page, 99 | }) 100 | } 101 | 102 | err = tmpl.Execute(&dest, data) 103 | if err != nil { 104 | return "", cerrors.New(err, "failed to execute template", nil) 105 | } 106 | 107 | //nolint:gosec 108 | return template.HTML(dest.String()), nil 109 | } 110 | 111 | func (r *HTMLRenderer) partial(req *http.Request) func(name string, data interface{}) (template.HTML, error) { 112 | return func(name string, data interface{}) (template.HTML, error) { 113 | var dest strings.Builder 114 | 115 | tmpl, err := template.New(name). 116 | Funcs(r.funcMap(req)). 117 | ParseFS(r.htmlDir, 118 | path.Join("src", "partials", "*html"), 119 | ) 120 | if err != nil { 121 | return "", cerrors.New(err, "failed to parse partial template", map[string]interface{}{ 122 | "name": name, 123 | }) 124 | } 125 | 126 | for _, ext := range []string{".html", ".gohtml"} { 127 | t := tmpl.Lookup(name + ext) 128 | if t != nil { 129 | tmpl = t 130 | break 131 | } 132 | } 133 | 134 | err = tmpl.Execute(&dest, data) 135 | if err != nil { 136 | return "", cerrors.New(err, "failed to execute partial template", map[string]interface{}{ 137 | "name": name, 138 | }) 139 | } 140 | 141 | //nolint:gosec 142 | return template.HTML(dest.String()), nil 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /chttp/html_router.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "net/http" 5 | "path" 6 | "strings" 7 | ) 8 | 9 | type ( 10 | // HTMLRouter provides routes to serve (1) static assets (2) index page for an SPA 11 | HTMLRouter struct { 12 | rw *HTMLReaderWriter 13 | staticDir StaticDir 14 | config Config 15 | } 16 | 17 | // NewHTMLRouterParams holds the params needed to instantiate a new Router 18 | NewHTMLRouterParams struct { 19 | StaticDir StaticDir 20 | RW *HTMLReaderWriter 21 | Config Config 22 | } 23 | ) 24 | 25 | // NewHTMLRouter instantiates a new Router 26 | func NewHTMLRouter(p NewHTMLRouterParams) (*HTMLRouter, error) { 27 | return &HTMLRouter{ 28 | rw: p.RW, 29 | staticDir: p.StaticDir, 30 | config: p.Config, 31 | }, nil 32 | } 33 | 34 | // Routes defines the HTTP routes for this router 35 | func (ro *HTMLRouter) Routes() []Route { 36 | routes := []Route{ 37 | { 38 | Path: "/static/{path:.*}", 39 | Methods: []string{http.MethodGet}, 40 | Handler: ro.HandleStaticFile, 41 | RegisterWithBasePath: true, 42 | }, 43 | } 44 | 45 | if ro.config.EnableSinglePageRouting { 46 | routes = append(routes, Route{ 47 | Path: "/{path:.*}", 48 | Methods: []string{http.MethodGet}, 49 | Handler: ro.HandleIndexPage, 50 | RegisterWithBasePath: true, 51 | }) 52 | } 53 | 54 | return routes 55 | } 56 | 57 | // HandleStaticFile serves the requested static file as found in the web/public directory. In non-dev env, the static 58 | // files are embedded in the binary. 59 | func (ro *HTMLRouter) HandleStaticFile(w http.ResponseWriter, r *http.Request) { 60 | // Disable directory listing 61 | if strings.HasSuffix(r.URL.Path, "/") { 62 | http.NotFound(w, r) 63 | return 64 | } 65 | 66 | if ro.config.UseLocalHTML { 67 | http.ServeFile(w, r, path.Join("web", "public", URLParams(r)["path"])) 68 | return 69 | } 70 | 71 | http.FileServer(http.FS(ro.staticDir)).ServeHTTP(w, r) 72 | } 73 | 74 | // HandleIndexPage renders the index.html page 75 | func (ro *HTMLRouter) HandleIndexPage(w http.ResponseWriter, r *http.Request) { 76 | ro.rw.WriteHTML(w, r, WriteHTMLParams{ 77 | PageTemplate: "index.html", 78 | }) 79 | } 80 | -------------------------------------------------------------------------------- /chttp/http.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // Route represents a single HTTP route (ex. /api/profile) that can be configured with middlewares, path, 8 | // HTTP methods, and a handler. 9 | type Route struct { 10 | Middlewares []Middleware 11 | Path string 12 | Methods []string 13 | Handler http.HandlerFunc 14 | 15 | // RegisterWithBasePath ensures the route is registered with the base path, 16 | // even if its original path does not include the base path prefix 17 | RegisterWithBasePath bool 18 | } 19 | 20 | // Router is used to group routes together that are returned by the Routes method. 21 | type Router interface { 22 | Routes() []Route 23 | } 24 | -------------------------------------------------------------------------------- /chttp/json_reader_writer.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "io" 7 | "net/http" 8 | 9 | "github.com/asaskevich/govalidator" 10 | "github.com/gocopper/copper/cerrors" 11 | "github.com/gocopper/copper/clogger" 12 | ) 13 | 14 | type ( 15 | // WriteJSONParams holds the params for the WriteJSON function in JSONReaderWriter 16 | WriteJSONParams struct { 17 | StatusCode int 18 | Data interface{} 19 | } 20 | 21 | // JSONReaderWriter provides functions to read and write JSON data from/to HTTP requests/responses 22 | JSONReaderWriter struct { 23 | config Config 24 | logger clogger.Logger 25 | } 26 | ) 27 | 28 | // NewJSONReaderWriter instantiates a new JSONReaderWriter with its dependencies 29 | func NewJSONReaderWriter(config Config, logger clogger.Logger) *JSONReaderWriter { 30 | return &JSONReaderWriter{ 31 | config: config, 32 | logger: logger, 33 | } 34 | } 35 | 36 | // WriteJSON writes a JSON response to the http.ResponseWriter. It can be configured with status code and data using 37 | // WriteJSONParams. 38 | func (rw *JSONReaderWriter) WriteJSON(w http.ResponseWriter, p WriteJSONParams) { 39 | w.Header().Set("Content-Type", "application/json") 40 | 41 | if p.StatusCode > 0 { 42 | w.WriteHeader(p.StatusCode) 43 | } 44 | 45 | if p.Data == nil { 46 | return 47 | } 48 | 49 | errData, ok := p.Data.(error) 50 | if ok { 51 | err := json.NewEncoder(w).Encode(map[string]string{ 52 | "error": errData.Error(), 53 | }) 54 | if err != nil { 55 | rw.logger.Error("Failed to marshal error response as json", err) 56 | w.WriteHeader(http.StatusInternalServerError) 57 | } 58 | 59 | return 60 | } 61 | 62 | err := json.NewEncoder(w).Encode(p.Data) 63 | if err != nil { 64 | rw.logger.Error("Failed to marshal response as json", err) 65 | w.WriteHeader(http.StatusInternalServerError) 66 | 67 | return 68 | } 69 | } 70 | 71 | // ReadJSON reads JSON from the http.Request into the body var. If the body struct has validate tags on it, the 72 | // struct is also validated. If the validation fails, a BadRequest response is sent back and the function returns 73 | // false. 74 | func (rw *JSONReaderWriter) ReadJSON(w http.ResponseWriter, req *http.Request, body interface{}) bool { 75 | url := req.URL.String() 76 | 77 | err := json.NewDecoder(req.Body).Decode(body) 78 | if err != nil && errors.Is(err, io.EOF) { 79 | rw.WriteJSON(w, WriteJSONParams{ 80 | StatusCode: http.StatusBadRequest, 81 | Data: map[string]string{"error": "empty body"}, 82 | }) 83 | 84 | return false 85 | } else if err != nil { 86 | rw.WriteJSON(w, WriteJSONParams{ 87 | StatusCode: http.StatusBadRequest, 88 | Data: cerrors.New(err, "invalid body json", map[string]interface{}{ 89 | "url": url, 90 | }), 91 | }) 92 | 93 | return false 94 | } 95 | 96 | ok, err := govalidator.ValidateStruct(body) 97 | if !ok { 98 | rw.logger.Warn("Failed to read body", cerrors.New(err, "data validation failed", map[string]interface{}{ 99 | "url": url, 100 | })) 101 | 102 | rw.WriteJSON(w, WriteJSONParams{ 103 | StatusCode: http.StatusBadRequest, 104 | Data: err, 105 | }) 106 | 107 | return false 108 | } 109 | 110 | return true 111 | } 112 | 113 | // Unauthorized writes a 401 Unauthorized JSON response 114 | func (rw *JSONReaderWriter) Unauthorized(w http.ResponseWriter) { 115 | rw.WriteJSON(w, WriteJSONParams{ 116 | StatusCode: http.StatusUnauthorized, 117 | Data: map[string]string{"error": "unauthorized"}, 118 | }) 119 | } 120 | -------------------------------------------------------------------------------- /chttp/json_reader_writer_test.go: -------------------------------------------------------------------------------- 1 | package chttp_test 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/gocopper/copper/chttp" 11 | "github.com/gocopper/copper/clogger" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestJSONReaderWriter_ReadJSON(t *testing.T) { 16 | t.Parallel() 17 | 18 | var body struct { 19 | Key string `json:"key"` 20 | } 21 | 22 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop()) 23 | 24 | ok := rw.ReadJSON( 25 | httptest.NewRecorder(), 26 | httptest.NewRequest(http.MethodGet, "/", bytes.NewReader([]byte(`{"key": "value"}`))), 27 | &body, 28 | ) 29 | 30 | assert.True(t, ok) 31 | assert.Equal(t, "value", body.Key) 32 | } 33 | 34 | func TestJSONReaderWriter_ReadJSON_Invalid_Body(t *testing.T) { 35 | t.Parallel() 36 | 37 | var body struct { 38 | Key string `json:"key"` 39 | } 40 | 41 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop()) 42 | resp := httptest.NewRecorder() 43 | 44 | ok := rw.ReadJSON( 45 | resp, 46 | httptest.NewRequest(http.MethodGet, "/", bytes.NewReader([]byte(`{ invalid json }`))), 47 | &body, 48 | ) 49 | 50 | assert.False(t, ok) 51 | assert.Equal(t, http.StatusBadRequest, resp.Code) 52 | assert.Equal(t, "application/json", resp.Header().Get("content-type")) 53 | assert.Contains(t, resp.Body.String(), "invalid body json") 54 | } 55 | 56 | func TestJSONReaderWriter_ReadJSON_Empty_Body(t *testing.T) { 57 | t.Parallel() 58 | 59 | var body struct { 60 | Key string `json:"key"` 61 | } 62 | 63 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop()) 64 | resp := httptest.NewRecorder() 65 | 66 | ok := rw.ReadJSON( 67 | resp, 68 | httptest.NewRequest(http.MethodGet, "/", bytes.NewReader([]byte(``))), 69 | &body, 70 | ) 71 | 72 | assert.False(t, ok) 73 | assert.Equal(t, http.StatusBadRequest, resp.Code) 74 | assert.Equal(t, "application/json", resp.Header().Get("content-type")) 75 | assert.Contains(t, resp.Body.String(), "empty body") 76 | } 77 | 78 | func TestJSONReaderWriter_ReadJSON_Validator(t *testing.T) { 79 | t.Parallel() 80 | 81 | var body struct { 82 | Key string `json:"key" valid:"email"` 83 | } 84 | 85 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop()) 86 | resp := httptest.NewRecorder() 87 | 88 | ok := rw.ReadJSON( 89 | resp, 90 | httptest.NewRequest(http.MethodGet, "/", bytes.NewReader([]byte(`{"key": "value"}`))), 91 | &body, 92 | ) 93 | 94 | assert.False(t, ok) 95 | assert.Equal(t, http.StatusBadRequest, resp.Code) 96 | } 97 | 98 | func TestJSONReaderWriter_WriteJSON_Data(t *testing.T) { 99 | t.Parallel() 100 | 101 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop()) 102 | resp := httptest.NewRecorder() 103 | 104 | rw.WriteJSON(resp, chttp.WriteJSONParams{ 105 | Data: map[string]string{ 106 | "key": "val", 107 | }, 108 | }) 109 | 110 | assert.Equal(t, http.StatusOK, resp.Code) 111 | assert.Equal(t, "application/json", resp.Header().Get("content-type")) 112 | assert.Contains(t, resp.Body.String(), `{"key":"val"}`) 113 | } 114 | 115 | func TestJSONReaderWriter_WriteJSON_StatusCode(t *testing.T) { 116 | t.Parallel() 117 | 118 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop()) 119 | resp := httptest.NewRecorder() 120 | 121 | rw.WriteJSON(resp, chttp.WriteJSONParams{ 122 | StatusCode: http.StatusCreated, 123 | Data: map[string]string{ 124 | "key": "val", 125 | }, 126 | }) 127 | 128 | assert.Equal(t, http.StatusCreated, resp.Code) 129 | assert.Equal(t, "application/json", resp.Header().Get("content-type")) 130 | assert.Contains(t, resp.Body.String(), `{"key":"val"}`) 131 | } 132 | 133 | func TestJSONReaderWriter_WriteJSON_Error(t *testing.T) { 134 | t.Parallel() 135 | 136 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop()) 137 | resp := httptest.NewRecorder() 138 | 139 | rw.WriteJSON(resp, chttp.WriteJSONParams{ 140 | StatusCode: http.StatusBadRequest, 141 | Data: errors.New("test-err"), //nolint:goerr113 142 | }) 143 | 144 | assert.Equal(t, http.StatusBadRequest, resp.Code) 145 | assert.Equal(t, "application/json", resp.Header().Get("content-type")) 146 | assert.Contains(t, resp.Body.String(), `{"error":"test-err"}`) 147 | } 148 | 149 | func TestJSONReaderWriter_WriteJSON_NilData(t *testing.T) { 150 | t.Parallel() 151 | 152 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop()) 153 | resp := httptest.NewRecorder() 154 | 155 | rw.WriteJSON(resp, chttp.WriteJSONParams{ 156 | StatusCode: http.StatusOK, 157 | Data: nil, 158 | }) 159 | 160 | assert.Equal(t, http.StatusOK, resp.Code) 161 | assert.Equal(t, "application/json", resp.Header().Get("content-type")) 162 | assert.Empty(t, resp.Body.String()) 163 | } 164 | 165 | func TestJSONReaderWriter_Unauthorized(t *testing.T) { 166 | t.Parallel() 167 | 168 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop()) 169 | resp := httptest.NewRecorder() 170 | 171 | rw.Unauthorized(resp) 172 | 173 | assert.Equal(t, http.StatusUnauthorized, resp.Code) 174 | assert.Equal(t, "application/json", resp.Header().Get("content-type")) 175 | assert.Contains(t, resp.Body.String(), `{"error":"unauthorized"}`) 176 | } 177 | -------------------------------------------------------------------------------- /chttp/middleware.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import "net/http" 4 | 5 | // Middleware defines the interface with a Handle func which is of the MiddlewareFunc type. Implementations of 6 | // Middleware can be used with NewHandler for global middlewares or Route for route-specific middlewares. 7 | type Middleware interface { 8 | Handle(next http.Handler) http.Handler 9 | } 10 | 11 | // MiddlewareFunc is a function that takes in a http.Handler and returns one as well. It allows you to execute 12 | // code before or after calling the handler. 13 | type MiddlewareFunc func(next http.Handler) http.Handler 14 | 15 | // HandleMiddleware returns an implementation of Middleware that runs the provided func. 16 | func HandleMiddleware(fn MiddlewareFunc) Middleware { 17 | return &middlewareFuncHandler{fn: fn} 18 | } 19 | 20 | type middlewareFuncHandler struct { 21 | fn MiddlewareFunc 22 | } 23 | 24 | func (mw *middlewareFuncHandler) Handle(next http.Handler) http.Handler { 25 | return mw.fn(next) 26 | } 27 | -------------------------------------------------------------------------------- /chttp/panic_logger_mw.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "net/http" 5 | "runtime/debug" 6 | 7 | "github.com/gocopper/copper/clogger" 8 | ) 9 | 10 | func panicLoggerMiddleware(logger clogger.Logger) Middleware { 11 | mw := func(next http.Handler) http.Handler { 12 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 13 | defer func() { 14 | log := logger.WithTags(map[string]interface{}{ 15 | "path": r.URL.Path, 16 | }) 17 | 18 | switch r := recover().(type) { 19 | case nil: 20 | break 21 | case error: 22 | log.WithTags(map[string]interface{}{ 23 | "stack": string(debug.Stack()), 24 | }).Error("Recovered from a panic while handling HTTP request", r) 25 | w.WriteHeader(http.StatusInternalServerError) 26 | default: 27 | log.WithTags(map[string]interface{}{ 28 | "error": r, 29 | "stack": string(debug.Stack()), 30 | }).Error("Recovered from a panic while handling HTTP request", nil) 31 | w.WriteHeader(http.StatusInternalServerError) 32 | } 33 | }() 34 | 35 | next.ServeHTTP(w, r) 36 | }) 37 | } 38 | 39 | return HandleMiddleware(mw) 40 | } 41 | -------------------------------------------------------------------------------- /chttp/panic_logger_mw_test.go: -------------------------------------------------------------------------------- 1 | package chttp_test 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/gocopper/copper/clogger" 10 | 11 | "github.com/gocopper/copper/chttp" 12 | "github.com/gocopper/copper/chttp/chttptest" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestPanicLoggerMiddleware_PanicError(t *testing.T) { 17 | t.Parallel() 18 | 19 | var ( 20 | router = chttptest.NewRouter([]chttp.Route{ 21 | { 22 | Path: "/", 23 | Methods: []string{http.MethodGet}, 24 | Handler: func(w http.ResponseWriter, r *http.Request) { 25 | panic(errors.New("test-error")) 26 | }, 27 | }, 28 | }) 29 | 30 | logs = make([]clogger.RecordedLog, 0) 31 | 32 | handler = chttp.NewHandler(chttp.NewHandlerParams{ 33 | Routers: []chttp.Router{router}, 34 | Logger: clogger.NewRecorder(&logs), 35 | }) 36 | ) 37 | 38 | server := httptest.NewServer(handler) 39 | defer server.Close() 40 | 41 | resp, err := http.Get(server.URL) //nolint:noctx 42 | assert.NoError(t, err) 43 | assert.NoError(t, resp.Body.Close()) 44 | 45 | assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 46 | assert.Equal(t, 1, len(logs)) 47 | assert.Equal(t, "Recovered from a panic while handling HTTP request", logs[0].Msg) 48 | assert.Equal(t, clogger.LevelError, logs[0].Level) 49 | assert.Equal(t, errors.New("test-error"), logs[0].Error) 50 | assert.Contains(t, logs[0].Tags["stack"], "panic_logger_mw.go") 51 | } 52 | 53 | func TestPanicLoggerMiddleware_NoPanic(t *testing.T) { 54 | t.Parallel() 55 | 56 | var ( 57 | router = chttptest.NewRouter([]chttp.Route{ 58 | { 59 | Path: "/", 60 | Methods: []string{http.MethodGet}, 61 | Handler: func(w http.ResponseWriter, r *http.Request) { 62 | w.WriteHeader(http.StatusOK) 63 | }, 64 | }, 65 | }) 66 | 67 | logs = make([]clogger.RecordedLog, 0) 68 | 69 | handler = chttp.NewHandler(chttp.NewHandlerParams{ 70 | Routers: []chttp.Router{router}, 71 | Logger: clogger.NewRecorder(&logs), 72 | }) 73 | ) 74 | 75 | server := httptest.NewServer(handler) 76 | defer server.Close() 77 | 78 | resp, err := http.Get(server.URL) //nolint:noctx 79 | assert.NoError(t, err) 80 | assert.NoError(t, resp.Body.Close()) 81 | 82 | assert.Equal(t, http.StatusOK, resp.StatusCode) 83 | assert.Equal(t, 0, len(logs)) 84 | } 85 | 86 | func TestPanicLoggerMiddleware_PanicNonError(t *testing.T) { 87 | t.Parallel() 88 | 89 | var ( 90 | router = chttptest.NewRouter([]chttp.Route{ 91 | { 92 | Path: "/", 93 | Methods: []string{http.MethodGet}, 94 | Handler: func(w http.ResponseWriter, r *http.Request) { 95 | panic("test-error") 96 | }, 97 | }, 98 | }) 99 | 100 | logs = make([]clogger.RecordedLog, 0) 101 | 102 | handler = chttp.NewHandler(chttp.NewHandlerParams{ 103 | Routers: []chttp.Router{router}, 104 | Logger: clogger.NewRecorder(&logs), 105 | }) 106 | ) 107 | 108 | server := httptest.NewServer(handler) 109 | defer server.Close() 110 | 111 | resp, err := http.Get(server.URL) //nolint:noctx 112 | assert.NoError(t, err) 113 | assert.NoError(t, resp.Body.Close()) 114 | 115 | assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 116 | assert.Equal(t, 1, len(logs)) 117 | assert.Equal(t, "Recovered from a panic while handling HTTP request", logs[0].Msg) 118 | assert.Equal(t, clogger.LevelError, logs[0].Level) 119 | assert.Equal(t, "test-error", logs[0].Tags["error"]) 120 | assert.Nil(t, logs[0].Error) 121 | assert.Contains(t, logs[0].Tags["stack"], "panic_logger_mw.go") 122 | } 123 | -------------------------------------------------------------------------------- /chttp/request_id_mw.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/google/uuid" 8 | ) 9 | 10 | type ctxRequestID string 11 | 12 | const ctxRequestIDKey = ctxRequestID("chttp/request-id") 13 | 14 | // SetRequestIDInCtxMiddleware sets a unique request id in the context 15 | func SetRequestIDInCtxMiddleware() Middleware { 16 | var mw = func(next http.Handler) http.Handler { 17 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 18 | ctx := context.WithValue(r.Context(), ctxRequestIDKey, uuid.New().String()) 19 | 20 | next.ServeHTTP(w, r.WithContext(ctx)) 21 | }) 22 | } 23 | 24 | return HandleMiddleware(mw) 25 | } 26 | 27 | // GetRequestID returns the request id from the context. 28 | // If the request id is not found in the context, it returns 29 | // an empty string. 30 | // It should be used only after the request id middleware is applied. 31 | func GetRequestID(ctx context.Context) string { 32 | requestID, ok := ctx.Value(ctxRequestIDKey).(string) 33 | if !ok || requestID == "" { 34 | return "" 35 | } 36 | 37 | return requestID 38 | } 39 | -------------------------------------------------------------------------------- /chttp/request_logger_mw.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "github.com/gocopper/copper/cmetrics" 8 | "net" 9 | "net/http" 10 | "time" 11 | 12 | "github.com/gocopper/copper/clogger" 13 | ) 14 | 15 | var errRWIsNotHijacker = errors.New("internal response writer is not http.Hijacker") 16 | 17 | // NewRequestLoggerMiddleware creates a new RequestLoggerMiddleware. 18 | func NewRequestLoggerMiddleware(metrics cmetrics.Metrics, logger clogger.Logger) *RequestLoggerMiddleware { 19 | return &RequestLoggerMiddleware{ 20 | metrics: metrics, 21 | logger: logger, 22 | } 23 | } 24 | 25 | // RequestLoggerMiddleware logs each request's HTTP method, path, and status code along with user uuid 26 | // (from basic auth) if any. 27 | type RequestLoggerMiddleware struct { 28 | metrics cmetrics.Metrics 29 | logger clogger.Logger 30 | } 31 | 32 | // Handle wraps the current request with a request/response recorder. It records the method path and the 33 | // return status code. It logs this with the given logger. 34 | func (mw *RequestLoggerMiddleware) Handle(next http.Handler) http.Handler { 35 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 36 | var ( 37 | loggerRw = requestLoggerRw{ 38 | internal: w, 39 | statusCode: http.StatusOK, 40 | } 41 | 42 | tags = map[string]interface{}{ 43 | "method": r.Method, 44 | "url": r.URL.Path, 45 | } 46 | 47 | begin = time.Now() 48 | ) 49 | 50 | user, _, ok := r.BasicAuth() 51 | if ok { 52 | tags["user"] = user 53 | } 54 | 55 | next.ServeHTTP(&loggerRw, r) 56 | 57 | tags["statusCode"] = loggerRw.statusCode 58 | tags["duration"] = time.Since(begin).String() 59 | 60 | mw.metrics.CounterInc("http_requests_total", map[string]string{ 61 | "status_code": fmt.Sprintf("%d", loggerRw.statusCode), 62 | "path": RawRoutePath(r), 63 | }) 64 | mw.metrics.HistogramObserve("http_request_duration_seconds", map[string]string{ 65 | "status_code": fmt.Sprintf("%d", loggerRw.statusCode), 66 | "path": RawRoutePath(r), 67 | }, time.Since(begin).Seconds()) 68 | 69 | mw.logger.WithTags(tags).Info(fmt.Sprintf("%s %s %d", r.Method, r.URL.Path, loggerRw.statusCode)) 70 | }) 71 | } 72 | 73 | type requestLoggerRw struct { 74 | internal http.ResponseWriter 75 | statusCode int 76 | } 77 | 78 | func (rw *requestLoggerRw) Hijack() (net.Conn, *bufio.ReadWriter, error) { 79 | h, ok := rw.internal.(http.Hijacker) 80 | if !ok { 81 | return nil, nil, errRWIsNotHijacker 82 | } 83 | 84 | return h.Hijack() 85 | } 86 | 87 | func (rw *requestLoggerRw) Header() http.Header { 88 | return rw.internal.Header() 89 | } 90 | 91 | func (rw *requestLoggerRw) Write(b []byte) (int, error) { 92 | return rw.internal.Write(b) 93 | } 94 | 95 | func (rw *requestLoggerRw) WriteHeader(statusCode int) { 96 | rw.internal.WriteHeader(statusCode) 97 | rw.statusCode = statusCode 98 | } 99 | -------------------------------------------------------------------------------- /chttp/request_logger_mw_test.go: -------------------------------------------------------------------------------- 1 | package chttp_test 2 | 3 | import ( 4 | "github.com/gocopper/copper/cmetrics" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/gocopper/copper/chttp" 11 | "github.com/gocopper/copper/chttp/chttptest" 12 | "github.com/gocopper/copper/clogger" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestNewRequestLoggerMiddleware(t *testing.T) { 17 | t.Parallel() 18 | 19 | var ( 20 | logs = make([]clogger.RecordedLog, 0) 21 | logger = clogger.NewRecorder(&logs) 22 | metrics = cmetrics.NewNoopMetrics() 23 | router = chttptest.NewRouter([]chttp.Route{ 24 | { 25 | Middlewares: []chttp.Middleware{ 26 | chttp.NewRequestLoggerMiddleware(metrics, logger), 27 | }, 28 | Path: "/test", 29 | Methods: []string{http.MethodGet}, 30 | Handler: func(w http.ResponseWriter, r *http.Request) { 31 | w.WriteHeader(201) 32 | 33 | _, err := w.Write([]byte("OK")) 34 | assert.NoError(t, err) 35 | }, 36 | }, 37 | }) 38 | handler = chttp.NewHandler(chttp.NewHandlerParams{ 39 | Routers: []chttp.Router{router}, 40 | GlobalMiddlewares: nil, 41 | Logger: clogger.NewNoop(), 42 | }) 43 | ) 44 | 45 | server := httptest.NewServer(handler) 46 | defer server.Close() 47 | 48 | resp, err := http.Get(server.URL + "/test") //nolint:noctx 49 | assert.NoError(t, err) 50 | 51 | body, err := io.ReadAll(resp.Body) 52 | assert.NoError(t, err) 53 | assert.NoError(t, resp.Body.Close()) 54 | 55 | assert.Equal(t, "OK", string(body)) 56 | assert.Equal(t, clogger.LevelInfo, logs[0].Level) 57 | assert.Equal(t, "GET /test 201", logs[0].Msg) 58 | } 59 | -------------------------------------------------------------------------------- /chttp/route_ctx_mw.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | type ctxRoutePath string 9 | 10 | const ctxRoutePathKey = ctxRoutePath("chttp/route-path") 11 | 12 | func setRoutePathInCtxMiddleware(path string) Middleware { 13 | var mw = func(next http.Handler) http.Handler { 14 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 15 | ctx := context.WithValue(r.Context(), ctxRoutePathKey, path) 16 | 17 | next.ServeHTTP(w, r.WithContext(ctx)) 18 | }) 19 | } 20 | 21 | return HandleMiddleware(mw) 22 | } 23 | 24 | // RawRoutePath returns the route path that matched for the given http.Request. This path includes the raw URL 25 | // variables. For example, a route path "/foo/{id}" will be returned as-is (i.e. {id} will NOT be replaced with the 26 | // actual url path) 27 | func RawRoutePath(r *http.Request) string { 28 | return r.Context().Value(ctxRoutePathKey).(string) 29 | } 30 | -------------------------------------------------------------------------------- /chttp/route_ctx_mw_test.go: -------------------------------------------------------------------------------- 1 | package chttp_test 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/gocopper/copper/clogger" 9 | 10 | "github.com/gocopper/copper/chttp" 11 | "github.com/gocopper/copper/chttp/chttptest" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestRoutePathInCtxMiddleware(t *testing.T) { 16 | t.Parallel() 17 | 18 | var ( 19 | routeMWRawRoutePath string 20 | globalMWRawRoutePath string 21 | 22 | router = chttptest.NewRouter([]chttp.Route{ 23 | { 24 | Path: "/foo/{id}", 25 | Methods: []string{http.MethodGet}, 26 | Handler: func(w http.ResponseWriter, r *http.Request) { 27 | routeMWRawRoutePath = chttp.RawRoutePath(r) 28 | }, 29 | }, 30 | }) 31 | 32 | handler = chttp.NewHandler(chttp.NewHandlerParams{ 33 | Routers: []chttp.Router{router}, 34 | GlobalMiddlewares: []chttp.Middleware{ 35 | chttp.HandleMiddleware(func(next http.Handler) http.Handler { 36 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 37 | globalMWRawRoutePath = chttp.RawRoutePath(r) 38 | 39 | next.ServeHTTP(w, r) 40 | }) 41 | }), 42 | }, 43 | Logger: clogger.NewNoop(), 44 | }) 45 | ) 46 | 47 | server := httptest.NewServer(handler) 48 | defer server.Close() 49 | 50 | resp, err := http.Get(server.URL + "/foo/bar") //nolint:noctx 51 | assert.NoError(t, err) 52 | assert.NoError(t, resp.Body.Close()) 53 | 54 | assert.Equal(t, "/foo/{id}", globalMWRawRoutePath) 55 | assert.Equal(t, "/foo/{id}", routeMWRawRoutePath) 56 | } 57 | -------------------------------------------------------------------------------- /chttp/server.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/gocopper/copper/clifecycle" 11 | "github.com/gocopper/copper/clogger" 12 | ) 13 | 14 | // NewServerParams holds the params needed to create a server. 15 | type NewServerParams struct { 16 | Handler http.Handler 17 | Lifecycle *clifecycle.Lifecycle 18 | Config Config 19 | Logger clogger.Logger 20 | } 21 | 22 | // NewServer creates a new server. 23 | func NewServer(p NewServerParams) *Server { 24 | return &Server{ 25 | handler: p.Handler, 26 | config: p.Config, 27 | logger: p.Logger, 28 | lc: p.Lifecycle, 29 | internal: http.Server{ 30 | ReadTimeout: time.Duration(p.Config.ReadTimeoutSeconds) * time.Second, 31 | }, 32 | } 33 | } 34 | 35 | // Server represents a configurable HTTP server that supports graceful shutdown. 36 | type Server struct { 37 | handler http.Handler 38 | config Config 39 | logger clogger.Logger 40 | lc *clifecycle.Lifecycle 41 | 42 | internal http.Server 43 | } 44 | 45 | // Run configures an HTTP server using the provided app config and starts it. 46 | func (s *Server) Run() error { 47 | s.internal.Addr = fmt.Sprintf(":%d", s.config.Port) 48 | s.internal.Handler = s.handler 49 | 50 | s.lc.OnStop(func(ctx context.Context) error { 51 | s.logger.Info("Shutting down http server..") 52 | 53 | return s.internal.Shutdown(ctx) 54 | }) 55 | 56 | go func() { 57 | s.logger. 58 | WithTags(map[string]interface{}{"port": s.config.Port}). 59 | Info("Starting http server..") 60 | 61 | err := s.internal.ListenAndServe() 62 | if err != nil && !errors.Is(err, http.ErrServerClosed) { 63 | s.logger.Error("Server did not close cleanly", err) 64 | } 65 | }() 66 | 67 | return nil 68 | } 69 | -------------------------------------------------------------------------------- /chttp/server_test.go: -------------------------------------------------------------------------------- 1 | package chttp_test 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | "time" 7 | 8 | "github.com/gocopper/copper/chttp" 9 | "github.com/gocopper/copper/clifecycle" 10 | "github.com/gocopper/copper/clogger" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestServer_Run(t *testing.T) { 15 | t.Parallel() 16 | 17 | logger := clogger.New() 18 | lc := clifecycle.New() 19 | 20 | server := chttp.NewServer(chttp.NewServerParams{ 21 | Handler: http.NotFoundHandler(), 22 | Config: chttp.Config{Port: 8999}, 23 | Logger: logger, 24 | Lifecycle: lc, 25 | }) 26 | 27 | go func() { 28 | err := server.Run() 29 | assert.NoError(t, err) 30 | }() 31 | 32 | time.Sleep(50 * time.Millisecond) // wait for server to start 33 | 34 | resp, err := http.Get("http://127.0.0.1:8999") //nolint:noctx 35 | assert.NoError(t, err) 36 | assert.NoError(t, resp.Body.Close()) 37 | 38 | assert.NoError(t, resp.Body.Close()) 39 | assert.Equal(t, http.StatusNotFound, resp.StatusCode) 40 | 41 | lc.Stop(logger) 42 | 43 | time.Sleep(50 * time.Millisecond) // wait for server to stop 44 | 45 | _, err = http.Get("http://127.0.0.1:8999") //nolint:noctx,bodyclose 46 | assert.EqualError(t, err, "Get \"http://127.0.0.1:8999\": dial tcp 127.0.0.1:8999: connect: connection refused") 47 | } 48 | -------------------------------------------------------------------------------- /chttp/wire.go: -------------------------------------------------------------------------------- 1 | package chttp 2 | 3 | import ( 4 | "github.com/google/wire" 5 | ) 6 | 7 | // WireModule can be used as part of google/wire setup. 8 | var WireModule = wire.NewSet( //nolint:gochecknoglobals 9 | LoadConfig, 10 | NewJSONReaderWriter, 11 | NewHTMLReaderWriter, 12 | NewRequestLoggerMiddleware, 13 | wire.Struct(new(NewServerParams), "*"), 14 | NewServer, 15 | wire.Struct(new(NewHTMLRouterParams), "*"), 16 | NewHTMLRouter, 17 | wire.Struct(new(NewHTMLRendererParams), "*"), 18 | NewHTMLRenderer, 19 | ) 20 | 21 | // WireModuleEmptyHTML provides empty/default values for html and static dirs. This can be used to satisfy 22 | // wire when the project does not use/need html rendering. 23 | var WireModuleEmptyHTML = wire.NewSet( //nolint:gochecknoglobals 24 | wire.InterfaceValue(new(HTMLDir), &EmptyFS{}), 25 | wire.InterfaceValue(new(StaticDir), &EmptyFS{}), 26 | wire.Value([]HTMLRenderFunc{}), 27 | ) 28 | -------------------------------------------------------------------------------- /clifecycle/doc.go: -------------------------------------------------------------------------------- 1 | // Package clifecycle provides a Copper app's lifecycle management. It allows hooks to be registered that are run 2 | // during various checkpoints in the app's lifecycle. 3 | package clifecycle 4 | -------------------------------------------------------------------------------- /clifecycle/lifecycle.go: -------------------------------------------------------------------------------- 1 | package clifecycle 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | const defaultStopTimeout = 10 * time.Second 9 | 10 | // New instantiates and returns a new Lifecycle that can be used with 11 | // New to create a Copper app. 12 | func New() *Lifecycle { 13 | return &Lifecycle{ 14 | onStop: make([]func(ctx context.Context) error, 0), 15 | stopTimeout: defaultStopTimeout, 16 | } 17 | } 18 | 19 | // Lifecycle represents the lifecycle of an app. Most importantly, it 20 | // allows various parts of the app to register stop funcs that are run 21 | // before the app exits. 22 | // Packages such as chttp use Lifecycle to gracefully stop the HTTP 23 | // server before the app exits. 24 | type Lifecycle struct { 25 | onStop []func(ctx context.Context) error 26 | stopTimeout time.Duration 27 | } 28 | 29 | // OnStop registers the provided fn to run before the app exits. The fn 30 | // is given a context with a deadline. Once the deadline expires, the 31 | // app may exit forcefully. 32 | func (lc *Lifecycle) OnStop(fn func(ctx context.Context) error) { 33 | lc.onStop = append(lc.onStop, fn) 34 | } 35 | 36 | // Stop runs all of the registered stop funcs in order along with a 37 | // context with a configured timeout. 38 | func (lc *Lifecycle) Stop(logger Logger) { 39 | for _, fn := range lc.onStop { 40 | ctx, cancel := context.WithTimeout(context.Background(), lc.stopTimeout) 41 | 42 | err := fn(ctx) 43 | if err != nil { 44 | logger.Error("Failed to run cleanup func", err) 45 | } 46 | 47 | cancel() 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /clifecycle/logger.go: -------------------------------------------------------------------------------- 1 | package clifecycle 2 | 3 | // Logger provides the methods needed by Lifecycle to log errors. 4 | type Logger interface { 5 | Error(msg string, err error) 6 | } 7 | -------------------------------------------------------------------------------- /clogger/config.go: -------------------------------------------------------------------------------- 1 | package clogger 2 | 3 | import ( 4 | "github.com/gocopper/copper/cconfig" 5 | "github.com/gocopper/copper/cerrors" 6 | ) 7 | 8 | // Format represents the output format of log statements 9 | type Format string 10 | 11 | // Formats supported by Logger 12 | const ( 13 | FormatPlain = Format("plain") 14 | FormatJSON = Format("json") 15 | ) 16 | 17 | // LoadConfig loads Config from app's config 18 | func LoadConfig(appConfig cconfig.Loader) (Config, error) { 19 | var config Config 20 | 21 | err := appConfig.Load("clogger", &config) 22 | if err != nil { 23 | return Config{}, cerrors.New(err, "failed to load clogger config", nil) 24 | } 25 | 26 | if config.Format != FormatPlain && config.Format != FormatJSON { 27 | config.Format = FormatPlain 28 | } 29 | 30 | return config, nil 31 | } 32 | 33 | // Config holds the params needed to configure Logger 34 | type Config struct { 35 | Out string `toml:"out"` 36 | Err string `toml:"err"` 37 | Format Format `toml:"format"` 38 | RedactFields []string `toml:"redact_fields"` 39 | } 40 | -------------------------------------------------------------------------------- /clogger/doc.go: -------------------------------------------------------------------------------- 1 | // Package clogger provides a Logger interface that can be used to log messages and errors 2 | package clogger 3 | -------------------------------------------------------------------------------- /clogger/json_redactor.go: -------------------------------------------------------------------------------- 1 | package clogger 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "strings" 7 | ) 8 | 9 | func redactJSONObject(in map[string]interface{}, redactFields []string) (map[string]interface{}, error) { 10 | var b bytes.Buffer 11 | 12 | enc := json.NewEncoder(&b) 13 | enc.SetEscapeHTML(false) 14 | 15 | err := enc.Encode(in) 16 | if err != nil { 17 | return nil, err 18 | } 19 | 20 | redacted, err := redactJSON(b.Bytes(), redactFields) 21 | if err != nil { 22 | return nil, err 23 | } 24 | 25 | var out map[string]interface{} 26 | err = json.Unmarshal(redacted, &out) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | return out, nil 32 | } 33 | 34 | func redactJSON(in json.RawMessage, redactFields []string) (json.RawMessage, error) { 35 | var err error 36 | 37 | if in[0] == 123 { // 123 is `{` => object 38 | var cont map[string]json.RawMessage 39 | 40 | err = json.Unmarshal(in, &cont) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | for k, v := range cont { 46 | 47 | didRedact := false 48 | for i := range redactFields { 49 | if strings.Contains(strings.ToLower(k), strings.ToLower(redactFields[i])) { 50 | cont[k] = json.RawMessage(`"redacted"`) 51 | didRedact = true 52 | break 53 | } 54 | } 55 | 56 | if didRedact { 57 | continue 58 | } 59 | 60 | cont[k], err = redactJSON(v, redactFields) 61 | if err != nil { 62 | return nil, err 63 | } 64 | } 65 | 66 | return json.Marshal(cont) 67 | } else if in[0] == 91 { // 91 is `[` => array 68 | var cont []json.RawMessage 69 | 70 | err = json.Unmarshal(in, &cont) 71 | if err != nil { 72 | return nil, err 73 | } 74 | 75 | for i, v := range cont { 76 | cont[i], err = redactJSON(v, redactFields) 77 | if err != nil { 78 | return nil, err 79 | } 80 | } 81 | 82 | return json.Marshal(cont) 83 | } 84 | 85 | return in, nil 86 | } 87 | -------------------------------------------------------------------------------- /clogger/json_redactor_test.go: -------------------------------------------------------------------------------- 1 | package clogger 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/shopspring/decimal" 7 | "github.com/stretchr/testify/assert" 8 | "testing" 9 | ) 10 | 11 | type FooDecimal struct { 12 | decimal.Decimal 13 | } 14 | 15 | func (d *FooDecimal) MarshalJSON() ([]byte, error) { 16 | return json.Marshal("0x" + d.BigInt().Text(16)) 17 | } 18 | 19 | func TestRedactJSONObject(t *testing.T) { 20 | d := FooDecimal{decimal.NewFromInt(100)} 21 | o, err := json.Marshal(&d) 22 | assert.NoError(t, err) 23 | 24 | fmt.Println("====> ", string(o)) 25 | 26 | var t1 = map[string]interface{}{ 27 | "a": &d, 28 | } 29 | 30 | _, err = redactJSONObject(t1, []string{"b"}) 31 | assert.NoError(t, err) 32 | } 33 | 34 | func TestRedactJSON(t *testing.T) { 35 | var t1 = map[string]interface{}{ 36 | "a": 1, 37 | "b": "foo", 38 | "c": map[string]interface{}{"d": 2}, 39 | "e": []interface{}{1, 2, map[string]interface{}{"f": 3}}, 40 | } 41 | 42 | in, err := json.Marshal(t1) 43 | assert.NoError(t, err) 44 | 45 | out, err := redactJSON(in, []string{"f"}) 46 | assert.NoError(t, err) 47 | 48 | assert.Equal(t, `{"a":1,"b":"foo","c":{"d":2},"e":[1,2,{"f":"redacted"}]}`, string(out)) 49 | } 50 | -------------------------------------------------------------------------------- /clogger/level.go: -------------------------------------------------------------------------------- 1 | package clogger 2 | 3 | // Level represents the severity level of a log. 4 | type Level int 5 | 6 | // Pre-defined log levels. 7 | const ( 8 | LevelDebug = Level(iota + 1) 9 | LevelInfo 10 | LevelWarn 11 | LevelError 12 | ) 13 | 14 | func (l Level) String() string { 15 | switch l { 16 | case LevelDebug: 17 | return "DEBUG" 18 | case LevelInfo: 19 | return "INFO" 20 | case LevelWarn: 21 | return "WARN" 22 | case LevelError: 23 | return "ERROR" 24 | default: 25 | return "UNKNOWN" 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /clogger/level_test.go: -------------------------------------------------------------------------------- 1 | package clogger_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/gocopper/copper/clogger" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestLevel_String(t *testing.T) { 11 | t.Parallel() 12 | 13 | assert.Equal(t, "DEBUG", clogger.LevelDebug.String()) 14 | assert.Equal(t, "INFO", clogger.LevelInfo.String()) 15 | assert.Equal(t, "WARN", clogger.LevelWarn.String()) 16 | assert.Equal(t, "ERROR", clogger.LevelError.String()) 17 | assert.Equal(t, "UNKNOWN", clogger.Level(-99).String()) 18 | } 19 | -------------------------------------------------------------------------------- /clogger/logger.go: -------------------------------------------------------------------------------- 1 | package clogger 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "log" 7 | "os" 8 | "strings" 9 | "time" 10 | 11 | "github.com/gocopper/copper/cerrors" 12 | ) 13 | 14 | // Logger can be used to log messages and errors. 15 | type Logger interface { 16 | WithTags(tags map[string]interface{}) Logger 17 | 18 | Debug(msg string) 19 | Info(msg string) 20 | Warn(msg string, err error) 21 | Error(msg string, err error) 22 | } 23 | 24 | // New returns a Logger implementation that can logs to console. 25 | func New() Logger { 26 | return NewWithWriters(os.Stdout, os.Stderr, FormatPlain, nil) 27 | } 28 | 29 | // NewWithConfig creates a Logger based on the provided config. 30 | func NewWithConfig(config Config) (Logger, error) { 31 | const LogFilePerms = 0666 32 | 33 | var ( 34 | outFile io.Writer = os.Stdout 35 | errFile io.Writer = os.Stderr 36 | err error 37 | ) 38 | 39 | if config.Out != "" { 40 | outFile, err = os.OpenFile(config.Out, os.O_APPEND|os.O_CREATE|os.O_WRONLY, LogFilePerms) 41 | if err != nil { 42 | return nil, cerrors.New(err, "failed to open log file", map[string]interface{}{ 43 | "path": config.Out, 44 | }) 45 | } 46 | } 47 | 48 | if config.Out == config.Err { 49 | errFile = outFile 50 | } else if config.Err != "" { 51 | errFile, err = os.OpenFile(config.Err, os.O_APPEND|os.O_CREATE|os.O_WRONLY, LogFilePerms) 52 | if err != nil { 53 | return nil, cerrors.New(err, "failed to open error log file", map[string]interface{}{ 54 | "path": config.Err, 55 | }) 56 | } 57 | } 58 | 59 | return NewWithWriters(outFile, errFile, config.Format, config.RedactFields), nil 60 | } 61 | 62 | // NewWithWriters creates a Logger that uses the provided writers. out is 63 | // used for debug and info levels. err is used for warn and error levels. 64 | func NewWithWriters(out, err io.Writer, format Format, redactFields []string) Logger { 65 | return &logger{ 66 | out: out, 67 | err: err, 68 | tags: make(map[string]interface{}), 69 | format: format, 70 | redactFields: expandRedactedFields(redactFields), 71 | } 72 | } 73 | 74 | type logger struct { 75 | out io.Writer 76 | err io.Writer 77 | tags map[string]interface{} 78 | format Format 79 | redactFields []string 80 | } 81 | 82 | func (l *logger) WithTags(tags map[string]interface{}) Logger { 83 | return &logger{ 84 | out: l.out, 85 | err: l.err, 86 | tags: mergeTags(l.tags, tags), 87 | format: l.format, 88 | redactFields: l.redactFields, 89 | } 90 | } 91 | 92 | func (l *logger) Debug(msg string) { 93 | l.log(l.out, LevelDebug, msg, nil) //nolint:goerr113 94 | } 95 | 96 | func (l *logger) Info(msg string) { 97 | l.log(l.out, LevelInfo, msg, nil) //nolint:goerr113 98 | } 99 | 100 | func (l *logger) Warn(msg string, err error) { 101 | l.log(l.err, LevelWarn, msg, err) 102 | } 103 | 104 | func (l *logger) Error(msg string, err error) { 105 | l.log(l.err, LevelError, msg, err) 106 | } 107 | 108 | func (l *logger) log(dest io.Writer, lvl Level, msg string, err error) { 109 | switch l.format { 110 | case FormatJSON: 111 | l.logJSON(dest, lvl, msg, err) 112 | case FormatPlain: 113 | fallthrough 114 | default: 115 | l.logPlain(dest, lvl, msg, err) 116 | } 117 | } 118 | 119 | func (l *logger) logJSON(dest io.Writer, lvl Level, msg string, err error) { 120 | var dict = map[string]interface{}{ 121 | "ts": time.Now().Format(time.RFC3339), 122 | "level": lvl.String(), 123 | "msg": msg, 124 | } 125 | 126 | if err != nil { 127 | dict["error"] = cerrors.WithoutTags(err).Error() 128 | } 129 | 130 | if redactedTags, err := redactJSONObject(mergeTags(cerrors.Tags(err), l.tags), l.redactFields); err != nil { 131 | dict["tags"] = cerrors.New(err, "tag redaction failed", nil).Error() 132 | } else { 133 | dict["tags"] = redactedTags 134 | } 135 | 136 | enc := json.NewEncoder(dest) 137 | enc.SetEscapeHTML(false) 138 | 139 | _ = enc.Encode(dict) 140 | } 141 | 142 | func (l *logger) logPlain(dest io.Writer, lvl Level, msg string, err error) { 143 | var ( 144 | logErr = cerrors.New(nil, msg, l.tags).Error() 145 | 146 | o strings.Builder 147 | ) 148 | 149 | if len(l.redactFields) == 0 { 150 | o.WriteString(logErr) 151 | 152 | if err != nil { 153 | o.WriteString(" because\n> ") 154 | o.WriteString(err.Error()) 155 | } 156 | } else { 157 | o.WriteString("") 158 | } 159 | 160 | log.New(dest, "", log.LstdFlags).Printf("[%s] %s", lvl.String(), o.String()) 161 | } 162 | -------------------------------------------------------------------------------- /clogger/logger_test.go: -------------------------------------------------------------------------------- 1 | package clogger_test 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "os" 7 | "testing" 8 | 9 | "github.com/gocopper/copper/cerrors" 10 | 11 | "github.com/gocopper/copper/clogger" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestNew(t *testing.T) { 16 | t.Parallel() 17 | 18 | logger := clogger.New() 19 | assert.NotNil(t, logger) 20 | } 21 | 22 | func TestNewWithConfig(t *testing.T) { 23 | t.Parallel() 24 | 25 | log, err := os.CreateTemp("", "*") 26 | assert.NoError(t, err) 27 | 28 | t.Cleanup(func() { 29 | assert.NoError(t, os.Remove(log.Name())) 30 | }) 31 | 32 | logger, err := clogger.NewWithConfig(clogger.Config{ 33 | Out: log.Name(), 34 | Err: log.Name(), 35 | Format: clogger.FormatPlain, 36 | }) 37 | assert.NoError(t, err) 38 | assert.NotNil(t, logger) 39 | } 40 | 41 | func TestNewWithConfig_OutFileErr(t *testing.T) { 42 | t.Parallel() 43 | 44 | log, err := os.CreateTemp("", "*") 45 | assert.NoError(t, err) 46 | 47 | assert.NoError(t, os.Chmod(log.Name(), 0000)) 48 | 49 | t.Cleanup(func() { 50 | assert.NoError(t, os.Remove(log.Name())) 51 | }) 52 | 53 | _, err = clogger.NewWithConfig(clogger.Config{ 54 | Out: log.Name(), 55 | Format: clogger.FormatPlain, 56 | }) 57 | assert.Error(t, err) 58 | } 59 | 60 | func TestNewWithConfig_ErrFileErr(t *testing.T) { 61 | t.Parallel() 62 | 63 | log, err := os.CreateTemp("", "*") 64 | assert.NoError(t, err) 65 | 66 | assert.NoError(t, os.Chmod(log.Name(), 0000)) 67 | 68 | t.Cleanup(func() { 69 | assert.NoError(t, os.Remove(log.Name())) 70 | }) 71 | 72 | _, err = clogger.NewWithConfig(clogger.Config{ 73 | Err: log.Name(), 74 | Format: clogger.FormatPlain, 75 | }) 76 | assert.Error(t, err) 77 | } 78 | 79 | func TestNewWithParams(t *testing.T) { 80 | t.Parallel() 81 | 82 | logger := clogger.NewWithWriters(nil, nil, clogger.FormatPlain, nil) 83 | assert.NotNil(t, logger) 84 | } 85 | 86 | func TestLogger_Debug(t *testing.T) { 87 | t.Parallel() 88 | 89 | var ( 90 | buf bytes.Buffer 91 | logger = clogger.NewWithWriters(&buf, &buf, clogger.FormatPlain, nil) 92 | ) 93 | 94 | logger.Debug("test debug log") 95 | 96 | assert.Contains(t, buf.String(), "[DEBUG] test debug log") 97 | } 98 | 99 | func TestLogger_WithTags_Debug(t *testing.T) { 100 | t.Parallel() 101 | 102 | var ( 103 | buf bytes.Buffer 104 | logger = clogger.NewWithWriters(&buf, &buf, clogger.FormatPlain, nil) 105 | ) 106 | 107 | logger. 108 | WithTags(map[string]interface{}{ 109 | "key": "val", 110 | }). 111 | WithTags(map[string]interface{}{ 112 | "key2": "val2", 113 | }).Debug("test debug log") 114 | 115 | assert.Contains(t, buf.String(), "[DEBUG] test debug log where key2=val2,key=val") 116 | } 117 | 118 | func TestLogger_WithTags_RedactedFields(t *testing.T) { 119 | t.Parallel() 120 | 121 | for _, format := range []clogger.Format{clogger.FormatJSON, clogger.FormatPlain} { 122 | var ( 123 | buf bytes.Buffer 124 | logger = clogger.NewWithWriters(&buf, &buf, format, []string{ 125 | "secret", "password", "userPin", 126 | }) 127 | 128 | testErr = cerrors.New(nil, "test-error", map[string]interface{}{ 129 | "secret": "my_api_key", 130 | "user-pin": "12456", 131 | "data": map[string]string{ 132 | "password": "abc123", 133 | }, 134 | }) 135 | ) 136 | 137 | logger.WithTags(map[string]interface{}{ 138 | "passwordOwner": "abc123", 139 | "USER_PIN": "123456", 140 | "params": map[string]string{ 141 | "myPassword": "abc123", 142 | }, 143 | }).Error("test debug log", testErr) 144 | 145 | assert.NotContains(t, buf.String(), "my_api_key") 146 | assert.NotContains(t, buf.String(), "12456") 147 | assert.NotContains(t, buf.String(), "123456") 148 | assert.NotContains(t, buf.String(), "abc123") 149 | assert.Contains(t, buf.String(), "redact") 150 | } 151 | } 152 | 153 | func TestLogger_Info(t *testing.T) { 154 | t.Parallel() 155 | 156 | var ( 157 | buf bytes.Buffer 158 | logger = clogger.NewWithWriters(&buf, &buf, clogger.FormatPlain, nil) 159 | ) 160 | 161 | logger.Info("test info log") 162 | 163 | assert.Contains(t, buf.String(), "[INFO] test info log", nil) 164 | } 165 | 166 | func TestLogger_Warn(t *testing.T) { 167 | t.Parallel() 168 | 169 | var ( 170 | buf bytes.Buffer 171 | logger = clogger.NewWithWriters(&buf, &buf, clogger.FormatPlain, nil) 172 | ) 173 | 174 | logger.Warn("test warn log", errors.New("test-error")) //nolint:goerr113 175 | 176 | assert.Contains(t, buf.String(), "[WARN] test warn log because\n> test-error") 177 | } 178 | 179 | func TestLogger_Error(t *testing.T) { 180 | t.Parallel() 181 | 182 | var ( 183 | buf bytes.Buffer 184 | logger = clogger.NewWithWriters(&buf, &buf, clogger.FormatPlain, nil) 185 | ) 186 | 187 | logger.Error("test error log", errors.New("test-error")) //nolint:goerr113 188 | 189 | assert.Contains(t, buf.String(), "[ERROR] test error log because\n> test-error") 190 | } 191 | -------------------------------------------------------------------------------- /clogger/noop.go: -------------------------------------------------------------------------------- 1 | package clogger 2 | 3 | // NewNoop returns a no-op implementation of Logger. 4 | // Useful in passing it as a valid logger in unit tests. 5 | func NewNoop() Logger { 6 | return &noop{} 7 | } 8 | 9 | type noop struct{} 10 | 11 | func (l *noop) WithTags(tags map[string]interface{}) Logger { 12 | return l 13 | } 14 | 15 | func (l *noop) Debug(msg string) {} 16 | 17 | func (l *noop) Info(msg string) {} 18 | 19 | func (l *noop) Warn(msg string, err error) {} 20 | 21 | func (l *noop) Error(msg string, err error) {} 22 | -------------------------------------------------------------------------------- /clogger/noop_test.go: -------------------------------------------------------------------------------- 1 | package clogger_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/gocopper/copper/clogger" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestNewNoop(t *testing.T) { 11 | t.Parallel() 12 | 13 | logger := clogger.NewNoop() 14 | assert.NotNil(t, logger) 15 | } 16 | 17 | func TestNoopLogger_WithTags(t *testing.T) { 18 | t.Parallel() 19 | 20 | logger := clogger.NewNoop().WithTags(nil) 21 | 22 | assert.NotNil(t, logger) 23 | } 24 | 25 | func TestNoopLogger_Debug(t *testing.T) { 26 | t.Parallel() 27 | 28 | logger := clogger.NewNoop().WithTags(nil) 29 | 30 | logger.Debug("test-debug") 31 | } 32 | 33 | func TestNoopLogger_Info(t *testing.T) { 34 | t.Parallel() 35 | 36 | logger := clogger.NewNoop().WithTags(nil) 37 | 38 | logger.Info("info") 39 | } 40 | 41 | func TestNoopLogger_Warn(t *testing.T) { 42 | t.Parallel() 43 | 44 | logger := clogger.NewNoop().WithTags(nil) 45 | 46 | logger.Warn("warn", nil) 47 | } 48 | 49 | func TestNoopLogger_Error(t *testing.T) { 50 | t.Parallel() 51 | 52 | logger := clogger.NewNoop().WithTags(nil) 53 | 54 | logger.Error("error", nil) 55 | } 56 | -------------------------------------------------------------------------------- /clogger/recorder.go: -------------------------------------------------------------------------------- 1 | package clogger 2 | 3 | // NewRecorder returns an implementation of Logger that keeps 4 | // a record of each log. Useful in unit tests when logs need 5 | // to be tested. 6 | func NewRecorder(logs *[]RecordedLog) Logger { 7 | return &recorder{ 8 | Logs: logs, 9 | tags: make(map[string]interface{}), 10 | } 11 | } 12 | 13 | // RecordedLog represents a single log. 14 | type RecordedLog struct { 15 | Level Level 16 | Tags map[string]interface{} 17 | Msg string 18 | Error error 19 | } 20 | 21 | type recorder struct { 22 | Logs *[]RecordedLog 23 | tags map[string]interface{} 24 | } 25 | 26 | func (l *recorder) WithTags(tags map[string]interface{}) Logger { 27 | return &recorder{ 28 | Logs: l.Logs, 29 | tags: mergeTags(l.tags, tags), 30 | } 31 | } 32 | 33 | func (l *recorder) Debug(msg string) { 34 | *l.Logs = append(*l.Logs, RecordedLog{ 35 | Level: LevelDebug, 36 | Tags: mergeTags(l.tags, nil), 37 | Msg: msg, 38 | Error: nil, 39 | }) 40 | } 41 | 42 | func (l *recorder) Info(msg string) { 43 | *l.Logs = append(*l.Logs, RecordedLog{ 44 | Level: LevelInfo, 45 | Tags: mergeTags(l.tags, nil), 46 | Msg: msg, 47 | Error: nil, 48 | }) 49 | } 50 | 51 | func (l *recorder) Warn(msg string, err error) { 52 | *l.Logs = append(*l.Logs, RecordedLog{ 53 | Level: LevelWarn, 54 | Tags: mergeTags(l.tags, nil), 55 | Msg: msg, 56 | Error: err, 57 | }) 58 | } 59 | 60 | func (l *recorder) Error(msg string, err error) { 61 | *l.Logs = append(*l.Logs, RecordedLog{ 62 | Level: LevelError, 63 | Tags: mergeTags(l.tags, nil), 64 | Msg: msg, 65 | Error: err, 66 | }) 67 | } 68 | -------------------------------------------------------------------------------- /clogger/recorder_test.go: -------------------------------------------------------------------------------- 1 | package clogger_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/gocopper/copper/clogger" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestNewRecorder(t *testing.T) { 12 | t.Parallel() 13 | 14 | logger := clogger.NewRecorder(nil) 15 | assert.NotNil(t, logger) 16 | } 17 | 18 | func TestRecorder_Debug(t *testing.T) { 19 | t.Parallel() 20 | 21 | var ( 22 | logs = make([]clogger.RecordedLog, 0) 23 | logger = clogger.NewRecorder(&logs) 24 | ) 25 | 26 | logger.Debug("test debug log") 27 | 28 | assert.Len(t, logs, 1) 29 | 30 | log := logs[0] 31 | 32 | assert.Equal(t, clogger.LevelDebug, log.Level) 33 | assert.Equal(t, "test debug log", log.Msg) 34 | assert.Empty(t, log.Tags) 35 | assert.Nil(t, log.Error) 36 | } 37 | 38 | func TestRecorder_WithTags_Debug(t *testing.T) { 39 | t.Parallel() 40 | 41 | var ( 42 | logs = make([]clogger.RecordedLog, 0) 43 | logger = clogger.NewRecorder(&logs) 44 | ) 45 | 46 | logger. 47 | WithTags(map[string]interface{}{ 48 | "key": "val", 49 | }). 50 | WithTags(map[string]interface{}{ 51 | "key2": "val2", 52 | }).Debug("test debug log") 53 | 54 | assert.Len(t, logs, 1) 55 | 56 | log := logs[0] 57 | 58 | assert.Equal(t, clogger.LevelDebug, log.Level) 59 | assert.Equal(t, "test debug log", log.Msg) 60 | assert.Equal(t, map[string]interface{}{ 61 | "key": "val", 62 | "key2": "val2", 63 | }, log.Tags) 64 | assert.Nil(t, log.Error) 65 | } 66 | 67 | func TestRecorder_Info(t *testing.T) { 68 | t.Parallel() 69 | 70 | var ( 71 | logs = make([]clogger.RecordedLog, 0) 72 | logger = clogger.NewRecorder(&logs) 73 | ) 74 | 75 | logger.Info("test info log") 76 | 77 | assert.Len(t, logs, 1) 78 | 79 | log := logs[0] 80 | 81 | assert.Equal(t, clogger.LevelInfo, log.Level) 82 | assert.Equal(t, "test info log", log.Msg) 83 | assert.Empty(t, log.Tags) 84 | assert.Nil(t, log.Error) 85 | } 86 | 87 | func TestRecorder_Warn(t *testing.T) { 88 | t.Parallel() 89 | 90 | var ( 91 | logs = make([]clogger.RecordedLog, 0) 92 | logger = clogger.NewRecorder(&logs) 93 | ) 94 | 95 | logger.Warn("test warn log", errors.New("test-err")) //nolint:goerr113 96 | 97 | assert.Len(t, logs, 1) 98 | 99 | log := logs[0] 100 | 101 | assert.Equal(t, clogger.LevelWarn, log.Level) 102 | assert.Equal(t, "test warn log", log.Msg) 103 | assert.Empty(t, log.Tags) 104 | assert.EqualError(t, log.Error, "test-err") 105 | } 106 | 107 | func TestRecorder_Error(t *testing.T) { 108 | t.Parallel() 109 | 110 | var ( 111 | logs = make([]clogger.RecordedLog, 0) 112 | logger = clogger.NewRecorder(&logs) 113 | ) 114 | 115 | logger.Error("test error log", errors.New("test-err")) //nolint:goerr113 116 | 117 | assert.Len(t, logs, 1) 118 | 119 | log := logs[0] 120 | 121 | assert.Equal(t, clogger.LevelError, log.Level) 122 | assert.Equal(t, "test error log", log.Msg) 123 | assert.Empty(t, log.Tags) 124 | assert.EqualError(t, log.Error, "test-err") 125 | } 126 | -------------------------------------------------------------------------------- /clogger/util.go: -------------------------------------------------------------------------------- 1 | package clogger 2 | 3 | import ( 4 | "github.com/iancoleman/strcase" 5 | ) 6 | 7 | func mergeTags(t1, t2 map[string]interface{}) map[string]interface{} { 8 | merged := make(map[string]interface{}) 9 | 10 | for k, v := range t1 { 11 | merged[k] = v 12 | } 13 | 14 | for k, v := range t2 { 15 | merged[k] = v 16 | } 17 | 18 | return merged 19 | } 20 | 21 | func tagsToKVs(tags map[string]interface{}) []interface{} { 22 | const TagsToKVsMultiplier = 2 23 | 24 | kvs := make([]interface{}, 0, len(tags)*TagsToKVsMultiplier) 25 | for k, v := range tags { 26 | kvs = append(kvs, k, v) 27 | } 28 | return kvs 29 | } 30 | 31 | func formatToZapEncoding(f Format) string { 32 | switch f { 33 | case FormatJSON: 34 | return "json" 35 | case FormatPlain: 36 | return "console" 37 | default: 38 | return "console" 39 | } 40 | } 41 | 42 | func expandRedactedFields(redactedFields []string) []string { 43 | expanded := make([]string, 0, len(redactedFields)) 44 | for _, f := range redactedFields { 45 | expanded = append(expanded, f, 46 | strcase.ToSnake(f), 47 | strcase.ToKebab(f), 48 | strcase.ToDelimited(f, '-'), 49 | strcase.ToDelimited(f, '.'), 50 | strcase.ToCamel(f), 51 | ) 52 | } 53 | return expanded 54 | } 55 | -------------------------------------------------------------------------------- /clogger/zap.go: -------------------------------------------------------------------------------- 1 | package clogger 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/gocopper/copper/cerrors" 7 | "github.com/gocopper/copper/clifecycle" 8 | "go.uber.org/zap" 9 | ) 10 | 11 | // NewZapLogger creates a Logger that internally uses go.uber.org/zap for logging 12 | func NewZapLogger(config Config, lc *clifecycle.Lifecycle) (Logger, error) { 13 | const OutStdErr = "stderr" 14 | 15 | var ( 16 | outPath = OutStdErr 17 | errOutPath = OutStdErr 18 | ) 19 | 20 | if config.Out != "" { 21 | outPath = config.Out 22 | } 23 | 24 | if config.Err != "" { 25 | errOutPath = config.Err 26 | } 27 | 28 | encoderConfig := zap.NewDevelopmentEncoderConfig() 29 | if config.Format == FormatJSON { 30 | encoderConfig = zap.NewProductionEncoderConfig() 31 | } 32 | 33 | z, err := zap.Config{ 34 | Level: zap.NewAtomicLevelAt(zap.DebugLevel), 35 | Encoding: formatToZapEncoding(config.Format), 36 | EncoderConfig: encoderConfig, 37 | OutputPaths: []string{outPath}, 38 | ErrorOutputPaths: []string{errOutPath}, 39 | }.Build(zap.AddCallerSkip(1)) 40 | if err != nil { 41 | return nil, cerrors.New(err, "failed to create zap logger", nil) 42 | } 43 | 44 | lc.OnStop(func(ctx context.Context) error { 45 | // Skip sync if logs are written to stderr because it will throw an error: 46 | // https://github.com/uber-go/zap/issues/880 47 | if outPath == OutStdErr && errOutPath == OutStdErr { 48 | return nil 49 | } 50 | 51 | return z.Sync() 52 | }) 53 | 54 | return &zapLogger{ 55 | zap: z.Sugar(), 56 | tags: make(map[string]interface{}), 57 | }, nil 58 | } 59 | 60 | type zapLogger struct { 61 | zap *zap.SugaredLogger 62 | tags map[string]interface{} 63 | } 64 | 65 | func (l *zapLogger) WithTags(tags map[string]interface{}) Logger { 66 | return &zapLogger{ 67 | zap: l.zap, 68 | tags: mergeTags(l.tags, tags), 69 | } 70 | } 71 | 72 | func (l *zapLogger) Debug(msg string) { 73 | l.zap.Debugw(msg, tagsToKVs(l.tags)...) 74 | } 75 | 76 | func (l *zapLogger) Info(msg string) { 77 | l.zap.Infow(msg, tagsToKVs(l.tags)...) 78 | } 79 | 80 | func (l *zapLogger) Warn(msg string, err error) { 81 | l.zap.With("error", err).Warnw(msg, tagsToKVs(l.tags)...) 82 | } 83 | 84 | func (l *zapLogger) Error(msg string, err error) { 85 | l.zap.With("error", err).Errorw(msg, tagsToKVs(l.tags)...) 86 | } 87 | -------------------------------------------------------------------------------- /cmetrics/metrics.go: -------------------------------------------------------------------------------- 1 | package cmetrics 2 | 3 | import ( 4 | "github.com/gocopper/copper/cerrors" 5 | "github.com/gocopper/copper/clogger" 6 | "github.com/prometheus/client_golang/prometheus" 7 | ) 8 | 9 | type Metrics interface { 10 | CounterInc(name string, labels map[string]string) 11 | HistogramObserve(name string, labels map[string]string, value float64) 12 | } 13 | 14 | type metrics struct { 15 | counters map[string]*prometheus.CounterVec 16 | histograms map[string]*prometheus.HistogramVec 17 | 18 | logger clogger.Logger 19 | } 20 | 21 | func NewMetrics(registry *Registry, logger clogger.Logger) (Metrics, error) { 22 | countersByName := make(map[string]*prometheus.CounterVec) 23 | for i := range registry.Counters { 24 | countersByName[registry.Counters[i].Name] = prometheus.NewCounterVec(prometheus.CounterOpts{ 25 | Name: registry.Counters[i].Name, 26 | }, registry.Counters[i].Labels) 27 | 28 | err := prometheus.DefaultRegisterer.Register(countersByName[registry.Counters[i].Name]) 29 | if err != nil { 30 | return nil, cerrors.New(err, "failed to register counter metric", map[string]interface{}{ 31 | "name": registry.Counters[i].Name, 32 | }) 33 | } 34 | } 35 | 36 | histogramsByName := make(map[string]*prometheus.HistogramVec) 37 | for i := range registry.Histograms { 38 | histogramsByName[registry.Histograms[i].Name] = prometheus.NewHistogramVec(prometheus.HistogramOpts{ 39 | Name: registry.Histograms[i].Name, 40 | Buckets: registry.Histograms[i].Buckets, 41 | }, registry.Histograms[i].Labels) 42 | 43 | err := prometheus.DefaultRegisterer.Register(histogramsByName[registry.Histograms[i].Name]) 44 | if err != nil { 45 | return nil, cerrors.New(err, "failed to register histogram metric", map[string]interface{}{ 46 | "name": registry.Histograms[i].Name, 47 | }) 48 | 49 | } 50 | } 51 | 52 | return &metrics{ 53 | counters: countersByName, 54 | histograms: histogramsByName, 55 | 56 | logger: logger, 57 | }, nil 58 | } 59 | 60 | func (m *metrics) HistogramObserve(name string, labels map[string]string, value float64) { 61 | histogram, ok := m.histograms[name] 62 | if !ok { 63 | m.logger.WithTags(map[string]interface{}{ 64 | "name": name, 65 | }).Warn("Histogram is not registered. Ignoring..", nil) 66 | return 67 | } 68 | 69 | metric, err := histogram.GetMetricWith(labels) 70 | if err != nil { 71 | m.logger.WithTags(map[string]interface{}{ 72 | "name": name, 73 | }).Warn("Failed to get histogram metric with labels", err) 74 | return 75 | 76 | } 77 | 78 | metric.Observe(value) 79 | } 80 | 81 | func (m *metrics) CounterInc(name string, labels map[string]string) { 82 | counter, ok := m.counters[name] 83 | if !ok { 84 | m.logger.WithTags(map[string]interface{}{ 85 | "name": name, 86 | }).Warn("Counter is not registered. Ignoring..", nil) 87 | return 88 | } 89 | 90 | metric, err := counter.GetMetricWith(labels) 91 | if err != nil { 92 | m.logger.WithTags(map[string]interface{}{ 93 | "name": name, 94 | }).Warn("Failed to get counter metric with labels", err) 95 | return 96 | } 97 | 98 | metric.Inc() 99 | } 100 | -------------------------------------------------------------------------------- /cmetrics/models.go: -------------------------------------------------------------------------------- 1 | package cmetrics 2 | 3 | type ( 4 | Registry struct { 5 | Counters []Counter 6 | Histograms []Histogram 7 | } 8 | 9 | Counter struct { 10 | Name string 11 | Labels []string 12 | } 13 | 14 | Histogram struct { 15 | Name string 16 | Labels []string 17 | Buckets []float64 18 | } 19 | ) 20 | 21 | type NewRegistryParams struct { 22 | Counters []Counter 23 | Histograms []Histogram 24 | } 25 | 26 | var ( 27 | internalCounters = []Counter{ 28 | { 29 | Name: "http_requests_total", 30 | Labels: []string{"status_code", "path"}, 31 | }, 32 | } 33 | 34 | internalHistograms = []Histogram{ 35 | { 36 | Name: "http_request_duration_seconds", 37 | Labels: []string{"status_code", "path"}, 38 | Buckets: []float64{0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0}, 39 | }, 40 | } 41 | ) 42 | 43 | func NewRegistry(p NewRegistryParams) *Registry { 44 | return &Registry{ 45 | Counters: append(p.Counters, internalCounters...), 46 | Histograms: append(p.Histograms, internalHistograms...), 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /cmetrics/noop.go: -------------------------------------------------------------------------------- 1 | package cmetrics 2 | 3 | func NewNoopMetrics() Metrics { 4 | return &noop{} 5 | } 6 | 7 | type noop struct{} 8 | 9 | func (m *noop) CounterInc(name string, labels map[string]string) { 10 | } 11 | 12 | func (m *noop) HistogramObserve(name string, labels map[string]string, value float64) { 13 | } 14 | -------------------------------------------------------------------------------- /cmetrics/wire.go: -------------------------------------------------------------------------------- 1 | package cmetrics 2 | 3 | import "github.com/google/wire" 4 | 5 | var WireModule = wire.NewSet( 6 | NewMetrics, 7 | ) 8 | -------------------------------------------------------------------------------- /csql/config.go: -------------------------------------------------------------------------------- 1 | package csql 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/gocopper/copper/cconfig" 7 | "github.com/gocopper/copper/cerrors" 8 | migrate "github.com/rubenv/sql-migrate" 9 | ) 10 | 11 | // MigrationsSource are valid options for csql.migrations.source configuration option. 12 | // Use "dir" to load migrations from the local filesystem. 13 | // Use "embed" to load migrations from the embedded directory in the binary. 14 | const ( 15 | MigrationsSourceDir = "dir" 16 | MigrationsSourceEmbed = "embed" 17 | ) 18 | 19 | // MigrationsDirection are valid options for csql.migrations.direction configuration option. 20 | // Use "up" when running forward migrations and "down" when rolling back migrations. 21 | const ( 22 | MigrationsDirectionUp = "up" 23 | MigrationsDirectionDown = "down" 24 | ) 25 | 26 | // LoadConfig loads the csql config from the app config 27 | func LoadConfig(appConfig cconfig.Loader) (Config, error) { 28 | config := Config{ 29 | Migrations: ConfigMigrations{ 30 | Direction: MigrationsDirectionUp, 31 | Source: MigrationsSourceEmbed, 32 | }, 33 | } 34 | 35 | err := appConfig.Load("csql", &config) 36 | if err != nil { 37 | return Config{}, cerrors.New(err, "failed to load sql config", nil) 38 | } 39 | 40 | return config, nil 41 | } 42 | 43 | type ( 44 | // Config configures the csql module 45 | Config struct { 46 | Dialect string `toml:"dialect"` 47 | DSN string `toml:"dsn"` 48 | Migrations ConfigMigrations `toml:"migrations"` 49 | MaxOpenConnections *int `toml:"max_open_connections"` 50 | } 51 | 52 | // ConfigMigrations configures the migrations 53 | ConfigMigrations struct { 54 | Direction string `toml:"direction"` 55 | Source string `toml:"source"` 56 | } 57 | ) 58 | 59 | func (cm ConfigMigrations) sqlMigrateDirection() (migrate.MigrationDirection, error) { 60 | switch strings.ToLower(cm.Direction) { 61 | case MigrationsDirectionUp: 62 | return migrate.Up, nil 63 | case MigrationsDirectionDown: 64 | return migrate.Down, nil 65 | default: 66 | return 0, cerrors.New(nil, "invalid migration direction", map[string]interface{}{ 67 | "direction": cm.Direction, 68 | }) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /csql/ctx.go: -------------------------------------------------------------------------------- 1 | package csql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/gocopper/copper/cerrors" 8 | "github.com/jmoiron/sqlx" 9 | ) 10 | 11 | type ctxKey string 12 | 13 | const connCtxKey = ctxKey("csql/*sqlx.Tx") 14 | 15 | // CtxWithTx creates a context with a new database transaction. Any queries run using Querier will be run within 16 | // this transaction. 17 | func CtxWithTx(parentCtx context.Context, db *sql.DB, dialect string) (context.Context, *sql.Tx, error) { 18 | tx, err := sqlx.NewDb(db, dialect).Beginx() 19 | if err != nil { 20 | return nil, nil, cerrors.New(err, "failed to begin db transaction", map[string]interface{}{ 21 | "dialect": dialect, 22 | }) 23 | } 24 | 25 | return context.WithValue(parentCtx, connCtxKey, tx), tx.Tx, nil 26 | } 27 | 28 | // TxFromCtx returns an existing transaction from the context. This method should be called with context created 29 | // using CtxWithTx. 30 | func TxFromCtx(ctx context.Context) (*sql.Tx, error) { 31 | tx, err := txFromCtx(ctx) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | return tx.Tx, nil 37 | } 38 | 39 | func txFromCtx(ctx context.Context) (*sqlx.Tx, error) { 40 | tx, ok := ctx.Value(connCtxKey).(*sqlx.Tx) 41 | if !ok { 42 | return nil, cerrors.New(nil, "no database transaction in the context", nil) 43 | } 44 | 45 | return tx, nil 46 | } 47 | 48 | func mustTxFromCtx(ctx context.Context) *sqlx.Tx { 49 | tx, err := txFromCtx(ctx) 50 | if err != nil { 51 | panic(err) 52 | } 53 | 54 | return tx 55 | } 56 | -------------------------------------------------------------------------------- /csql/ctx_test.go: -------------------------------------------------------------------------------- 1 | package csql_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "testing" 7 | 8 | "github.com/gocopper/copper/csql" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestCtxWithTx(t *testing.T) { 13 | t.Parallel() 14 | 15 | db, err := sql.Open("sqlite3", ":memory:") 16 | assert.NoError(t, err) 17 | 18 | ctx, tx1, err := csql.CtxWithTx(context.Background(), db, "sqlite3") 19 | assert.NoError(t, err) 20 | 21 | tx2, err := csql.TxFromCtx(ctx) 22 | assert.NoError(t, err) 23 | 24 | assert.Equal(t, tx1, tx2) 25 | } 26 | -------------------------------------------------------------------------------- /csql/db.go: -------------------------------------------------------------------------------- 1 | package csql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | 7 | "github.com/gocopper/copper/cerrors" 8 | "github.com/gocopper/copper/clifecycle" 9 | "github.com/gocopper/copper/clogger" 10 | ) 11 | 12 | // NewDBConnection creates and returns a new database connection. The connection is closed when the app exits. 13 | func NewDBConnection(lc *clifecycle.Lifecycle, config Config, logger clogger.Logger) (*sql.DB, error) { 14 | logger.WithTags(map[string]interface{}{ 15 | "dialect": config.Dialect, 16 | }).Info("Opening a database connection..") 17 | 18 | db, err := sql.Open(config.Dialect, config.DSN) 19 | if err != nil { 20 | return nil, cerrors.New(err, "failed to open db connection", map[string]interface{}{ 21 | "dialect": config.Dialect, 22 | }) 23 | } 24 | 25 | err = db.Ping() 26 | if err != nil { 27 | return nil, cerrors.New(err, "failed to ping db", nil) 28 | } 29 | 30 | if config.MaxOpenConnections != nil { 31 | db.SetMaxOpenConns(*config.MaxOpenConnections) 32 | } 33 | 34 | lc.OnStop(func(ctx context.Context) error { 35 | logger.Info("Closing database connection..") 36 | 37 | err := db.Close() 38 | if err != nil { 39 | return cerrors.New(err, "failed to close db connection", nil) 40 | } 41 | 42 | return nil 43 | }) 44 | 45 | return db, nil 46 | } 47 | -------------------------------------------------------------------------------- /csql/db_test.go: -------------------------------------------------------------------------------- 1 | package csql_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/gocopper/copper/clifecycle" 7 | "github.com/gocopper/copper/clogger" 8 | "github.com/gocopper/copper/csql" 9 | _ "github.com/mattn/go-sqlite3" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestNewDBConnection(t *testing.T) { 14 | t.Parallel() 15 | 16 | var ( 17 | logger = clogger.New() 18 | lc = clifecycle.New() 19 | ) 20 | 21 | db, err := csql.NewDBConnection(lc, csql.Config{ 22 | Dialect: "sqlite3", 23 | DSN: ":memory:", 24 | }, logger) 25 | assert.NoError(t, err) 26 | 27 | assert.NoError(t, db.Ping()) 28 | 29 | lc.Stop(logger) 30 | 31 | assert.Error(t, db.Ping()) 32 | } 33 | -------------------------------------------------------------------------------- /csql/doc.go: -------------------------------------------------------------------------------- 1 | // Package csql helps create and manage database connections 2 | package csql 3 | -------------------------------------------------------------------------------- /csql/migrations_test.sql: -------------------------------------------------------------------------------- 1 | -- +migrate Up 2 | create table people (name text); 3 | insert into people (name) values ('test'); 4 | 5 | -- +migrate Down 6 | drop table people; 7 | -------------------------------------------------------------------------------- /csql/migrator.go: -------------------------------------------------------------------------------- 1 | package csql 2 | 3 | import ( 4 | "crypto/sha256" 5 | "database/sql" 6 | "embed" 7 | "fmt" 8 | "io" 9 | 10 | "github.com/gocopper/copper/cerrors" 11 | "github.com/gocopper/copper/clogger" 12 | migrate "github.com/rubenv/sql-migrate" 13 | ) 14 | 15 | // Migrations is a collection of .sql files that represent the database schema 16 | type Migrations embed.FS 17 | 18 | // NewMigratorParams holds the params needed for NewMigrator 19 | type NewMigratorParams struct { 20 | DB *sql.DB 21 | Migrations Migrations 22 | Config Config 23 | Logger clogger.Logger 24 | } 25 | 26 | // NewMigrator creates a new Migrator 27 | func NewMigrator(p NewMigratorParams) *Migrator { 28 | return &Migrator{ 29 | db: p.DB, 30 | migrations: embed.FS(p.Migrations), 31 | config: p.Config, 32 | logger: p.Logger, 33 | } 34 | } 35 | 36 | // Migrator can run database migrations by running the provided migrations in the migrations dir 37 | type Migrator struct { 38 | db *sql.DB 39 | migrations embed.FS 40 | config Config 41 | logger clogger.Logger 42 | } 43 | 44 | // Run runs the provided database migrations 45 | func (m *Migrator) Run() error { 46 | m.logger.WithTags(map[string]interface{}{ 47 | "direction": m.config.Migrations.Direction, 48 | "source": m.config.Migrations.Source, 49 | }).Info("Running database migrations..") 50 | 51 | direction, err := m.config.Migrations.sqlMigrateDirection() 52 | if err != nil { 53 | return cerrors.New(err, "failed to get sql migrate direction from config", nil) 54 | } 55 | 56 | hasMigrations, err := m.hasMigrations() 57 | if err != nil { 58 | return cerrors.New(err, "failed to check for migrations", nil) 59 | } 60 | 61 | if !hasMigrations { 62 | m.logger.Info("No migrations found") 63 | return nil 64 | } 65 | 66 | source := migrate.MigrationSource(migrate.EmbedFileSystemMigrationSource{ 67 | FileSystem: m.migrations, 68 | Root: ".", 69 | }) 70 | if m.config.Migrations.Source == MigrationsSourceDir { 71 | source = migrate.FileMigrationSource{ 72 | Dir: "./migrations", 73 | } 74 | } 75 | 76 | migrateMax := 0 // no limit 77 | if direction == migrate.Down { 78 | migrateMax = 1 // only run 1 migration when reverting 79 | } 80 | 81 | dialect := m.config.Dialect 82 | if dialect == "pgx" { 83 | dialect = "postgres" 84 | } 85 | 86 | n, err := migrate.ExecMax(m.db, dialect, source, direction, migrateMax) 87 | if err != nil { 88 | return cerrors.New(err, "failed to exec database migrations", nil) 89 | } 90 | 91 | m.logger.WithTags(map[string]interface{}{ 92 | "count": n, 93 | }).Info("Successfully applied migrations") 94 | 95 | return nil 96 | } 97 | 98 | // hasMigrations returns true if the migrations directory has at least 1 non-empty migration file. 99 | func (m *Migrator) hasMigrations() (bool, error) { 100 | const emptyMigrationsChecksum = "fba9ab24993a94e181dc952f2568a4e98b47e331d89772af3115fe1c7b90d27f" 101 | 102 | entries, err := m.migrations.ReadDir(".") 103 | if err != nil { 104 | return false, cerrors.New(err, "failed to read migrations dir", nil) 105 | } 106 | 107 | if len(entries) == 0 { 108 | return false, nil 109 | } 110 | 111 | if len(entries) > 1 { 112 | return true, nil 113 | } 114 | 115 | f, err := m.migrations.Open(entries[0].Name()) 116 | if err != nil { 117 | return false, cerrors.New(err, "failed to open migrations file", map[string]interface{}{ 118 | "name": entries[0].Name(), 119 | }) 120 | } 121 | defer func() { _ = f.Close() }() 122 | 123 | h := sha256.New() 124 | if _, err = io.Copy(h, f); err != nil { 125 | return false, cerrors.New(err, "failed to calculate sha256 for migration file", nil) 126 | } 127 | 128 | checksum := fmt.Sprintf("%x", h.Sum(nil)) 129 | if checksum == emptyMigrationsChecksum { 130 | return false, nil 131 | } 132 | 133 | return true, nil 134 | } 135 | -------------------------------------------------------------------------------- /csql/migrator_test.go: -------------------------------------------------------------------------------- 1 | package csql_test 2 | 3 | import ( 4 | "database/sql" 5 | "embed" 6 | "testing" 7 | 8 | "github.com/gocopper/copper/clogger" 9 | "github.com/gocopper/copper/csql" 10 | _ "github.com/mattn/go-sqlite3" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | //go:embed migrations_test.sql 15 | var Migrations embed.FS 16 | 17 | func TestMigrator_Run(t *testing.T) { 18 | t.Parallel() 19 | 20 | db, err := sql.Open("sqlite3", ":memory:") 21 | assert.NoError(t, err) 22 | 23 | // migrate up 24 | 25 | migratorUp := csql.NewMigrator(csql.NewMigratorParams{ 26 | DB: db, 27 | Migrations: csql.Migrations(Migrations), 28 | Config: csql.Config{ 29 | Dialect: "sqlite3", 30 | Migrations: csql.ConfigMigrations{ 31 | Direction: csql.MigrationsDirectionUp, 32 | Source: csql.MigrationsSourceEmbed, 33 | }, 34 | }, 35 | Logger: clogger.NewNoop(), 36 | }) 37 | 38 | err = migratorUp.Run() 39 | assert.NoError(t, err) 40 | 41 | res, err := db.Query("select * from people") 42 | assert.NoError(t, err) 43 | assert.NoError(t, res.Err()) 44 | 45 | assert.True(t, res.Next()) 46 | 47 | // migrate down 48 | 49 | migratorDown := csql.NewMigrator(csql.NewMigratorParams{ 50 | DB: db, 51 | Migrations: csql.Migrations(Migrations), 52 | Config: csql.Config{ 53 | Dialect: "sqlite3", 54 | Migrations: csql.ConfigMigrations{ 55 | Direction: csql.MigrationsDirectionDown, 56 | Source: csql.MigrationsSourceEmbed, 57 | }, 58 | }, 59 | Logger: clogger.NewNoop(), 60 | }) 61 | 62 | err = migratorDown.Run() 63 | assert.NoError(t, err) 64 | 65 | _, err = db.Query("select * from people") //nolint:rowserrcheck 66 | assert.EqualError(t, err, "no such table: people") 67 | } 68 | -------------------------------------------------------------------------------- /csql/qb/query_builder.go: -------------------------------------------------------------------------------- 1 | package qb 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | ) 8 | 9 | type fieldExtractor func(reflect.StructField, reflect.Value) (string, any, bool) 10 | 11 | func ValuePlaceholders(model any) string { 12 | columns := extractRawColumns(model) 13 | placeholders := make([]string, len(columns)) 14 | for i := range columns { 15 | placeholders[i] = "?" 16 | } 17 | return strings.Join(placeholders, ", ") 18 | } 19 | 20 | func Columns(model any, alias ...string) string { 21 | a := "" 22 | if len(alias) > 0 { 23 | a = alias[0] 24 | } 25 | 26 | columns := extractRawColumns(model) 27 | if a == "" { 28 | return strings.Join(columns, ", ") 29 | } 30 | 31 | prefixedColumns := make([]string, len(columns)) 32 | for i, col := range columns { 33 | prefixedColumns[i] = fmt.Sprintf("%s.%s", a, col) 34 | } 35 | return strings.Join(prefixedColumns, ", ") 36 | } 37 | 38 | func SetColumns(model any) string { 39 | return strings.Join(extractSetColumns(model), ", ") 40 | } 41 | 42 | func Values(model any) []any { 43 | return extractValues(model, columnExtractor) 44 | } 45 | 46 | func SetValues(model any) []any { 47 | return extractValues(model, setColumnExtractor) 48 | } 49 | 50 | func ValuesAndSetValues(model any) []any { 51 | values := extractValues(model, columnExtractor) 52 | setValues := extractValues(model, setColumnExtractor) 53 | return append(values, setValues...) 54 | } 55 | 56 | func extractRawColumns(model any) []string { 57 | var columns []string 58 | processModelFields(model, func(field reflect.StructField, value reflect.Value) bool { 59 | tag := field.Tag.Get("db") 60 | if tag == "" || tag == "-" { 61 | return false 62 | } 63 | 64 | parts := strings.Split(tag, ",") 65 | columns = append(columns, parts[0]) 66 | return true 67 | }) 68 | return columns 69 | } 70 | 71 | func extractSetColumns(model any) []string { 72 | var columns []string 73 | processModelFields(model, func(field reflect.StructField, value reflect.Value) bool { 74 | tag := field.Tag.Get("db") 75 | if tag == "" || tag == "-" { 76 | return false 77 | } 78 | 79 | parts := strings.Split(tag, ",") 80 | if len(parts) >= 2 && parts[1] == "readonly" { 81 | return false 82 | } 83 | 84 | columns = append(columns, fmt.Sprintf("%s = ?", parts[0])) 85 | return true 86 | }) 87 | return columns 88 | } 89 | 90 | func extractValues(model any, extractor fieldExtractor) []any { 91 | var values []any 92 | processModelFields(model, func(field reflect.StructField, value reflect.Value) bool { 93 | _, val, ok := extractor(field, value) 94 | if ok { 95 | values = append(values, val) 96 | } 97 | return ok 98 | }) 99 | return values 100 | } 101 | 102 | func columnExtractor(field reflect.StructField, value reflect.Value) (string, any, bool) { 103 | tag := field.Tag.Get("db") 104 | if tag == "" || tag == "-" { 105 | return "", nil, false 106 | } 107 | 108 | parts := strings.Split(tag, ",") 109 | return parts[0], value.Interface(), true 110 | } 111 | 112 | func setColumnExtractor(field reflect.StructField, value reflect.Value) (string, any, bool) { 113 | tag := field.Tag.Get("db") 114 | if tag == "" || tag == "-" { 115 | return "", nil, false 116 | } 117 | 118 | parts := strings.Split(tag, ",") 119 | if len(parts) >= 2 && parts[1] == "readonly" { 120 | return "", nil, false 121 | } 122 | 123 | return fmt.Sprintf("%s = ?", parts[0]), value.Interface(), true 124 | } 125 | 126 | func processModelFields(model any, processor func(reflect.StructField, reflect.Value) bool) { 127 | t := reflect.TypeOf(model) 128 | v := reflect.ValueOf(model) 129 | 130 | if t.Kind() == reflect.Ptr { 131 | t = t.Elem() 132 | v = v.Elem() 133 | } 134 | 135 | for i := 0; i < t.NumField(); i++ { 136 | field := t.Field(i) 137 | 138 | if !field.IsExported() { 139 | continue 140 | } 141 | 142 | if field.Anonymous && field.Type.Kind() == reflect.Struct { 143 | processModelFields(v.Field(i).Interface(), processor) 144 | continue 145 | } 146 | 147 | processor(field, v.Field(i)) 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /csql/qb/query_builder_test.go: -------------------------------------------------------------------------------- 1 | package qb_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/gocopper/copper/csql/qb" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type TestModel struct { 11 | ID int `db:"id,readonly"` 12 | Name string `db:"name"` 13 | Email string `db:"email"` 14 | Ignored string `db:"-"` 15 | } 16 | 17 | func TestPlaceholders(t *testing.T) { 18 | model := TestModel{ID: 1, Name: "Test", Email: "test@example.com"} 19 | result := qb.ValuePlaceholders(model) 20 | assert.Equal(t, "?, ?, ?", result) 21 | } 22 | 23 | func TestColumns(t *testing.T) { 24 | model := TestModel{} 25 | 26 | // Without alias 27 | result := qb.Columns(model) 28 | assert.Equal(t, "id, name, email", result) 29 | 30 | // With alias 31 | resultWithAlias := qb.Columns(model, "t") 32 | assert.Equal(t, "t.id, t.name, t.email", resultWithAlias) 33 | } 34 | 35 | func TestSetColumns(t *testing.T) { 36 | model := TestModel{} 37 | result := qb.SetColumns(model) 38 | assert.Equal(t, "name = ?, email = ?", result) 39 | } 40 | 41 | func TestValues(t *testing.T) { 42 | model := TestModel{ID: 1, Name: "Test", Email: "test@example.com"} 43 | result := qb.Values(model) 44 | 45 | assert.Len(t, result, 3) 46 | assert.Equal(t, 1, result[0]) 47 | assert.Equal(t, "Test", result[1]) 48 | assert.Equal(t, "test@example.com", result[2]) 49 | } 50 | 51 | func TestSetValues(t *testing.T) { 52 | model := TestModel{ID: 1, Name: "Test", Email: "test@example.com"} 53 | result := qb.SetValues(model) 54 | 55 | assert.Len(t, result, 2) 56 | assert.Equal(t, "Test", result[0]) 57 | assert.Equal(t, "test@example.com", result[1]) 58 | } 59 | 60 | func TestWithEmbeddedStruct(t *testing.T) { 61 | type BaseModel struct { 62 | ID int `db:"id,readonly"` 63 | Version int `db:"version"` 64 | } 65 | 66 | type UserModel struct { 67 | BaseModel 68 | Name string `db:"name"` 69 | Skip string `db:"-"` 70 | } 71 | 72 | model := UserModel{ 73 | BaseModel: BaseModel{ID: 1, Version: 2}, 74 | Name: "Test", 75 | Skip: "Ignored", 76 | } 77 | 78 | columns := qb.Columns(model) 79 | assert.Equal(t, "id, version, name", columns) 80 | 81 | setColumns := qb.SetColumns(model) 82 | assert.Equal(t, "version = ?, name = ?", setColumns) 83 | } 84 | -------------------------------------------------------------------------------- /csql/querier.go: -------------------------------------------------------------------------------- 1 | package csql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "strings" 8 | "time" 9 | 10 | "github.com/gocopper/copper/cerrors" 11 | "github.com/gocopper/copper/clogger" 12 | "github.com/jmoiron/sqlx" 13 | ) 14 | 15 | // Querier provides a set of helpful methods to run database queries. It can be used to run parameterized queries 16 | // and scan results into Go structs or slices. 17 | type Querier interface { 18 | CtxWithTx(ctx context.Context) (context.Context, *sql.Tx, error) 19 | InTx(ctx context.Context, fn func(context.Context) error) error 20 | WithIn() Querier 21 | Get(ctx context.Context, dest interface{}, query string, args ...interface{}) error 22 | Select(ctx context.Context, dest interface{}, query string, args ...interface{}) error 23 | Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 24 | OnCommit(ctx context.Context, cb func(context.Context) error) error 25 | CommitTx(tx *sql.Tx) error 26 | } 27 | 28 | // NewQuerier returns a querier using the given database connection and the dialect 29 | func NewQuerier(db *sql.DB, config Config, logger clogger.Logger) Querier { 30 | return &querier{ 31 | db: sqlx.NewDb(db, config.Dialect), 32 | dialect: config.Dialect, 33 | in: false, 34 | logger: logger, 35 | callbacksByTx: make(map[*sql.Tx][]func(context.Context) error), 36 | } 37 | } 38 | 39 | type querier struct { 40 | db *sqlx.DB 41 | dialect string 42 | in bool 43 | logger clogger.Logger 44 | 45 | callbacksByTx map[*sql.Tx][]func(context.Context) error 46 | } 47 | 48 | func (q *querier) OnCommit(ctx context.Context, cb func(context.Context) error) error { 49 | tx, err := TxFromCtx(ctx) 50 | if err != nil { 51 | return cerrors.New(err, "failed to get database transaction from context", nil) 52 | } 53 | 54 | if _, ok := q.callbacksByTx[tx]; !ok { 55 | q.callbacksByTx[tx] = make([]func(context.Context) error, 0) 56 | } 57 | 58 | q.callbacksByTx[tx] = append(q.callbacksByTx[tx], cb) 59 | 60 | return nil 61 | } 62 | 63 | func (q *querier) CommitTx(tx *sql.Tx) error { 64 | err := tx.Commit() 65 | if err != nil && !errors.Is(err, sql.ErrTxDone) && !strings.Contains(err.Error(), "commit unexpectedly resulted in rollback") { 66 | return err 67 | } 68 | 69 | if err != nil && strings.Contains(err.Error(), "commit unexpectedly resulted in rollback") { 70 | q.logger.Warn(err.Error(), nil) 71 | } 72 | 73 | if callbacks, ok := q.callbacksByTx[tx]; ok { 74 | for i := range callbacks { 75 | go func(cb func(context.Context) error) { 76 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 77 | 78 | err := cb(ctx) 79 | if err != nil { 80 | q.logger.Error("Failed to run callback", err) 81 | } 82 | 83 | cancel() 84 | }(callbacks[i]) 85 | } 86 | } 87 | 88 | delete(q.callbacksByTx, tx) 89 | 90 | return nil 91 | } 92 | 93 | func (q *querier) CtxWithTx(ctx context.Context) (context.Context, *sql.Tx, error) { 94 | return CtxWithTx(ctx, q.db.DB, q.dialect) 95 | } 96 | 97 | func (q *querier) InTx(ctx context.Context, fn func(context.Context) error) error { 98 | ctx, tx, err := CtxWithTx(ctx, q.db.DB, q.dialect) 99 | if err != nil { 100 | return cerrors.New(err, "failed to create context with database transaction", nil) 101 | } 102 | 103 | defer func() { 104 | // Try a rollback in a deferred function to account for panics 105 | err := tx.Rollback() 106 | if err != nil && !errors.Is(err, sql.ErrTxDone) { 107 | q.logger.Error("Failed to rollback database transaction", err) 108 | return 109 | } 110 | 111 | if err == nil { 112 | q.logger.Warn("Rolled back an unexpectedly open database transaction", nil) 113 | } 114 | }() 115 | 116 | err = fn(ctx) 117 | if err != nil { 118 | rollbackErr := tx.Rollback() 119 | if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) { 120 | q.logger.Error("Failed to rollback database transaction", err) 121 | } 122 | return err 123 | } 124 | 125 | err = q.CommitTx(tx) 126 | if err != nil { 127 | return cerrors.New(err, "failed to commit database transaction", nil) 128 | } 129 | 130 | return nil 131 | } 132 | 133 | func (q *querier) WithIn() Querier { 134 | return &querier{ 135 | db: q.db, 136 | dialect: q.dialect, 137 | in: true, 138 | logger: q.logger, 139 | } 140 | } 141 | 142 | func (q *querier) Get(ctx context.Context, dest interface{}, query string, args ...interface{}) error { 143 | query, args, err := q.mkQueryWithArgs(ctx, query, args) 144 | if err != nil { 145 | return err 146 | } 147 | 148 | return mustTxFromCtx(ctx).GetContext(ctx, dest, query, args...) 149 | } 150 | 151 | func (q *querier) Select(ctx context.Context, dest interface{}, query string, args ...interface{}) error { 152 | query, args, err := q.mkQueryWithArgs(ctx, query, args) 153 | if err != nil { 154 | return err 155 | } 156 | 157 | return mustTxFromCtx(ctx).SelectContext(ctx, dest, query, args...) 158 | } 159 | 160 | func (q *querier) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 161 | query, args, err := q.mkQueryWithArgs(ctx, query, args) 162 | if err != nil { 163 | return nil, err 164 | } 165 | 166 | return mustTxFromCtx(ctx).ExecContext(ctx, query, args...) 167 | } 168 | 169 | func (q *querier) mkQueryWithArgs(ctx context.Context, query string, args []interface{}) (string, []interface{}, error) { 170 | var err error 171 | 172 | if q.in { 173 | query, args, err = sqlx.In(query, args...) 174 | if err != nil { 175 | return "", nil, cerrors.New(err, "failed to create IN query", nil) 176 | } 177 | } 178 | 179 | tx, err := txFromCtx(ctx) 180 | if err != nil { 181 | return "", nil, err 182 | } 183 | 184 | return tx.Rebind(query), args, nil 185 | } 186 | -------------------------------------------------------------------------------- /csql/querier_test.go: -------------------------------------------------------------------------------- 1 | package csql_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "testing" 7 | 8 | "github.com/gocopper/copper/clogger" 9 | "github.com/gocopper/copper/csql" 10 | _ "github.com/mattn/go-sqlite3" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestQuerier_Read(t *testing.T) { 15 | t.Parallel() 16 | 17 | db, err := sql.Open("sqlite3", ":memory:") 18 | assert.NoError(t, err) 19 | 20 | _, err = db.Exec("create table people (name text);insert into people (name) values ('test');") 21 | assert.NoError(t, err) 22 | 23 | querier := csql.NewQuerier(db, csql.Config{Dialect: "sqlite3"}, clogger.NewNoop()) 24 | 25 | ctx, _, err := csql.CtxWithTx(context.Background(), db, "sqlite3") 26 | assert.NoError(t, err) 27 | 28 | t.Run("get", func(t *testing.T) { 29 | t.Parallel() 30 | 31 | var dest struct { 32 | Name string 33 | } 34 | 35 | err = querier.Get(ctx, &dest, "select * from people") 36 | assert.NoError(t, err) 37 | 38 | assert.Equal(t, "test", dest.Name) 39 | }) 40 | 41 | t.Run("select", func(t *testing.T) { 42 | t.Parallel() 43 | 44 | var dest []struct { 45 | Name string 46 | } 47 | 48 | err = querier.Select(ctx, &dest, "select * from people") 49 | assert.NoError(t, err) 50 | 51 | assert.Equal(t, 1, len(dest)) 52 | assert.Equal(t, "test", dest[0].Name) 53 | }) 54 | 55 | t.Run("select in", func(t *testing.T) { 56 | t.Parallel() 57 | 58 | var dest []struct { 59 | Name string 60 | } 61 | 62 | err = querier.WithIn().Select(ctx, &dest, "select * from people where name in (?)", []string{"test"}) 63 | assert.NoError(t, err) 64 | 65 | assert.Equal(t, 1, len(dest)) 66 | assert.Equal(t, "test", dest[0].Name) 67 | }) 68 | } 69 | 70 | func TestQuerier_Exec(t *testing.T) { 71 | t.Parallel() 72 | 73 | db, err := sql.Open("sqlite3", ":memory:") 74 | assert.NoError(t, err) 75 | 76 | _, err = db.Exec("create table people (name text);insert into people (name) values ('test');") 77 | assert.NoError(t, err) 78 | 79 | querier := csql.NewQuerier(db, csql.Config{Dialect: "sqlite3"}, clogger.NewNoop()) 80 | 81 | ctx, _, err := csql.CtxWithTx(context.Background(), db, "sqlite3") 82 | assert.NoError(t, err) 83 | 84 | res, err := querier.Exec(ctx, "delete from people") 85 | assert.NoError(t, err) 86 | 87 | n, err := res.RowsAffected() 88 | assert.NoError(t, err) 89 | 90 | assert.Equal(t, int64(1), n) 91 | } 92 | -------------------------------------------------------------------------------- /csql/tx_middleware.go: -------------------------------------------------------------------------------- 1 | package csql 2 | 3 | import ( 4 | "bufio" 5 | "database/sql" 6 | "errors" 7 | "github.com/gocopper/copper/cerrors" 8 | "github.com/gocopper/copper/clogger" 9 | "net" 10 | "net/http" 11 | ) 12 | 13 | // NewTxMiddleware creates a new TxMiddleware 14 | func NewTxMiddleware(db *sql.DB, querier Querier, config Config, logger clogger.Logger) *TxMiddleware { 15 | return &TxMiddleware{ 16 | db: db, 17 | querier: querier, 18 | config: config, 19 | logger: logger, 20 | } 21 | } 22 | 23 | // TxMiddleware is a chttp.Middleware that wraps an HTTP request in a database transaction. If the request succeeds 24 | // (i.e. 2xx or 3xx response code), the transaction is committed. Else, the transaction is rolled back. 25 | type TxMiddleware struct { 26 | db *sql.DB 27 | querier Querier 28 | config Config 29 | logger clogger.Logger 30 | } 31 | 32 | // Handle implements the chttp.Middleware interface. See TxMiddleware 33 | func (m *TxMiddleware) Handle(next http.Handler) http.Handler { 34 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 35 | ctx, tx, err := CtxWithTx(r.Context(), m.db, m.config.Dialect) 36 | if err != nil { 37 | m.logger.Error("Failed to create context with database transaction", err) 38 | w.WriteHeader(http.StatusInternalServerError) 39 | return 40 | } 41 | 42 | defer func() { 43 | // Try a rollback in a deferred function to account for panics 44 | err := tx.Rollback() 45 | if err != nil && !errors.Is(err, sql.ErrTxDone) { 46 | m.logger.Error("Failed to rollback database transaction", err) 47 | w.WriteHeader(http.StatusInternalServerError) 48 | return 49 | } 50 | 51 | if err == nil { 52 | m.logger.Warn("Rolled back an unexpectedly open database transaction", nil) 53 | } 54 | }() 55 | 56 | next.ServeHTTP(&txnrw{ 57 | internal: w, 58 | tx: tx, 59 | querier: m.querier, 60 | logger: m.logger, 61 | }, r.WithContext(ctx)) 62 | 63 | err = m.querier.CommitTx(tx) 64 | if err != nil { 65 | m.logger.Error("Failed to commit database transaction", err) 66 | return 67 | } 68 | }) 69 | } 70 | 71 | type txnrw struct { 72 | internal http.ResponseWriter 73 | tx *sql.Tx 74 | querier Querier 75 | logger clogger.Logger 76 | } 77 | 78 | func (w *txnrw) Header() http.Header { 79 | return w.internal.Header() 80 | } 81 | 82 | func (w *txnrw) Write(b []byte) (int, error) { 83 | err := w.querier.CommitTx(w.tx) 84 | if err != nil { 85 | return 0, cerrors.New(err, "failed to commit database transaction", nil) 86 | } 87 | 88 | return w.internal.Write(b) 89 | } 90 | 91 | func (w *txnrw) WriteHeader(statusCode int) { 92 | const MinErrStatusCode = 400 93 | 94 | if statusCode >= MinErrStatusCode { 95 | err := w.tx.Rollback() 96 | if err != nil && !errors.Is(err, sql.ErrTxDone) { 97 | w.logger.WithTags(map[string]interface{}{ 98 | "originalStatusCode": statusCode, 99 | }).Error("Failed to rollback database transaction", err) 100 | w.internal.WriteHeader(http.StatusInternalServerError) 101 | return 102 | } 103 | 104 | w.internal.WriteHeader(statusCode) 105 | return 106 | } 107 | 108 | err := w.querier.CommitTx(w.tx) 109 | if err != nil { 110 | w.internal.WriteHeader(http.StatusInternalServerError) 111 | return 112 | } 113 | 114 | w.internal.WriteHeader(statusCode) 115 | } 116 | 117 | func (w *txnrw) Hijack() (net.Conn, *bufio.ReadWriter, error) { 118 | h, ok := w.internal.(http.Hijacker) 119 | if !ok { 120 | return nil, nil, errors.New("internal response writer is not http.Hijacker") 121 | } 122 | 123 | return h.Hijack() 124 | } 125 | -------------------------------------------------------------------------------- /csql/tx_middleware_test.go: -------------------------------------------------------------------------------- 1 | package csql_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/gocopper/copper/clogger" 11 | "github.com/gocopper/copper/csql" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestTxMiddleware_Handle_Commit(t *testing.T) { 16 | t.Parallel() 17 | 18 | db, err := sql.Open("sqlite3", ":memory:") 19 | assert.NoError(t, err) 20 | 21 | _, err = db.Exec("create table people (name text)") 22 | assert.NoError(t, err) 23 | 24 | var ( 25 | logger = clogger.NewNoop() 26 | config = csql.Config{Dialect: "sqlite3"} 27 | querier = csql.NewQuerier(db, config, logger) 28 | mw = csql.NewTxMiddleware(db, config, logger) 29 | ) 30 | 31 | req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil) 32 | assert.NoError(t, err) 33 | 34 | mw.Handle(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 35 | _, err := querier.Exec(r.Context(), "insert into people (name) values ('test')") 36 | assert.NoError(t, err) 37 | })).ServeHTTP(httptest.NewRecorder(), req) 38 | 39 | rows, err := db.Query("select * from people") 40 | assert.NoError(t, err) 41 | assert.NoError(t, rows.Err()) 42 | 43 | assert.True(t, rows.Next()) 44 | } 45 | 46 | func TestTxMiddleware_Handle_Rollback(t *testing.T) { 47 | t.Parallel() 48 | 49 | db, err := sql.Open("sqlite3", ":memory:") 50 | assert.NoError(t, err) 51 | 52 | _, err = db.Exec("create table people (name text)") 53 | assert.NoError(t, err) 54 | 55 | var ( 56 | logger = clogger.NewNoop() 57 | config = csql.Config{Dialect: "sqlite3"} 58 | querier = csql.NewQuerier(db, config, logger) 59 | mw = csql.NewTxMiddleware(db, config, logger) 60 | ) 61 | 62 | req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil) 63 | assert.NoError(t, err) 64 | 65 | mw.Handle(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 66 | _, err := querier.Exec(r.Context(), "insert into people (name) values ('test')") 67 | assert.NoError(t, err) 68 | 69 | w.WriteHeader(http.StatusInternalServerError) 70 | })).ServeHTTP(httptest.NewRecorder(), req) 71 | 72 | rows, err := db.Query("select * from people") 73 | assert.NoError(t, err) 74 | assert.NoError(t, rows.Err()) 75 | 76 | assert.False(t, rows.Next()) 77 | } 78 | -------------------------------------------------------------------------------- /csql/wire.go: -------------------------------------------------------------------------------- 1 | package csql 2 | 3 | import "github.com/google/wire" 4 | 5 | // WireModule can be used as part of google/wire setup. 6 | var WireModule = wire.NewSet( 7 | NewDBConnection, 8 | NewQuerier, 9 | NewMigrator, 10 | LoadConfig, 11 | NewTxMiddleware, 12 | 13 | wire.Struct(new(NewMigratorParams), "*"), 14 | ) 15 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package copper encapsulates everything you need to build apps quickly 2 | package copper 3 | -------------------------------------------------------------------------------- /flags.go: -------------------------------------------------------------------------------- 1 | package copper 2 | 3 | import ( 4 | "flag" 5 | 6 | "github.com/gocopper/copper/cconfig" 7 | ) 8 | 9 | // Flags holds flag values passed in via command line. These can be used to configure the app environment 10 | // and override the config directory. 11 | type Flags struct { 12 | ConfigPath cconfig.Path 13 | ConfigOverrides cconfig.Overrides 14 | } 15 | 16 | // NewFlags reads the command line flags and returns Flags with the values set. 17 | func NewFlags() *Flags { 18 | var ( 19 | configPath = flag.String("config", "./config/dev.toml", "Path to config file") 20 | configOverrides = flag.String("set", "", "Config overrides ex. \"chttp.port=5902\". Separate multiple overrides with ;") 21 | ) 22 | 23 | flag.Parse() 24 | 25 | return &Flags{ 26 | ConfigPath: cconfig.Path(*configPath), 27 | ConfigOverrides: cconfig.Overrides(*configOverrides), 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/gocopper/copper 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/Masterminds/sprig/v3 v3.2.3 7 | github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf 8 | github.com/google/uuid v1.6.0 9 | github.com/google/wire v0.6.0 10 | github.com/gorilla/mux v1.8.1 11 | github.com/iancoleman/strcase v0.3.0 12 | github.com/jmoiron/sqlx v1.3.5 13 | github.com/mattn/go-sqlite3 v1.14.18 14 | github.com/pelletier/go-toml v1.9.3 15 | github.com/prometheus/client_golang v1.20.3 16 | github.com/rubenv/sql-migrate v1.1.2 17 | github.com/shopspring/decimal v1.2.0 18 | github.com/stretchr/testify v1.9.0 19 | go.uber.org/zap v1.27.0 20 | ) 21 | 22 | require ( 23 | github.com/Masterminds/goutils v1.1.1 // indirect 24 | github.com/Masterminds/semver/v3 v3.2.0 // indirect 25 | github.com/beorn7/perks v1.0.1 // indirect 26 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 27 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 28 | github.com/go-gorp/gorp/v3 v3.0.2 // indirect 29 | github.com/go-sql-driver/mysql v1.7.1 // indirect 30 | github.com/huandu/xstrings v1.3.3 // indirect 31 | github.com/imdario/mergo v0.3.12 // indirect 32 | github.com/lib/pq v1.10.2 // indirect 33 | github.com/mitchellh/copystructure v1.0.0 // indirect 34 | github.com/mitchellh/reflectwalk v1.0.0 // indirect 35 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 36 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 37 | github.com/prometheus/client_model v0.6.1 // indirect 38 | github.com/prometheus/common v0.59.1 // indirect 39 | github.com/prometheus/procfs v0.15.1 // indirect 40 | github.com/rogpeppe/go-internal v1.12.0 // indirect 41 | github.com/sirupsen/logrus v1.9.3 // indirect 42 | github.com/spf13/cast v1.7.0 // indirect 43 | go.uber.org/multierr v1.11.0 // indirect 44 | golang.org/x/crypto v0.31.0 // indirect 45 | golang.org/x/sys v0.28.0 // indirect 46 | google.golang.org/protobuf v1.34.2 // indirect 47 | gopkg.in/yaml.v3 v3.0.1 // indirect 48 | ) 49 | -------------------------------------------------------------------------------- /wire.go: -------------------------------------------------------------------------------- 1 | //go:build wireinject 2 | // +build wireinject 3 | 4 | package copper 5 | 6 | import ( 7 | "github.com/gocopper/copper/cconfig" 8 | "github.com/gocopper/copper/clifecycle" 9 | "github.com/gocopper/copper/clogger" 10 | "github.com/google/wire" 11 | ) 12 | 13 | // InitApp creates a new Copper app along with its dependencies. 14 | func InitApp() (*App, error) { 15 | panic( 16 | wire.Build( 17 | NewApp, 18 | NewFlags, 19 | clifecycle.New, 20 | cconfig.NewWithKeyOverrides, 21 | clogger.NewWithConfig, 22 | clogger.LoadConfig, 23 | 24 | wire.FieldsOf(new(*Flags), "ConfigPath", "ConfigOverrides"), 25 | ), 26 | ) 27 | } 28 | 29 | // WireModule can be used as part of google/wire setup to include the app's 30 | // lifecycle, config, and logger. 31 | var WireModule = wire.NewSet( 32 | wire.FieldsOf(new(*App), "Lifecycle", "Config", "Logger"), 33 | ) 34 | -------------------------------------------------------------------------------- /wire_gen.go: -------------------------------------------------------------------------------- 1 | // Code generated by Wire. DO NOT EDIT. 2 | 3 | //go:generate go run github.com/google/wire/cmd/wire 4 | //go:build !wireinject 5 | // +build !wireinject 6 | 7 | package copper 8 | 9 | import ( 10 | "github.com/gocopper/copper/cconfig" 11 | "github.com/gocopper/copper/clifecycle" 12 | "github.com/gocopper/copper/clogger" 13 | "github.com/google/wire" 14 | ) 15 | 16 | // Injectors from wire.go: 17 | 18 | // InitApp creates a new Copper app along with its dependencies. 19 | func InitApp() (*App, error) { 20 | lifecycle := clifecycle.New() 21 | flags := NewFlags() 22 | path := flags.ConfigPath 23 | overrides := flags.ConfigOverrides 24 | loader, err := cconfig.NewWithKeyOverrides(path, overrides) 25 | if err != nil { 26 | return nil, err 27 | } 28 | config, err := clogger.LoadConfig(loader) 29 | if err != nil { 30 | return nil, err 31 | } 32 | logger, err := clogger.NewWithConfig(config) 33 | if err != nil { 34 | return nil, err 35 | } 36 | app := NewApp(lifecycle, loader, logger) 37 | return app, nil 38 | } 39 | 40 | // wire.go: 41 | 42 | // WireModule can be used as part of google/wire setup to include the app's 43 | // lifecycle, config, and logger. 44 | var WireModule = wire.NewSet(wire.FieldsOf(new(*App), "Lifecycle", "Config", "Logger")) 45 | --------------------------------------------------------------------------------