├── .gitignore
├── .golangci.yaml
├── LICENSE
├── Makefile
├── README.md
├── app.go
├── cconfig
├── cconfigtest
│ ├── doc.go
│ └── util.go
├── doc.go
├── loader.go
├── loader_test.go
└── tree.go
├── cerrors
├── error.go
├── error_test.go
├── utils.go
└── utils_test.go
├── chttp
├── chttptest
│ ├── content_type.go
│ ├── doc.go
│ ├── handler.go
│ ├── html_reader_writer.go
│ ├── json_reader_writer.go
│ ├── route.go
│ └── src
│ │ ├── layouts
│ │ └── main.html
│ │ ├── pages
│ │ ├── index.html
│ │ └── not-found.html
│ │ └── partials
│ │ └── test-component.html
├── config.go
├── doc.go
├── error.html
├── fs.go
├── handler.go
├── handler_test.go
├── html_reader_writer.go
├── html_reader_writer_test.go
├── html_renderer.go
├── html_router.go
├── http.go
├── json_reader_writer.go
├── json_reader_writer_test.go
├── middleware.go
├── panic_logger_mw.go
├── panic_logger_mw_test.go
├── request_id_mw.go
├── request_logger_mw.go
├── request_logger_mw_test.go
├── route_ctx_mw.go
├── route_ctx_mw_test.go
├── server.go
├── server_test.go
└── wire.go
├── clifecycle
├── doc.go
├── lifecycle.go
└── logger.go
├── clogger
├── config.go
├── doc.go
├── json_redactor.go
├── json_redactor_test.go
├── level.go
├── level_test.go
├── logger.go
├── logger_test.go
├── noop.go
├── noop_test.go
├── recorder.go
├── recorder_test.go
├── util.go
└── zap.go
├── cmetrics
├── metrics.go
├── models.go
├── noop.go
└── wire.go
├── csql
├── config.go
├── ctx.go
├── ctx_test.go
├── db.go
├── db_test.go
├── doc.go
├── migrations_test.sql
├── migrator.go
├── migrator_test.go
├── qb
│ ├── query_builder.go
│ └── query_builder_test.go
├── querier.go
├── querier_test.go
├── tx_middleware.go
├── tx_middleware_test.go
└── wire.go
├── doc.go
├── flags.go
├── go.mod
├── go.sum
├── wire.go
└── wire_gen.go
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 | vendor/
4 |
--------------------------------------------------------------------------------
/.golangci.yaml:
--------------------------------------------------------------------------------
1 | issues:
2 | exclude-use-default: false
3 | fix: true
4 |
5 | run:
6 | build-tags:
7 | - wireinject
8 |
9 | linters:
10 | disable-all: true
11 | enable:
12 | - bodyclose
13 | - depguard
14 | - dogsled
15 | - dupl
16 | - errcheck
17 | - exportloopref
18 | - exhaustive
19 | - funlen
20 | - gochecknoinits
21 | - goconst
22 | - gocritic
23 | - gocyclo
24 | - gofmt
25 | - goimports
26 | - gomnd
27 | - goprintffuncname
28 | - gosec
29 | - gosimple
30 | - govet
31 | - ineffassign
32 | - misspell
33 | - nakedret
34 | - noctx
35 | - nolintlint
36 | - revive
37 | - rowserrcheck
38 | - staticcheck
39 | - stylecheck
40 | - typecheck
41 | - unconvert
42 | - whitespace
43 | - paralleltest
44 | - tparallel
45 | - testpackage
46 | - thelper
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Tushar Soni
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | GO=GO111MODULE=on go
2 | WIRE=wire
3 | COVERAGE_FILE="/tmp/copper_coverage.out"
4 |
5 | .PHONY: all
6 | all: lint generate test
7 |
8 | .PHONY: cover
9 | cover: test
10 | $(GO) tool cover -html=$(COVERAGE_FILE)
11 |
12 | .PHONY: test
13 | test:
14 | $(GO) test -coverprofile=$(COVERAGE_FILE) ./...
15 |
16 | .PHONY: lint
17 | lint: tidy
18 | golangci-lint run
19 |
20 | .PHONY: tidy
21 | tidy:
22 | $(GO) mod tidy
23 |
24 | .PHONY: generate
25 | generate:
26 | $(WIRE) .
27 |
28 | .PHONY: release
29 | release:
30 | git tag -a $(version) -m "Release $(version)"
31 | git push origin $(version)
32 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | # Copper
19 |
20 |
21 | Copper is a Go toolkit complete with everything you need to build web apps. It focuses on developer productivity and makes building web apps in Go more fun with less boilerplate and out-of-the-box support for common needs.
22 |
23 |
24 | #### 🚀 Fullstack Toolkit
25 | Copper provides a toolkit complete with everything you need to build web apps quickly.
26 |
27 |
28 | #### 📦 One Binary
29 | Build frontend apps along with your backend and ship everything in a single binary.
30 |
31 |
32 | #### 📝 Server-side HTML
33 | Copper includes utilities that help build web apps with server rendered HTML pages.
34 |
35 | #### 💡 Auto Restarts
36 | Copper detects changes and automatically restarts server to save time.
37 |
38 | #### 🏗 Scaffolding
39 | Skip boilerplate and scaffold code for your packages, database queries and routes.
40 |
41 | #### 🔋 Batteries Included
42 | Includes CLI, lint, dev server, config management, and more!
43 |
44 | #### 🔩 First-party packages
45 | Includes packages for authentication, pub/sub, queues, emails, and websockets.
46 |
47 |
48 |
49 |
50 | ## Intro Video (Hacker News Clone)
51 | [](https://vimeo.com/723537998)
52 |
53 |
54 |
55 | ## Getting Started
56 |
57 | Head over to the documentation to get started - [https://docs.gocopper.dev/](https://docs.gocopper.dev/getting-started)
58 |
59 | Join us on Discord here - https://discord.gg/fT2AEZyM6A
60 |
61 |
62 |
63 | ## License
64 | MIT
65 |
--------------------------------------------------------------------------------
/app.go:
--------------------------------------------------------------------------------
1 | package copper
2 |
3 | import (
4 | "log"
5 | "os"
6 | "os/signal"
7 | "syscall"
8 |
9 | "github.com/gocopper/copper/cconfig"
10 | "github.com/gocopper/copper/clifecycle"
11 | "github.com/gocopper/copper/clogger"
12 | )
13 |
14 | // Runner provides an interface that can be run by a Copper app using the Run or Start funcs.
15 | // This interface is implemented by various packages within Copper including chttp.Server.
16 | type Runner interface {
17 | Run() error
18 | }
19 |
20 | // New provides a convenience wrapper around InitApp that logs and exits if there is an error.
21 | func New() *App {
22 | app, err := InitApp()
23 | if err != nil {
24 | log.Fatal(err)
25 | }
26 |
27 | return app
28 | }
29 |
30 | // NewApp creates a new Copper app and returns it along with the app's lifecycle manager,
31 | // config, and the logger.
32 | func NewApp(lifecycle *clifecycle.Lifecycle, config cconfig.Loader, logger clogger.Logger) *App {
33 | return &App{
34 | Lifecycle: lifecycle,
35 | Config: config,
36 | Logger: logger,
37 | }
38 | }
39 |
40 | // App defines a Copper app container that can run provided code in its managed lifecycle.
41 | // It provides functionality to read config in multiple environments as defined by
42 | // command-line flags.
43 | type App struct {
44 | Lifecycle *clifecycle.Lifecycle
45 | Config cconfig.Loader
46 | Logger clogger.Logger
47 | }
48 |
49 | // Run runs the provided funcs. Once all of the functions complete their run,
50 | // the lifecycle's stop funcs are also called. If any of the fns return an error,
51 | // the app exits with an exit code 1.
52 | // Run should be used when none of the fn are long-running. For long-running funcs like
53 | // an HTTP server, use Start.
54 | func (a *App) Run(fns ...Runner) {
55 | for i := range fns {
56 | err := fns[i].Run()
57 | if err != nil {
58 | a.Logger.Error("Failed to run", err)
59 | a.Lifecycle.Stop(a.Logger)
60 | os.Exit(1)
61 | }
62 | }
63 |
64 | a.Lifecycle.Stop(a.Logger)
65 | }
66 |
67 | // Start runs the provided fns and then waits on the OS's INT and TERM signals from the
68 | // user to exit. Once the signal is received, the lifecycle's stop funcs are
69 | // called.
70 | // If any of the fns fail to run and returns an error, the app exits with exit code
71 | // 1.
72 | func (a *App) Start(fns ...Runner) {
73 | for i := range fns {
74 | err := fns[i].Run()
75 | if err != nil {
76 | a.Logger.Error("Failed to run", err)
77 | os.Exit(1)
78 | }
79 | }
80 |
81 | osInt := make(chan os.Signal, 1)
82 |
83 | signal.Notify(osInt, syscall.SIGINT, syscall.SIGTERM)
84 |
85 | <-osInt
86 |
87 | a.Lifecycle.Stop(a.Logger)
88 | }
89 |
--------------------------------------------------------------------------------
/cconfig/cconfigtest/doc.go:
--------------------------------------------------------------------------------
1 | // Package cconfigtest provides helper methods to test the cconfig package
2 | package cconfigtest
3 |
--------------------------------------------------------------------------------
/cconfig/cconfigtest/util.go:
--------------------------------------------------------------------------------
1 | package cconfigtest
2 |
3 | import (
4 | "os"
5 | "path"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | // SetupDirWithConfigs creates a temp directory that can store config files.
12 | // The directory is cleaned up after test run.
13 | func SetupDirWithConfigs(t *testing.T, configs map[string]string) string {
14 | t.Helper()
15 |
16 | dir, err := os.MkdirTemp("", "")
17 | assert.NoError(t, err)
18 |
19 | for fp, data := range configs {
20 | err = os.WriteFile(path.Join(dir, fp), []byte(data), os.ModePerm)
21 | assert.NoError(t, err)
22 | }
23 |
24 | t.Cleanup(func() {
25 | assert.NoError(t, os.RemoveAll(dir))
26 | })
27 |
28 | return dir
29 | }
30 |
--------------------------------------------------------------------------------
/cconfig/doc.go:
--------------------------------------------------------------------------------
1 | // Package cconfig provides utilities to read app config files. See New and Loader for more documentation and example.
2 | package cconfig
3 |
--------------------------------------------------------------------------------
/cconfig/loader.go:
--------------------------------------------------------------------------------
1 | package cconfig
2 |
3 | import (
4 | "github.com/gocopper/copper/cerrors"
5 | "github.com/pelletier/go-toml"
6 | )
7 |
8 | type (
9 | // Path defines the path to the config file.
10 | Path string
11 | // Overrides defines a ';' separated string of config overrides in TOML format
12 | Overrides string
13 | )
14 |
15 | // Loader provides methods to load config files into structs.
16 | type Loader interface {
17 | // Load reads the config values defined under the given key (TOML table) and sets them into the dest struct.
18 | // For example:
19 | // # prod.toml
20 | // key1 = "val1"
21 | //
22 | // # config.go
23 | // type MyConfig struct {
24 | // Key1 string `toml:"key1"`
25 | // }
26 | //
27 | // func LoadMyConfig(loader cconfig.Loader) (MyConfig, error) {
28 | // var config MyConfig
29 | //
30 | // err := loader.Load("my_config", &config)
31 | // if err != nil {
32 | // return MyConfig{}, cerrors.New(err, "failed to load configs for my_config", nil)
33 | // }
34 | //
35 | // return config, nil
36 | // }
37 | Load(key string, dest interface{}) error
38 | }
39 |
40 | // New provides an implementation of Loader that reads a config file at the given file path. It supports extending the
41 | // config file at the given path by using an 'extends' key. For example, the config file may be extended like so:
42 | //
43 | // # base.toml
44 | // key1 = "val1"
45 | //
46 | // # prod.toml
47 | // extends = "base.toml
48 | // key2 = "val2"
49 | //
50 | // If New is called with the path to prod.toml, it loads both key2 (from prod.toml) and key1 (from base.toml) since
51 | // prod.toml extends base.toml.
52 | //
53 | // The extends key can support multiple files like so:
54 | // extends = ["base.toml", "secrets.toml"]
55 | //
56 | // If a config key is present in multiple files, New returns an error. For example, if prod.toml sets a value for 'key1'
57 | // that has already been set in base.toml, an error will be returned. To enable key overrides see NewWithKeyOverrides.
58 | func New(fp Path, ov Overrides) (Loader, error) {
59 | return newLoader(string(fp), string(ov), true)
60 | }
61 |
62 | // NewWithKeyOverrides works exactly the same way as New except it supports key overrides. For example, this is a valid
63 | // config:
64 | // # base.toml
65 | // key1 = "val1"
66 | //
67 | // # prod.toml
68 | // extends = "base.toml
69 | // key1 = "val2"
70 | //
71 | // If prod.toml is loaded, key1 will be set to "val2" since it has been overridden in prod.toml.
72 | func NewWithKeyOverrides(fp Path, overrides Overrides) (Loader, error) {
73 | return newLoader(string(fp), string(overrides), false)
74 | }
75 |
76 | func newLoader(fp, overrides string, disableKeyOverrides bool) (*loader, error) {
77 | tree, err := loadTree(fp, overrides, disableKeyOverrides)
78 | if err != nil {
79 | return nil, cerrors.New(err, "failed to load config tree", map[string]interface{}{
80 | "path": fp,
81 | })
82 | }
83 |
84 | return &loader{
85 | tree: tree,
86 | }, nil
87 | }
88 |
89 | type loader struct {
90 | tree *toml.Tree
91 | }
92 |
93 | func (l *loader) Load(key string, dest interface{}) error {
94 | if !l.tree.Has(key) {
95 | return nil
96 | }
97 |
98 | keyTree, ok := l.tree.Get(key).(*toml.Tree)
99 | if !ok {
100 | return cerrors.New(nil, "invalid key type", map[string]interface{}{
101 | "key": key,
102 | })
103 | }
104 |
105 | err := keyTree.Unmarshal(dest)
106 | if err != nil {
107 | return cerrors.New(err, "failed to unmarshal config into dest", map[string]interface{}{
108 | "key": key,
109 | })
110 | }
111 |
112 | return nil
113 | }
114 |
--------------------------------------------------------------------------------
/cconfig/loader_test.go:
--------------------------------------------------------------------------------
1 | package cconfig_test
2 |
3 | import (
4 | "path"
5 | "testing"
6 |
7 | "github.com/gocopper/copper/cconfig"
8 | "github.com/gocopper/copper/cconfig/cconfigtest"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func TestLoader_Load(t *testing.T) {
13 | t.Parallel()
14 |
15 | var (
16 | dir = cconfigtest.SetupDirWithConfigs(t, map[string]string{
17 | "base.toml": `
18 | [group1]
19 | key1 = "val1-base"
20 | key2 = "val2-base"
21 | `,
22 | "test.toml": `
23 | extends = "base.toml"
24 |
25 | [group1]
26 | key1 = "val1-test"
27 | `,
28 | })
29 | fp = cconfig.Path(path.Join(dir, "test.toml"))
30 | ov = cconfig.Overrides("group1.key2=\"val2-override\"")
31 | )
32 |
33 | t.Run("load with extends and key overrides", func(t *testing.T) {
34 | t.Parallel()
35 |
36 | var testConfig struct {
37 | Key1 string `toml:"key1"`
38 | Key2 string `toml:"key2"`
39 | }
40 |
41 | configs, err := cconfig.NewWithKeyOverrides(fp, ov)
42 | assert.NoError(t, err)
43 |
44 | err = configs.Load("group1", &testConfig)
45 | assert.NoError(t, err)
46 |
47 | assert.Equal(t, "val1-test", testConfig.Key1)
48 | assert.Equal(t, "val2-override", testConfig.Key2)
49 | })
50 |
51 | t.Run("error with key overrides disabled", func(t *testing.T) {
52 | t.Parallel()
53 |
54 | _, err := cconfig.New(fp, ov)
55 |
56 | assert.NotNil(t, err)
57 | assert.Contains(t, err.Error(), "key is being overridden when key overrides are disabled")
58 | })
59 | }
60 |
--------------------------------------------------------------------------------
/cconfig/tree.go:
--------------------------------------------------------------------------------
1 | package cconfig
2 |
3 | import (
4 | "html/template"
5 | "os"
6 | "path/filepath"
7 | "reflect"
8 | "strings"
9 |
10 | "github.com/gocopper/copper/cerrors"
11 | "github.com/pelletier/go-toml"
12 | )
13 |
14 | //nolint:funlen
15 | func loadTree(fp, overrides string, disableKeyOverrides bool) (*toml.Tree, error) {
16 | tmpl, err := template.ParseFiles(fp)
17 | if err != nil {
18 | return nil, cerrors.New(err, "failed to parse config file as template", map[string]interface{}{
19 | "path": fp,
20 | })
21 | }
22 |
23 | envVars := make(map[string]string)
24 | for _, e := range os.Environ() {
25 | pair := strings.Split(e, "=")
26 | envVars[pair[0]] = pair[1]
27 | }
28 |
29 | var tomlOut strings.Builder
30 | err = tmpl.Execute(&tomlOut, map[string]interface{}{
31 | "EnvVars": envVars,
32 | })
33 | if err != nil {
34 | return nil, cerrors.New(err, "failed to execute config file template", map[string]interface{}{
35 | "path": fp,
36 | })
37 | }
38 |
39 | tree, err := toml.LoadBytes([]byte(tomlOut.String()))
40 | if err != nil {
41 | return nil, cerrors.New(err, "failed to load config file", map[string]interface{}{
42 | "path": fp,
43 | })
44 | }
45 |
46 | // If the TOML tree does not have a top-level 'extends' key, we can return the tree as-is
47 | if !tree.Has("extends") {
48 | return tree, nil
49 | }
50 |
51 | parentFilePaths := make([]string, 0)
52 |
53 | // The extends key can be a string or a list of strings representing the config file paths that need to be loaded
54 | switch extends := tree.Get("extends").(type) {
55 | case string:
56 | parentFilePaths = append(parentFilePaths, extends)
57 |
58 | // If extends is set to a list, verify each value is a valid string, and add it to parentFilePaths
59 | case []interface{}:
60 | for i := range extends {
61 | parentFilePath, ok := extends[i].(string)
62 | if !ok {
63 | return nil, cerrors.New(nil, "extends can only contain strings", map[string]interface{}{
64 | "path": fp,
65 | "value": extends[i],
66 | })
67 | }
68 |
69 | parentFilePaths = append(parentFilePaths, parentFilePath)
70 | }
71 | default:
72 | return nil, cerrors.New(nil, "'extends' must be string or []string", map[string]interface{}{
73 | "path": fp,
74 | "type": reflect.TypeOf(extends).String(),
75 | })
76 | }
77 |
78 | // Load each parentFilePath in-order
79 | for _, parentFP := range parentFilePaths {
80 | parentFilePath := filepath.Join(filepath.Dir(fp), parentFP)
81 |
82 | // Load the parent tree at the given path defined by the extends key. Note that this is a recursive call
83 | // that will load all ancestors.
84 | parentTree, err := loadTree(parentFilePath, "", disableKeyOverrides)
85 | if err != nil {
86 | return nil, cerrors.New(err, "failed to load parent tree", map[string]interface{}{
87 | "parentPath": parentFilePath,
88 | })
89 | }
90 |
91 | // Once the parent tree and its ancestors are loaded, we need to merge it with our current tree
92 | tree, err = mergeTrees(parentTree, tree, disableKeyOverrides)
93 | if err != nil {
94 | return nil, cerrors.New(err, "failed to merge with parent tree", map[string]interface{}{
95 | "parentPath": parentFilePath,
96 | })
97 | }
98 | }
99 |
100 | // Apply overrides
101 | for _, ov := range strings.Split(overrides, ";") {
102 | t, err := toml.Load(ov)
103 | if err != nil {
104 | return nil, cerrors.New(err, "failed to parse override as TOML", map[string]interface{}{
105 | "override": ov,
106 | })
107 | }
108 |
109 | tree, err = mergeTrees(tree, t, disableKeyOverrides)
110 | if err != nil {
111 | return nil, cerrors.New(err, "failed to merge tree with overrides", map[string]interface{}{
112 | "override": ov,
113 | })
114 | }
115 | }
116 | return tree, nil
117 | }
118 |
119 | //nolint:funlen
120 | func mergeTrees(base, override *toml.Tree, disableKeyOverrides bool) (*toml.Tree, error) {
121 | // For each key in the override tree, we need to apply it to the base tree
122 | for _, key := range override.Keys() {
123 | switch keyVal := override.Get(key).(type) {
124 | // If the value at the given key is a TOML tree (aka a table according to the spec), we need to merge it with
125 | // the base table.
126 | // For example, if the base tree contains:
127 | // [group1]
128 | // key1="val1"
129 | // and the override tree contains:
130 | // [group1]
131 | // key2="val"2
132 | // We need to load it in such a way where group1 contains both key1 and key2
133 | case *toml.Tree:
134 | // If the base does not contain the key, we can set the entire value from the override table as-is
135 | if !base.Has(key) {
136 | base.Set(key, keyVal)
137 | continue
138 | }
139 |
140 | // Verify that the value type in the base and override trees are the same. For example, this is invalid:
141 | // # base.toml
142 | // group1 = "I am a string"
143 | //
144 | // # prod.toml
145 | // extends = "base.toml"
146 | // [group1] # I am a table!
147 | // key1="val"1
148 | //
149 | // The above configuration is invalid because group1 is a table in prod.toml but a string in base.toml. As
150 | // a result, they cannot be merged.
151 | baseTree, ok := base.Get(key).(*toml.Tree)
152 | if !ok {
153 | return nil, cerrors.New(nil, "base and override key types don't match", map[string]interface{}{
154 | "key": key,
155 | })
156 | }
157 |
158 | // Now that we have two trees, we can merge them recursively
159 | mergedTree, err := mergeTrees(baseTree, keyVal, disableKeyOverrides)
160 | if err != nil {
161 | return nil, cerrors.New(err, "failed to merge tree for key", map[string]interface{}{
162 | "key": key,
163 | })
164 | }
165 |
166 | base.Set(key, mergedTree)
167 |
168 | // This handles all non-table keys. The key, as found in the override tree, is set on the base tree. If the base
169 | // tree already has a value for the key, it is only overridden if disableKeyOverrides is false. If a key is
170 | // being overridden with disableKeyOverrides=true, an error is returned.
171 | default:
172 | if base.Has(key) && disableKeyOverrides {
173 | return nil, cerrors.New(nil, "key is being overridden when key overrides are disabled", map[string]interface{}{
174 | "key": key,
175 | })
176 | }
177 |
178 | base.Set(key, keyVal)
179 | }
180 | }
181 |
182 | return base, nil
183 | }
184 |
--------------------------------------------------------------------------------
/cerrors/error.go:
--------------------------------------------------------------------------------
1 | // Package cerrors provides a custom error type that can hold more context than the stdlib error package.
2 | package cerrors
3 |
4 | import (
5 | "errors"
6 | "fmt"
7 | "reflect"
8 | "sort"
9 | "strings"
10 | )
11 |
12 | // Error can wrap an error with additional context such as structured tags.
13 | type Error struct {
14 | Message string
15 | Tags map[string]interface{}
16 | Cause error
17 | }
18 |
19 | // New creates an error by (optionally) wrapping an existing error and
20 | // annotating the error with structured tags.
21 | func New(cause error, msg string, tags map[string]interface{}) error {
22 | return Error{
23 | Message: msg,
24 | Tags: tags,
25 | Cause: cause,
26 | }
27 | }
28 |
29 | // WithTags annotates an existing error with structured tags.
30 | func WithTags(err error, tags map[string]interface{}) error {
31 | cerr, ok := err.(Error) //nolint:errorlint
32 | if !ok {
33 | return Error{
34 | Message: err.Error(),
35 | Tags: tags,
36 | Cause: errors.Unwrap(err),
37 | }
38 | }
39 |
40 | return Error{
41 | Message: cerr.Message,
42 | Tags: tags,
43 | Cause: cerr.Cause,
44 | }
45 | }
46 |
47 | // Unwrap returns the underlying cause of an error (if any).
48 | func (e Error) Unwrap() error {
49 | return e.Cause
50 | }
51 |
52 | // Error returns a human-friendly string that contains the
53 | // entire error chain along with all of the tags on each
54 | // error.
55 | func (e Error) Error() string {
56 | var (
57 | err strings.Builder
58 | tags []string
59 | )
60 |
61 | err.WriteString(e.Message)
62 |
63 | for tag, val := range e.Tags {
64 | reflectVal := reflect.ValueOf(val)
65 | if reflectVal.Kind() == reflect.Ptr && !reflectVal.IsNil() {
66 | tags = append(tags, fmt.Sprintf("%s=%+v", tag, reflectVal.Elem()))
67 | } else {
68 | tags = append(tags, fmt.Sprintf("%s=%+v", tag, val))
69 | }
70 | }
71 |
72 | sort.Strings(tags)
73 |
74 | if len(tags) > 0 {
75 | err.WriteString(" where ")
76 | err.WriteString(strings.Join(tags, ","))
77 | }
78 |
79 | if e.Cause != nil {
80 | err.WriteString(" because\n")
81 | err.WriteString("> ")
82 | err.WriteString(e.Cause.Error())
83 | }
84 |
85 | return err.String()
86 | }
87 |
--------------------------------------------------------------------------------
/cerrors/error_test.go:
--------------------------------------------------------------------------------
1 | package cerrors_test
2 |
3 | import (
4 | "errors"
5 | "testing"
6 |
7 | "github.com/gocopper/copper/cerrors"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestNew(t *testing.T) {
12 | t.Parallel()
13 |
14 | err := cerrors.New(nil, "test-err", nil)
15 |
16 | cErr, ok := err.(cerrors.Error) //nolint:errorlint
17 |
18 | assert.True(t, ok)
19 | assert.Equal(t, "test-err", cErr.Message)
20 | assert.Nil(t, cErr.Cause)
21 | assert.Nil(t, cErr.Tags)
22 | }
23 |
24 | func TestWithTags_StdErr(t *testing.T) {
25 | t.Parallel()
26 |
27 | err := cerrors.WithTags(errors.New("test-err"), map[string]interface{}{ //nolint:goerr113
28 | "key": "val",
29 | })
30 |
31 | cErr, ok := err.(cerrors.Error) //nolint:errorlint
32 |
33 | assert.True(t, ok)
34 | assert.Equal(t, "test-err", cErr.Message)
35 | assert.Contains(t, cErr.Tags, "key")
36 | assert.Equal(t, cErr.Tags["key"], "val")
37 | assert.Nil(t, cErr.Cause)
38 | }
39 |
40 | func TestWithTags_CErr(t *testing.T) {
41 | t.Parallel()
42 |
43 | err := cerrors.WithTags(cerrors.New(nil, "test-cerr", nil), map[string]interface{}{
44 | "key": "val",
45 | })
46 |
47 | cErr, ok := err.(cerrors.Error) //nolint:errorlint
48 |
49 | assert.True(t, ok)
50 | assert.Equal(t, "test-cerr", cErr.Message)
51 | assert.Contains(t, cErr.Tags, "key")
52 | assert.Equal(t, cErr.Tags["key"], "val")
53 | assert.Nil(t, cErr.Cause)
54 | }
55 |
56 | func TestError_Unwrap(t *testing.T) {
57 | t.Parallel()
58 |
59 | err := cerrors.New(errors.New("cause-err"), "test-err", nil) //nolint:goerr113
60 | cause := errors.Unwrap(err)
61 |
62 | assert.NotNil(t, cause)
63 | assert.EqualError(t, cause, "cause-err")
64 | }
65 |
66 | func TestError_Unwrap_NoCause(t *testing.T) {
67 | t.Parallel()
68 |
69 | err := cerrors.New(nil, "test-err", nil)
70 |
71 | assert.Nil(t, errors.Unwrap(err))
72 | }
73 |
74 | func TestError_Error(t *testing.T) {
75 | t.Parallel()
76 |
77 | err := cerrors.New(nil, "test-err", nil)
78 |
79 | assert.NotNil(t, err)
80 | assert.Equal(t, "test-err", err.Error())
81 | }
82 |
83 | func TestError_Error_Cause(t *testing.T) {
84 | t.Parallel()
85 |
86 | err := cerrors.New(errors.New("cause-err"), "test-err", nil) //nolint:goerr113
87 |
88 | expectedErr := `test-err because
89 | > cause-err`
90 |
91 | assert.NotNil(t, err)
92 | assert.Equal(t, expectedErr, err.Error())
93 | }
94 |
95 | func TestError_Error_Tags(t *testing.T) {
96 | t.Parallel()
97 |
98 | err := cerrors.New(nil, "test-err", map[string]interface{}{
99 | "key": "val",
100 | })
101 |
102 | assert.NotNil(t, err)
103 | assert.Equal(t, "test-err where key=val", err.Error())
104 | }
105 |
106 | func TestError_Error_PtrTags(t *testing.T) {
107 | t.Parallel()
108 |
109 | val := "val"
110 | err := cerrors.New(nil, "test-err", map[string]interface{}{
111 | "key": &val,
112 | })
113 |
114 | assert.NotNil(t, err)
115 | assert.Equal(t, "test-err where key=val", err.Error())
116 | }
117 |
--------------------------------------------------------------------------------
/cerrors/utils.go:
--------------------------------------------------------------------------------
1 | package cerrors
2 |
3 | import "errors"
4 |
5 | func Tags(err error) map[string]interface{} {
6 | var cerr Error
7 |
8 | if !errors.As(err, &cerr) {
9 | return nil
10 | }
11 |
12 | return mergeTags(cerr.Tags, Tags(cerr.Cause))
13 | }
14 |
15 | func WithoutTags(err error) error {
16 | var cerr Error
17 |
18 | if !errors.As(err, &cerr) {
19 | return err
20 | }
21 |
22 | cerr.Tags = nil
23 | cerr.Cause = WithoutTags(cerr.Cause)
24 |
25 | return cerr
26 | }
27 |
28 | func mergeTags(t1, t2 map[string]interface{}) map[string]interface{} {
29 | merged := make(map[string]interface{})
30 |
31 | for k, v := range t1 {
32 | merged[k] = v
33 | }
34 |
35 | for k, v := range t2 {
36 | merged[k] = v
37 | }
38 |
39 | return merged
40 | }
41 |
--------------------------------------------------------------------------------
/cerrors/utils_test.go:
--------------------------------------------------------------------------------
1 | package cerrors_test
2 |
3 | import (
4 | "fmt"
5 | "github.com/gocopper/copper/cerrors"
6 | "github.com/stretchr/testify/assert"
7 | "testing"
8 | )
9 |
10 | func TestWithoutTags(t *testing.T) {
11 | err := fmt.Errorf("test-error-0; %w", cerrors.New(cerrors.New(nil, "test-error-2", map[string]interface{}{
12 | "tag": "val",
13 | }), "test-error-1", map[string]interface{}{
14 | "tag": "val",
15 | }))
16 |
17 | out := cerrors.WithoutTags(err)
18 |
19 | assert.NotContains(t, out.Error(), "tag")
20 | }
21 |
--------------------------------------------------------------------------------
/chttp/chttptest/content_type.go:
--------------------------------------------------------------------------------
1 | package chttptest
2 |
3 | // ContentTypeApplicationJSON defines the application/json content type that can be used when making HTTP requests.
4 | const ContentTypeApplicationJSON = "application/json"
5 |
--------------------------------------------------------------------------------
/chttp/chttptest/doc.go:
--------------------------------------------------------------------------------
1 | // Package chttptest provides utility functions that are useful when testing chttp
2 | package chttptest
3 |
--------------------------------------------------------------------------------
/chttp/chttptest/handler.go:
--------------------------------------------------------------------------------
1 | package chttptest
2 |
3 | import (
4 | "io"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/gocopper/copper/clogger"
10 |
11 | "github.com/gocopper/copper/chttp"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | // PingRoutes creates a handler using chttp.NewHandler, starts a test
16 | // http server, and calls each provided route. It verifies that each
17 | // route's handler is called successfully.
18 | func PingRoutes(t *testing.T, routes []chttp.Route) {
19 | t.Helper()
20 |
21 | for i := range routes {
22 | body := routes[i].Path
23 | routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
24 | _, err := w.Write([]byte(body))
25 | assert.NoError(t, err)
26 | }
27 | }
28 |
29 | server := httptest.NewServer(chttp.NewHandler(chttp.NewHandlerParams{
30 | Routers: []chttp.Router{NewRouter(routes)},
31 | GlobalMiddlewares: nil,
32 | Logger: clogger.NewNoop(),
33 | }))
34 | defer server.Close()
35 |
36 | for _, route := range routes {
37 | resp, err := http.Get(server.URL + route.Path) //nolint:noctx
38 | assert.NoError(t, err)
39 |
40 | body, err := io.ReadAll(resp.Body)
41 | assert.NoError(t, resp.Body.Close())
42 | assert.NoError(t, err)
43 |
44 | assert.Equal(t, route.Path, string(body))
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/chttp/chttptest/html_reader_writer.go:
--------------------------------------------------------------------------------
1 | package chttptest
2 |
3 | import (
4 | "embed"
5 | "testing"
6 |
7 | "github.com/gocopper/copper/chttp"
8 | "github.com/gocopper/copper/clogger"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | // HTMLDir embeds a directory that can be used with chttp.ReaderWriter
13 | //
14 | //go:embed src
15 | var HTMLDir embed.FS
16 |
17 | // NewHTMLReaderWriter creates a *chttp.HTMLReaderWriter suitable for use in tests
18 | func NewHTMLReaderWriter(t *testing.T) *chttp.HTMLReaderWriter {
19 | t.Helper()
20 |
21 | r, err := chttp.NewHTMLRenderer(chttp.NewHTMLRendererParams{
22 | HTMLDir: HTMLDir,
23 | StaticDir: nil,
24 | Config: chttp.Config{},
25 | Logger: clogger.NewNoop(),
26 | })
27 | assert.NoError(t, err)
28 |
29 | return chttp.NewHTMLReaderWriter(r, chttp.Config{}, clogger.NewNoop())
30 | }
31 |
--------------------------------------------------------------------------------
/chttp/chttptest/json_reader_writer.go:
--------------------------------------------------------------------------------
1 | package chttptest
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/gocopper/copper/chttp"
7 | "github.com/gocopper/copper/clogger"
8 | )
9 |
10 | // NewJSONReaderWriter creates a *chttp.NewJSONReaderWriter suitable for use in tests
11 | func NewJSONReaderWriter(t *testing.T) *chttp.JSONReaderWriter {
12 | t.Helper()
13 |
14 | return chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop())
15 | }
16 |
--------------------------------------------------------------------------------
/chttp/chttptest/route.go:
--------------------------------------------------------------------------------
1 | package chttptest
2 |
3 | import "github.com/gocopper/copper/chttp"
4 |
5 | // ReverseRoutes reverses the provided slice of chttp.Route.
6 | func ReverseRoutes(routes []chttp.Route) []chttp.Route {
7 | for i := 0; i < len(routes)/2; i++ {
8 | j := len(routes) - i - 1
9 | routes[i], routes[j] = routes[j], routes[i]
10 | }
11 |
12 | return routes
13 | }
14 |
15 | // NewRouter returns a router that returns the given routes.
16 | func NewRouter(routes []chttp.Route) chttp.Router {
17 | return &router{routes: routes}
18 | }
19 |
20 | type router struct {
21 | routes []chttp.Route
22 | }
23 |
24 | func (ro *router) Routes() []chttp.Route {
25 | return ro.routes
26 | }
27 |
--------------------------------------------------------------------------------
/chttp/chttptest/src/layouts/main.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Test Page
7 |
8 |
9 | {{ template "content" . }}
10 |
11 |
12 |
--------------------------------------------------------------------------------
/chttp/chttptest/src/pages/index.html:
--------------------------------------------------------------------------------
1 | {{ define "content" }}
2 | Test Page
3 | {{ end }}
4 |
--------------------------------------------------------------------------------
/chttp/chttptest/src/pages/not-found.html:
--------------------------------------------------------------------------------
1 | {{ define "content" }}
2 | The content you are looking for was not found.
3 | {{ end }}
4 |
--------------------------------------------------------------------------------
/chttp/chttptest/src/partials/test-component.html:
--------------------------------------------------------------------------------
1 | {{ define "test-component" }}
2 | {{ end }}
3 |
--------------------------------------------------------------------------------
/chttp/config.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "github.com/gocopper/copper/cconfig"
5 | "github.com/gocopper/copper/cerrors"
6 | )
7 |
8 | // LoadConfig loads Config from app's config
9 | func LoadConfig(appConfig cconfig.Loader) (Config, error) {
10 | var config Config
11 |
12 | err := appConfig.Load("chttp", &config)
13 | if err != nil {
14 | return Config{}, cerrors.New(err, "failed to load chttp config", nil)
15 | }
16 |
17 | return config, nil
18 | }
19 |
20 | // Config holds the params needed to configure Server
21 | type Config struct {
22 | Port uint `default:"7501"`
23 | UseLocalHTML bool `toml:"use_local_html"`
24 | RenderHTMLError bool `toml:"render_html_error"`
25 | EnableSinglePageRouting bool `toml:"enable_single_page_routing"`
26 | ReadTimeoutSeconds uint `toml:"read_timeout_seconds" default:"10"`
27 | RedirectURLForUnauthorizedRequests *string `toml:"redirect_url_for_unauthorized_requests"`
28 | BasePath *string `toml:"base_path"`
29 | }
30 |
--------------------------------------------------------------------------------
/chttp/doc.go:
--------------------------------------------------------------------------------
1 | // Package chttp helps setup a http server with routing, middlewares, and more.
2 | package chttp
3 |
--------------------------------------------------------------------------------
/chttp/error.html:
--------------------------------------------------------------------------------
1 |
28 |
29 |
30 |
Failed to handle request
31 |
> {{ .Error }}
32 |
33 |
38 |
39 |
--------------------------------------------------------------------------------
/chttp/fs.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "errors"
5 | "io/fs"
6 | )
7 |
8 | // EmptyFS is a simple implementation of fs.FS interface that only returns an error.
9 | // This implementation emulates an empty directory.
10 | type EmptyFS struct{}
11 |
12 | // Open returns an error since this fs is empty.
13 | func (fs *EmptyFS) Open(string) (fs.File, error) {
14 | return nil, errors.New("empty fs")
15 | }
16 |
--------------------------------------------------------------------------------
/chttp/handler.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "net/http"
5 | "path"
6 | "regexp"
7 | "sort"
8 | "strings"
9 |
10 | "github.com/gocopper/copper/clogger"
11 |
12 | "github.com/gorilla/mux"
13 | )
14 |
15 | // NewHandlerParams holds the params needed for NewHandler.
16 | type NewHandlerParams struct {
17 | Routers []Router
18 | GlobalMiddlewares []Middleware
19 | Config Config
20 | Logger clogger.Logger
21 | }
22 |
23 | // NewHandler creates a http.Handler with the given routes and middlewares.
24 | // The handler can be used with a http.Server or as an argument to StartServer.
25 | func NewHandler(p NewHandlerParams) http.Handler {
26 | var (
27 | muxRouter = mux.NewRouter()
28 | muxHandler = http.NewServeMux()
29 | )
30 |
31 | routes := make([]Route, 0)
32 | for _, router := range p.Routers {
33 | routes = append(routes, router.Routes()...)
34 | }
35 |
36 | sortRoutes(routes)
37 |
38 | for _, route := range routes {
39 | routePath := route.Path
40 |
41 | // If a base path is set in the configuration ensure that the route path
42 | // starts with the base path. If it does not, skip this route.
43 | // Otherwise, remove the base path prefix from the route path to allow for
44 | // proper registration in the router.
45 | if p.Config.BasePath != nil {
46 | if route.RegisterWithBasePath {
47 | routePath = path.Join(*p.Config.BasePath, routePath)
48 | }
49 |
50 | if !strings.HasPrefix(routePath, *p.Config.BasePath) {
51 | continue
52 | }
53 |
54 | routePath = routePath[len(*p.Config.BasePath):]
55 | }
56 |
57 | handler := http.Handler(route.Handler)
58 |
59 | // Register route-level handlers
60 | // Since we are wrapping the handler in middleware functions, the outermost one will run first.
61 | // By applying the middlewares in reverse, we ensure that the first middleware in the list is the outermost one.
62 | for i := len(route.Middlewares) - 1; i >= 0; i-- {
63 | handler = route.Middlewares[i].Handle(handler)
64 | }
65 |
66 | // Register global middlewares
67 | for i := len(p.GlobalMiddlewares) - 1; i >= 0; i-- {
68 | handler = p.GlobalMiddlewares[i].Handle(handler)
69 | }
70 |
71 | handler = setRoutePathInCtxMiddleware(routePath).Handle(handler)
72 | handler = panicLoggerMiddleware(p.Logger).Handle(handler)
73 |
74 | muxRoute := muxRouter.Handle(routePath, handler)
75 |
76 | if len(route.Methods) > 0 {
77 | muxRoute.Methods(route.Methods...)
78 | }
79 | }
80 |
81 | muxHandler.Handle("/", muxRouter)
82 |
83 | return muxHandler
84 | }
85 |
86 | func sortRoutes(routes []Route) {
87 | const matcherPlaceholder = "{{matcher}}"
88 |
89 | re := regexp.MustCompile(`(?U)(\{.*\})`)
90 |
91 | sort.Slice(routes, func(i, j int) bool {
92 | aPath := re.ReplaceAllString(routes[i].Path, matcherPlaceholder)
93 | bPath := re.ReplaceAllString(routes[j].Path, matcherPlaceholder)
94 |
95 | aParts := strings.Split(aPath, "/")
96 | bParts := strings.Split(bPath, "/")
97 |
98 | if aPath == "/" {
99 | aParts = nil
100 | }
101 |
102 | if bPath == "/" {
103 | bParts = nil
104 | }
105 |
106 | if len(aParts) != len(bParts) {
107 | return len(aParts) > len(bParts)
108 | }
109 |
110 | for i, aPart := range aParts {
111 | bPart := bParts[i]
112 |
113 | if aPart == matcherPlaceholder {
114 | return false
115 | }
116 |
117 | if bPart == matcherPlaceholder {
118 | return true
119 | }
120 | }
121 |
122 | return false
123 | })
124 | }
125 |
--------------------------------------------------------------------------------
/chttp/handler_test.go:
--------------------------------------------------------------------------------
1 | package chttp_test
2 |
3 | import (
4 | "io"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/gocopper/copper/clogger"
10 |
11 | "github.com/gocopper/copper/chttp"
12 | "github.com/gocopper/copper/chttp/chttptest"
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | func TestNewHandler(t *testing.T) {
17 | t.Parallel()
18 |
19 | router := chttptest.NewRouter([]chttp.Route{
20 | {
21 | Path: "/",
22 | Methods: []string{http.MethodGet},
23 | Handler: func(w http.ResponseWriter, r *http.Request) {
24 | _, err := w.Write([]byte("success"))
25 | assert.NoError(t, err)
26 | },
27 | },
28 | })
29 |
30 | server := httptest.NewServer(chttp.NewHandler(chttp.NewHandlerParams{
31 | Routers: []chttp.Router{router},
32 | GlobalMiddlewares: nil,
33 | Logger: clogger.NewNoop(),
34 | }))
35 | defer server.Close()
36 |
37 | resp, err := http.Get(server.URL) //nolint:noctx
38 | assert.NoError(t, err)
39 |
40 | body, err := io.ReadAll(resp.Body)
41 | assert.NoError(t, resp.Body.Close())
42 | assert.NoError(t, err)
43 |
44 | assert.Equal(t, "success", string(body))
45 | }
46 |
47 | func TestNewHandler_GlobalMiddleware(t *testing.T) {
48 | t.Parallel()
49 |
50 | didCallGlobalMiddleware := false
51 |
52 | router := chttptest.NewRouter([]chttp.Route{
53 | {
54 | Path: "/",
55 | Handler: func(w http.ResponseWriter, r *http.Request) {},
56 | },
57 | })
58 |
59 | server := httptest.NewServer(chttp.NewHandler(chttp.NewHandlerParams{
60 | Routers: []chttp.Router{router},
61 | GlobalMiddlewares: []chttp.Middleware{
62 | chttp.HandleMiddleware(func(next http.Handler) http.Handler {
63 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
64 | didCallGlobalMiddleware = true
65 | next.ServeHTTP(w, r)
66 | })
67 | }),
68 | },
69 | Logger: clogger.NewNoop(),
70 | }))
71 | defer server.Close()
72 |
73 | resp, err := http.Get(server.URL) //nolint:noctx
74 | assert.NoError(t, err)
75 | assert.NoError(t, resp.Body.Close())
76 |
77 | assert.True(t, didCallGlobalMiddleware)
78 | }
79 |
80 | func TestNewHandler_RouteMiddleware(t *testing.T) {
81 | t.Parallel()
82 |
83 | middlewareCalls := make([]string, 0)
84 |
85 | router := chttptest.NewRouter([]chttp.Route{
86 | {
87 | Path: "/",
88 | Middlewares: []chttp.Middleware{
89 | chttp.HandleMiddleware(func(next http.Handler) http.Handler {
90 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91 | middlewareCalls = append(middlewareCalls, "1")
92 | next.ServeHTTP(w, r)
93 | })
94 | }),
95 |
96 | chttp.HandleMiddleware(func(next http.Handler) http.Handler {
97 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
98 | middlewareCalls = append(middlewareCalls, "2")
99 | next.ServeHTTP(w, r)
100 | })
101 | }),
102 | },
103 | Handler: func(w http.ResponseWriter, r *http.Request) {},
104 | },
105 | })
106 |
107 | server := httptest.NewServer(chttp.NewHandler(chttp.NewHandlerParams{
108 | Routers: []chttp.Router{router},
109 | GlobalMiddlewares: nil,
110 | Logger: clogger.NewNoop(),
111 | }))
112 | defer server.Close()
113 |
114 | resp, err := http.Get(server.URL) //nolint:noctx
115 | assert.NoError(t, err)
116 | assert.NoError(t, resp.Body.Close())
117 |
118 | assert.Equal(t, []string{"1", "2"}, middlewareCalls)
119 | }
120 |
121 | func TestNewHandler_RoutePriority_WithPlaceholder(t *testing.T) {
122 | t.Parallel()
123 |
124 | routes := []chttp.Route{
125 | {Path: "/foo"},
126 | {Path: "/{id}"},
127 | }
128 |
129 | chttptest.PingRoutes(t, routes)
130 | chttptest.PingRoutes(t, chttptest.ReverseRoutes(routes))
131 | }
132 |
133 | func TestNewHandler_RoutePriority_WithIndex(t *testing.T) {
134 | t.Parallel()
135 |
136 | routes := []chttp.Route{
137 | {Path: "/foo"},
138 | {Path: "/"},
139 | }
140 |
141 | chttptest.PingRoutes(t, routes)
142 | chttptest.PingRoutes(t, chttptest.ReverseRoutes(routes))
143 | }
144 |
145 | func TestNewHandler_RoutePriority_Equal(t *testing.T) {
146 | t.Parallel()
147 |
148 | routes := []chttp.Route{
149 | {Path: "/foo"},
150 | {Path: "/bar"},
151 | }
152 |
153 | chttptest.PingRoutes(t, routes)
154 | chttptest.PingRoutes(t, chttptest.ReverseRoutes(routes))
155 | }
156 |
--------------------------------------------------------------------------------
/chttp/html_reader_writer.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | // Used to embed error.html
5 | _ "embed"
6 | "html/template"
7 | "net/http"
8 |
9 | "github.com/gorilla/mux"
10 |
11 | "github.com/gocopper/copper/cerrors"
12 | "github.com/gocopper/copper/clogger"
13 | )
14 |
15 | type (
16 | // WriteHTMLParams holds the params for the WriteHTML function in ReaderWriter
17 | WriteHTMLParams struct {
18 | StatusCode int
19 | Error error
20 | Data interface{}
21 | PageTemplate string
22 | LayoutTemplate string
23 | }
24 |
25 | // WritePartialParams are the parameters for WritePartial
26 | WritePartialParams struct {
27 | Name string
28 | Data interface{}
29 | }
30 |
31 | // HTMLReaderWriter provides functions to read data from HTTP requests and write HTML response bodies
32 | HTMLReaderWriter struct {
33 | html *HTMLRenderer
34 | config Config
35 | logger clogger.Logger
36 | }
37 | )
38 |
39 | // URLParams returns the route variables for the current request, if any
40 | var URLParams = mux.Vars
41 |
42 | //go:embed error.html
43 | var errorHTML string
44 |
45 | // NewHTMLReaderWriter instantiates a new HTMLReaderWriter with its dependencies
46 | func NewHTMLReaderWriter(html *HTMLRenderer, config Config, logger clogger.Logger) *HTMLReaderWriter {
47 | return &HTMLReaderWriter{
48 | html: html,
49 | config: config,
50 | logger: logger,
51 | }
52 | }
53 |
54 | // WriteHTMLError handles the given error. In render_error is configured to true, it writes an HTML page with the error.
55 | // Errors are always logged.
56 | func (rw *HTMLReaderWriter) WriteHTMLError(w http.ResponseWriter, r *http.Request, err error) {
57 | rw.WriteHTML(w, r, WriteHTMLParams{
58 | Error: err,
59 | })
60 | }
61 |
62 | // WriteHTML writes an HTML response to the provided http.ResponseWriter. Using the given WriteHTMLParams, the HTML
63 | // is generated with a layout, page, and component templates.
64 | func (rw *HTMLReaderWriter) WriteHTML(w http.ResponseWriter, r *http.Request, p WriteHTMLParams) {
65 | if p.StatusCode == 0 && p.Error == nil {
66 | p.StatusCode = http.StatusOK
67 | }
68 |
69 | if p.StatusCode == 0 && p.Error != nil {
70 | p.StatusCode = http.StatusInternalServerError
71 | }
72 |
73 | if p.LayoutTemplate == "" {
74 | p.LayoutTemplate = "main.html"
75 | }
76 |
77 | if p.PageTemplate == "" && p.StatusCode == http.StatusInternalServerError {
78 | p.PageTemplate = "internal-error.html"
79 | }
80 |
81 | if p.PageTemplate == "" && p.StatusCode == http.StatusNotFound {
82 | p.PageTemplate = "not-found.html"
83 | }
84 |
85 | if p.Error != nil {
86 | rw.logger.WithTags(map[string]interface{}{
87 | "url": r.URL.String(),
88 | }).Error("Failed to handle request", p.Error)
89 | }
90 |
91 | if p.Error != nil && rw.config.RenderHTMLError {
92 | w.Header().Set("content-type", "text/html")
93 | w.WriteHeader(p.StatusCode)
94 |
95 | errorHTMLTmpl := template.Must(template.New("chtml/error.html").Parse(errorHTML))
96 |
97 | _ = errorHTMLTmpl.Execute(w, map[string]interface{}{
98 | "Error": p.Error.Error(),
99 | })
100 |
101 | return
102 | }
103 |
104 | out, err := rw.html.render(r, p.LayoutTemplate, p.PageTemplate, p.Data)
105 | if err != nil {
106 | rw.logger.Error("Failed to render html template", cerrors.WithTags(err, map[string]interface{}{
107 | "layout": p.LayoutTemplate,
108 | "page": p.PageTemplate,
109 | }))
110 | w.WriteHeader(http.StatusInternalServerError)
111 | return
112 | }
113 |
114 | w.Header().Set("content-type", "text/html")
115 | w.WriteHeader(p.StatusCode)
116 | _, _ = w.Write([]byte(out))
117 | }
118 |
119 | // WritePartial renders a partial template with the given name and data
120 | func (rw *HTMLReaderWriter) WritePartial(w http.ResponseWriter, r *http.Request, p WritePartialParams) {
121 | out, err := rw.html.partial(r)(p.Name, p.Data)
122 | if err != nil {
123 | rw.WriteHTMLError(w, r, cerrors.New(err, "failed to render partial", map[string]interface{}{
124 | "name": p.Name,
125 | }))
126 | return
127 | }
128 |
129 | w.Header().Set("content-type", "text/html")
130 | _, _ = w.Write([]byte(out))
131 | }
132 |
133 | // Unauthorized writes a 401 Unauthorized response to the http.ResponseWriter. If a redirect URL is configured,
134 | // the user is redirected to that URL instead.
135 | func (rw *HTMLReaderWriter) Unauthorized(w http.ResponseWriter, r *http.Request) {
136 | if rw.config.RedirectURLForUnauthorizedRequests != nil {
137 | http.Redirect(w, r, *rw.config.RedirectURLForUnauthorizedRequests, http.StatusSeeOther)
138 | return
139 | }
140 |
141 | w.WriteHeader(http.StatusUnauthorized)
142 | }
143 |
--------------------------------------------------------------------------------
/chttp/html_reader_writer_test.go:
--------------------------------------------------------------------------------
1 | package chttp_test
2 |
3 | import (
4 | "errors"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/gocopper/copper/chttp"
10 | "github.com/gocopper/copper/chttp/chttptest"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestHTMLReaderWriter_WriteHTML(t *testing.T) {
15 | t.Parallel()
16 |
17 | rw := chttptest.NewHTMLReaderWriter(t)
18 | resp := httptest.NewRecorder()
19 | req := httptest.NewRequest(http.MethodGet, "/", nil)
20 |
21 | rw.WriteHTML(resp, req, chttp.WriteHTMLParams{
22 | StatusCode: http.StatusOK,
23 | Data: map[string]string{"user": "test"},
24 | PageTemplate: "index.html",
25 | })
26 |
27 | assert.Equal(t, http.StatusOK, resp.Code)
28 | assert.Equal(t, "text/html", resp.Header().Get("content-type"))
29 | assert.Contains(t, resp.Body.String(), `Test Page`)
30 | }
31 |
32 | func TestHTMLReaderWriter_WriteHTML_NotFound(t *testing.T) {
33 | t.Parallel()
34 |
35 | rw := chttptest.NewHTMLReaderWriter(t)
36 | resp := httptest.NewRecorder()
37 | req := httptest.NewRequest(http.MethodGet, "/", nil)
38 |
39 | rw.WriteHTML(resp, req, chttp.WriteHTMLParams{
40 | StatusCode: http.StatusNotFound,
41 | Data: map[string]string{"user": "test"},
42 | })
43 |
44 | assert.Equal(t, http.StatusNotFound, resp.Code)
45 | assert.Equal(t, "text/html", resp.Header().Get("content-type"))
46 | assert.Contains(t, resp.Body.String(), `not found`)
47 | }
48 |
49 | func TestHTMLReaderWriter_WriteHTML_WithError(t *testing.T) {
50 | t.Parallel()
51 |
52 | rw := chttptest.NewHTMLReaderWriter(t)
53 | resp := httptest.NewRecorder()
54 | req := httptest.NewRequest(http.MethodGet, "/", nil)
55 |
56 | rw.WriteHTML(resp, req, chttp.WriteHTMLParams{
57 | Error: errors.New("test error"),
58 | Data: map[string]string{"user": "test"},
59 | })
60 |
61 | assert.Equal(t, http.StatusInternalServerError, resp.Code)
62 | assert.Equal(t, "text/html", resp.Header().Get("content-type"))
63 | // Since we don't know exactly what the error page looks like, we're not checking content
64 | }
65 |
66 | func TestHTMLReaderWriter_WriteHTMLError(t *testing.T) {
67 | t.Parallel()
68 |
69 | rw := chttptest.NewHTMLReaderWriter(t)
70 | resp := httptest.NewRecorder()
71 | req := httptest.NewRequest(http.MethodGet, "/", nil)
72 | testErr := errors.New("test error")
73 |
74 | rw.WriteHTMLError(resp, req, testErr)
75 |
76 | assert.Equal(t, http.StatusInternalServerError, resp.Code)
77 | assert.Equal(t, "text/html", resp.Header().Get("content-type"))
78 | // Since the error rendering depends on config, we're not checking the exact content
79 | }
80 |
81 | func TestHTMLReaderWriter_WriteHTML_CustomLayout(t *testing.T) {
82 | t.Parallel()
83 |
84 | rw := chttptest.NewHTMLReaderWriter(t)
85 | resp := httptest.NewRecorder()
86 | req := httptest.NewRequest(http.MethodGet, "/", nil)
87 |
88 | rw.WriteHTML(resp, req, chttp.WriteHTMLParams{
89 | StatusCode: http.StatusOK,
90 | Data: map[string]string{"user": "test"},
91 | PageTemplate: "index.html",
92 | LayoutTemplate: "main.html", // explicitly set the default layout
93 | })
94 |
95 | assert.Equal(t, http.StatusOK, resp.Code)
96 | assert.Equal(t, "text/html", resp.Header().Get("content-type"))
97 | assert.Contains(t, resp.Body.String(), `Test Page`)
98 | }
99 |
100 | func TestHTMLReaderWriter_WritePartial(t *testing.T) {
101 | t.Parallel()
102 |
103 | rw := chttptest.NewHTMLReaderWriter(t)
104 | resp := httptest.NewRecorder()
105 | req := httptest.NewRequest(http.MethodGet, "/", nil)
106 |
107 | rw.WritePartial(resp, req, chttp.WritePartialParams{
108 | Name: "partial.html",
109 | Data: map[string]string{"content": "partial content"},
110 | })
111 |
112 | assert.Equal(t, "text/html", resp.Header().Get("content-type"))
113 | // Test would be more specific if we knew what the partial template contained
114 | }
115 |
116 | func TestHTMLReaderWriter_Unauthorized_WithoutRedirect(t *testing.T) {
117 | t.Parallel()
118 |
119 | // Create a reader/writer with default config (no redirect URL)
120 | rw := chttptest.NewHTMLReaderWriter(t)
121 | resp := httptest.NewRecorder()
122 | req := httptest.NewRequest(http.MethodGet, "/protected", nil)
123 |
124 | rw.Unauthorized(resp, req)
125 |
126 | assert.Equal(t, http.StatusUnauthorized, resp.Code)
127 | }
128 |
129 | func TestHTMLReaderWriter_Unauthorized_WithRedirect(t *testing.T) {
130 | t.Parallel()
131 |
132 | // This test would require mocking a config with RedirectURLForUnauthorizedRequests
133 | // Since we can't easily modify the config in the test helper, this is a demonstration
134 | // of how the test would look
135 |
136 | // Assuming we had a way to create a HTMLReaderWriter with a redirect URL in the config:
137 | // redirectURL := "/login"
138 | // rw := createReaderWriterWithRedirectConfig(t, &redirectURL)
139 |
140 | // resp := httptest.NewRecorder()
141 | // req := httptest.NewRequest(http.MethodGet, "/protected", nil)
142 |
143 | // rw.Unauthorized(resp, req)
144 |
145 | // assert.Equal(t, http.StatusSeeOther, resp.Code)
146 | // assert.Equal(t, "/login", resp.Header().Get("Location"))
147 | }
148 |
--------------------------------------------------------------------------------
/chttp/html_renderer.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "html/template"
5 | "io/fs"
6 | "net/http"
7 | "os"
8 | "path"
9 | "path/filepath"
10 | "strings"
11 |
12 | "github.com/gocopper/copper/clogger"
13 |
14 | "github.com/Masterminds/sprig/v3"
15 | "github.com/gocopper/copper/cerrors"
16 | )
17 |
18 | type (
19 | // HTMLDir is a directory that can be embedded or found on the host system. It should contain sub-directories
20 | // and files to support the WriteHTML function in ReaderWriter.
21 | HTMLDir fs.FS
22 |
23 | // StaticDir represents a directory that holds static resources (JS, CSS, images, etc.)
24 | StaticDir fs.FS
25 |
26 | // HTMLRenderer provides functionality in rendering templatized HTML along with HTML components
27 | HTMLRenderer struct {
28 | htmlDir HTMLDir
29 | staticDir StaticDir
30 | renderFuncs []HTMLRenderFunc
31 | }
32 |
33 | // HTMLRenderFunc can be used to register new template functions
34 | HTMLRenderFunc struct {
35 | // Name for the function that can be invoked in a template
36 | Name string
37 |
38 | // Func should return a function that takes in any number of params and returns either a single return value,
39 | // or two return values of which the second has type error.
40 | Func func(r *http.Request) interface{}
41 | }
42 |
43 | // NewHTMLRendererParams holds the params needed to create HTMLRenderer
44 | NewHTMLRendererParams struct {
45 | HTMLDir HTMLDir
46 | StaticDir StaticDir
47 | RenderFuncs []HTMLRenderFunc
48 | Config Config
49 | Logger clogger.Logger
50 | }
51 | )
52 |
53 | // NewHTMLRenderer creates a new HTMLRenderer with HTML templates stored in dir and registers the provided HTML
54 | // components
55 | func NewHTMLRenderer(p NewHTMLRendererParams) (*HTMLRenderer, error) {
56 | hr := HTMLRenderer{
57 | htmlDir: p.HTMLDir,
58 | staticDir: p.StaticDir,
59 | renderFuncs: p.RenderFuncs,
60 | }
61 |
62 | if p.Config.UseLocalHTML {
63 | wd, err := os.Getwd()
64 | if err != nil {
65 | return nil, cerrors.New(err, "failed to get current working directory", nil)
66 | }
67 |
68 | hr.htmlDir = os.DirFS(filepath.Join(wd, "web"))
69 | }
70 |
71 | return &hr, nil
72 | }
73 |
74 | func (r *HTMLRenderer) funcMap(req *http.Request) template.FuncMap {
75 | var funcMap = sprig.FuncMap()
76 |
77 | for i := range r.renderFuncs {
78 | funcMap[r.renderFuncs[i].Name] = r.renderFuncs[i].Func(req)
79 | }
80 |
81 | funcMap["partial"] = r.partial(req)
82 |
83 | return funcMap
84 | }
85 |
86 | func (r *HTMLRenderer) render(req *http.Request, layout, page string, data interface{}) (template.HTML, error) {
87 | var dest strings.Builder
88 |
89 | tmpl, err := template.New(layout).
90 | Funcs(r.funcMap(req)).
91 | ParseFS(r.htmlDir,
92 | path.Join("src", "layouts", layout),
93 | path.Join("src", "pages", page),
94 | )
95 | if err != nil {
96 | return "", cerrors.New(err, "failed to parse templates in html dir", map[string]interface{}{
97 | "layout": layout,
98 | "page": page,
99 | })
100 | }
101 |
102 | err = tmpl.Execute(&dest, data)
103 | if err != nil {
104 | return "", cerrors.New(err, "failed to execute template", nil)
105 | }
106 |
107 | //nolint:gosec
108 | return template.HTML(dest.String()), nil
109 | }
110 |
111 | func (r *HTMLRenderer) partial(req *http.Request) func(name string, data interface{}) (template.HTML, error) {
112 | return func(name string, data interface{}) (template.HTML, error) {
113 | var dest strings.Builder
114 |
115 | tmpl, err := template.New(name).
116 | Funcs(r.funcMap(req)).
117 | ParseFS(r.htmlDir,
118 | path.Join("src", "partials", "*html"),
119 | )
120 | if err != nil {
121 | return "", cerrors.New(err, "failed to parse partial template", map[string]interface{}{
122 | "name": name,
123 | })
124 | }
125 |
126 | for _, ext := range []string{".html", ".gohtml"} {
127 | t := tmpl.Lookup(name + ext)
128 | if t != nil {
129 | tmpl = t
130 | break
131 | }
132 | }
133 |
134 | err = tmpl.Execute(&dest, data)
135 | if err != nil {
136 | return "", cerrors.New(err, "failed to execute partial template", map[string]interface{}{
137 | "name": name,
138 | })
139 | }
140 |
141 | //nolint:gosec
142 | return template.HTML(dest.String()), nil
143 | }
144 | }
145 |
--------------------------------------------------------------------------------
/chttp/html_router.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "net/http"
5 | "path"
6 | "strings"
7 | )
8 |
9 | type (
10 | // HTMLRouter provides routes to serve (1) static assets (2) index page for an SPA
11 | HTMLRouter struct {
12 | rw *HTMLReaderWriter
13 | staticDir StaticDir
14 | config Config
15 | }
16 |
17 | // NewHTMLRouterParams holds the params needed to instantiate a new Router
18 | NewHTMLRouterParams struct {
19 | StaticDir StaticDir
20 | RW *HTMLReaderWriter
21 | Config Config
22 | }
23 | )
24 |
25 | // NewHTMLRouter instantiates a new Router
26 | func NewHTMLRouter(p NewHTMLRouterParams) (*HTMLRouter, error) {
27 | return &HTMLRouter{
28 | rw: p.RW,
29 | staticDir: p.StaticDir,
30 | config: p.Config,
31 | }, nil
32 | }
33 |
34 | // Routes defines the HTTP routes for this router
35 | func (ro *HTMLRouter) Routes() []Route {
36 | routes := []Route{
37 | {
38 | Path: "/static/{path:.*}",
39 | Methods: []string{http.MethodGet},
40 | Handler: ro.HandleStaticFile,
41 | RegisterWithBasePath: true,
42 | },
43 | }
44 |
45 | if ro.config.EnableSinglePageRouting {
46 | routes = append(routes, Route{
47 | Path: "/{path:.*}",
48 | Methods: []string{http.MethodGet},
49 | Handler: ro.HandleIndexPage,
50 | RegisterWithBasePath: true,
51 | })
52 | }
53 |
54 | return routes
55 | }
56 |
57 | // HandleStaticFile serves the requested static file as found in the web/public directory. In non-dev env, the static
58 | // files are embedded in the binary.
59 | func (ro *HTMLRouter) HandleStaticFile(w http.ResponseWriter, r *http.Request) {
60 | // Disable directory listing
61 | if strings.HasSuffix(r.URL.Path, "/") {
62 | http.NotFound(w, r)
63 | return
64 | }
65 |
66 | if ro.config.UseLocalHTML {
67 | http.ServeFile(w, r, path.Join("web", "public", URLParams(r)["path"]))
68 | return
69 | }
70 |
71 | http.FileServer(http.FS(ro.staticDir)).ServeHTTP(w, r)
72 | }
73 |
74 | // HandleIndexPage renders the index.html page
75 | func (ro *HTMLRouter) HandleIndexPage(w http.ResponseWriter, r *http.Request) {
76 | ro.rw.WriteHTML(w, r, WriteHTMLParams{
77 | PageTemplate: "index.html",
78 | })
79 | }
80 |
--------------------------------------------------------------------------------
/chttp/http.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "net/http"
5 | )
6 |
7 | // Route represents a single HTTP route (ex. /api/profile) that can be configured with middlewares, path,
8 | // HTTP methods, and a handler.
9 | type Route struct {
10 | Middlewares []Middleware
11 | Path string
12 | Methods []string
13 | Handler http.HandlerFunc
14 |
15 | // RegisterWithBasePath ensures the route is registered with the base path,
16 | // even if its original path does not include the base path prefix
17 | RegisterWithBasePath bool
18 | }
19 |
20 | // Router is used to group routes together that are returned by the Routes method.
21 | type Router interface {
22 | Routes() []Route
23 | }
24 |
--------------------------------------------------------------------------------
/chttp/json_reader_writer.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "io"
7 | "net/http"
8 |
9 | "github.com/asaskevich/govalidator"
10 | "github.com/gocopper/copper/cerrors"
11 | "github.com/gocopper/copper/clogger"
12 | )
13 |
14 | type (
15 | // WriteJSONParams holds the params for the WriteJSON function in JSONReaderWriter
16 | WriteJSONParams struct {
17 | StatusCode int
18 | Data interface{}
19 | }
20 |
21 | // JSONReaderWriter provides functions to read and write JSON data from/to HTTP requests/responses
22 | JSONReaderWriter struct {
23 | config Config
24 | logger clogger.Logger
25 | }
26 | )
27 |
28 | // NewJSONReaderWriter instantiates a new JSONReaderWriter with its dependencies
29 | func NewJSONReaderWriter(config Config, logger clogger.Logger) *JSONReaderWriter {
30 | return &JSONReaderWriter{
31 | config: config,
32 | logger: logger,
33 | }
34 | }
35 |
36 | // WriteJSON writes a JSON response to the http.ResponseWriter. It can be configured with status code and data using
37 | // WriteJSONParams.
38 | func (rw *JSONReaderWriter) WriteJSON(w http.ResponseWriter, p WriteJSONParams) {
39 | w.Header().Set("Content-Type", "application/json")
40 |
41 | if p.StatusCode > 0 {
42 | w.WriteHeader(p.StatusCode)
43 | }
44 |
45 | if p.Data == nil {
46 | return
47 | }
48 |
49 | errData, ok := p.Data.(error)
50 | if ok {
51 | err := json.NewEncoder(w).Encode(map[string]string{
52 | "error": errData.Error(),
53 | })
54 | if err != nil {
55 | rw.logger.Error("Failed to marshal error response as json", err)
56 | w.WriteHeader(http.StatusInternalServerError)
57 | }
58 |
59 | return
60 | }
61 |
62 | err := json.NewEncoder(w).Encode(p.Data)
63 | if err != nil {
64 | rw.logger.Error("Failed to marshal response as json", err)
65 | w.WriteHeader(http.StatusInternalServerError)
66 |
67 | return
68 | }
69 | }
70 |
71 | // ReadJSON reads JSON from the http.Request into the body var. If the body struct has validate tags on it, the
72 | // struct is also validated. If the validation fails, a BadRequest response is sent back and the function returns
73 | // false.
74 | func (rw *JSONReaderWriter) ReadJSON(w http.ResponseWriter, req *http.Request, body interface{}) bool {
75 | url := req.URL.String()
76 |
77 | err := json.NewDecoder(req.Body).Decode(body)
78 | if err != nil && errors.Is(err, io.EOF) {
79 | rw.WriteJSON(w, WriteJSONParams{
80 | StatusCode: http.StatusBadRequest,
81 | Data: map[string]string{"error": "empty body"},
82 | })
83 |
84 | return false
85 | } else if err != nil {
86 | rw.WriteJSON(w, WriteJSONParams{
87 | StatusCode: http.StatusBadRequest,
88 | Data: cerrors.New(err, "invalid body json", map[string]interface{}{
89 | "url": url,
90 | }),
91 | })
92 |
93 | return false
94 | }
95 |
96 | ok, err := govalidator.ValidateStruct(body)
97 | if !ok {
98 | rw.logger.Warn("Failed to read body", cerrors.New(err, "data validation failed", map[string]interface{}{
99 | "url": url,
100 | }))
101 |
102 | rw.WriteJSON(w, WriteJSONParams{
103 | StatusCode: http.StatusBadRequest,
104 | Data: err,
105 | })
106 |
107 | return false
108 | }
109 |
110 | return true
111 | }
112 |
113 | // Unauthorized writes a 401 Unauthorized JSON response
114 | func (rw *JSONReaderWriter) Unauthorized(w http.ResponseWriter) {
115 | rw.WriteJSON(w, WriteJSONParams{
116 | StatusCode: http.StatusUnauthorized,
117 | Data: map[string]string{"error": "unauthorized"},
118 | })
119 | }
120 |
--------------------------------------------------------------------------------
/chttp/json_reader_writer_test.go:
--------------------------------------------------------------------------------
1 | package chttp_test
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 |
10 | "github.com/gocopper/copper/chttp"
11 | "github.com/gocopper/copper/clogger"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | func TestJSONReaderWriter_ReadJSON(t *testing.T) {
16 | t.Parallel()
17 |
18 | var body struct {
19 | Key string `json:"key"`
20 | }
21 |
22 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop())
23 |
24 | ok := rw.ReadJSON(
25 | httptest.NewRecorder(),
26 | httptest.NewRequest(http.MethodGet, "/", bytes.NewReader([]byte(`{"key": "value"}`))),
27 | &body,
28 | )
29 |
30 | assert.True(t, ok)
31 | assert.Equal(t, "value", body.Key)
32 | }
33 |
34 | func TestJSONReaderWriter_ReadJSON_Invalid_Body(t *testing.T) {
35 | t.Parallel()
36 |
37 | var body struct {
38 | Key string `json:"key"`
39 | }
40 |
41 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop())
42 | resp := httptest.NewRecorder()
43 |
44 | ok := rw.ReadJSON(
45 | resp,
46 | httptest.NewRequest(http.MethodGet, "/", bytes.NewReader([]byte(`{ invalid json }`))),
47 | &body,
48 | )
49 |
50 | assert.False(t, ok)
51 | assert.Equal(t, http.StatusBadRequest, resp.Code)
52 | assert.Equal(t, "application/json", resp.Header().Get("content-type"))
53 | assert.Contains(t, resp.Body.String(), "invalid body json")
54 | }
55 |
56 | func TestJSONReaderWriter_ReadJSON_Empty_Body(t *testing.T) {
57 | t.Parallel()
58 |
59 | var body struct {
60 | Key string `json:"key"`
61 | }
62 |
63 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop())
64 | resp := httptest.NewRecorder()
65 |
66 | ok := rw.ReadJSON(
67 | resp,
68 | httptest.NewRequest(http.MethodGet, "/", bytes.NewReader([]byte(``))),
69 | &body,
70 | )
71 |
72 | assert.False(t, ok)
73 | assert.Equal(t, http.StatusBadRequest, resp.Code)
74 | assert.Equal(t, "application/json", resp.Header().Get("content-type"))
75 | assert.Contains(t, resp.Body.String(), "empty body")
76 | }
77 |
78 | func TestJSONReaderWriter_ReadJSON_Validator(t *testing.T) {
79 | t.Parallel()
80 |
81 | var body struct {
82 | Key string `json:"key" valid:"email"`
83 | }
84 |
85 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop())
86 | resp := httptest.NewRecorder()
87 |
88 | ok := rw.ReadJSON(
89 | resp,
90 | httptest.NewRequest(http.MethodGet, "/", bytes.NewReader([]byte(`{"key": "value"}`))),
91 | &body,
92 | )
93 |
94 | assert.False(t, ok)
95 | assert.Equal(t, http.StatusBadRequest, resp.Code)
96 | }
97 |
98 | func TestJSONReaderWriter_WriteJSON_Data(t *testing.T) {
99 | t.Parallel()
100 |
101 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop())
102 | resp := httptest.NewRecorder()
103 |
104 | rw.WriteJSON(resp, chttp.WriteJSONParams{
105 | Data: map[string]string{
106 | "key": "val",
107 | },
108 | })
109 |
110 | assert.Equal(t, http.StatusOK, resp.Code)
111 | assert.Equal(t, "application/json", resp.Header().Get("content-type"))
112 | assert.Contains(t, resp.Body.String(), `{"key":"val"}`)
113 | }
114 |
115 | func TestJSONReaderWriter_WriteJSON_StatusCode(t *testing.T) {
116 | t.Parallel()
117 |
118 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop())
119 | resp := httptest.NewRecorder()
120 |
121 | rw.WriteJSON(resp, chttp.WriteJSONParams{
122 | StatusCode: http.StatusCreated,
123 | Data: map[string]string{
124 | "key": "val",
125 | },
126 | })
127 |
128 | assert.Equal(t, http.StatusCreated, resp.Code)
129 | assert.Equal(t, "application/json", resp.Header().Get("content-type"))
130 | assert.Contains(t, resp.Body.String(), `{"key":"val"}`)
131 | }
132 |
133 | func TestJSONReaderWriter_WriteJSON_Error(t *testing.T) {
134 | t.Parallel()
135 |
136 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop())
137 | resp := httptest.NewRecorder()
138 |
139 | rw.WriteJSON(resp, chttp.WriteJSONParams{
140 | StatusCode: http.StatusBadRequest,
141 | Data: errors.New("test-err"), //nolint:goerr113
142 | })
143 |
144 | assert.Equal(t, http.StatusBadRequest, resp.Code)
145 | assert.Equal(t, "application/json", resp.Header().Get("content-type"))
146 | assert.Contains(t, resp.Body.String(), `{"error":"test-err"}`)
147 | }
148 |
149 | func TestJSONReaderWriter_WriteJSON_NilData(t *testing.T) {
150 | t.Parallel()
151 |
152 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop())
153 | resp := httptest.NewRecorder()
154 |
155 | rw.WriteJSON(resp, chttp.WriteJSONParams{
156 | StatusCode: http.StatusOK,
157 | Data: nil,
158 | })
159 |
160 | assert.Equal(t, http.StatusOK, resp.Code)
161 | assert.Equal(t, "application/json", resp.Header().Get("content-type"))
162 | assert.Empty(t, resp.Body.String())
163 | }
164 |
165 | func TestJSONReaderWriter_Unauthorized(t *testing.T) {
166 | t.Parallel()
167 |
168 | rw := chttp.NewJSONReaderWriter(chttp.Config{}, clogger.NewNoop())
169 | resp := httptest.NewRecorder()
170 |
171 | rw.Unauthorized(resp)
172 |
173 | assert.Equal(t, http.StatusUnauthorized, resp.Code)
174 | assert.Equal(t, "application/json", resp.Header().Get("content-type"))
175 | assert.Contains(t, resp.Body.String(), `{"error":"unauthorized"}`)
176 | }
177 |
--------------------------------------------------------------------------------
/chttp/middleware.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import "net/http"
4 |
5 | // Middleware defines the interface with a Handle func which is of the MiddlewareFunc type. Implementations of
6 | // Middleware can be used with NewHandler for global middlewares or Route for route-specific middlewares.
7 | type Middleware interface {
8 | Handle(next http.Handler) http.Handler
9 | }
10 |
11 | // MiddlewareFunc is a function that takes in a http.Handler and returns one as well. It allows you to execute
12 | // code before or after calling the handler.
13 | type MiddlewareFunc func(next http.Handler) http.Handler
14 |
15 | // HandleMiddleware returns an implementation of Middleware that runs the provided func.
16 | func HandleMiddleware(fn MiddlewareFunc) Middleware {
17 | return &middlewareFuncHandler{fn: fn}
18 | }
19 |
20 | type middlewareFuncHandler struct {
21 | fn MiddlewareFunc
22 | }
23 |
24 | func (mw *middlewareFuncHandler) Handle(next http.Handler) http.Handler {
25 | return mw.fn(next)
26 | }
27 |
--------------------------------------------------------------------------------
/chttp/panic_logger_mw.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "net/http"
5 | "runtime/debug"
6 |
7 | "github.com/gocopper/copper/clogger"
8 | )
9 |
10 | func panicLoggerMiddleware(logger clogger.Logger) Middleware {
11 | mw := func(next http.Handler) http.Handler {
12 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
13 | defer func() {
14 | log := logger.WithTags(map[string]interface{}{
15 | "path": r.URL.Path,
16 | })
17 |
18 | switch r := recover().(type) {
19 | case nil:
20 | break
21 | case error:
22 | log.WithTags(map[string]interface{}{
23 | "stack": string(debug.Stack()),
24 | }).Error("Recovered from a panic while handling HTTP request", r)
25 | w.WriteHeader(http.StatusInternalServerError)
26 | default:
27 | log.WithTags(map[string]interface{}{
28 | "error": r,
29 | "stack": string(debug.Stack()),
30 | }).Error("Recovered from a panic while handling HTTP request", nil)
31 | w.WriteHeader(http.StatusInternalServerError)
32 | }
33 | }()
34 |
35 | next.ServeHTTP(w, r)
36 | })
37 | }
38 |
39 | return HandleMiddleware(mw)
40 | }
41 |
--------------------------------------------------------------------------------
/chttp/panic_logger_mw_test.go:
--------------------------------------------------------------------------------
1 | package chttp_test
2 |
3 | import (
4 | "errors"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/gocopper/copper/clogger"
10 |
11 | "github.com/gocopper/copper/chttp"
12 | "github.com/gocopper/copper/chttp/chttptest"
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | func TestPanicLoggerMiddleware_PanicError(t *testing.T) {
17 | t.Parallel()
18 |
19 | var (
20 | router = chttptest.NewRouter([]chttp.Route{
21 | {
22 | Path: "/",
23 | Methods: []string{http.MethodGet},
24 | Handler: func(w http.ResponseWriter, r *http.Request) {
25 | panic(errors.New("test-error"))
26 | },
27 | },
28 | })
29 |
30 | logs = make([]clogger.RecordedLog, 0)
31 |
32 | handler = chttp.NewHandler(chttp.NewHandlerParams{
33 | Routers: []chttp.Router{router},
34 | Logger: clogger.NewRecorder(&logs),
35 | })
36 | )
37 |
38 | server := httptest.NewServer(handler)
39 | defer server.Close()
40 |
41 | resp, err := http.Get(server.URL) //nolint:noctx
42 | assert.NoError(t, err)
43 | assert.NoError(t, resp.Body.Close())
44 |
45 | assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
46 | assert.Equal(t, 1, len(logs))
47 | assert.Equal(t, "Recovered from a panic while handling HTTP request", logs[0].Msg)
48 | assert.Equal(t, clogger.LevelError, logs[0].Level)
49 | assert.Equal(t, errors.New("test-error"), logs[0].Error)
50 | assert.Contains(t, logs[0].Tags["stack"], "panic_logger_mw.go")
51 | }
52 |
53 | func TestPanicLoggerMiddleware_NoPanic(t *testing.T) {
54 | t.Parallel()
55 |
56 | var (
57 | router = chttptest.NewRouter([]chttp.Route{
58 | {
59 | Path: "/",
60 | Methods: []string{http.MethodGet},
61 | Handler: func(w http.ResponseWriter, r *http.Request) {
62 | w.WriteHeader(http.StatusOK)
63 | },
64 | },
65 | })
66 |
67 | logs = make([]clogger.RecordedLog, 0)
68 |
69 | handler = chttp.NewHandler(chttp.NewHandlerParams{
70 | Routers: []chttp.Router{router},
71 | Logger: clogger.NewRecorder(&logs),
72 | })
73 | )
74 |
75 | server := httptest.NewServer(handler)
76 | defer server.Close()
77 |
78 | resp, err := http.Get(server.URL) //nolint:noctx
79 | assert.NoError(t, err)
80 | assert.NoError(t, resp.Body.Close())
81 |
82 | assert.Equal(t, http.StatusOK, resp.StatusCode)
83 | assert.Equal(t, 0, len(logs))
84 | }
85 |
86 | func TestPanicLoggerMiddleware_PanicNonError(t *testing.T) {
87 | t.Parallel()
88 |
89 | var (
90 | router = chttptest.NewRouter([]chttp.Route{
91 | {
92 | Path: "/",
93 | Methods: []string{http.MethodGet},
94 | Handler: func(w http.ResponseWriter, r *http.Request) {
95 | panic("test-error")
96 | },
97 | },
98 | })
99 |
100 | logs = make([]clogger.RecordedLog, 0)
101 |
102 | handler = chttp.NewHandler(chttp.NewHandlerParams{
103 | Routers: []chttp.Router{router},
104 | Logger: clogger.NewRecorder(&logs),
105 | })
106 | )
107 |
108 | server := httptest.NewServer(handler)
109 | defer server.Close()
110 |
111 | resp, err := http.Get(server.URL) //nolint:noctx
112 | assert.NoError(t, err)
113 | assert.NoError(t, resp.Body.Close())
114 |
115 | assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
116 | assert.Equal(t, 1, len(logs))
117 | assert.Equal(t, "Recovered from a panic while handling HTTP request", logs[0].Msg)
118 | assert.Equal(t, clogger.LevelError, logs[0].Level)
119 | assert.Equal(t, "test-error", logs[0].Tags["error"])
120 | assert.Nil(t, logs[0].Error)
121 | assert.Contains(t, logs[0].Tags["stack"], "panic_logger_mw.go")
122 | }
123 |
--------------------------------------------------------------------------------
/chttp/request_id_mw.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "context"
5 | "net/http"
6 |
7 | "github.com/google/uuid"
8 | )
9 |
10 | type ctxRequestID string
11 |
12 | const ctxRequestIDKey = ctxRequestID("chttp/request-id")
13 |
14 | // SetRequestIDInCtxMiddleware sets a unique request id in the context
15 | func SetRequestIDInCtxMiddleware() Middleware {
16 | var mw = func(next http.Handler) http.Handler {
17 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
18 | ctx := context.WithValue(r.Context(), ctxRequestIDKey, uuid.New().String())
19 |
20 | next.ServeHTTP(w, r.WithContext(ctx))
21 | })
22 | }
23 |
24 | return HandleMiddleware(mw)
25 | }
26 |
27 | // GetRequestID returns the request id from the context.
28 | // If the request id is not found in the context, it returns
29 | // an empty string.
30 | // It should be used only after the request id middleware is applied.
31 | func GetRequestID(ctx context.Context) string {
32 | requestID, ok := ctx.Value(ctxRequestIDKey).(string)
33 | if !ok || requestID == "" {
34 | return ""
35 | }
36 |
37 | return requestID
38 | }
39 |
--------------------------------------------------------------------------------
/chttp/request_logger_mw.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "bufio"
5 | "errors"
6 | "fmt"
7 | "github.com/gocopper/copper/cmetrics"
8 | "net"
9 | "net/http"
10 | "time"
11 |
12 | "github.com/gocopper/copper/clogger"
13 | )
14 |
15 | var errRWIsNotHijacker = errors.New("internal response writer is not http.Hijacker")
16 |
17 | // NewRequestLoggerMiddleware creates a new RequestLoggerMiddleware.
18 | func NewRequestLoggerMiddleware(metrics cmetrics.Metrics, logger clogger.Logger) *RequestLoggerMiddleware {
19 | return &RequestLoggerMiddleware{
20 | metrics: metrics,
21 | logger: logger,
22 | }
23 | }
24 |
25 | // RequestLoggerMiddleware logs each request's HTTP method, path, and status code along with user uuid
26 | // (from basic auth) if any.
27 | type RequestLoggerMiddleware struct {
28 | metrics cmetrics.Metrics
29 | logger clogger.Logger
30 | }
31 |
32 | // Handle wraps the current request with a request/response recorder. It records the method path and the
33 | // return status code. It logs this with the given logger.
34 | func (mw *RequestLoggerMiddleware) Handle(next http.Handler) http.Handler {
35 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
36 | var (
37 | loggerRw = requestLoggerRw{
38 | internal: w,
39 | statusCode: http.StatusOK,
40 | }
41 |
42 | tags = map[string]interface{}{
43 | "method": r.Method,
44 | "url": r.URL.Path,
45 | }
46 |
47 | begin = time.Now()
48 | )
49 |
50 | user, _, ok := r.BasicAuth()
51 | if ok {
52 | tags["user"] = user
53 | }
54 |
55 | next.ServeHTTP(&loggerRw, r)
56 |
57 | tags["statusCode"] = loggerRw.statusCode
58 | tags["duration"] = time.Since(begin).String()
59 |
60 | mw.metrics.CounterInc("http_requests_total", map[string]string{
61 | "status_code": fmt.Sprintf("%d", loggerRw.statusCode),
62 | "path": RawRoutePath(r),
63 | })
64 | mw.metrics.HistogramObserve("http_request_duration_seconds", map[string]string{
65 | "status_code": fmt.Sprintf("%d", loggerRw.statusCode),
66 | "path": RawRoutePath(r),
67 | }, time.Since(begin).Seconds())
68 |
69 | mw.logger.WithTags(tags).Info(fmt.Sprintf("%s %s %d", r.Method, r.URL.Path, loggerRw.statusCode))
70 | })
71 | }
72 |
73 | type requestLoggerRw struct {
74 | internal http.ResponseWriter
75 | statusCode int
76 | }
77 |
78 | func (rw *requestLoggerRw) Hijack() (net.Conn, *bufio.ReadWriter, error) {
79 | h, ok := rw.internal.(http.Hijacker)
80 | if !ok {
81 | return nil, nil, errRWIsNotHijacker
82 | }
83 |
84 | return h.Hijack()
85 | }
86 |
87 | func (rw *requestLoggerRw) Header() http.Header {
88 | return rw.internal.Header()
89 | }
90 |
91 | func (rw *requestLoggerRw) Write(b []byte) (int, error) {
92 | return rw.internal.Write(b)
93 | }
94 |
95 | func (rw *requestLoggerRw) WriteHeader(statusCode int) {
96 | rw.internal.WriteHeader(statusCode)
97 | rw.statusCode = statusCode
98 | }
99 |
--------------------------------------------------------------------------------
/chttp/request_logger_mw_test.go:
--------------------------------------------------------------------------------
1 | package chttp_test
2 |
3 | import (
4 | "github.com/gocopper/copper/cmetrics"
5 | "io"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 |
10 | "github.com/gocopper/copper/chttp"
11 | "github.com/gocopper/copper/chttp/chttptest"
12 | "github.com/gocopper/copper/clogger"
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | func TestNewRequestLoggerMiddleware(t *testing.T) {
17 | t.Parallel()
18 |
19 | var (
20 | logs = make([]clogger.RecordedLog, 0)
21 | logger = clogger.NewRecorder(&logs)
22 | metrics = cmetrics.NewNoopMetrics()
23 | router = chttptest.NewRouter([]chttp.Route{
24 | {
25 | Middlewares: []chttp.Middleware{
26 | chttp.NewRequestLoggerMiddleware(metrics, logger),
27 | },
28 | Path: "/test",
29 | Methods: []string{http.MethodGet},
30 | Handler: func(w http.ResponseWriter, r *http.Request) {
31 | w.WriteHeader(201)
32 |
33 | _, err := w.Write([]byte("OK"))
34 | assert.NoError(t, err)
35 | },
36 | },
37 | })
38 | handler = chttp.NewHandler(chttp.NewHandlerParams{
39 | Routers: []chttp.Router{router},
40 | GlobalMiddlewares: nil,
41 | Logger: clogger.NewNoop(),
42 | })
43 | )
44 |
45 | server := httptest.NewServer(handler)
46 | defer server.Close()
47 |
48 | resp, err := http.Get(server.URL + "/test") //nolint:noctx
49 | assert.NoError(t, err)
50 |
51 | body, err := io.ReadAll(resp.Body)
52 | assert.NoError(t, err)
53 | assert.NoError(t, resp.Body.Close())
54 |
55 | assert.Equal(t, "OK", string(body))
56 | assert.Equal(t, clogger.LevelInfo, logs[0].Level)
57 | assert.Equal(t, "GET /test 201", logs[0].Msg)
58 | }
59 |
--------------------------------------------------------------------------------
/chttp/route_ctx_mw.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "context"
5 | "net/http"
6 | )
7 |
8 | type ctxRoutePath string
9 |
10 | const ctxRoutePathKey = ctxRoutePath("chttp/route-path")
11 |
12 | func setRoutePathInCtxMiddleware(path string) Middleware {
13 | var mw = func(next http.Handler) http.Handler {
14 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15 | ctx := context.WithValue(r.Context(), ctxRoutePathKey, path)
16 |
17 | next.ServeHTTP(w, r.WithContext(ctx))
18 | })
19 | }
20 |
21 | return HandleMiddleware(mw)
22 | }
23 |
24 | // RawRoutePath returns the route path that matched for the given http.Request. This path includes the raw URL
25 | // variables. For example, a route path "/foo/{id}" will be returned as-is (i.e. {id} will NOT be replaced with the
26 | // actual url path)
27 | func RawRoutePath(r *http.Request) string {
28 | return r.Context().Value(ctxRoutePathKey).(string)
29 | }
30 |
--------------------------------------------------------------------------------
/chttp/route_ctx_mw_test.go:
--------------------------------------------------------------------------------
1 | package chttp_test
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "testing"
7 |
8 | "github.com/gocopper/copper/clogger"
9 |
10 | "github.com/gocopper/copper/chttp"
11 | "github.com/gocopper/copper/chttp/chttptest"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | func TestRoutePathInCtxMiddleware(t *testing.T) {
16 | t.Parallel()
17 |
18 | var (
19 | routeMWRawRoutePath string
20 | globalMWRawRoutePath string
21 |
22 | router = chttptest.NewRouter([]chttp.Route{
23 | {
24 | Path: "/foo/{id}",
25 | Methods: []string{http.MethodGet},
26 | Handler: func(w http.ResponseWriter, r *http.Request) {
27 | routeMWRawRoutePath = chttp.RawRoutePath(r)
28 | },
29 | },
30 | })
31 |
32 | handler = chttp.NewHandler(chttp.NewHandlerParams{
33 | Routers: []chttp.Router{router},
34 | GlobalMiddlewares: []chttp.Middleware{
35 | chttp.HandleMiddleware(func(next http.Handler) http.Handler {
36 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
37 | globalMWRawRoutePath = chttp.RawRoutePath(r)
38 |
39 | next.ServeHTTP(w, r)
40 | })
41 | }),
42 | },
43 | Logger: clogger.NewNoop(),
44 | })
45 | )
46 |
47 | server := httptest.NewServer(handler)
48 | defer server.Close()
49 |
50 | resp, err := http.Get(server.URL + "/foo/bar") //nolint:noctx
51 | assert.NoError(t, err)
52 | assert.NoError(t, resp.Body.Close())
53 |
54 | assert.Equal(t, "/foo/{id}", globalMWRawRoutePath)
55 | assert.Equal(t, "/foo/{id}", routeMWRawRoutePath)
56 | }
57 |
--------------------------------------------------------------------------------
/chttp/server.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "net/http"
8 | "time"
9 |
10 | "github.com/gocopper/copper/clifecycle"
11 | "github.com/gocopper/copper/clogger"
12 | )
13 |
14 | // NewServerParams holds the params needed to create a server.
15 | type NewServerParams struct {
16 | Handler http.Handler
17 | Lifecycle *clifecycle.Lifecycle
18 | Config Config
19 | Logger clogger.Logger
20 | }
21 |
22 | // NewServer creates a new server.
23 | func NewServer(p NewServerParams) *Server {
24 | return &Server{
25 | handler: p.Handler,
26 | config: p.Config,
27 | logger: p.Logger,
28 | lc: p.Lifecycle,
29 | internal: http.Server{
30 | ReadTimeout: time.Duration(p.Config.ReadTimeoutSeconds) * time.Second,
31 | },
32 | }
33 | }
34 |
35 | // Server represents a configurable HTTP server that supports graceful shutdown.
36 | type Server struct {
37 | handler http.Handler
38 | config Config
39 | logger clogger.Logger
40 | lc *clifecycle.Lifecycle
41 |
42 | internal http.Server
43 | }
44 |
45 | // Run configures an HTTP server using the provided app config and starts it.
46 | func (s *Server) Run() error {
47 | s.internal.Addr = fmt.Sprintf(":%d", s.config.Port)
48 | s.internal.Handler = s.handler
49 |
50 | s.lc.OnStop(func(ctx context.Context) error {
51 | s.logger.Info("Shutting down http server..")
52 |
53 | return s.internal.Shutdown(ctx)
54 | })
55 |
56 | go func() {
57 | s.logger.
58 | WithTags(map[string]interface{}{"port": s.config.Port}).
59 | Info("Starting http server..")
60 |
61 | err := s.internal.ListenAndServe()
62 | if err != nil && !errors.Is(err, http.ErrServerClosed) {
63 | s.logger.Error("Server did not close cleanly", err)
64 | }
65 | }()
66 |
67 | return nil
68 | }
69 |
--------------------------------------------------------------------------------
/chttp/server_test.go:
--------------------------------------------------------------------------------
1 | package chttp_test
2 |
3 | import (
4 | "net/http"
5 | "testing"
6 | "time"
7 |
8 | "github.com/gocopper/copper/chttp"
9 | "github.com/gocopper/copper/clifecycle"
10 | "github.com/gocopper/copper/clogger"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestServer_Run(t *testing.T) {
15 | t.Parallel()
16 |
17 | logger := clogger.New()
18 | lc := clifecycle.New()
19 |
20 | server := chttp.NewServer(chttp.NewServerParams{
21 | Handler: http.NotFoundHandler(),
22 | Config: chttp.Config{Port: 8999},
23 | Logger: logger,
24 | Lifecycle: lc,
25 | })
26 |
27 | go func() {
28 | err := server.Run()
29 | assert.NoError(t, err)
30 | }()
31 |
32 | time.Sleep(50 * time.Millisecond) // wait for server to start
33 |
34 | resp, err := http.Get("http://127.0.0.1:8999") //nolint:noctx
35 | assert.NoError(t, err)
36 | assert.NoError(t, resp.Body.Close())
37 |
38 | assert.NoError(t, resp.Body.Close())
39 | assert.Equal(t, http.StatusNotFound, resp.StatusCode)
40 |
41 | lc.Stop(logger)
42 |
43 | time.Sleep(50 * time.Millisecond) // wait for server to stop
44 |
45 | _, err = http.Get("http://127.0.0.1:8999") //nolint:noctx,bodyclose
46 | assert.EqualError(t, err, "Get \"http://127.0.0.1:8999\": dial tcp 127.0.0.1:8999: connect: connection refused")
47 | }
48 |
--------------------------------------------------------------------------------
/chttp/wire.go:
--------------------------------------------------------------------------------
1 | package chttp
2 |
3 | import (
4 | "github.com/google/wire"
5 | )
6 |
7 | // WireModule can be used as part of google/wire setup.
8 | var WireModule = wire.NewSet( //nolint:gochecknoglobals
9 | LoadConfig,
10 | NewJSONReaderWriter,
11 | NewHTMLReaderWriter,
12 | NewRequestLoggerMiddleware,
13 | wire.Struct(new(NewServerParams), "*"),
14 | NewServer,
15 | wire.Struct(new(NewHTMLRouterParams), "*"),
16 | NewHTMLRouter,
17 | wire.Struct(new(NewHTMLRendererParams), "*"),
18 | NewHTMLRenderer,
19 | )
20 |
21 | // WireModuleEmptyHTML provides empty/default values for html and static dirs. This can be used to satisfy
22 | // wire when the project does not use/need html rendering.
23 | var WireModuleEmptyHTML = wire.NewSet( //nolint:gochecknoglobals
24 | wire.InterfaceValue(new(HTMLDir), &EmptyFS{}),
25 | wire.InterfaceValue(new(StaticDir), &EmptyFS{}),
26 | wire.Value([]HTMLRenderFunc{}),
27 | )
28 |
--------------------------------------------------------------------------------
/clifecycle/doc.go:
--------------------------------------------------------------------------------
1 | // Package clifecycle provides a Copper app's lifecycle management. It allows hooks to be registered that are run
2 | // during various checkpoints in the app's lifecycle.
3 | package clifecycle
4 |
--------------------------------------------------------------------------------
/clifecycle/lifecycle.go:
--------------------------------------------------------------------------------
1 | package clifecycle
2 |
3 | import (
4 | "context"
5 | "time"
6 | )
7 |
8 | const defaultStopTimeout = 10 * time.Second
9 |
10 | // New instantiates and returns a new Lifecycle that can be used with
11 | // New to create a Copper app.
12 | func New() *Lifecycle {
13 | return &Lifecycle{
14 | onStop: make([]func(ctx context.Context) error, 0),
15 | stopTimeout: defaultStopTimeout,
16 | }
17 | }
18 |
19 | // Lifecycle represents the lifecycle of an app. Most importantly, it
20 | // allows various parts of the app to register stop funcs that are run
21 | // before the app exits.
22 | // Packages such as chttp use Lifecycle to gracefully stop the HTTP
23 | // server before the app exits.
24 | type Lifecycle struct {
25 | onStop []func(ctx context.Context) error
26 | stopTimeout time.Duration
27 | }
28 |
29 | // OnStop registers the provided fn to run before the app exits. The fn
30 | // is given a context with a deadline. Once the deadline expires, the
31 | // app may exit forcefully.
32 | func (lc *Lifecycle) OnStop(fn func(ctx context.Context) error) {
33 | lc.onStop = append(lc.onStop, fn)
34 | }
35 |
36 | // Stop runs all of the registered stop funcs in order along with a
37 | // context with a configured timeout.
38 | func (lc *Lifecycle) Stop(logger Logger) {
39 | for _, fn := range lc.onStop {
40 | ctx, cancel := context.WithTimeout(context.Background(), lc.stopTimeout)
41 |
42 | err := fn(ctx)
43 | if err != nil {
44 | logger.Error("Failed to run cleanup func", err)
45 | }
46 |
47 | cancel()
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/clifecycle/logger.go:
--------------------------------------------------------------------------------
1 | package clifecycle
2 |
3 | // Logger provides the methods needed by Lifecycle to log errors.
4 | type Logger interface {
5 | Error(msg string, err error)
6 | }
7 |
--------------------------------------------------------------------------------
/clogger/config.go:
--------------------------------------------------------------------------------
1 | package clogger
2 |
3 | import (
4 | "github.com/gocopper/copper/cconfig"
5 | "github.com/gocopper/copper/cerrors"
6 | )
7 |
8 | // Format represents the output format of log statements
9 | type Format string
10 |
11 | // Formats supported by Logger
12 | const (
13 | FormatPlain = Format("plain")
14 | FormatJSON = Format("json")
15 | )
16 |
17 | // LoadConfig loads Config from app's config
18 | func LoadConfig(appConfig cconfig.Loader) (Config, error) {
19 | var config Config
20 |
21 | err := appConfig.Load("clogger", &config)
22 | if err != nil {
23 | return Config{}, cerrors.New(err, "failed to load clogger config", nil)
24 | }
25 |
26 | if config.Format != FormatPlain && config.Format != FormatJSON {
27 | config.Format = FormatPlain
28 | }
29 |
30 | return config, nil
31 | }
32 |
33 | // Config holds the params needed to configure Logger
34 | type Config struct {
35 | Out string `toml:"out"`
36 | Err string `toml:"err"`
37 | Format Format `toml:"format"`
38 | RedactFields []string `toml:"redact_fields"`
39 | }
40 |
--------------------------------------------------------------------------------
/clogger/doc.go:
--------------------------------------------------------------------------------
1 | // Package clogger provides a Logger interface that can be used to log messages and errors
2 | package clogger
3 |
--------------------------------------------------------------------------------
/clogger/json_redactor.go:
--------------------------------------------------------------------------------
1 | package clogger
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "strings"
7 | )
8 |
9 | func redactJSONObject(in map[string]interface{}, redactFields []string) (map[string]interface{}, error) {
10 | var b bytes.Buffer
11 |
12 | enc := json.NewEncoder(&b)
13 | enc.SetEscapeHTML(false)
14 |
15 | err := enc.Encode(in)
16 | if err != nil {
17 | return nil, err
18 | }
19 |
20 | redacted, err := redactJSON(b.Bytes(), redactFields)
21 | if err != nil {
22 | return nil, err
23 | }
24 |
25 | var out map[string]interface{}
26 | err = json.Unmarshal(redacted, &out)
27 | if err != nil {
28 | return nil, err
29 | }
30 |
31 | return out, nil
32 | }
33 |
34 | func redactJSON(in json.RawMessage, redactFields []string) (json.RawMessage, error) {
35 | var err error
36 |
37 | if in[0] == 123 { // 123 is `{` => object
38 | var cont map[string]json.RawMessage
39 |
40 | err = json.Unmarshal(in, &cont)
41 | if err != nil {
42 | return nil, err
43 | }
44 |
45 | for k, v := range cont {
46 |
47 | didRedact := false
48 | for i := range redactFields {
49 | if strings.Contains(strings.ToLower(k), strings.ToLower(redactFields[i])) {
50 | cont[k] = json.RawMessage(`"redacted"`)
51 | didRedact = true
52 | break
53 | }
54 | }
55 |
56 | if didRedact {
57 | continue
58 | }
59 |
60 | cont[k], err = redactJSON(v, redactFields)
61 | if err != nil {
62 | return nil, err
63 | }
64 | }
65 |
66 | return json.Marshal(cont)
67 | } else if in[0] == 91 { // 91 is `[` => array
68 | var cont []json.RawMessage
69 |
70 | err = json.Unmarshal(in, &cont)
71 | if err != nil {
72 | return nil, err
73 | }
74 |
75 | for i, v := range cont {
76 | cont[i], err = redactJSON(v, redactFields)
77 | if err != nil {
78 | return nil, err
79 | }
80 | }
81 |
82 | return json.Marshal(cont)
83 | }
84 |
85 | return in, nil
86 | }
87 |
--------------------------------------------------------------------------------
/clogger/json_redactor_test.go:
--------------------------------------------------------------------------------
1 | package clogger
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "github.com/shopspring/decimal"
7 | "github.com/stretchr/testify/assert"
8 | "testing"
9 | )
10 |
11 | type FooDecimal struct {
12 | decimal.Decimal
13 | }
14 |
15 | func (d *FooDecimal) MarshalJSON() ([]byte, error) {
16 | return json.Marshal("0x" + d.BigInt().Text(16))
17 | }
18 |
19 | func TestRedactJSONObject(t *testing.T) {
20 | d := FooDecimal{decimal.NewFromInt(100)}
21 | o, err := json.Marshal(&d)
22 | assert.NoError(t, err)
23 |
24 | fmt.Println("====> ", string(o))
25 |
26 | var t1 = map[string]interface{}{
27 | "a": &d,
28 | }
29 |
30 | _, err = redactJSONObject(t1, []string{"b"})
31 | assert.NoError(t, err)
32 | }
33 |
34 | func TestRedactJSON(t *testing.T) {
35 | var t1 = map[string]interface{}{
36 | "a": 1,
37 | "b": "foo",
38 | "c": map[string]interface{}{"d": 2},
39 | "e": []interface{}{1, 2, map[string]interface{}{"f": 3}},
40 | }
41 |
42 | in, err := json.Marshal(t1)
43 | assert.NoError(t, err)
44 |
45 | out, err := redactJSON(in, []string{"f"})
46 | assert.NoError(t, err)
47 |
48 | assert.Equal(t, `{"a":1,"b":"foo","c":{"d":2},"e":[1,2,{"f":"redacted"}]}`, string(out))
49 | }
50 |
--------------------------------------------------------------------------------
/clogger/level.go:
--------------------------------------------------------------------------------
1 | package clogger
2 |
3 | // Level represents the severity level of a log.
4 | type Level int
5 |
6 | // Pre-defined log levels.
7 | const (
8 | LevelDebug = Level(iota + 1)
9 | LevelInfo
10 | LevelWarn
11 | LevelError
12 | )
13 |
14 | func (l Level) String() string {
15 | switch l {
16 | case LevelDebug:
17 | return "DEBUG"
18 | case LevelInfo:
19 | return "INFO"
20 | case LevelWarn:
21 | return "WARN"
22 | case LevelError:
23 | return "ERROR"
24 | default:
25 | return "UNKNOWN"
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/clogger/level_test.go:
--------------------------------------------------------------------------------
1 | package clogger_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/gocopper/copper/clogger"
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestLevel_String(t *testing.T) {
11 | t.Parallel()
12 |
13 | assert.Equal(t, "DEBUG", clogger.LevelDebug.String())
14 | assert.Equal(t, "INFO", clogger.LevelInfo.String())
15 | assert.Equal(t, "WARN", clogger.LevelWarn.String())
16 | assert.Equal(t, "ERROR", clogger.LevelError.String())
17 | assert.Equal(t, "UNKNOWN", clogger.Level(-99).String())
18 | }
19 |
--------------------------------------------------------------------------------
/clogger/logger.go:
--------------------------------------------------------------------------------
1 | package clogger
2 |
3 | import (
4 | "encoding/json"
5 | "io"
6 | "log"
7 | "os"
8 | "strings"
9 | "time"
10 |
11 | "github.com/gocopper/copper/cerrors"
12 | )
13 |
14 | // Logger can be used to log messages and errors.
15 | type Logger interface {
16 | WithTags(tags map[string]interface{}) Logger
17 |
18 | Debug(msg string)
19 | Info(msg string)
20 | Warn(msg string, err error)
21 | Error(msg string, err error)
22 | }
23 |
24 | // New returns a Logger implementation that can logs to console.
25 | func New() Logger {
26 | return NewWithWriters(os.Stdout, os.Stderr, FormatPlain, nil)
27 | }
28 |
29 | // NewWithConfig creates a Logger based on the provided config.
30 | func NewWithConfig(config Config) (Logger, error) {
31 | const LogFilePerms = 0666
32 |
33 | var (
34 | outFile io.Writer = os.Stdout
35 | errFile io.Writer = os.Stderr
36 | err error
37 | )
38 |
39 | if config.Out != "" {
40 | outFile, err = os.OpenFile(config.Out, os.O_APPEND|os.O_CREATE|os.O_WRONLY, LogFilePerms)
41 | if err != nil {
42 | return nil, cerrors.New(err, "failed to open log file", map[string]interface{}{
43 | "path": config.Out,
44 | })
45 | }
46 | }
47 |
48 | if config.Out == config.Err {
49 | errFile = outFile
50 | } else if config.Err != "" {
51 | errFile, err = os.OpenFile(config.Err, os.O_APPEND|os.O_CREATE|os.O_WRONLY, LogFilePerms)
52 | if err != nil {
53 | return nil, cerrors.New(err, "failed to open error log file", map[string]interface{}{
54 | "path": config.Err,
55 | })
56 | }
57 | }
58 |
59 | return NewWithWriters(outFile, errFile, config.Format, config.RedactFields), nil
60 | }
61 |
62 | // NewWithWriters creates a Logger that uses the provided writers. out is
63 | // used for debug and info levels. err is used for warn and error levels.
64 | func NewWithWriters(out, err io.Writer, format Format, redactFields []string) Logger {
65 | return &logger{
66 | out: out,
67 | err: err,
68 | tags: make(map[string]interface{}),
69 | format: format,
70 | redactFields: expandRedactedFields(redactFields),
71 | }
72 | }
73 |
74 | type logger struct {
75 | out io.Writer
76 | err io.Writer
77 | tags map[string]interface{}
78 | format Format
79 | redactFields []string
80 | }
81 |
82 | func (l *logger) WithTags(tags map[string]interface{}) Logger {
83 | return &logger{
84 | out: l.out,
85 | err: l.err,
86 | tags: mergeTags(l.tags, tags),
87 | format: l.format,
88 | redactFields: l.redactFields,
89 | }
90 | }
91 |
92 | func (l *logger) Debug(msg string) {
93 | l.log(l.out, LevelDebug, msg, nil) //nolint:goerr113
94 | }
95 |
96 | func (l *logger) Info(msg string) {
97 | l.log(l.out, LevelInfo, msg, nil) //nolint:goerr113
98 | }
99 |
100 | func (l *logger) Warn(msg string, err error) {
101 | l.log(l.err, LevelWarn, msg, err)
102 | }
103 |
104 | func (l *logger) Error(msg string, err error) {
105 | l.log(l.err, LevelError, msg, err)
106 | }
107 |
108 | func (l *logger) log(dest io.Writer, lvl Level, msg string, err error) {
109 | switch l.format {
110 | case FormatJSON:
111 | l.logJSON(dest, lvl, msg, err)
112 | case FormatPlain:
113 | fallthrough
114 | default:
115 | l.logPlain(dest, lvl, msg, err)
116 | }
117 | }
118 |
119 | func (l *logger) logJSON(dest io.Writer, lvl Level, msg string, err error) {
120 | var dict = map[string]interface{}{
121 | "ts": time.Now().Format(time.RFC3339),
122 | "level": lvl.String(),
123 | "msg": msg,
124 | }
125 |
126 | if err != nil {
127 | dict["error"] = cerrors.WithoutTags(err).Error()
128 | }
129 |
130 | if redactedTags, err := redactJSONObject(mergeTags(cerrors.Tags(err), l.tags), l.redactFields); err != nil {
131 | dict["tags"] = cerrors.New(err, "tag redaction failed", nil).Error()
132 | } else {
133 | dict["tags"] = redactedTags
134 | }
135 |
136 | enc := json.NewEncoder(dest)
137 | enc.SetEscapeHTML(false)
138 |
139 | _ = enc.Encode(dict)
140 | }
141 |
142 | func (l *logger) logPlain(dest io.Writer, lvl Level, msg string, err error) {
143 | var (
144 | logErr = cerrors.New(nil, msg, l.tags).Error()
145 |
146 | o strings.Builder
147 | )
148 |
149 | if len(l.redactFields) == 0 {
150 | o.WriteString(logErr)
151 |
152 | if err != nil {
153 | o.WriteString(" because\n> ")
154 | o.WriteString(err.Error())
155 | }
156 | } else {
157 | o.WriteString("")
158 | }
159 |
160 | log.New(dest, "", log.LstdFlags).Printf("[%s] %s", lvl.String(), o.String())
161 | }
162 |
--------------------------------------------------------------------------------
/clogger/logger_test.go:
--------------------------------------------------------------------------------
1 | package clogger_test
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "os"
7 | "testing"
8 |
9 | "github.com/gocopper/copper/cerrors"
10 |
11 | "github.com/gocopper/copper/clogger"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | func TestNew(t *testing.T) {
16 | t.Parallel()
17 |
18 | logger := clogger.New()
19 | assert.NotNil(t, logger)
20 | }
21 |
22 | func TestNewWithConfig(t *testing.T) {
23 | t.Parallel()
24 |
25 | log, err := os.CreateTemp("", "*")
26 | assert.NoError(t, err)
27 |
28 | t.Cleanup(func() {
29 | assert.NoError(t, os.Remove(log.Name()))
30 | })
31 |
32 | logger, err := clogger.NewWithConfig(clogger.Config{
33 | Out: log.Name(),
34 | Err: log.Name(),
35 | Format: clogger.FormatPlain,
36 | })
37 | assert.NoError(t, err)
38 | assert.NotNil(t, logger)
39 | }
40 |
41 | func TestNewWithConfig_OutFileErr(t *testing.T) {
42 | t.Parallel()
43 |
44 | log, err := os.CreateTemp("", "*")
45 | assert.NoError(t, err)
46 |
47 | assert.NoError(t, os.Chmod(log.Name(), 0000))
48 |
49 | t.Cleanup(func() {
50 | assert.NoError(t, os.Remove(log.Name()))
51 | })
52 |
53 | _, err = clogger.NewWithConfig(clogger.Config{
54 | Out: log.Name(),
55 | Format: clogger.FormatPlain,
56 | })
57 | assert.Error(t, err)
58 | }
59 |
60 | func TestNewWithConfig_ErrFileErr(t *testing.T) {
61 | t.Parallel()
62 |
63 | log, err := os.CreateTemp("", "*")
64 | assert.NoError(t, err)
65 |
66 | assert.NoError(t, os.Chmod(log.Name(), 0000))
67 |
68 | t.Cleanup(func() {
69 | assert.NoError(t, os.Remove(log.Name()))
70 | })
71 |
72 | _, err = clogger.NewWithConfig(clogger.Config{
73 | Err: log.Name(),
74 | Format: clogger.FormatPlain,
75 | })
76 | assert.Error(t, err)
77 | }
78 |
79 | func TestNewWithParams(t *testing.T) {
80 | t.Parallel()
81 |
82 | logger := clogger.NewWithWriters(nil, nil, clogger.FormatPlain, nil)
83 | assert.NotNil(t, logger)
84 | }
85 |
86 | func TestLogger_Debug(t *testing.T) {
87 | t.Parallel()
88 |
89 | var (
90 | buf bytes.Buffer
91 | logger = clogger.NewWithWriters(&buf, &buf, clogger.FormatPlain, nil)
92 | )
93 |
94 | logger.Debug("test debug log")
95 |
96 | assert.Contains(t, buf.String(), "[DEBUG] test debug log")
97 | }
98 |
99 | func TestLogger_WithTags_Debug(t *testing.T) {
100 | t.Parallel()
101 |
102 | var (
103 | buf bytes.Buffer
104 | logger = clogger.NewWithWriters(&buf, &buf, clogger.FormatPlain, nil)
105 | )
106 |
107 | logger.
108 | WithTags(map[string]interface{}{
109 | "key": "val",
110 | }).
111 | WithTags(map[string]interface{}{
112 | "key2": "val2",
113 | }).Debug("test debug log")
114 |
115 | assert.Contains(t, buf.String(), "[DEBUG] test debug log where key2=val2,key=val")
116 | }
117 |
118 | func TestLogger_WithTags_RedactedFields(t *testing.T) {
119 | t.Parallel()
120 |
121 | for _, format := range []clogger.Format{clogger.FormatJSON, clogger.FormatPlain} {
122 | var (
123 | buf bytes.Buffer
124 | logger = clogger.NewWithWriters(&buf, &buf, format, []string{
125 | "secret", "password", "userPin",
126 | })
127 |
128 | testErr = cerrors.New(nil, "test-error", map[string]interface{}{
129 | "secret": "my_api_key",
130 | "user-pin": "12456",
131 | "data": map[string]string{
132 | "password": "abc123",
133 | },
134 | })
135 | )
136 |
137 | logger.WithTags(map[string]interface{}{
138 | "passwordOwner": "abc123",
139 | "USER_PIN": "123456",
140 | "params": map[string]string{
141 | "myPassword": "abc123",
142 | },
143 | }).Error("test debug log", testErr)
144 |
145 | assert.NotContains(t, buf.String(), "my_api_key")
146 | assert.NotContains(t, buf.String(), "12456")
147 | assert.NotContains(t, buf.String(), "123456")
148 | assert.NotContains(t, buf.String(), "abc123")
149 | assert.Contains(t, buf.String(), "redact")
150 | }
151 | }
152 |
153 | func TestLogger_Info(t *testing.T) {
154 | t.Parallel()
155 |
156 | var (
157 | buf bytes.Buffer
158 | logger = clogger.NewWithWriters(&buf, &buf, clogger.FormatPlain, nil)
159 | )
160 |
161 | logger.Info("test info log")
162 |
163 | assert.Contains(t, buf.String(), "[INFO] test info log", nil)
164 | }
165 |
166 | func TestLogger_Warn(t *testing.T) {
167 | t.Parallel()
168 |
169 | var (
170 | buf bytes.Buffer
171 | logger = clogger.NewWithWriters(&buf, &buf, clogger.FormatPlain, nil)
172 | )
173 |
174 | logger.Warn("test warn log", errors.New("test-error")) //nolint:goerr113
175 |
176 | assert.Contains(t, buf.String(), "[WARN] test warn log because\n> test-error")
177 | }
178 |
179 | func TestLogger_Error(t *testing.T) {
180 | t.Parallel()
181 |
182 | var (
183 | buf bytes.Buffer
184 | logger = clogger.NewWithWriters(&buf, &buf, clogger.FormatPlain, nil)
185 | )
186 |
187 | logger.Error("test error log", errors.New("test-error")) //nolint:goerr113
188 |
189 | assert.Contains(t, buf.String(), "[ERROR] test error log because\n> test-error")
190 | }
191 |
--------------------------------------------------------------------------------
/clogger/noop.go:
--------------------------------------------------------------------------------
1 | package clogger
2 |
3 | // NewNoop returns a no-op implementation of Logger.
4 | // Useful in passing it as a valid logger in unit tests.
5 | func NewNoop() Logger {
6 | return &noop{}
7 | }
8 |
9 | type noop struct{}
10 |
11 | func (l *noop) WithTags(tags map[string]interface{}) Logger {
12 | return l
13 | }
14 |
15 | func (l *noop) Debug(msg string) {}
16 |
17 | func (l *noop) Info(msg string) {}
18 |
19 | func (l *noop) Warn(msg string, err error) {}
20 |
21 | func (l *noop) Error(msg string, err error) {}
22 |
--------------------------------------------------------------------------------
/clogger/noop_test.go:
--------------------------------------------------------------------------------
1 | package clogger_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/gocopper/copper/clogger"
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestNewNoop(t *testing.T) {
11 | t.Parallel()
12 |
13 | logger := clogger.NewNoop()
14 | assert.NotNil(t, logger)
15 | }
16 |
17 | func TestNoopLogger_WithTags(t *testing.T) {
18 | t.Parallel()
19 |
20 | logger := clogger.NewNoop().WithTags(nil)
21 |
22 | assert.NotNil(t, logger)
23 | }
24 |
25 | func TestNoopLogger_Debug(t *testing.T) {
26 | t.Parallel()
27 |
28 | logger := clogger.NewNoop().WithTags(nil)
29 |
30 | logger.Debug("test-debug")
31 | }
32 |
33 | func TestNoopLogger_Info(t *testing.T) {
34 | t.Parallel()
35 |
36 | logger := clogger.NewNoop().WithTags(nil)
37 |
38 | logger.Info("info")
39 | }
40 |
41 | func TestNoopLogger_Warn(t *testing.T) {
42 | t.Parallel()
43 |
44 | logger := clogger.NewNoop().WithTags(nil)
45 |
46 | logger.Warn("warn", nil)
47 | }
48 |
49 | func TestNoopLogger_Error(t *testing.T) {
50 | t.Parallel()
51 |
52 | logger := clogger.NewNoop().WithTags(nil)
53 |
54 | logger.Error("error", nil)
55 | }
56 |
--------------------------------------------------------------------------------
/clogger/recorder.go:
--------------------------------------------------------------------------------
1 | package clogger
2 |
3 | // NewRecorder returns an implementation of Logger that keeps
4 | // a record of each log. Useful in unit tests when logs need
5 | // to be tested.
6 | func NewRecorder(logs *[]RecordedLog) Logger {
7 | return &recorder{
8 | Logs: logs,
9 | tags: make(map[string]interface{}),
10 | }
11 | }
12 |
13 | // RecordedLog represents a single log.
14 | type RecordedLog struct {
15 | Level Level
16 | Tags map[string]interface{}
17 | Msg string
18 | Error error
19 | }
20 |
21 | type recorder struct {
22 | Logs *[]RecordedLog
23 | tags map[string]interface{}
24 | }
25 |
26 | func (l *recorder) WithTags(tags map[string]interface{}) Logger {
27 | return &recorder{
28 | Logs: l.Logs,
29 | tags: mergeTags(l.tags, tags),
30 | }
31 | }
32 |
33 | func (l *recorder) Debug(msg string) {
34 | *l.Logs = append(*l.Logs, RecordedLog{
35 | Level: LevelDebug,
36 | Tags: mergeTags(l.tags, nil),
37 | Msg: msg,
38 | Error: nil,
39 | })
40 | }
41 |
42 | func (l *recorder) Info(msg string) {
43 | *l.Logs = append(*l.Logs, RecordedLog{
44 | Level: LevelInfo,
45 | Tags: mergeTags(l.tags, nil),
46 | Msg: msg,
47 | Error: nil,
48 | })
49 | }
50 |
51 | func (l *recorder) Warn(msg string, err error) {
52 | *l.Logs = append(*l.Logs, RecordedLog{
53 | Level: LevelWarn,
54 | Tags: mergeTags(l.tags, nil),
55 | Msg: msg,
56 | Error: err,
57 | })
58 | }
59 |
60 | func (l *recorder) Error(msg string, err error) {
61 | *l.Logs = append(*l.Logs, RecordedLog{
62 | Level: LevelError,
63 | Tags: mergeTags(l.tags, nil),
64 | Msg: msg,
65 | Error: err,
66 | })
67 | }
68 |
--------------------------------------------------------------------------------
/clogger/recorder_test.go:
--------------------------------------------------------------------------------
1 | package clogger_test
2 |
3 | import (
4 | "errors"
5 | "testing"
6 |
7 | "github.com/gocopper/copper/clogger"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestNewRecorder(t *testing.T) {
12 | t.Parallel()
13 |
14 | logger := clogger.NewRecorder(nil)
15 | assert.NotNil(t, logger)
16 | }
17 |
18 | func TestRecorder_Debug(t *testing.T) {
19 | t.Parallel()
20 |
21 | var (
22 | logs = make([]clogger.RecordedLog, 0)
23 | logger = clogger.NewRecorder(&logs)
24 | )
25 |
26 | logger.Debug("test debug log")
27 |
28 | assert.Len(t, logs, 1)
29 |
30 | log := logs[0]
31 |
32 | assert.Equal(t, clogger.LevelDebug, log.Level)
33 | assert.Equal(t, "test debug log", log.Msg)
34 | assert.Empty(t, log.Tags)
35 | assert.Nil(t, log.Error)
36 | }
37 |
38 | func TestRecorder_WithTags_Debug(t *testing.T) {
39 | t.Parallel()
40 |
41 | var (
42 | logs = make([]clogger.RecordedLog, 0)
43 | logger = clogger.NewRecorder(&logs)
44 | )
45 |
46 | logger.
47 | WithTags(map[string]interface{}{
48 | "key": "val",
49 | }).
50 | WithTags(map[string]interface{}{
51 | "key2": "val2",
52 | }).Debug("test debug log")
53 |
54 | assert.Len(t, logs, 1)
55 |
56 | log := logs[0]
57 |
58 | assert.Equal(t, clogger.LevelDebug, log.Level)
59 | assert.Equal(t, "test debug log", log.Msg)
60 | assert.Equal(t, map[string]interface{}{
61 | "key": "val",
62 | "key2": "val2",
63 | }, log.Tags)
64 | assert.Nil(t, log.Error)
65 | }
66 |
67 | func TestRecorder_Info(t *testing.T) {
68 | t.Parallel()
69 |
70 | var (
71 | logs = make([]clogger.RecordedLog, 0)
72 | logger = clogger.NewRecorder(&logs)
73 | )
74 |
75 | logger.Info("test info log")
76 |
77 | assert.Len(t, logs, 1)
78 |
79 | log := logs[0]
80 |
81 | assert.Equal(t, clogger.LevelInfo, log.Level)
82 | assert.Equal(t, "test info log", log.Msg)
83 | assert.Empty(t, log.Tags)
84 | assert.Nil(t, log.Error)
85 | }
86 |
87 | func TestRecorder_Warn(t *testing.T) {
88 | t.Parallel()
89 |
90 | var (
91 | logs = make([]clogger.RecordedLog, 0)
92 | logger = clogger.NewRecorder(&logs)
93 | )
94 |
95 | logger.Warn("test warn log", errors.New("test-err")) //nolint:goerr113
96 |
97 | assert.Len(t, logs, 1)
98 |
99 | log := logs[0]
100 |
101 | assert.Equal(t, clogger.LevelWarn, log.Level)
102 | assert.Equal(t, "test warn log", log.Msg)
103 | assert.Empty(t, log.Tags)
104 | assert.EqualError(t, log.Error, "test-err")
105 | }
106 |
107 | func TestRecorder_Error(t *testing.T) {
108 | t.Parallel()
109 |
110 | var (
111 | logs = make([]clogger.RecordedLog, 0)
112 | logger = clogger.NewRecorder(&logs)
113 | )
114 |
115 | logger.Error("test error log", errors.New("test-err")) //nolint:goerr113
116 |
117 | assert.Len(t, logs, 1)
118 |
119 | log := logs[0]
120 |
121 | assert.Equal(t, clogger.LevelError, log.Level)
122 | assert.Equal(t, "test error log", log.Msg)
123 | assert.Empty(t, log.Tags)
124 | assert.EqualError(t, log.Error, "test-err")
125 | }
126 |
--------------------------------------------------------------------------------
/clogger/util.go:
--------------------------------------------------------------------------------
1 | package clogger
2 |
3 | import (
4 | "github.com/iancoleman/strcase"
5 | )
6 |
7 | func mergeTags(t1, t2 map[string]interface{}) map[string]interface{} {
8 | merged := make(map[string]interface{})
9 |
10 | for k, v := range t1 {
11 | merged[k] = v
12 | }
13 |
14 | for k, v := range t2 {
15 | merged[k] = v
16 | }
17 |
18 | return merged
19 | }
20 |
21 | func tagsToKVs(tags map[string]interface{}) []interface{} {
22 | const TagsToKVsMultiplier = 2
23 |
24 | kvs := make([]interface{}, 0, len(tags)*TagsToKVsMultiplier)
25 | for k, v := range tags {
26 | kvs = append(kvs, k, v)
27 | }
28 | return kvs
29 | }
30 |
31 | func formatToZapEncoding(f Format) string {
32 | switch f {
33 | case FormatJSON:
34 | return "json"
35 | case FormatPlain:
36 | return "console"
37 | default:
38 | return "console"
39 | }
40 | }
41 |
42 | func expandRedactedFields(redactedFields []string) []string {
43 | expanded := make([]string, 0, len(redactedFields))
44 | for _, f := range redactedFields {
45 | expanded = append(expanded, f,
46 | strcase.ToSnake(f),
47 | strcase.ToKebab(f),
48 | strcase.ToDelimited(f, '-'),
49 | strcase.ToDelimited(f, '.'),
50 | strcase.ToCamel(f),
51 | )
52 | }
53 | return expanded
54 | }
55 |
--------------------------------------------------------------------------------
/clogger/zap.go:
--------------------------------------------------------------------------------
1 | package clogger
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/gocopper/copper/cerrors"
7 | "github.com/gocopper/copper/clifecycle"
8 | "go.uber.org/zap"
9 | )
10 |
11 | // NewZapLogger creates a Logger that internally uses go.uber.org/zap for logging
12 | func NewZapLogger(config Config, lc *clifecycle.Lifecycle) (Logger, error) {
13 | const OutStdErr = "stderr"
14 |
15 | var (
16 | outPath = OutStdErr
17 | errOutPath = OutStdErr
18 | )
19 |
20 | if config.Out != "" {
21 | outPath = config.Out
22 | }
23 |
24 | if config.Err != "" {
25 | errOutPath = config.Err
26 | }
27 |
28 | encoderConfig := zap.NewDevelopmentEncoderConfig()
29 | if config.Format == FormatJSON {
30 | encoderConfig = zap.NewProductionEncoderConfig()
31 | }
32 |
33 | z, err := zap.Config{
34 | Level: zap.NewAtomicLevelAt(zap.DebugLevel),
35 | Encoding: formatToZapEncoding(config.Format),
36 | EncoderConfig: encoderConfig,
37 | OutputPaths: []string{outPath},
38 | ErrorOutputPaths: []string{errOutPath},
39 | }.Build(zap.AddCallerSkip(1))
40 | if err != nil {
41 | return nil, cerrors.New(err, "failed to create zap logger", nil)
42 | }
43 |
44 | lc.OnStop(func(ctx context.Context) error {
45 | // Skip sync if logs are written to stderr because it will throw an error:
46 | // https://github.com/uber-go/zap/issues/880
47 | if outPath == OutStdErr && errOutPath == OutStdErr {
48 | return nil
49 | }
50 |
51 | return z.Sync()
52 | })
53 |
54 | return &zapLogger{
55 | zap: z.Sugar(),
56 | tags: make(map[string]interface{}),
57 | }, nil
58 | }
59 |
60 | type zapLogger struct {
61 | zap *zap.SugaredLogger
62 | tags map[string]interface{}
63 | }
64 |
65 | func (l *zapLogger) WithTags(tags map[string]interface{}) Logger {
66 | return &zapLogger{
67 | zap: l.zap,
68 | tags: mergeTags(l.tags, tags),
69 | }
70 | }
71 |
72 | func (l *zapLogger) Debug(msg string) {
73 | l.zap.Debugw(msg, tagsToKVs(l.tags)...)
74 | }
75 |
76 | func (l *zapLogger) Info(msg string) {
77 | l.zap.Infow(msg, tagsToKVs(l.tags)...)
78 | }
79 |
80 | func (l *zapLogger) Warn(msg string, err error) {
81 | l.zap.With("error", err).Warnw(msg, tagsToKVs(l.tags)...)
82 | }
83 |
84 | func (l *zapLogger) Error(msg string, err error) {
85 | l.zap.With("error", err).Errorw(msg, tagsToKVs(l.tags)...)
86 | }
87 |
--------------------------------------------------------------------------------
/cmetrics/metrics.go:
--------------------------------------------------------------------------------
1 | package cmetrics
2 |
3 | import (
4 | "github.com/gocopper/copper/cerrors"
5 | "github.com/gocopper/copper/clogger"
6 | "github.com/prometheus/client_golang/prometheus"
7 | )
8 |
9 | type Metrics interface {
10 | CounterInc(name string, labels map[string]string)
11 | HistogramObserve(name string, labels map[string]string, value float64)
12 | }
13 |
14 | type metrics struct {
15 | counters map[string]*prometheus.CounterVec
16 | histograms map[string]*prometheus.HistogramVec
17 |
18 | logger clogger.Logger
19 | }
20 |
21 | func NewMetrics(registry *Registry, logger clogger.Logger) (Metrics, error) {
22 | countersByName := make(map[string]*prometheus.CounterVec)
23 | for i := range registry.Counters {
24 | countersByName[registry.Counters[i].Name] = prometheus.NewCounterVec(prometheus.CounterOpts{
25 | Name: registry.Counters[i].Name,
26 | }, registry.Counters[i].Labels)
27 |
28 | err := prometheus.DefaultRegisterer.Register(countersByName[registry.Counters[i].Name])
29 | if err != nil {
30 | return nil, cerrors.New(err, "failed to register counter metric", map[string]interface{}{
31 | "name": registry.Counters[i].Name,
32 | })
33 | }
34 | }
35 |
36 | histogramsByName := make(map[string]*prometheus.HistogramVec)
37 | for i := range registry.Histograms {
38 | histogramsByName[registry.Histograms[i].Name] = prometheus.NewHistogramVec(prometheus.HistogramOpts{
39 | Name: registry.Histograms[i].Name,
40 | Buckets: registry.Histograms[i].Buckets,
41 | }, registry.Histograms[i].Labels)
42 |
43 | err := prometheus.DefaultRegisterer.Register(histogramsByName[registry.Histograms[i].Name])
44 | if err != nil {
45 | return nil, cerrors.New(err, "failed to register histogram metric", map[string]interface{}{
46 | "name": registry.Histograms[i].Name,
47 | })
48 |
49 | }
50 | }
51 |
52 | return &metrics{
53 | counters: countersByName,
54 | histograms: histogramsByName,
55 |
56 | logger: logger,
57 | }, nil
58 | }
59 |
60 | func (m *metrics) HistogramObserve(name string, labels map[string]string, value float64) {
61 | histogram, ok := m.histograms[name]
62 | if !ok {
63 | m.logger.WithTags(map[string]interface{}{
64 | "name": name,
65 | }).Warn("Histogram is not registered. Ignoring..", nil)
66 | return
67 | }
68 |
69 | metric, err := histogram.GetMetricWith(labels)
70 | if err != nil {
71 | m.logger.WithTags(map[string]interface{}{
72 | "name": name,
73 | }).Warn("Failed to get histogram metric with labels", err)
74 | return
75 |
76 | }
77 |
78 | metric.Observe(value)
79 | }
80 |
81 | func (m *metrics) CounterInc(name string, labels map[string]string) {
82 | counter, ok := m.counters[name]
83 | if !ok {
84 | m.logger.WithTags(map[string]interface{}{
85 | "name": name,
86 | }).Warn("Counter is not registered. Ignoring..", nil)
87 | return
88 | }
89 |
90 | metric, err := counter.GetMetricWith(labels)
91 | if err != nil {
92 | m.logger.WithTags(map[string]interface{}{
93 | "name": name,
94 | }).Warn("Failed to get counter metric with labels", err)
95 | return
96 | }
97 |
98 | metric.Inc()
99 | }
100 |
--------------------------------------------------------------------------------
/cmetrics/models.go:
--------------------------------------------------------------------------------
1 | package cmetrics
2 |
3 | type (
4 | Registry struct {
5 | Counters []Counter
6 | Histograms []Histogram
7 | }
8 |
9 | Counter struct {
10 | Name string
11 | Labels []string
12 | }
13 |
14 | Histogram struct {
15 | Name string
16 | Labels []string
17 | Buckets []float64
18 | }
19 | )
20 |
21 | type NewRegistryParams struct {
22 | Counters []Counter
23 | Histograms []Histogram
24 | }
25 |
26 | var (
27 | internalCounters = []Counter{
28 | {
29 | Name: "http_requests_total",
30 | Labels: []string{"status_code", "path"},
31 | },
32 | }
33 |
34 | internalHistograms = []Histogram{
35 | {
36 | Name: "http_request_duration_seconds",
37 | Labels: []string{"status_code", "path"},
38 | Buckets: []float64{0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0},
39 | },
40 | }
41 | )
42 |
43 | func NewRegistry(p NewRegistryParams) *Registry {
44 | return &Registry{
45 | Counters: append(p.Counters, internalCounters...),
46 | Histograms: append(p.Histograms, internalHistograms...),
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/cmetrics/noop.go:
--------------------------------------------------------------------------------
1 | package cmetrics
2 |
3 | func NewNoopMetrics() Metrics {
4 | return &noop{}
5 | }
6 |
7 | type noop struct{}
8 |
9 | func (m *noop) CounterInc(name string, labels map[string]string) {
10 | }
11 |
12 | func (m *noop) HistogramObserve(name string, labels map[string]string, value float64) {
13 | }
14 |
--------------------------------------------------------------------------------
/cmetrics/wire.go:
--------------------------------------------------------------------------------
1 | package cmetrics
2 |
3 | import "github.com/google/wire"
4 |
5 | var WireModule = wire.NewSet(
6 | NewMetrics,
7 | )
8 |
--------------------------------------------------------------------------------
/csql/config.go:
--------------------------------------------------------------------------------
1 | package csql
2 |
3 | import (
4 | "strings"
5 |
6 | "github.com/gocopper/copper/cconfig"
7 | "github.com/gocopper/copper/cerrors"
8 | migrate "github.com/rubenv/sql-migrate"
9 | )
10 |
11 | // MigrationsSource are valid options for csql.migrations.source configuration option.
12 | // Use "dir" to load migrations from the local filesystem.
13 | // Use "embed" to load migrations from the embedded directory in the binary.
14 | const (
15 | MigrationsSourceDir = "dir"
16 | MigrationsSourceEmbed = "embed"
17 | )
18 |
19 | // MigrationsDirection are valid options for csql.migrations.direction configuration option.
20 | // Use "up" when running forward migrations and "down" when rolling back migrations.
21 | const (
22 | MigrationsDirectionUp = "up"
23 | MigrationsDirectionDown = "down"
24 | )
25 |
26 | // LoadConfig loads the csql config from the app config
27 | func LoadConfig(appConfig cconfig.Loader) (Config, error) {
28 | config := Config{
29 | Migrations: ConfigMigrations{
30 | Direction: MigrationsDirectionUp,
31 | Source: MigrationsSourceEmbed,
32 | },
33 | }
34 |
35 | err := appConfig.Load("csql", &config)
36 | if err != nil {
37 | return Config{}, cerrors.New(err, "failed to load sql config", nil)
38 | }
39 |
40 | return config, nil
41 | }
42 |
43 | type (
44 | // Config configures the csql module
45 | Config struct {
46 | Dialect string `toml:"dialect"`
47 | DSN string `toml:"dsn"`
48 | Migrations ConfigMigrations `toml:"migrations"`
49 | MaxOpenConnections *int `toml:"max_open_connections"`
50 | }
51 |
52 | // ConfigMigrations configures the migrations
53 | ConfigMigrations struct {
54 | Direction string `toml:"direction"`
55 | Source string `toml:"source"`
56 | }
57 | )
58 |
59 | func (cm ConfigMigrations) sqlMigrateDirection() (migrate.MigrationDirection, error) {
60 | switch strings.ToLower(cm.Direction) {
61 | case MigrationsDirectionUp:
62 | return migrate.Up, nil
63 | case MigrationsDirectionDown:
64 | return migrate.Down, nil
65 | default:
66 | return 0, cerrors.New(nil, "invalid migration direction", map[string]interface{}{
67 | "direction": cm.Direction,
68 | })
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/csql/ctx.go:
--------------------------------------------------------------------------------
1 | package csql
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 |
7 | "github.com/gocopper/copper/cerrors"
8 | "github.com/jmoiron/sqlx"
9 | )
10 |
11 | type ctxKey string
12 |
13 | const connCtxKey = ctxKey("csql/*sqlx.Tx")
14 |
15 | // CtxWithTx creates a context with a new database transaction. Any queries run using Querier will be run within
16 | // this transaction.
17 | func CtxWithTx(parentCtx context.Context, db *sql.DB, dialect string) (context.Context, *sql.Tx, error) {
18 | tx, err := sqlx.NewDb(db, dialect).Beginx()
19 | if err != nil {
20 | return nil, nil, cerrors.New(err, "failed to begin db transaction", map[string]interface{}{
21 | "dialect": dialect,
22 | })
23 | }
24 |
25 | return context.WithValue(parentCtx, connCtxKey, tx), tx.Tx, nil
26 | }
27 |
28 | // TxFromCtx returns an existing transaction from the context. This method should be called with context created
29 | // using CtxWithTx.
30 | func TxFromCtx(ctx context.Context) (*sql.Tx, error) {
31 | tx, err := txFromCtx(ctx)
32 | if err != nil {
33 | return nil, err
34 | }
35 |
36 | return tx.Tx, nil
37 | }
38 |
39 | func txFromCtx(ctx context.Context) (*sqlx.Tx, error) {
40 | tx, ok := ctx.Value(connCtxKey).(*sqlx.Tx)
41 | if !ok {
42 | return nil, cerrors.New(nil, "no database transaction in the context", nil)
43 | }
44 |
45 | return tx, nil
46 | }
47 |
48 | func mustTxFromCtx(ctx context.Context) *sqlx.Tx {
49 | tx, err := txFromCtx(ctx)
50 | if err != nil {
51 | panic(err)
52 | }
53 |
54 | return tx
55 | }
56 |
--------------------------------------------------------------------------------
/csql/ctx_test.go:
--------------------------------------------------------------------------------
1 | package csql_test
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "testing"
7 |
8 | "github.com/gocopper/copper/csql"
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func TestCtxWithTx(t *testing.T) {
13 | t.Parallel()
14 |
15 | db, err := sql.Open("sqlite3", ":memory:")
16 | assert.NoError(t, err)
17 |
18 | ctx, tx1, err := csql.CtxWithTx(context.Background(), db, "sqlite3")
19 | assert.NoError(t, err)
20 |
21 | tx2, err := csql.TxFromCtx(ctx)
22 | assert.NoError(t, err)
23 |
24 | assert.Equal(t, tx1, tx2)
25 | }
26 |
--------------------------------------------------------------------------------
/csql/db.go:
--------------------------------------------------------------------------------
1 | package csql
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 |
7 | "github.com/gocopper/copper/cerrors"
8 | "github.com/gocopper/copper/clifecycle"
9 | "github.com/gocopper/copper/clogger"
10 | )
11 |
12 | // NewDBConnection creates and returns a new database connection. The connection is closed when the app exits.
13 | func NewDBConnection(lc *clifecycle.Lifecycle, config Config, logger clogger.Logger) (*sql.DB, error) {
14 | logger.WithTags(map[string]interface{}{
15 | "dialect": config.Dialect,
16 | }).Info("Opening a database connection..")
17 |
18 | db, err := sql.Open(config.Dialect, config.DSN)
19 | if err != nil {
20 | return nil, cerrors.New(err, "failed to open db connection", map[string]interface{}{
21 | "dialect": config.Dialect,
22 | })
23 | }
24 |
25 | err = db.Ping()
26 | if err != nil {
27 | return nil, cerrors.New(err, "failed to ping db", nil)
28 | }
29 |
30 | if config.MaxOpenConnections != nil {
31 | db.SetMaxOpenConns(*config.MaxOpenConnections)
32 | }
33 |
34 | lc.OnStop(func(ctx context.Context) error {
35 | logger.Info("Closing database connection..")
36 |
37 | err := db.Close()
38 | if err != nil {
39 | return cerrors.New(err, "failed to close db connection", nil)
40 | }
41 |
42 | return nil
43 | })
44 |
45 | return db, nil
46 | }
47 |
--------------------------------------------------------------------------------
/csql/db_test.go:
--------------------------------------------------------------------------------
1 | package csql_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/gocopper/copper/clifecycle"
7 | "github.com/gocopper/copper/clogger"
8 | "github.com/gocopper/copper/csql"
9 | _ "github.com/mattn/go-sqlite3"
10 | "github.com/stretchr/testify/assert"
11 | )
12 |
13 | func TestNewDBConnection(t *testing.T) {
14 | t.Parallel()
15 |
16 | var (
17 | logger = clogger.New()
18 | lc = clifecycle.New()
19 | )
20 |
21 | db, err := csql.NewDBConnection(lc, csql.Config{
22 | Dialect: "sqlite3",
23 | DSN: ":memory:",
24 | }, logger)
25 | assert.NoError(t, err)
26 |
27 | assert.NoError(t, db.Ping())
28 |
29 | lc.Stop(logger)
30 |
31 | assert.Error(t, db.Ping())
32 | }
33 |
--------------------------------------------------------------------------------
/csql/doc.go:
--------------------------------------------------------------------------------
1 | // Package csql helps create and manage database connections
2 | package csql
3 |
--------------------------------------------------------------------------------
/csql/migrations_test.sql:
--------------------------------------------------------------------------------
1 | -- +migrate Up
2 | create table people (name text);
3 | insert into people (name) values ('test');
4 |
5 | -- +migrate Down
6 | drop table people;
7 |
--------------------------------------------------------------------------------
/csql/migrator.go:
--------------------------------------------------------------------------------
1 | package csql
2 |
3 | import (
4 | "crypto/sha256"
5 | "database/sql"
6 | "embed"
7 | "fmt"
8 | "io"
9 |
10 | "github.com/gocopper/copper/cerrors"
11 | "github.com/gocopper/copper/clogger"
12 | migrate "github.com/rubenv/sql-migrate"
13 | )
14 |
15 | // Migrations is a collection of .sql files that represent the database schema
16 | type Migrations embed.FS
17 |
18 | // NewMigratorParams holds the params needed for NewMigrator
19 | type NewMigratorParams struct {
20 | DB *sql.DB
21 | Migrations Migrations
22 | Config Config
23 | Logger clogger.Logger
24 | }
25 |
26 | // NewMigrator creates a new Migrator
27 | func NewMigrator(p NewMigratorParams) *Migrator {
28 | return &Migrator{
29 | db: p.DB,
30 | migrations: embed.FS(p.Migrations),
31 | config: p.Config,
32 | logger: p.Logger,
33 | }
34 | }
35 |
36 | // Migrator can run database migrations by running the provided migrations in the migrations dir
37 | type Migrator struct {
38 | db *sql.DB
39 | migrations embed.FS
40 | config Config
41 | logger clogger.Logger
42 | }
43 |
44 | // Run runs the provided database migrations
45 | func (m *Migrator) Run() error {
46 | m.logger.WithTags(map[string]interface{}{
47 | "direction": m.config.Migrations.Direction,
48 | "source": m.config.Migrations.Source,
49 | }).Info("Running database migrations..")
50 |
51 | direction, err := m.config.Migrations.sqlMigrateDirection()
52 | if err != nil {
53 | return cerrors.New(err, "failed to get sql migrate direction from config", nil)
54 | }
55 |
56 | hasMigrations, err := m.hasMigrations()
57 | if err != nil {
58 | return cerrors.New(err, "failed to check for migrations", nil)
59 | }
60 |
61 | if !hasMigrations {
62 | m.logger.Info("No migrations found")
63 | return nil
64 | }
65 |
66 | source := migrate.MigrationSource(migrate.EmbedFileSystemMigrationSource{
67 | FileSystem: m.migrations,
68 | Root: ".",
69 | })
70 | if m.config.Migrations.Source == MigrationsSourceDir {
71 | source = migrate.FileMigrationSource{
72 | Dir: "./migrations",
73 | }
74 | }
75 |
76 | migrateMax := 0 // no limit
77 | if direction == migrate.Down {
78 | migrateMax = 1 // only run 1 migration when reverting
79 | }
80 |
81 | dialect := m.config.Dialect
82 | if dialect == "pgx" {
83 | dialect = "postgres"
84 | }
85 |
86 | n, err := migrate.ExecMax(m.db, dialect, source, direction, migrateMax)
87 | if err != nil {
88 | return cerrors.New(err, "failed to exec database migrations", nil)
89 | }
90 |
91 | m.logger.WithTags(map[string]interface{}{
92 | "count": n,
93 | }).Info("Successfully applied migrations")
94 |
95 | return nil
96 | }
97 |
98 | // hasMigrations returns true if the migrations directory has at least 1 non-empty migration file.
99 | func (m *Migrator) hasMigrations() (bool, error) {
100 | const emptyMigrationsChecksum = "fba9ab24993a94e181dc952f2568a4e98b47e331d89772af3115fe1c7b90d27f"
101 |
102 | entries, err := m.migrations.ReadDir(".")
103 | if err != nil {
104 | return false, cerrors.New(err, "failed to read migrations dir", nil)
105 | }
106 |
107 | if len(entries) == 0 {
108 | return false, nil
109 | }
110 |
111 | if len(entries) > 1 {
112 | return true, nil
113 | }
114 |
115 | f, err := m.migrations.Open(entries[0].Name())
116 | if err != nil {
117 | return false, cerrors.New(err, "failed to open migrations file", map[string]interface{}{
118 | "name": entries[0].Name(),
119 | })
120 | }
121 | defer func() { _ = f.Close() }()
122 |
123 | h := sha256.New()
124 | if _, err = io.Copy(h, f); err != nil {
125 | return false, cerrors.New(err, "failed to calculate sha256 for migration file", nil)
126 | }
127 |
128 | checksum := fmt.Sprintf("%x", h.Sum(nil))
129 | if checksum == emptyMigrationsChecksum {
130 | return false, nil
131 | }
132 |
133 | return true, nil
134 | }
135 |
--------------------------------------------------------------------------------
/csql/migrator_test.go:
--------------------------------------------------------------------------------
1 | package csql_test
2 |
3 | import (
4 | "database/sql"
5 | "embed"
6 | "testing"
7 |
8 | "github.com/gocopper/copper/clogger"
9 | "github.com/gocopper/copper/csql"
10 | _ "github.com/mattn/go-sqlite3"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | //go:embed migrations_test.sql
15 | var Migrations embed.FS
16 |
17 | func TestMigrator_Run(t *testing.T) {
18 | t.Parallel()
19 |
20 | db, err := sql.Open("sqlite3", ":memory:")
21 | assert.NoError(t, err)
22 |
23 | // migrate up
24 |
25 | migratorUp := csql.NewMigrator(csql.NewMigratorParams{
26 | DB: db,
27 | Migrations: csql.Migrations(Migrations),
28 | Config: csql.Config{
29 | Dialect: "sqlite3",
30 | Migrations: csql.ConfigMigrations{
31 | Direction: csql.MigrationsDirectionUp,
32 | Source: csql.MigrationsSourceEmbed,
33 | },
34 | },
35 | Logger: clogger.NewNoop(),
36 | })
37 |
38 | err = migratorUp.Run()
39 | assert.NoError(t, err)
40 |
41 | res, err := db.Query("select * from people")
42 | assert.NoError(t, err)
43 | assert.NoError(t, res.Err())
44 |
45 | assert.True(t, res.Next())
46 |
47 | // migrate down
48 |
49 | migratorDown := csql.NewMigrator(csql.NewMigratorParams{
50 | DB: db,
51 | Migrations: csql.Migrations(Migrations),
52 | Config: csql.Config{
53 | Dialect: "sqlite3",
54 | Migrations: csql.ConfigMigrations{
55 | Direction: csql.MigrationsDirectionDown,
56 | Source: csql.MigrationsSourceEmbed,
57 | },
58 | },
59 | Logger: clogger.NewNoop(),
60 | })
61 |
62 | err = migratorDown.Run()
63 | assert.NoError(t, err)
64 |
65 | _, err = db.Query("select * from people") //nolint:rowserrcheck
66 | assert.EqualError(t, err, "no such table: people")
67 | }
68 |
--------------------------------------------------------------------------------
/csql/qb/query_builder.go:
--------------------------------------------------------------------------------
1 | package qb
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "strings"
7 | )
8 |
9 | type fieldExtractor func(reflect.StructField, reflect.Value) (string, any, bool)
10 |
11 | func ValuePlaceholders(model any) string {
12 | columns := extractRawColumns(model)
13 | placeholders := make([]string, len(columns))
14 | for i := range columns {
15 | placeholders[i] = "?"
16 | }
17 | return strings.Join(placeholders, ", ")
18 | }
19 |
20 | func Columns(model any, alias ...string) string {
21 | a := ""
22 | if len(alias) > 0 {
23 | a = alias[0]
24 | }
25 |
26 | columns := extractRawColumns(model)
27 | if a == "" {
28 | return strings.Join(columns, ", ")
29 | }
30 |
31 | prefixedColumns := make([]string, len(columns))
32 | for i, col := range columns {
33 | prefixedColumns[i] = fmt.Sprintf("%s.%s", a, col)
34 | }
35 | return strings.Join(prefixedColumns, ", ")
36 | }
37 |
38 | func SetColumns(model any) string {
39 | return strings.Join(extractSetColumns(model), ", ")
40 | }
41 |
42 | func Values(model any) []any {
43 | return extractValues(model, columnExtractor)
44 | }
45 |
46 | func SetValues(model any) []any {
47 | return extractValues(model, setColumnExtractor)
48 | }
49 |
50 | func ValuesAndSetValues(model any) []any {
51 | values := extractValues(model, columnExtractor)
52 | setValues := extractValues(model, setColumnExtractor)
53 | return append(values, setValues...)
54 | }
55 |
56 | func extractRawColumns(model any) []string {
57 | var columns []string
58 | processModelFields(model, func(field reflect.StructField, value reflect.Value) bool {
59 | tag := field.Tag.Get("db")
60 | if tag == "" || tag == "-" {
61 | return false
62 | }
63 |
64 | parts := strings.Split(tag, ",")
65 | columns = append(columns, parts[0])
66 | return true
67 | })
68 | return columns
69 | }
70 |
71 | func extractSetColumns(model any) []string {
72 | var columns []string
73 | processModelFields(model, func(field reflect.StructField, value reflect.Value) bool {
74 | tag := field.Tag.Get("db")
75 | if tag == "" || tag == "-" {
76 | return false
77 | }
78 |
79 | parts := strings.Split(tag, ",")
80 | if len(parts) >= 2 && parts[1] == "readonly" {
81 | return false
82 | }
83 |
84 | columns = append(columns, fmt.Sprintf("%s = ?", parts[0]))
85 | return true
86 | })
87 | return columns
88 | }
89 |
90 | func extractValues(model any, extractor fieldExtractor) []any {
91 | var values []any
92 | processModelFields(model, func(field reflect.StructField, value reflect.Value) bool {
93 | _, val, ok := extractor(field, value)
94 | if ok {
95 | values = append(values, val)
96 | }
97 | return ok
98 | })
99 | return values
100 | }
101 |
102 | func columnExtractor(field reflect.StructField, value reflect.Value) (string, any, bool) {
103 | tag := field.Tag.Get("db")
104 | if tag == "" || tag == "-" {
105 | return "", nil, false
106 | }
107 |
108 | parts := strings.Split(tag, ",")
109 | return parts[0], value.Interface(), true
110 | }
111 |
112 | func setColumnExtractor(field reflect.StructField, value reflect.Value) (string, any, bool) {
113 | tag := field.Tag.Get("db")
114 | if tag == "" || tag == "-" {
115 | return "", nil, false
116 | }
117 |
118 | parts := strings.Split(tag, ",")
119 | if len(parts) >= 2 && parts[1] == "readonly" {
120 | return "", nil, false
121 | }
122 |
123 | return fmt.Sprintf("%s = ?", parts[0]), value.Interface(), true
124 | }
125 |
126 | func processModelFields(model any, processor func(reflect.StructField, reflect.Value) bool) {
127 | t := reflect.TypeOf(model)
128 | v := reflect.ValueOf(model)
129 |
130 | if t.Kind() == reflect.Ptr {
131 | t = t.Elem()
132 | v = v.Elem()
133 | }
134 |
135 | for i := 0; i < t.NumField(); i++ {
136 | field := t.Field(i)
137 |
138 | if !field.IsExported() {
139 | continue
140 | }
141 |
142 | if field.Anonymous && field.Type.Kind() == reflect.Struct {
143 | processModelFields(v.Field(i).Interface(), processor)
144 | continue
145 | }
146 |
147 | processor(field, v.Field(i))
148 | }
149 | }
150 |
--------------------------------------------------------------------------------
/csql/qb/query_builder_test.go:
--------------------------------------------------------------------------------
1 | package qb_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/gocopper/copper/csql/qb"
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | type TestModel struct {
11 | ID int `db:"id,readonly"`
12 | Name string `db:"name"`
13 | Email string `db:"email"`
14 | Ignored string `db:"-"`
15 | }
16 |
17 | func TestPlaceholders(t *testing.T) {
18 | model := TestModel{ID: 1, Name: "Test", Email: "test@example.com"}
19 | result := qb.ValuePlaceholders(model)
20 | assert.Equal(t, "?, ?, ?", result)
21 | }
22 |
23 | func TestColumns(t *testing.T) {
24 | model := TestModel{}
25 |
26 | // Without alias
27 | result := qb.Columns(model)
28 | assert.Equal(t, "id, name, email", result)
29 |
30 | // With alias
31 | resultWithAlias := qb.Columns(model, "t")
32 | assert.Equal(t, "t.id, t.name, t.email", resultWithAlias)
33 | }
34 |
35 | func TestSetColumns(t *testing.T) {
36 | model := TestModel{}
37 | result := qb.SetColumns(model)
38 | assert.Equal(t, "name = ?, email = ?", result)
39 | }
40 |
41 | func TestValues(t *testing.T) {
42 | model := TestModel{ID: 1, Name: "Test", Email: "test@example.com"}
43 | result := qb.Values(model)
44 |
45 | assert.Len(t, result, 3)
46 | assert.Equal(t, 1, result[0])
47 | assert.Equal(t, "Test", result[1])
48 | assert.Equal(t, "test@example.com", result[2])
49 | }
50 |
51 | func TestSetValues(t *testing.T) {
52 | model := TestModel{ID: 1, Name: "Test", Email: "test@example.com"}
53 | result := qb.SetValues(model)
54 |
55 | assert.Len(t, result, 2)
56 | assert.Equal(t, "Test", result[0])
57 | assert.Equal(t, "test@example.com", result[1])
58 | }
59 |
60 | func TestWithEmbeddedStruct(t *testing.T) {
61 | type BaseModel struct {
62 | ID int `db:"id,readonly"`
63 | Version int `db:"version"`
64 | }
65 |
66 | type UserModel struct {
67 | BaseModel
68 | Name string `db:"name"`
69 | Skip string `db:"-"`
70 | }
71 |
72 | model := UserModel{
73 | BaseModel: BaseModel{ID: 1, Version: 2},
74 | Name: "Test",
75 | Skip: "Ignored",
76 | }
77 |
78 | columns := qb.Columns(model)
79 | assert.Equal(t, "id, version, name", columns)
80 |
81 | setColumns := qb.SetColumns(model)
82 | assert.Equal(t, "version = ?, name = ?", setColumns)
83 | }
84 |
--------------------------------------------------------------------------------
/csql/querier.go:
--------------------------------------------------------------------------------
1 | package csql
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "errors"
7 | "strings"
8 | "time"
9 |
10 | "github.com/gocopper/copper/cerrors"
11 | "github.com/gocopper/copper/clogger"
12 | "github.com/jmoiron/sqlx"
13 | )
14 |
15 | // Querier provides a set of helpful methods to run database queries. It can be used to run parameterized queries
16 | // and scan results into Go structs or slices.
17 | type Querier interface {
18 | CtxWithTx(ctx context.Context) (context.Context, *sql.Tx, error)
19 | InTx(ctx context.Context, fn func(context.Context) error) error
20 | WithIn() Querier
21 | Get(ctx context.Context, dest interface{}, query string, args ...interface{}) error
22 | Select(ctx context.Context, dest interface{}, query string, args ...interface{}) error
23 | Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
24 | OnCommit(ctx context.Context, cb func(context.Context) error) error
25 | CommitTx(tx *sql.Tx) error
26 | }
27 |
28 | // NewQuerier returns a querier using the given database connection and the dialect
29 | func NewQuerier(db *sql.DB, config Config, logger clogger.Logger) Querier {
30 | return &querier{
31 | db: sqlx.NewDb(db, config.Dialect),
32 | dialect: config.Dialect,
33 | in: false,
34 | logger: logger,
35 | callbacksByTx: make(map[*sql.Tx][]func(context.Context) error),
36 | }
37 | }
38 |
39 | type querier struct {
40 | db *sqlx.DB
41 | dialect string
42 | in bool
43 | logger clogger.Logger
44 |
45 | callbacksByTx map[*sql.Tx][]func(context.Context) error
46 | }
47 |
48 | func (q *querier) OnCommit(ctx context.Context, cb func(context.Context) error) error {
49 | tx, err := TxFromCtx(ctx)
50 | if err != nil {
51 | return cerrors.New(err, "failed to get database transaction from context", nil)
52 | }
53 |
54 | if _, ok := q.callbacksByTx[tx]; !ok {
55 | q.callbacksByTx[tx] = make([]func(context.Context) error, 0)
56 | }
57 |
58 | q.callbacksByTx[tx] = append(q.callbacksByTx[tx], cb)
59 |
60 | return nil
61 | }
62 |
63 | func (q *querier) CommitTx(tx *sql.Tx) error {
64 | err := tx.Commit()
65 | if err != nil && !errors.Is(err, sql.ErrTxDone) && !strings.Contains(err.Error(), "commit unexpectedly resulted in rollback") {
66 | return err
67 | }
68 |
69 | if err != nil && strings.Contains(err.Error(), "commit unexpectedly resulted in rollback") {
70 | q.logger.Warn(err.Error(), nil)
71 | }
72 |
73 | if callbacks, ok := q.callbacksByTx[tx]; ok {
74 | for i := range callbacks {
75 | go func(cb func(context.Context) error) {
76 | ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
77 |
78 | err := cb(ctx)
79 | if err != nil {
80 | q.logger.Error("Failed to run callback", err)
81 | }
82 |
83 | cancel()
84 | }(callbacks[i])
85 | }
86 | }
87 |
88 | delete(q.callbacksByTx, tx)
89 |
90 | return nil
91 | }
92 |
93 | func (q *querier) CtxWithTx(ctx context.Context) (context.Context, *sql.Tx, error) {
94 | return CtxWithTx(ctx, q.db.DB, q.dialect)
95 | }
96 |
97 | func (q *querier) InTx(ctx context.Context, fn func(context.Context) error) error {
98 | ctx, tx, err := CtxWithTx(ctx, q.db.DB, q.dialect)
99 | if err != nil {
100 | return cerrors.New(err, "failed to create context with database transaction", nil)
101 | }
102 |
103 | defer func() {
104 | // Try a rollback in a deferred function to account for panics
105 | err := tx.Rollback()
106 | if err != nil && !errors.Is(err, sql.ErrTxDone) {
107 | q.logger.Error("Failed to rollback database transaction", err)
108 | return
109 | }
110 |
111 | if err == nil {
112 | q.logger.Warn("Rolled back an unexpectedly open database transaction", nil)
113 | }
114 | }()
115 |
116 | err = fn(ctx)
117 | if err != nil {
118 | rollbackErr := tx.Rollback()
119 | if rollbackErr != nil && !errors.Is(rollbackErr, sql.ErrTxDone) {
120 | q.logger.Error("Failed to rollback database transaction", err)
121 | }
122 | return err
123 | }
124 |
125 | err = q.CommitTx(tx)
126 | if err != nil {
127 | return cerrors.New(err, "failed to commit database transaction", nil)
128 | }
129 |
130 | return nil
131 | }
132 |
133 | func (q *querier) WithIn() Querier {
134 | return &querier{
135 | db: q.db,
136 | dialect: q.dialect,
137 | in: true,
138 | logger: q.logger,
139 | }
140 | }
141 |
142 | func (q *querier) Get(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
143 | query, args, err := q.mkQueryWithArgs(ctx, query, args)
144 | if err != nil {
145 | return err
146 | }
147 |
148 | return mustTxFromCtx(ctx).GetContext(ctx, dest, query, args...)
149 | }
150 |
151 | func (q *querier) Select(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
152 | query, args, err := q.mkQueryWithArgs(ctx, query, args)
153 | if err != nil {
154 | return err
155 | }
156 |
157 | return mustTxFromCtx(ctx).SelectContext(ctx, dest, query, args...)
158 | }
159 |
160 | func (q *querier) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
161 | query, args, err := q.mkQueryWithArgs(ctx, query, args)
162 | if err != nil {
163 | return nil, err
164 | }
165 |
166 | return mustTxFromCtx(ctx).ExecContext(ctx, query, args...)
167 | }
168 |
169 | func (q *querier) mkQueryWithArgs(ctx context.Context, query string, args []interface{}) (string, []interface{}, error) {
170 | var err error
171 |
172 | if q.in {
173 | query, args, err = sqlx.In(query, args...)
174 | if err != nil {
175 | return "", nil, cerrors.New(err, "failed to create IN query", nil)
176 | }
177 | }
178 |
179 | tx, err := txFromCtx(ctx)
180 | if err != nil {
181 | return "", nil, err
182 | }
183 |
184 | return tx.Rebind(query), args, nil
185 | }
186 |
--------------------------------------------------------------------------------
/csql/querier_test.go:
--------------------------------------------------------------------------------
1 | package csql_test
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "testing"
7 |
8 | "github.com/gocopper/copper/clogger"
9 | "github.com/gocopper/copper/csql"
10 | _ "github.com/mattn/go-sqlite3"
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestQuerier_Read(t *testing.T) {
15 | t.Parallel()
16 |
17 | db, err := sql.Open("sqlite3", ":memory:")
18 | assert.NoError(t, err)
19 |
20 | _, err = db.Exec("create table people (name text);insert into people (name) values ('test');")
21 | assert.NoError(t, err)
22 |
23 | querier := csql.NewQuerier(db, csql.Config{Dialect: "sqlite3"}, clogger.NewNoop())
24 |
25 | ctx, _, err := csql.CtxWithTx(context.Background(), db, "sqlite3")
26 | assert.NoError(t, err)
27 |
28 | t.Run("get", func(t *testing.T) {
29 | t.Parallel()
30 |
31 | var dest struct {
32 | Name string
33 | }
34 |
35 | err = querier.Get(ctx, &dest, "select * from people")
36 | assert.NoError(t, err)
37 |
38 | assert.Equal(t, "test", dest.Name)
39 | })
40 |
41 | t.Run("select", func(t *testing.T) {
42 | t.Parallel()
43 |
44 | var dest []struct {
45 | Name string
46 | }
47 |
48 | err = querier.Select(ctx, &dest, "select * from people")
49 | assert.NoError(t, err)
50 |
51 | assert.Equal(t, 1, len(dest))
52 | assert.Equal(t, "test", dest[0].Name)
53 | })
54 |
55 | t.Run("select in", func(t *testing.T) {
56 | t.Parallel()
57 |
58 | var dest []struct {
59 | Name string
60 | }
61 |
62 | err = querier.WithIn().Select(ctx, &dest, "select * from people where name in (?)", []string{"test"})
63 | assert.NoError(t, err)
64 |
65 | assert.Equal(t, 1, len(dest))
66 | assert.Equal(t, "test", dest[0].Name)
67 | })
68 | }
69 |
70 | func TestQuerier_Exec(t *testing.T) {
71 | t.Parallel()
72 |
73 | db, err := sql.Open("sqlite3", ":memory:")
74 | assert.NoError(t, err)
75 |
76 | _, err = db.Exec("create table people (name text);insert into people (name) values ('test');")
77 | assert.NoError(t, err)
78 |
79 | querier := csql.NewQuerier(db, csql.Config{Dialect: "sqlite3"}, clogger.NewNoop())
80 |
81 | ctx, _, err := csql.CtxWithTx(context.Background(), db, "sqlite3")
82 | assert.NoError(t, err)
83 |
84 | res, err := querier.Exec(ctx, "delete from people")
85 | assert.NoError(t, err)
86 |
87 | n, err := res.RowsAffected()
88 | assert.NoError(t, err)
89 |
90 | assert.Equal(t, int64(1), n)
91 | }
92 |
--------------------------------------------------------------------------------
/csql/tx_middleware.go:
--------------------------------------------------------------------------------
1 | package csql
2 |
3 | import (
4 | "bufio"
5 | "database/sql"
6 | "errors"
7 | "github.com/gocopper/copper/cerrors"
8 | "github.com/gocopper/copper/clogger"
9 | "net"
10 | "net/http"
11 | )
12 |
13 | // NewTxMiddleware creates a new TxMiddleware
14 | func NewTxMiddleware(db *sql.DB, querier Querier, config Config, logger clogger.Logger) *TxMiddleware {
15 | return &TxMiddleware{
16 | db: db,
17 | querier: querier,
18 | config: config,
19 | logger: logger,
20 | }
21 | }
22 |
23 | // TxMiddleware is a chttp.Middleware that wraps an HTTP request in a database transaction. If the request succeeds
24 | // (i.e. 2xx or 3xx response code), the transaction is committed. Else, the transaction is rolled back.
25 | type TxMiddleware struct {
26 | db *sql.DB
27 | querier Querier
28 | config Config
29 | logger clogger.Logger
30 | }
31 |
32 | // Handle implements the chttp.Middleware interface. See TxMiddleware
33 | func (m *TxMiddleware) Handle(next http.Handler) http.Handler {
34 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
35 | ctx, tx, err := CtxWithTx(r.Context(), m.db, m.config.Dialect)
36 | if err != nil {
37 | m.logger.Error("Failed to create context with database transaction", err)
38 | w.WriteHeader(http.StatusInternalServerError)
39 | return
40 | }
41 |
42 | defer func() {
43 | // Try a rollback in a deferred function to account for panics
44 | err := tx.Rollback()
45 | if err != nil && !errors.Is(err, sql.ErrTxDone) {
46 | m.logger.Error("Failed to rollback database transaction", err)
47 | w.WriteHeader(http.StatusInternalServerError)
48 | return
49 | }
50 |
51 | if err == nil {
52 | m.logger.Warn("Rolled back an unexpectedly open database transaction", nil)
53 | }
54 | }()
55 |
56 | next.ServeHTTP(&txnrw{
57 | internal: w,
58 | tx: tx,
59 | querier: m.querier,
60 | logger: m.logger,
61 | }, r.WithContext(ctx))
62 |
63 | err = m.querier.CommitTx(tx)
64 | if err != nil {
65 | m.logger.Error("Failed to commit database transaction", err)
66 | return
67 | }
68 | })
69 | }
70 |
71 | type txnrw struct {
72 | internal http.ResponseWriter
73 | tx *sql.Tx
74 | querier Querier
75 | logger clogger.Logger
76 | }
77 |
78 | func (w *txnrw) Header() http.Header {
79 | return w.internal.Header()
80 | }
81 |
82 | func (w *txnrw) Write(b []byte) (int, error) {
83 | err := w.querier.CommitTx(w.tx)
84 | if err != nil {
85 | return 0, cerrors.New(err, "failed to commit database transaction", nil)
86 | }
87 |
88 | return w.internal.Write(b)
89 | }
90 |
91 | func (w *txnrw) WriteHeader(statusCode int) {
92 | const MinErrStatusCode = 400
93 |
94 | if statusCode >= MinErrStatusCode {
95 | err := w.tx.Rollback()
96 | if err != nil && !errors.Is(err, sql.ErrTxDone) {
97 | w.logger.WithTags(map[string]interface{}{
98 | "originalStatusCode": statusCode,
99 | }).Error("Failed to rollback database transaction", err)
100 | w.internal.WriteHeader(http.StatusInternalServerError)
101 | return
102 | }
103 |
104 | w.internal.WriteHeader(statusCode)
105 | return
106 | }
107 |
108 | err := w.querier.CommitTx(w.tx)
109 | if err != nil {
110 | w.internal.WriteHeader(http.StatusInternalServerError)
111 | return
112 | }
113 |
114 | w.internal.WriteHeader(statusCode)
115 | }
116 |
117 | func (w *txnrw) Hijack() (net.Conn, *bufio.ReadWriter, error) {
118 | h, ok := w.internal.(http.Hijacker)
119 | if !ok {
120 | return nil, nil, errors.New("internal response writer is not http.Hijacker")
121 | }
122 |
123 | return h.Hijack()
124 | }
125 |
--------------------------------------------------------------------------------
/csql/tx_middleware_test.go:
--------------------------------------------------------------------------------
1 | package csql_test
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 |
10 | "github.com/gocopper/copper/clogger"
11 | "github.com/gocopper/copper/csql"
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | func TestTxMiddleware_Handle_Commit(t *testing.T) {
16 | t.Parallel()
17 |
18 | db, err := sql.Open("sqlite3", ":memory:")
19 | assert.NoError(t, err)
20 |
21 | _, err = db.Exec("create table people (name text)")
22 | assert.NoError(t, err)
23 |
24 | var (
25 | logger = clogger.NewNoop()
26 | config = csql.Config{Dialect: "sqlite3"}
27 | querier = csql.NewQuerier(db, config, logger)
28 | mw = csql.NewTxMiddleware(db, config, logger)
29 | )
30 |
31 | req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil)
32 | assert.NoError(t, err)
33 |
34 | mw.Handle(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
35 | _, err := querier.Exec(r.Context(), "insert into people (name) values ('test')")
36 | assert.NoError(t, err)
37 | })).ServeHTTP(httptest.NewRecorder(), req)
38 |
39 | rows, err := db.Query("select * from people")
40 | assert.NoError(t, err)
41 | assert.NoError(t, rows.Err())
42 |
43 | assert.True(t, rows.Next())
44 | }
45 |
46 | func TestTxMiddleware_Handle_Rollback(t *testing.T) {
47 | t.Parallel()
48 |
49 | db, err := sql.Open("sqlite3", ":memory:")
50 | assert.NoError(t, err)
51 |
52 | _, err = db.Exec("create table people (name text)")
53 | assert.NoError(t, err)
54 |
55 | var (
56 | logger = clogger.NewNoop()
57 | config = csql.Config{Dialect: "sqlite3"}
58 | querier = csql.NewQuerier(db, config, logger)
59 | mw = csql.NewTxMiddleware(db, config, logger)
60 | )
61 |
62 | req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil)
63 | assert.NoError(t, err)
64 |
65 | mw.Handle(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
66 | _, err := querier.Exec(r.Context(), "insert into people (name) values ('test')")
67 | assert.NoError(t, err)
68 |
69 | w.WriteHeader(http.StatusInternalServerError)
70 | })).ServeHTTP(httptest.NewRecorder(), req)
71 |
72 | rows, err := db.Query("select * from people")
73 | assert.NoError(t, err)
74 | assert.NoError(t, rows.Err())
75 |
76 | assert.False(t, rows.Next())
77 | }
78 |
--------------------------------------------------------------------------------
/csql/wire.go:
--------------------------------------------------------------------------------
1 | package csql
2 |
3 | import "github.com/google/wire"
4 |
5 | // WireModule can be used as part of google/wire setup.
6 | var WireModule = wire.NewSet(
7 | NewDBConnection,
8 | NewQuerier,
9 | NewMigrator,
10 | LoadConfig,
11 | NewTxMiddleware,
12 |
13 | wire.Struct(new(NewMigratorParams), "*"),
14 | )
15 |
--------------------------------------------------------------------------------
/doc.go:
--------------------------------------------------------------------------------
1 | // Package copper encapsulates everything you need to build apps quickly
2 | package copper
3 |
--------------------------------------------------------------------------------
/flags.go:
--------------------------------------------------------------------------------
1 | package copper
2 |
3 | import (
4 | "flag"
5 |
6 | "github.com/gocopper/copper/cconfig"
7 | )
8 |
9 | // Flags holds flag values passed in via command line. These can be used to configure the app environment
10 | // and override the config directory.
11 | type Flags struct {
12 | ConfigPath cconfig.Path
13 | ConfigOverrides cconfig.Overrides
14 | }
15 |
16 | // NewFlags reads the command line flags and returns Flags with the values set.
17 | func NewFlags() *Flags {
18 | var (
19 | configPath = flag.String("config", "./config/dev.toml", "Path to config file")
20 | configOverrides = flag.String("set", "", "Config overrides ex. \"chttp.port=5902\". Separate multiple overrides with ;")
21 | )
22 |
23 | flag.Parse()
24 |
25 | return &Flags{
26 | ConfigPath: cconfig.Path(*configPath),
27 | ConfigOverrides: cconfig.Overrides(*configOverrides),
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/gocopper/copper
2 |
3 | go 1.22
4 |
5 | require (
6 | github.com/Masterminds/sprig/v3 v3.2.3
7 | github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf
8 | github.com/google/uuid v1.6.0
9 | github.com/google/wire v0.6.0
10 | github.com/gorilla/mux v1.8.1
11 | github.com/iancoleman/strcase v0.3.0
12 | github.com/jmoiron/sqlx v1.3.5
13 | github.com/mattn/go-sqlite3 v1.14.18
14 | github.com/pelletier/go-toml v1.9.3
15 | github.com/prometheus/client_golang v1.20.3
16 | github.com/rubenv/sql-migrate v1.1.2
17 | github.com/shopspring/decimal v1.2.0
18 | github.com/stretchr/testify v1.9.0
19 | go.uber.org/zap v1.27.0
20 | )
21 |
22 | require (
23 | github.com/Masterminds/goutils v1.1.1 // indirect
24 | github.com/Masterminds/semver/v3 v3.2.0 // indirect
25 | github.com/beorn7/perks v1.0.1 // indirect
26 | github.com/cespare/xxhash/v2 v2.3.0 // indirect
27 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
28 | github.com/go-gorp/gorp/v3 v3.0.2 // indirect
29 | github.com/go-sql-driver/mysql v1.7.1 // indirect
30 | github.com/huandu/xstrings v1.3.3 // indirect
31 | github.com/imdario/mergo v0.3.12 // indirect
32 | github.com/lib/pq v1.10.2 // indirect
33 | github.com/mitchellh/copystructure v1.0.0 // indirect
34 | github.com/mitchellh/reflectwalk v1.0.0 // indirect
35 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
36 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
37 | github.com/prometheus/client_model v0.6.1 // indirect
38 | github.com/prometheus/common v0.59.1 // indirect
39 | github.com/prometheus/procfs v0.15.1 // indirect
40 | github.com/rogpeppe/go-internal v1.12.0 // indirect
41 | github.com/sirupsen/logrus v1.9.3 // indirect
42 | github.com/spf13/cast v1.7.0 // indirect
43 | go.uber.org/multierr v1.11.0 // indirect
44 | golang.org/x/crypto v0.31.0 // indirect
45 | golang.org/x/sys v0.28.0 // indirect
46 | google.golang.org/protobuf v1.34.2 // indirect
47 | gopkg.in/yaml.v3 v3.0.1 // indirect
48 | )
49 |
--------------------------------------------------------------------------------
/wire.go:
--------------------------------------------------------------------------------
1 | //go:build wireinject
2 | // +build wireinject
3 |
4 | package copper
5 |
6 | import (
7 | "github.com/gocopper/copper/cconfig"
8 | "github.com/gocopper/copper/clifecycle"
9 | "github.com/gocopper/copper/clogger"
10 | "github.com/google/wire"
11 | )
12 |
13 | // InitApp creates a new Copper app along with its dependencies.
14 | func InitApp() (*App, error) {
15 | panic(
16 | wire.Build(
17 | NewApp,
18 | NewFlags,
19 | clifecycle.New,
20 | cconfig.NewWithKeyOverrides,
21 | clogger.NewWithConfig,
22 | clogger.LoadConfig,
23 |
24 | wire.FieldsOf(new(*Flags), "ConfigPath", "ConfigOverrides"),
25 | ),
26 | )
27 | }
28 |
29 | // WireModule can be used as part of google/wire setup to include the app's
30 | // lifecycle, config, and logger.
31 | var WireModule = wire.NewSet(
32 | wire.FieldsOf(new(*App), "Lifecycle", "Config", "Logger"),
33 | )
34 |
--------------------------------------------------------------------------------
/wire_gen.go:
--------------------------------------------------------------------------------
1 | // Code generated by Wire. DO NOT EDIT.
2 |
3 | //go:generate go run github.com/google/wire/cmd/wire
4 | //go:build !wireinject
5 | // +build !wireinject
6 |
7 | package copper
8 |
9 | import (
10 | "github.com/gocopper/copper/cconfig"
11 | "github.com/gocopper/copper/clifecycle"
12 | "github.com/gocopper/copper/clogger"
13 | "github.com/google/wire"
14 | )
15 |
16 | // Injectors from wire.go:
17 |
18 | // InitApp creates a new Copper app along with its dependencies.
19 | func InitApp() (*App, error) {
20 | lifecycle := clifecycle.New()
21 | flags := NewFlags()
22 | path := flags.ConfigPath
23 | overrides := flags.ConfigOverrides
24 | loader, err := cconfig.NewWithKeyOverrides(path, overrides)
25 | if err != nil {
26 | return nil, err
27 | }
28 | config, err := clogger.LoadConfig(loader)
29 | if err != nil {
30 | return nil, err
31 | }
32 | logger, err := clogger.NewWithConfig(config)
33 | if err != nil {
34 | return nil, err
35 | }
36 | app := NewApp(lifecycle, loader, logger)
37 | return app, nil
38 | }
39 |
40 | // wire.go:
41 |
42 | // WireModule can be used as part of google/wire setup to include the app's
43 | // lifecycle, config, and logger.
44 | var WireModule = wire.NewSet(wire.FieldsOf(new(*App), "Lifecycle", "Config", "Logger"))
45 |
--------------------------------------------------------------------------------