├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE.md └── workflows │ ├── cgo.yml │ ├── codeql.yaml │ ├── mysql-volumeless.yml │ └── nocgo.yml ├── .gitignore ├── .gitleaksignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── admin ├── README.md ├── admin.go ├── admin_test.go ├── health.go ├── health_test.go └── pprof_test.go ├── api └── common.yaml ├── build ├── log.go └── log_test.go ├── config ├── config.go ├── config_test.go └── testdata │ ├── with-search-and-security.yml │ ├── with-widget-secrets.yml │ └── with-widgets.yml ├── configs ├── config.app.yml ├── config.default.yml ├── config.extra.yml └── config.secrets.yml ├── database ├── database.go ├── database_test.go ├── error.go ├── migrator.go ├── model_config.go ├── model_config_test.go ├── mysql.go ├── mysql_test.go ├── pkger.go ├── postgres.go ├── postgres_test.go ├── spanner.go ├── spanner_test.go ├── testdata │ └── gencerts.sh ├── testdb │ ├── mysql.go │ ├── postgres.go │ └── spanner.go ├── tls.go ├── tls_test.go └── tx.go ├── doc.go ├── docker-compose.yml ├── docker ├── docker.go └── docker_test.go ├── error.go ├── error_test.go ├── go.mod ├── go.sum ├── http ├── bind │ ├── bind.go │ └── bind_test.go ├── doc.go ├── response.go ├── response_test.go ├── server.go └── server_test.go ├── id.go ├── id_test.go ├── k8s ├── kubernetes.go └── kubernetes_test.go ├── log ├── README.md ├── logger.go ├── logger_impl.go ├── logger_test.go ├── model_fields.go ├── model_levels.go ├── model_stacktrace.go ├── model_valuer.go ├── model_valuer_test.go ├── struct_context.go └── struct_context_test.go ├── makefile ├── mask ├── password.go └── password_test.go ├── migrations ├── 001_create_tests.up.sql ├── 002_create_tests.up.mysql.sql ├── 002_create_tests.up.postgres.sql ├── 002_create_tests.up.spanner.sql └── 002_create_tests.up.sqlite.sql ├── mysql ├── Dockerfile ├── README.md ├── makefile └── test-mysql-is-ready.sh ├── package.go ├── randx ├── randx.go └── randx_test.go ├── renovate.json ├── sql ├── db.go ├── sql.go ├── sql_test.go ├── stmt.go └── tx.go ├── stime ├── static_time.go └── system_time.go ├── strx ├── strx.go └── strx_test.go ├── telemetry ├── README.md ├── attributes.go ├── attributes_test.go ├── collector.go ├── config.go ├── config_test.go ├── env.go ├── exporter.go ├── honey.go ├── linked.go ├── stdout.go ├── tracer.go └── tracer_test.go ├── time.go ├── time.rb └── time_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @adamdecaf @infernojj -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | **What were you trying to do?** 4 | 5 | 6 | 7 | **What did you expect to see?** 8 | 9 | 10 | 11 | **What did you see?** 12 | 13 | 14 | 15 | **How can we reproduce the problem?** 16 | -------------------------------------------------------------------------------- /.github/workflows/cgo.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | name: Go Build (CGO) 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest, macos-latest, windows-latest] 16 | version: [stable, oldstable] 17 | steps: 18 | - name: Set up Go 1.x 19 | uses: actions/setup-go@v5 20 | with: 21 | go-version: ${{ matrix.version }} 22 | id: go 23 | 24 | - name: Check out code into the Go module directory 25 | uses: actions/checkout@v4 26 | 27 | - name: Install make (Windows) 28 | if: runner.os == 'Windows' 29 | run: choco install -y make mingw 30 | 31 | - name: Setup 32 | if: runner.os == 'Linux' 33 | run: make setup 34 | 35 | - name: Check 36 | if: runner.os == 'Linux' 37 | run: make check 38 | 39 | - name: Short Check 40 | if: runner.os != 'Linux' 41 | run: make check 42 | env: 43 | GOTEST_FLAGS: "-short" 44 | 45 | - name: Logs 46 | if: failure() && runner.os == 'Linux' 47 | run: docker compose logs 48 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yaml: -------------------------------------------------------------------------------- 1 | name: CodeQL Analysis 2 | 3 | on: 4 | push: 5 | pull_request: 6 | schedule: 7 | - cron: '0 0 * * 0' 8 | 9 | jobs: 10 | CodeQL-Build: 11 | strategy: 12 | fail-fast: false 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v4 17 | 18 | - name: Initialize CodeQL 19 | uses: github/codeql-action/init@v3 20 | with: 21 | languages: go 22 | 23 | - name: Perform CodeQL Analysis 24 | uses: github/codeql-action/analyze@v3 25 | -------------------------------------------------------------------------------- /.github/workflows/mysql-volumeless.yml: -------------------------------------------------------------------------------- 1 | name: Publish MySQL Docker Image 2 | 3 | on: 4 | push: 5 | tags: [ "v*.*.*" ] 6 | 7 | jobs: 8 | docker: 9 | name: Build and push MySQL volumeless docker image 10 | runs-on: ubuntu-latest 11 | defaults: 12 | run: 13 | working-directory: ./mysql 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | 18 | - name: Build docker image from MySQL initialized DB inside container 19 | run: make build 20 | 21 | - name: Docker Push 22 | run: |+ 23 | echo "$DOCKER_PASSWORD" | docker login -u "$DOCKER_USERNAME" --password-stdin 24 | make push 25 | env: 26 | DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} 27 | DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} 28 | -------------------------------------------------------------------------------- /.github/workflows/nocgo.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | name: Go Build (No CGO) 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest, macos-latest, windows-latest] 16 | version: [stable, oldstable] 17 | steps: 18 | - name: Set up Go 1.x 19 | uses: actions/setup-go@v5 20 | with: 21 | go-version: ${{ matrix.version }} 22 | id: go 23 | 24 | - name: Check out code into the Go module directory 25 | uses: actions/checkout@v4 26 | 27 | - name: Install make (Windows) 28 | if: runner.os == 'Windows' 29 | run: choco install -y make mingw 30 | 31 | - name: Setup 32 | if: runner.os == 'Linux' 33 | run: make setup 34 | 35 | - name: Check 36 | if: runner.os == 'Linux' 37 | run: make check 38 | env: 39 | CGO_ENABLED: "0" 40 | 41 | - name: Short Check 42 | if: runner.os != 'Linux' 43 | run: make check 44 | env: 45 | CGO_ENABLED: "0" 46 | GOTEST_FLAGS: "-short" 47 | 48 | - name: Logs 49 | if: failure() && runner.os == 'Linux' 50 | run: docker compose logs 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #### joe made this: http://goel.io/joe 2 | 3 | #####=== Go ===##### 4 | 5 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 6 | *.o 7 | *.a 8 | *.so 9 | 10 | # Folders 11 | _obj 12 | _test 13 | 14 | # VSCode ignore 15 | .vscode/* 16 | !.vscode/settings.json 17 | !.vscode/tasks.json 18 | !.vscode/launch.json 19 | !.vscode/extensions.json 20 | 21 | # Architecture specific extensions/prefixes 22 | *.[568vq] 23 | [568vq].out 24 | 25 | *.cgo1.go 26 | *.cgo2.c 27 | _cgo_defun.c 28 | _cgo_gotypes.go 29 | _cgo_export.* 30 | 31 | _testmain.go 32 | 33 | *.exe 34 | *.test 35 | *.prof 36 | 37 | /vendor/ 38 | /bin/ 39 | misspell* 40 | staticcheck* 41 | lint-project.sh 42 | gitleaks.tar.gz 43 | 44 | .vscode/launch.json 45 | 46 | # code coverage 47 | coverage.html 48 | cover.out 49 | coverage.txt 50 | 51 | *.pyc 52 | 53 | .idea/* 54 | 55 | testcerts/* -------------------------------------------------------------------------------- /.gitleaksignore: -------------------------------------------------------------------------------- 1 | testcerts/server.key 2 | testcerts/client.key:private-key:1 3 | testcerts/root.key:private-key:1 4 | testcerts/server.key:private-key:1 5 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | * Using welcoming and inclusive language 12 | * Being respectful of differing viewpoints and experiences 13 | * Gracefully accepting constructive criticism 14 | * Focusing on what is best for the community 15 | * Showing empathy towards other community members 16 | 17 | Examples of unacceptable behavior by participants include: 18 | 19 | * The use of sexualized language or imagery and unwelcome sexual attention or advances 20 | * Trolling, insulting/derogatory comments, and personal or political attacks 21 | * Public or private harassment 22 | * Publishing others' private information, such as a physical or electronic address, without explicit permission 23 | * Other conduct which could reasonably be considered inappropriate in a professional setting 24 | 25 | ## Our Responsibilities 26 | 27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 28 | 29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 30 | 31 | ## Scope 32 | 33 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 34 | 35 | ## Enforcement 36 | 37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at wade@wadearnold.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 38 | 39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 40 | 41 | ## Attribution 42 | 43 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] 44 | 45 | [homepage]: http://contributor-covenant.org 46 | [version]: http://contributor-covenant.org/version/1/4/ 47 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Wow, we really appreciate that you even looked at this section! We are trying to make the worlds best atomic building blocks for financial services that accelerate innovation in banking and we need your help! 4 | 5 | You only have a fresh set of eyes once! The easiest way to contribute is to give feedback on the documentation that you are reading right now. This can be as simple as sending a message to our Google Group with your feedback or updating the markdown in this documentation and issuing a pull request. 6 | 7 | Stability is the hallmark of any good software. If you find an edge case that isn't handled please open an GitHub issue with the example data so that we can make our software more robust for everyone. We also welcome pull requests if you want to get your hands dirty. 8 | 9 | Have a use case that we don't handle; or handle well! Start the discussion on our Google Group or open a GitHub Issue. We want to make the project meet the needs of the community and keeps you using our code. 10 | 11 | Please review our [Code of Conduct](CODE_OF_CONDUCT.md) to ensure you agree with the values of this project. 12 | 13 | We use GitHub to manage reviews of pull requests. 14 | 15 | * If you have a trivial fix or improvement, go ahead and create a pull 16 | request, addressing (with `@...`) one or more of the maintainers 17 | (see [AUTHORS.md](AUTHORS.md)) in the description of the pull request. 18 | 19 | * If you plan to do something more involved, first propose your ideas 20 | in a Github issue. This will avoid unnecessary work and surely give 21 | you and us a good deal of inspiration. 22 | 23 | * Relevant coding style guidelines are the [Go Code Review 24 | Comments](https://code.google.com/p/go-wiki/wiki/CodeReviewComments) 25 | and the _Formatting and style_ section of Peter Bourgon's [Go: Best 26 | Practices for Production 27 | Environments](http://peter.bourgon.org/go-in-production/#formatting-and-style). 28 | 29 | * When in doubt follow the [Go Proverbs](https://go-proverbs.github.io/) 30 | 31 | ## Getting the code 32 | 33 | We recommend using additional git remote's for pushing/pulling code. Go cares about where the `ach` project lives relative to [GOPATH](http://golang.org/doc/code.html#GOPATH). 34 | 35 | To pull our source code run: 36 | 37 | ``` 38 | $ go get github.com/moov-io/ach 39 | ``` 40 | 41 | Then, add your (or another user's) fork. 42 | 43 | ``` 44 | $ cd $GOPATH/src/github.com/moov-io/ach 45 | 46 | $ git remote add $user git@github.com:$user/ach.git 47 | 48 | $ git fetch $user 49 | ``` 50 | 51 | Now, feel free to branch and push (`git push $user $branch`) to your remote and send us Pull Requests! 52 | 53 | ## Pull Requests 54 | 55 | A good quality PR will have the following characteristics: 56 | 57 | * It will be a complete piece of work that adds value in some way. 58 | * It will have a title that reflects the work within, and a summary that helps to understand the context of the change. 59 | * There will be well written commit messages, with well crafted commits that tell the story of the development of this work. 60 | * Ideally it will be small and easy to understand. Single commit PRs are usually easy to submit, review, and merge. 61 | * The code contained within will meet the best practices set by the team wherever possible. 62 | * The code is able to be merged. 63 | * A PR does not end at submission though. A code change is not made until it is merged and used in production. 64 | 65 | A good PR should be able to flow through a peer review system easily and quickly. 66 | 67 | Our Build pipeline utilizes [Travis-CI](https://travis-ci.org/moov-io/ach) to enforce many tools that you should add to your editor before issuing a pull request. Learn more about these tools on our [Go Report card](https://goreportcard.com/report/github.com/moov-io/ach) 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | moov-io/base 2 | === 3 | [![GoDoc](https://godoc.org/github.com/moov-io/base?status.svg)](https://godoc.org/github.com/moov-io/base) 4 | [![Build Status](https://github.com/moov-io/base/workflows/Go/badge.svg)](https://github.com/moov-io/base/actions) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/moov-io/base)](https://goreportcard.com/report/github.com/moov-io/base) 6 | [![Apbasee 2 licensed](https://img.shields.io/badge/license-Apbasee2-blue.svg)](https://raw.githubusercontent.com/moov-io/base/master/LICENSE) 7 | 8 | Package `github.com/moov-io/base` implements core libraries used in multiple Moov projects. Refer to each projects documentation for more details. 9 | 10 | ## Getting Started 11 | 12 | You can either clone down the code (`git clone git@github.com:moov-io/base.git`) or grab the modules into your cache (`go get -u github.com/moov-io/base`). 13 | 14 | ## Configuration 15 | 16 | | Environmental Variable | Description | Default | 17 | |---------------------------------------|----------------------------------------|----------------------------------| 18 | | `KUBERNETES_SERVICE_ACCOUNT_FILEPATH` | Filepath to Kubernetes service account | `/var/run/secrets/kubernetes.io` | 19 | 20 | ## Getting Help 21 | 22 | channel | info 23 | ------- | ------- 24 | Twitter [@moov](https://twitter.com/moov) | You can follow Moov.io's Twitter feed to get updates on our project(s). You can also tweet us questions or just share blogs or stories. 25 | [GitHub Issue](https://github.com/moov-io/base/issues) | If you are able to reproduce a problem please open a GitHub Issue under the specific project that caused the error. 26 | [moov-io slack](https://slack.moov.io/) | Join our slack channel to have an interactive discussion about the development of the project. 27 | 28 | ## Supported and Tested Platforms 29 | 30 | - 64-bit Linux (Ubuntu, Debian), macOS, and Windows 31 | 32 | ## Contributing 33 | 34 | Yes please! Please review our [Contributing guide](CONTRIBUTING.md) and [Code of Conduct](CODE_OF_CONDUCT.md) to get started! 35 | 36 | This project uses [Go Modules](https://github.com/golang/go/wiki/Modules) and uses Go 1.14 or higher. See [Golang's install instructions](https://golang.org/doc/install) for help setting up Go. You can download the source code and we offer [tagged and released versions](https://github.com/moov-io/base/releases/latest) as well. We highly recommend you use a tagged release for production. 37 | 38 | ## License 39 | 40 | Apbasee License 2.0 See [LICENSE](LICENSE) for details. 41 | -------------------------------------------------------------------------------- /admin/README.md: -------------------------------------------------------------------------------- 1 | ## moov-io/base/admin 2 | 3 | Package admin implements an `http.Server` which can be used for operations and monitoring tools. It's designed to be shipped (and ran) inside an existing Go service. 4 | 5 | Here's an example of adding `admin.Server` to serve Prometheus metrics: 6 | 7 | ```Go 8 | import ( 9 | "fmt" 10 | "os" 11 | 12 | "github.com/moov-io/base/admin" 13 | 14 | "github.com/go-kit/log" 15 | ) 16 | 17 | var logger log.Logger 18 | 19 | // in main.go or cmd/server/main.go 20 | 21 | adminServer, err := admin.New(Opts{ 22 | Addr: ":9090", 23 | }) 24 | if err != nil { 25 | // handle error 26 | } 27 | go func() { 28 | logger.Log("admin", fmt.Sprintf("listening on %s", adminServer.BindAddr())) 29 | if err := adminServer.Listen(); err != nil { 30 | err = fmt.Errorf("problem starting admin http: %v", err) 31 | logger.Log("admin", err) 32 | // errs <- err // send err to shutdown channel 33 | } 34 | }() 35 | defer adminServer.Shutdown() 36 | ``` 37 | 38 | ### Endpoints 39 | 40 | An Admin server has some default endpoints that are useful for operational support and monitoring. 41 | 42 | #### Liveness Probe 43 | 44 | This endpoint inspects a set of liveness functions and returns `200 OK` if all functions return without errors. If errors are found then a `400 Bad Request` response with a JSON object is returned describing the errors. 45 | 46 | ``` 47 | GET /live 48 | ``` 49 | 50 | Liveness probes can be registered with the following callback: 51 | 52 | ``` 53 | func (s *Server) AddLivenessCheck(name string, f func() error) 54 | ``` 55 | 56 | #### Readiness Probe 57 | 58 | This endpoint inspects a set of readiness functions and returns `200 OK` if all functions return without errors. If errors are found then a `400 Bad Request` response with a JSON object is returned describing the errors. 59 | 60 | ``` 61 | GET /ready 62 | ``` 63 | 64 | Readiness probes can be registered with the following callback: 65 | 66 | ``` 67 | func (s *Server) AddReadinessCheck(name string, f func() error) 68 | ``` 69 | 70 | ### Metrics 71 | 72 | This endpoint returns prometheus metrics registered to the [prometheus/client_golang](https://github.com/prometheus/client_golang) singleton metrics registry. Their `promauto` package can be used to add Counters, Guages, Histograms, etc. The default Go metrics provided by `prometheus/client_golang` are included. 73 | 74 | ``` 75 | GET /metrics 76 | 77 | ... 78 | # HELP promhttp_metric_handler_requests_total Total number of scrapes by HTTP status code. 79 | # TYPE promhttp_metric_handler_requests_total counter 80 | promhttp_metric_handler_requests_total{code="200"} 0 81 | promhttp_metric_handler_requests_total{code="500"} 0 82 | promhttp_metric_handler_requests_total{code="503"} 0 83 | # HELP stream_file_processing_errors Counter of stream submitted ACH files that failed processing 84 | # TYPE stream_file_processing_errors counter 85 | stream_file_processing_errors 0 86 | ``` 87 | -------------------------------------------------------------------------------- /admin/admin.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package admin implements an http.Server which can be used for operations 6 | // and monitoring tools. It's designed to be shipped (and ran) inside 7 | // an existing Go service. 8 | package admin 9 | 10 | import ( 11 | "context" 12 | "fmt" 13 | "net" 14 | "net/http" 15 | "net/http/pprof" 16 | "os" 17 | "runtime" 18 | "strings" 19 | "time" 20 | 21 | "github.com/gorilla/mux" 22 | "github.com/prometheus/client_golang/prometheus/promhttp" 23 | ) 24 | 25 | type Opts struct { 26 | Addr string 27 | Timeout time.Duration 28 | } 29 | 30 | // New returns an admin.Server instance that handles Prometheus metrics and pprof requests. 31 | // Callers can use ':0' to bind onto a random port and call BindAddr() for the address. 32 | func New(opts Opts) (*Server, error) { 33 | timeout, _ := time.ParseDuration("45s") 34 | if opts.Timeout >= 0*time.Second { 35 | timeout = opts.Timeout 36 | } 37 | 38 | var listener net.Listener 39 | var err error 40 | if opts.Addr == "" || opts.Addr == ":0" { 41 | listener, err = net.Listen("tcp", "127.0.0.1:0") 42 | } else { 43 | listener, err = net.Listen("tcp", opts.Addr) 44 | } 45 | if err != nil { 46 | return nil, fmt.Errorf("listening on %s failed: %v", opts.Addr, err) 47 | } 48 | 49 | router := handler() 50 | svc := &Server{ 51 | router: router, 52 | listener: listener, 53 | svc: &http.Server{ 54 | Addr: listener.Addr().String(), 55 | Handler: router, 56 | ReadTimeout: timeout, 57 | WriteTimeout: timeout, 58 | IdleTimeout: timeout, 59 | }, 60 | } 61 | 62 | svc.AddHandler("/live", svc.livenessHandler()) 63 | svc.AddHandler("/ready", svc.readinessHandler()) 64 | return svc, nil 65 | } 66 | 67 | // Server represents a holder around a net/http Server which 68 | // is used for admin endpoints. (i.e. metrics, healthcheck) 69 | type Server struct { 70 | router *mux.Router 71 | svc *http.Server 72 | listener net.Listener 73 | 74 | liveChecks []*healthCheck 75 | readyChecks []*healthCheck 76 | } 77 | 78 | // BindAddr returns the server's bind address. This is in Go's format so :8080 is valid. 79 | func (s *Server) BindAddr() string { 80 | if s == nil || s.svc == nil { 81 | return "" 82 | } 83 | return s.listener.Addr().String() 84 | } 85 | 86 | func (s *Server) SetReadTimeout(timeout time.Duration) { 87 | if s == nil || s.svc == nil { 88 | return 89 | } 90 | s.svc.ReadTimeout = timeout 91 | } 92 | 93 | func (s *Server) SetWriteTimeout(timeout time.Duration) { 94 | if s == nil || s.svc == nil { 95 | return 96 | } 97 | s.svc.WriteTimeout = timeout 98 | } 99 | 100 | func (s *Server) SetIdleTimeout(timeout time.Duration) { 101 | if s == nil || s.svc == nil { 102 | return 103 | } 104 | s.svc.IdleTimeout = timeout 105 | } 106 | 107 | // Listen brings up the admin HTTP server. This call blocks until the server is Shutdown or panics. 108 | func (s *Server) Listen() error { 109 | if s == nil || s.svc == nil || s.listener == nil { 110 | return nil 111 | } 112 | return s.svc.Serve(s.listener) 113 | } 114 | 115 | // Shutdown unbinds the HTTP server. 116 | func (s *Server) Shutdown() { 117 | if s == nil || s.svc == nil { 118 | return 119 | } 120 | s.svc.Shutdown(context.TODO()) 121 | } 122 | 123 | // AddHandler will append an http.HandlerFunc to the admin Server 124 | func (s *Server) AddHandler(path string, hf http.HandlerFunc) { 125 | s.router.HandleFunc(path, hf) 126 | } 127 | 128 | // AddVersionHandler will append 'GET /version' route returning the provided version 129 | func (s *Server) AddVersionHandler(version string) { 130 | s.AddHandler("/version", func(w http.ResponseWriter, r *http.Request) { 131 | w.WriteHeader(http.StatusOK) 132 | w.Write([]byte(version)) 133 | }) 134 | } 135 | 136 | // Subrouter creates and returns a subrouter with the specific prefix. 137 | // 138 | // The returned subrouter can use middleware without impacting 139 | // the parent router. For example: 140 | // 141 | // svr, err := New(Opts{ 142 | // Addr: ":9090", 143 | // }) 144 | // 145 | // subRouter := svr.Subrouter("/prefix") 146 | // subRouter.Use(someMiddleware) 147 | // subRouter.HandleFunc("/resource", ResourceHandler) 148 | // 149 | // Here, requests for "/prefix/resource" would go through someMiddleware while 150 | // the liveliness and readiness routes added to the parent router by New() 151 | // would not. 152 | func (s *Server) Subrouter(pathPrefix string) *mux.Router { 153 | return s.router.PathPrefix(pathPrefix).Subrouter() 154 | } 155 | 156 | // profileEnabled returns if a given pprof handler should be 157 | // enabled according to pprofHandlers and the PPROF_* environment 158 | // variables. 159 | // 160 | // These profiles can be disabled by setting the appropriate PPROF_* 161 | // environment variable. (i.e. PPROF_ALLOCS=no) 162 | // 163 | // An empty string, "yes", or "true" enables the profile. Any other 164 | // value disables the profile. 165 | func profileEnabled(name string) bool { 166 | k := fmt.Sprintf("PPROF_%s", strings.ToUpper(name)) 167 | v := strings.ToLower(os.Getenv(k)) 168 | return v == "" || v == "yes" || v == "true" 169 | } 170 | 171 | // Handler returns an http.Handler for the admin http service. 172 | // This contains metrics and pprof handlers. 173 | // 174 | // No metrics specific to the handler are recorded. 175 | // 176 | // We only want to expose on the admin servlet because these 177 | // profiles/dumps can contain sensitive info (raw memory). 178 | func Handler() http.Handler { 179 | return handler() 180 | } 181 | 182 | func handler() *mux.Router { 183 | r := mux.NewRouter() 184 | 185 | // prometheus metrics 186 | r.Path("/metrics").Handler(promhttp.Handler()) 187 | 188 | // always register index and cmdline handlers 189 | r.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index)) 190 | r.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) 191 | 192 | if profileEnabled("profile") { 193 | r.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) 194 | } 195 | if profileEnabled("symbol") { 196 | r.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) 197 | } 198 | if profileEnabled("trace") { 199 | r.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace)) 200 | } 201 | 202 | // Register runtime/pprof handlers 203 | if profileEnabled("allocs") { 204 | r.Handle("/debug/pprof/allocs", pprof.Handler("allocs")) 205 | } 206 | if profileEnabled("block") { 207 | runtime.SetBlockProfileRate(1) 208 | r.Handle("/debug/pprof/block", pprof.Handler("block")) 209 | } 210 | if profileEnabled("goroutine") { 211 | r.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine")) 212 | } 213 | if profileEnabled("heap") { 214 | r.Handle("/debug/pprof/heap", pprof.Handler("heap")) 215 | } 216 | if profileEnabled("mutex") { 217 | runtime.SetMutexProfileFraction(1) 218 | r.Handle("/debug/pprof/mutex", pprof.Handler("mutex")) 219 | } 220 | if profileEnabled("threadcreate") { 221 | r.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate")) 222 | } 223 | 224 | return r 225 | } 226 | -------------------------------------------------------------------------------- /admin/admin_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package admin 6 | 7 | import ( 8 | "io/ioutil" 9 | "net/http" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestAdmin__pprof(t *testing.T) { 17 | svc, err := New(Opts{Addr: ":13983"}) // hopefully nothing locally has this 18 | require.NoError(t, err) 19 | 20 | go svc.Listen() 21 | defer svc.Shutdown() 22 | 23 | // Check for Prometheus metrics endpoint 24 | resp, err := http.DefaultClient.Get("http://localhost:13983/metrics") 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | if resp.StatusCode != http.StatusOK { 29 | t.Errorf("bogus HTTP status code: %s", resp.Status) 30 | } 31 | resp.Body.Close() 32 | 33 | // Check always on pprof endpoint 34 | resp, err = http.DefaultClient.Get("http://localhost:13983/debug/pprof/cmdline") 35 | if err != nil { 36 | t.Fatal(err) 37 | } 38 | if resp.StatusCode != http.StatusOK { 39 | t.Errorf("bogus HTTP status code: %s", resp.Status) 40 | } 41 | resp.Body.Close() 42 | } 43 | 44 | func TestAdmin__AddHandler(t *testing.T) { 45 | svc, err := New(Opts{Addr: ":13984"}) 46 | require.NoError(t, err) 47 | 48 | go svc.Listen() 49 | defer svc.Shutdown() 50 | 51 | special := func(w http.ResponseWriter, r *http.Request) { 52 | if r.URL.Path != "/special-path" { 53 | w.WriteHeader(http.StatusBadRequest) 54 | return 55 | } 56 | w.WriteHeader(http.StatusOK) 57 | w.Write([]byte("special")) 58 | } 59 | svc.AddHandler("/special-path", special) 60 | 61 | req, err := http.NewRequest("GET", "http://localhost:13984/special-path", nil) 62 | if err != nil { 63 | t.Fatal(err) 64 | } 65 | resp, err := http.DefaultClient.Do(req) 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | defer resp.Body.Close() 70 | 71 | if resp.StatusCode != http.StatusOK { 72 | t.Errorf("bogus HTTP status: %d", resp.StatusCode) 73 | } 74 | bs, _ := ioutil.ReadAll(resp.Body) 75 | if v := string(bs); v != "special" { 76 | t.Errorf("response was %q", v) 77 | } 78 | } 79 | 80 | func TestAdmin__fullAddress(t *testing.T) { 81 | svc, err := New(Opts{Addr: "127.0.0.1:13985"}) 82 | require.NoError(t, err) 83 | 84 | go svc.Listen() 85 | defer svc.Shutdown() 86 | 87 | resp, err := http.DefaultClient.Get("http://localhost:13985/metrics") 88 | if err != nil { 89 | t.Fatal(err) 90 | } 91 | if resp.StatusCode != http.StatusOK { 92 | t.Errorf("bogus HTTP status code: %s", resp.Status) 93 | } 94 | resp.Body.Close() 95 | } 96 | 97 | func TestAdmin__AddVersionHandler(t *testing.T) { 98 | svc, err := New(Opts{Addr: ":0"}) 99 | require.NoError(t, err) 100 | 101 | go svc.Listen() 102 | defer svc.Shutdown() 103 | 104 | svc.AddVersionHandler("v0.1.0") 105 | 106 | req, err := http.NewRequest("GET", "http://"+svc.BindAddr()+"/version", nil) 107 | if err != nil { 108 | t.Fatal(err) 109 | } 110 | resp, err := http.DefaultClient.Do(req) 111 | if err != nil { 112 | t.Fatal(err) 113 | } 114 | defer resp.Body.Close() 115 | 116 | if resp.StatusCode != http.StatusOK { 117 | t.Errorf("bogus HTTP status: %d", resp.StatusCode) 118 | } 119 | bs, _ := ioutil.ReadAll(resp.Body) 120 | if v := string(bs); v != "v0.1.0" { 121 | t.Errorf("got %s", v) 122 | } 123 | } 124 | 125 | func TestAdmin__Listen(t *testing.T) { 126 | svc := &Server{} 127 | if err := svc.Listen(); err != nil { 128 | t.Error("expected no error") 129 | } 130 | 131 | svc = nil 132 | if err := svc.Listen(); err != nil { 133 | t.Error("expected no error") 134 | } 135 | } 136 | 137 | func TestAdmin__BindAddr(t *testing.T) { 138 | svc, err := New(Opts{Addr: ":0"}) 139 | require.NoError(t, err) 140 | 141 | svc.AddHandler("/test/ping", func(w http.ResponseWriter, _ *http.Request) { 142 | w.WriteHeader(http.StatusOK) 143 | }) 144 | 145 | go svc.Listen() 146 | defer svc.Shutdown() 147 | 148 | if v := svc.BindAddr(); v == ":0" { 149 | t.Errorf("BindAddr: %v", v) 150 | } 151 | 152 | resp, err := http.DefaultClient.Get("http://" + svc.BindAddr() + "/test/ping") 153 | if err != nil { 154 | t.Fatal(err) 155 | } 156 | defer resp.Body.Close() 157 | if resp.StatusCode != http.StatusOK { 158 | t.Errorf("bogus HTTP status code: %d", resp.StatusCode) 159 | } 160 | } 161 | 162 | func TestServer_Subrouter(t *testing.T) { 163 | svc, err := New(Opts{Addr: ":0"}) 164 | require.NoError(t, err) 165 | 166 | subrouter := svc.Subrouter("/sub") 167 | subrouter.Use(func(h http.Handler) http.Handler { 168 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 169 | w.Write([]byte("middleware\n")) 170 | h.ServeHTTP(w, r) 171 | }) 172 | }) 173 | subrouter.Path("/test").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 174 | w.Write([]byte("/sub/test")) 175 | }) 176 | go svc.Listen() 177 | defer svc.Shutdown() 178 | 179 | // This request is expected to go through the subrouter with its middleware 180 | resp, err := http.DefaultClient.Get("http://" + svc.BindAddr() + "/sub/test") 181 | require.NoError(t, err) 182 | defer resp.Body.Close() 183 | 184 | assert.Equal(t, http.StatusOK, resp.StatusCode) 185 | body, err := ioutil.ReadAll(resp.Body) 186 | require.NoError(t, err) 187 | assert.Equal(t, "middleware\n/sub/test", string(body)) 188 | 189 | // This request hits the main router, so should not have a path prefix or middleware 190 | liveResponse, err := http.DefaultClient.Get("http://" + svc.BindAddr() + "/live") 191 | require.NoError(t, err) 192 | defer liveResponse.Body.Close() 193 | 194 | assert.Equal(t, http.StatusOK, liveResponse.StatusCode) 195 | liveBody, err := ioutil.ReadAll(liveResponse.Body) 196 | require.NoError(t, err) 197 | assert.NotContains(t, string(liveBody), "middleware") 198 | } 199 | -------------------------------------------------------------------------------- /admin/health.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package admin 6 | 7 | import ( 8 | "encoding/json" 9 | "errors" 10 | "net/http" 11 | "sync" 12 | "time" 13 | ) 14 | 15 | var ( 16 | errTimeout = errors.New("timeout exceeded") 17 | 18 | httpCheckTimeout = 10 * time.Second 19 | ) 20 | 21 | type healthCheck struct { 22 | name string 23 | check func() error 24 | } 25 | 26 | // Error executes the health check and will block until the result returns 27 | func (hc *healthCheck) Error() error { 28 | if hc == nil || hc.check == nil { 29 | return nil 30 | } 31 | return hc.check() 32 | } 33 | 34 | type result struct { 35 | name string 36 | err error 37 | } 38 | 39 | // AddLivenessCheck will register a new health check that is executed on every 40 | // HTTP request of 'GET /live' against the admin server. 41 | // 42 | // Every check will timeout after 10s and return a timeout error. 43 | // 44 | // These checks are designed to be unhealthy only when the application has started but 45 | // a dependency is unreachable or unhealthy. 46 | func (s *Server) AddLivenessCheck(name string, f func() error) { 47 | s.liveChecks = append(s.liveChecks, &healthCheck{ 48 | name: name, 49 | check: f, 50 | }) 51 | } 52 | 53 | func (s *Server) livenessHandler() http.HandlerFunc { 54 | return func(w http.ResponseWriter, r *http.Request) { 55 | results := processChecks(s.liveChecks) 56 | if len(results) == 0 { 57 | w.WriteHeader(http.StatusOK) 58 | return 59 | } 60 | status := http.StatusOK 61 | kv := make(map[string]string) 62 | for i := range results { 63 | if results[i].err != nil { 64 | status = http.StatusBadRequest 65 | kv[results[i].name] = results[i].err.Error() 66 | } else { 67 | kv[results[i].name] = "good" 68 | } 69 | } 70 | bs, _ := json.Marshal(kv) 71 | w.WriteHeader(status) 72 | w.Write(bs) 73 | } 74 | } 75 | 76 | // AddReadinessCheck will register a new health check that is executed on every 77 | // HTTP request of 'GET /ready' against the admin server. 78 | // 79 | // Every check will timeout after 10s and return a timeout error. 80 | // 81 | // These checks are designed to be unhealthy while the application is starting. 82 | func (s *Server) AddReadinessCheck(name string, f func() error) { 83 | s.readyChecks = append(s.readyChecks, &healthCheck{ 84 | name: name, 85 | check: f, 86 | }) 87 | } 88 | 89 | func (s *Server) readinessHandler() http.HandlerFunc { 90 | return func(w http.ResponseWriter, r *http.Request) { 91 | results := processChecks(s.readyChecks) 92 | if len(results) == 0 { 93 | w.WriteHeader(http.StatusOK) 94 | return 95 | } 96 | status := http.StatusOK 97 | kv := make(map[string]string) 98 | for i := range results { 99 | if results[i].err != nil { 100 | status = http.StatusBadRequest 101 | kv[results[i].name] = results[i].err.Error() 102 | } else { 103 | kv[results[i].name] = "good" 104 | } 105 | } 106 | bs, _ := json.Marshal(kv) 107 | w.WriteHeader(status) 108 | w.Write(bs) 109 | } 110 | } 111 | 112 | func processChecks(checks []*healthCheck) []result { 113 | var results []result 114 | var mu sync.Mutex 115 | 116 | var wg sync.WaitGroup 117 | wg.Add(len(checks)) 118 | 119 | for i := range checks { 120 | go func(check *healthCheck) { 121 | defer wg.Done() 122 | err := try(func() error { return check.Error() }, httpCheckTimeout) 123 | mu.Lock() 124 | results = append(results, result{ 125 | name: check.name, 126 | err: err, 127 | }) 128 | mu.Unlock() 129 | }(checks[i]) 130 | } 131 | 132 | wg.Wait() 133 | return results 134 | } 135 | 136 | // try will attempt to call f, but only for as long as t. If the function is still 137 | // processing after t has elapsed then errTimeout will be returned. 138 | func try(f func() error, t time.Duration) error { 139 | answer := make(chan error) 140 | go func() { 141 | answer <- f() 142 | }() 143 | select { 144 | case err := <-answer: 145 | return err 146 | case <-time.After(t): 147 | return errTimeout 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /admin/health_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package admin 6 | 7 | import ( 8 | "encoding/json" 9 | "errors" 10 | "fmt" 11 | "net/http" 12 | "testing" 13 | "time" 14 | 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | func TestHealth_healthCheck(t *testing.T) { 19 | c := &healthCheck{"example", func() error { 20 | return errors.New("example error") 21 | }} 22 | if err := c.Error(); err == nil { 23 | t.Error("expected error") 24 | } 25 | } 26 | 27 | func TestHealth_processChecks(t *testing.T) { 28 | checks := []*healthCheck{ 29 | {"good", func() error { return nil }}, 30 | {"bad", func() error { return errors.New("bad") }}, 31 | } 32 | results := processChecks(checks) 33 | if len(results) != 2 { 34 | t.Fatalf("Got %v", results) 35 | } 36 | for i := range results { 37 | if results[i].name == "good" && results[i].err != nil { 38 | t.Errorf("%q got err=%v", results[i].name, results[i].err) 39 | continue 40 | } 41 | if results[i].name == "bad" && results[i].err.Error() != "bad" { 42 | t.Errorf("%q got err=%v", results[i].name, results[i].err) 43 | continue 44 | } 45 | } 46 | } 47 | 48 | func TestHealth__LiveHTTP(t *testing.T) { 49 | svc, err := New(Opts{ 50 | Addr: ":13993", 51 | }) // hopefully nothing locally has this 52 | require.NoError(t, err) 53 | go svc.Listen() 54 | defer svc.Shutdown() 55 | 56 | // no checks, should be healthy 57 | resp, err := http.DefaultClient.Get("http://localhost:13993/live") 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | if resp.StatusCode != http.StatusOK { 62 | t.Errorf("bogus HTTP status: %s", resp.Status) 63 | } 64 | resp.Body.Close() 65 | 66 | // add a healthy check 67 | svc.AddLivenessCheck("live-good", func() error { 68 | return nil 69 | }) 70 | resp, err = http.DefaultClient.Get("http://localhost:13993/live") 71 | if err != nil { 72 | t.Fatal(err) 73 | } 74 | if resp.StatusCode != http.StatusOK { 75 | t.Errorf("bogus HTTP status: %s", resp.Status) 76 | } 77 | resp.Body.Close() 78 | 79 | // one bad check, should fail 80 | svc.AddLivenessCheck("live-bad", func() error { 81 | return errors.New("unhealthy") 82 | }) 83 | resp, err = http.DefaultClient.Get("http://localhost:13993/live") 84 | if err != nil { 85 | t.Fatal(err) 86 | } 87 | if resp.StatusCode != http.StatusBadRequest { 88 | t.Errorf("bogus HTTP status: %s", resp.Status) 89 | } 90 | defer resp.Body.Close() 91 | 92 | // Read JSON response body 93 | var checks map[string]string 94 | if err := json.NewDecoder(resp.Body).Decode(&checks); err != nil { 95 | t.Fatal(err) 96 | } 97 | if len(checks) != 2 { 98 | t.Errorf("checks: %#v", checks) 99 | } 100 | if v := fmt.Sprintf("%v", checks["live-good"]); v != "good" { 101 | t.Errorf("live-good: %s", v) 102 | } 103 | if v := fmt.Sprintf("%v", checks["live-bad"]); v != "unhealthy" { 104 | t.Errorf("live-bad: %s", v) 105 | } 106 | } 107 | 108 | func TestHealth__ReadyHTTP(t *testing.T) { 109 | svc, err := New(Opts{ 110 | Addr: ":13994", 111 | }) // hopefully nothing locally has this 112 | require.NoError(t, err) 113 | go svc.Listen() 114 | defer svc.Shutdown() 115 | 116 | // no checks, should be healthy 117 | resp, err := http.DefaultClient.Get("http://localhost:13994/ready") 118 | if err != nil { 119 | t.Fatal(err) 120 | } 121 | if resp.StatusCode != http.StatusOK { 122 | t.Errorf("bogus HTTP status: %s", resp.Status) 123 | } 124 | resp.Body.Close() 125 | 126 | // add a healthy check 127 | svc.AddReadinessCheck("ready-good", func() error { 128 | return nil 129 | }) 130 | resp, err = http.DefaultClient.Get("http://localhost:13994/ready") 131 | if err != nil { 132 | t.Fatal(err) 133 | } 134 | if resp.StatusCode != http.StatusOK { 135 | t.Errorf("bogus HTTP status: %s", resp.Status) 136 | } 137 | resp.Body.Close() 138 | 139 | // one bad check, should fail 140 | svc.AddReadinessCheck("ready-bad", func() error { 141 | return errors.New("unhealthy") 142 | }) 143 | resp, err = http.DefaultClient.Get("http://localhost:13994/ready") 144 | if err != nil { 145 | t.Fatal(err) 146 | } 147 | if resp.StatusCode != http.StatusBadRequest { 148 | t.Errorf("bogus HTTP status: %s", resp.Status) 149 | } 150 | defer resp.Body.Close() 151 | 152 | // Read JSON response body 153 | var checks map[string]string 154 | if err := json.NewDecoder(resp.Body).Decode(&checks); err != nil { 155 | t.Fatal(err) 156 | } 157 | if len(checks) != 2 { 158 | t.Errorf("checks: %#v", checks) 159 | } 160 | if v := fmt.Sprintf("%v", checks["ready-good"]); v != "good" { 161 | t.Errorf("ready-good: %s", v) 162 | } 163 | if v := fmt.Sprintf("%v", checks["ready-bad"]); v != "unhealthy" { 164 | t.Errorf("ready-bad: %s", v) 165 | } 166 | } 167 | 168 | func TestHealth_try(t *testing.T) { 169 | // happy path, no timeout 170 | if err := try(func() error { return nil }, 1*time.Second); err != nil { 171 | t.Error("expected no error") 172 | } 173 | 174 | // error returned, no timeout 175 | if err := try(func() error { return errors.New("error") }, 1*time.Second); err == nil { 176 | t.Error("expected error, got none") 177 | } else { 178 | if err.Error() != "error" { 179 | t.Errorf("got %v", err) 180 | } 181 | } 182 | 183 | // timeout 184 | f := func() error { 185 | time.Sleep(1 * time.Second) 186 | return errors.New("after sleep") 187 | } 188 | if err := try(f, 10*time.Millisecond); err == nil { 189 | t.Errorf("expected (timeout) error, got none") 190 | } else { 191 | if err != errTimeout { 192 | t.Errorf("unknown error: %v", err) 193 | } 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /admin/pprof_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | package admin 5 | 6 | import ( 7 | "fmt" 8 | "testing" 9 | ) 10 | 11 | func TestAdmin__profileEnabled(t *testing.T) { 12 | cases := map[string]bool{ 13 | // enable 14 | "yes": true, 15 | " true ": true, 16 | "": true, 17 | // disable 18 | "no": false, 19 | "jsadlsaj": false, 20 | } 21 | for value, enabled := range cases { 22 | t.Setenv("PPROF_TESTING_VALUE", fmt.Sprintf("%v", enabled)) 23 | 24 | if v := profileEnabled("TESTING_VALUE"); v != enabled { 25 | t.Errorf("value=%q, got=%v, expected=%v", value, v, enabled) 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /api/common.yaml: -------------------------------------------------------------------------------- 1 | openapi: 3.0.2 2 | info: 3 | x-fragment: true 4 | title: An include file to define common schemas. 5 | version: 1.0.0 6 | 7 | paths: {} 8 | 9 | components: 10 | schemas: 11 | Error: 12 | required: 13 | - error 14 | properties: 15 | error: 16 | type: string 17 | description: An error message describing the problem intended for humans. 18 | example: Example error, see description 19 | -------------------------------------------------------------------------------- /build/log.go: -------------------------------------------------------------------------------- 1 | package build 2 | 3 | import ( 4 | "runtime/debug" 5 | "strings" 6 | 7 | "github.com/moov-io/base/log" 8 | ) 9 | 10 | func Log(logger log.Logger) { 11 | info, ok := debug.ReadBuildInfo() 12 | if info == nil || !ok { 13 | logger.Error().Log("unable to read build info, pleasure ensure go module support") 14 | return 15 | } 16 | 17 | logger = logger.With(log.Fields{ 18 | "build_path": log.String(info.Path), 19 | "build_go_version": log.String(info.GoVersion), 20 | }) 21 | 22 | for _, mod := range info.Deps { 23 | mod = runningModule(mod) 24 | 25 | if strings.Contains(strings.ToLower(mod.Path), "/moov") { 26 | logger.With(log.Fields{ 27 | "build_mod_path": log.String(mod.Path), 28 | "build_mod_version": log.String(mod.Version), 29 | }).Log("") 30 | } 31 | } 32 | } 33 | 34 | // Recurse through all the replaces to find whats actually running 35 | func runningModule(mod *debug.Module) *debug.Module { 36 | if mod.Replace != nil { 37 | return runningModule(mod.Replace) 38 | } else { 39 | return mod 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /build/log_test.go: -------------------------------------------------------------------------------- 1 | package build_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/moov-io/base/build" 7 | "github.com/moov-io/base/log" 8 | ) 9 | 10 | func Test_LogDeps(t *testing.T) { 11 | _, logger := log.NewBufferLogger() 12 | 13 | // Running it purely to make sure it doesn't panic as it requires a compiled binary to work. 14 | build.Log(logger) 15 | } 16 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "io/fs" 7 | "os" 8 | "reflect" 9 | "regexp" 10 | "strings" 11 | 12 | "github.com/moov-io/base/log" 13 | 14 | "github.com/go-viper/mapstructure/v2" 15 | "github.com/markbates/pkger" 16 | "github.com/spf13/viper" 17 | ) 18 | 19 | const APP_CONFIG = "APP_CONFIG" 20 | const APP_CONFIG_SECRETS = "APP_CONFIG_SECRETS" //nolint:gosec 21 | 22 | type Service struct { 23 | logger log.Logger 24 | } 25 | 26 | func NewService(logger log.Logger) Service { 27 | return Service{ 28 | logger: logger.Set("component", log.String("Service")), 29 | } 30 | } 31 | 32 | func (s *Service) Load(config interface{}) error { 33 | if err := s.LoadFile(pkger.Include("/configs/config.default.yml"), config); err != nil { 34 | return err 35 | } 36 | 37 | return s.MergeEnvironments(config) 38 | } 39 | 40 | func (s *Service) LoadFromFS(config interface{}, fs fs.FS) error { 41 | if err := s.LoadEmbeddedFile("configs/config.default.yml", config, fs); err != nil { 42 | return err 43 | } 44 | 45 | return s.MergeEnvironments(config) 46 | } 47 | 48 | func (s *Service) MergeEnvironments(config interface{}) error { 49 | v := viper.New() 50 | v.SetConfigType("yaml") 51 | 52 | if err := LoadEnvironmentFile(s.logger, APP_CONFIG, v); err != nil { 53 | return err 54 | } 55 | 56 | if err := LoadEnvironmentFile(s.logger, APP_CONFIG_SECRETS, v); err != nil { 57 | return err 58 | } 59 | 60 | return v.UnmarshalExact(config, overwriteConfig) 61 | } 62 | 63 | func (s *Service) LoadFile(file string, config interface{}) error { 64 | logger := s.logger.Set("file", log.String(file)) 65 | logger.Info().Logf("loading config file") 66 | 67 | f, err := pkger.Open(file) 68 | if err != nil { 69 | return logger.LogErrorf("pkger unable to load %s: %w", file, err).Err() 70 | } 71 | 72 | if err := configFromReader(config, f); err != nil { 73 | return logger.LogError(err).Err() 74 | } 75 | 76 | return nil 77 | } 78 | 79 | func (s *Service) LoadEmbeddedFile(file string, config interface{}, fs fs.FS) error { 80 | logger := s.logger.Set("file", log.String(file)) 81 | logger.Info().Logf("loading config file") 82 | 83 | f, err := fs.Open(file) 84 | if err != nil { 85 | return logger.LogErrorf("go:embed FS unable to load %s: %w", file, err).Err() 86 | } 87 | 88 | if err := configFromReader(config, f); err != nil { 89 | return logger.LogError(err).Err() 90 | } 91 | 92 | return nil 93 | } 94 | 95 | func configFromReader(config interface{}, f io.Reader) error { 96 | deflt := viper.New() 97 | deflt.SetConfigType("yaml") 98 | if err := deflt.ReadConfig(f); err != nil { 99 | return fmt.Errorf("unable to load the defaults: %w", err) 100 | } 101 | 102 | if err := deflt.UnmarshalExact(config, overwriteConfig); err != nil { 103 | return fmt.Errorf("unable to unmarshal the defaults: %w", err) 104 | } 105 | 106 | return nil 107 | } 108 | 109 | func LoadEnvironmentFile(logger log.Logger, envVar string, v *viper.Viper) error { 110 | if file, ok := os.LookupEnv(envVar); ok && strings.TrimSpace(file) != "" { 111 | 112 | logger := logger.Set(envVar, log.String(file)) 113 | logger.Info().Logf("Loading %s config file", envVar) 114 | 115 | logger = logger.Set("file", log.String(file)) 116 | logger.Info().Logf("loading config file") 117 | 118 | v.SetConfigFile(file) 119 | 120 | if err := v.MergeInConfig(); err != nil { 121 | return logger.LogErrorf("merging config failed: %w", err).Err() 122 | } 123 | } 124 | 125 | return nil 126 | } 127 | 128 | func overwriteConfig(cfg *mapstructure.DecoderConfig) { 129 | cfg.ErrorUnused = true 130 | cfg.ZeroFields = true 131 | 132 | cfg.DecodeHook = mapstructure.ComposeDecodeHookFunc(decodeRegexHook, mapstructure.StringToTimeDurationHookFunc()) 133 | } 134 | 135 | func decodeRegexHook(t1 reflect.Type, t2 reflect.Type, value interface{}) (interface{}, error) { 136 | decodingRegex := t2.String() == "regexp.Regexp" 137 | if decodingRegex { 138 | if stringValue, ok := value.(string); ok { 139 | return regexp.Compile(stringValue) 140 | } 141 | } 142 | return value, nil 143 | } 144 | -------------------------------------------------------------------------------- /config/config_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "path/filepath" 5 | "regexp" 6 | "testing" 7 | "time" 8 | 9 | "github.com/moov-io/base" 10 | "github.com/moov-io/base/config" 11 | "github.com/moov-io/base/log" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | type GlobalConfigModel struct { 16 | Config ConfigModel 17 | } 18 | 19 | type ConfigModel struct { 20 | Default string 21 | App string 22 | Secret string 23 | Values []string 24 | Zero string 25 | 26 | Widgets map[string]Widget 27 | 28 | Search SearchConfig 29 | Security SecurityConfig 30 | } 31 | 32 | type Widget struct { 33 | Name string 34 | Credentials Credentials 35 | Nested Nested 36 | } 37 | 38 | type Credentials struct { 39 | Username string 40 | Password string 41 | } 42 | 43 | type Nested struct { 44 | Nested2 Nested2 45 | } 46 | 47 | type Nested2 struct { 48 | Nested3 Nested3 49 | } 50 | 51 | type Nested3 struct { 52 | Value string 53 | } 54 | 55 | type SearchConfig struct { 56 | Patterns []*regexp.Regexp 57 | 58 | MaxResults int 59 | Timeout time.Duration 60 | } 61 | 62 | type SecurityConfig struct { 63 | Audience []string `mapstructure:"x-audience"` 64 | Cluster string `mapstructure:"x-cluster"` 65 | Service string `mapstructure:"x-service"` 66 | } 67 | 68 | func Test_Load(t *testing.T) { 69 | t.Setenv(config.APP_CONFIG, filepath.Join("..", "configs", "config.app.yml")) 70 | t.Setenv(config.APP_CONFIG_SECRETS, filepath.Join("..", "configs", "config.secrets.yml")) 71 | 72 | cfg := &GlobalConfigModel{} 73 | 74 | service := config.NewService(log.NewDefaultLogger()) 75 | err := service.Load(cfg) 76 | require.Nil(t, err) 77 | 78 | require.Equal(t, "default", cfg.Config.Default) 79 | require.Equal(t, "app", cfg.Config.App) 80 | require.Equal(t, "keep secret!", cfg.Config.Secret) 81 | 82 | // This test documents some unexpected behavior where slices are merged and not overwritten. 83 | // Slices in secrets should overwrite the slice in app and default. 84 | require.Len(t, cfg.Config.Values, 1) 85 | require.Equal(t, "secret", cfg.Config.Values[0]) 86 | 87 | require.Equal(t, "", cfg.Config.Zero) 88 | 89 | // Verify attempting to load from our default file errors on extra fields 90 | cfg = &GlobalConfigModel{} 91 | err = service.LoadFile("/configs/config.extra.yml", &cfg) 92 | require.NotNil(t, err) 93 | require.Contains(t, err.Error(), `'Config' has invalid keys: extra`) 94 | 95 | // Verify attempting to load additional fields via env vars errors out 96 | t.Setenv(config.APP_CONFIG, filepath.Join("..", "configs", "config.extra.yml")) 97 | cfg = &GlobalConfigModel{} 98 | err = service.Load(cfg) 99 | require.NotNil(t, err) 100 | require.Contains(t, err.Error(), `'Config' has invalid keys: extra`) 101 | } 102 | 103 | func Test_Embedded_Load(t *testing.T) { 104 | t.Setenv(config.APP_CONFIG, filepath.Join("..", "configs", "config.app.yml")) 105 | t.Setenv(config.APP_CONFIG_SECRETS, filepath.Join("..", "configs", "config.secrets.yml")) 106 | 107 | cfg := &GlobalConfigModel{} 108 | 109 | service := config.NewService(log.NewDefaultLogger()) 110 | err := service.LoadFromFS(cfg, base.ConfigDefaults) 111 | require.Nil(t, err) 112 | 113 | require.Equal(t, "default", cfg.Config.Default) 114 | require.Equal(t, "app", cfg.Config.App) 115 | require.Equal(t, "keep secret!", cfg.Config.Secret) 116 | 117 | // This test documents some unexpected behavior where slices are merged and not overwritten. 118 | // Slices in secrets should overwrite the slice in app and default. 119 | require.Len(t, cfg.Config.Values, 1) 120 | require.Equal(t, "secret", cfg.Config.Values[0]) 121 | 122 | require.Equal(t, "", cfg.Config.Zero) 123 | 124 | // Verify attempting to load from our default file errors on extra fields 125 | cfg = &GlobalConfigModel{} 126 | err = service.LoadFile("/configs/config.extra.yml", &cfg) 127 | require.NotNil(t, err) 128 | require.Contains(t, err.Error(), `'Config' has invalid keys: extra`) 129 | 130 | // Verify attempting to load additional fields via env vars errors out 131 | t.Setenv(config.APP_CONFIG, filepath.Join("..", "configs", "config.extra.yml")) 132 | cfg = &GlobalConfigModel{} 133 | err = service.Load(cfg) 134 | require.NotNil(t, err) 135 | require.Contains(t, err.Error(), `'Config' has invalid keys: extra`) 136 | } 137 | 138 | func Test_WidgetsConfig(t *testing.T) { 139 | t.Setenv(config.APP_CONFIG, filepath.Join("testdata", "with-widgets.yml")) 140 | t.Setenv(config.APP_CONFIG_SECRETS, filepath.Join("testdata", "with-widget-secrets.yml")) 141 | 142 | cfg := &GlobalConfigModel{} 143 | 144 | service := config.NewService(log.NewDefaultLogger()) 145 | err := service.LoadFromFS(cfg, base.ConfigDefaults) 146 | require.Nil(t, err) 147 | 148 | w, ok := cfg.Config.Widgets["aaa"] 149 | require.True(t, ok) 150 | 151 | require.Equal(t, "aaa", w.Name) 152 | require.Equal(t, "u1", w.Credentials.Username) 153 | require.Equal(t, "p2", w.Credentials.Password) 154 | require.Equal(t, "v1", w.Nested.Nested2.Nested3.Value) 155 | } 156 | 157 | func Test_SearchAndSecurityConfig(t *testing.T) { 158 | t.Setenv(config.APP_CONFIG, filepath.Join("testdata", "with-search-and-security.yml")) 159 | t.Setenv(config.APP_CONFIG_SECRETS, "") 160 | 161 | cfg := &GlobalConfigModel{} 162 | 163 | service := config.NewService(log.NewDefaultLogger()) 164 | err := service.LoadFromFS(cfg, base.ConfigDefaults) 165 | require.Nil(t, err) 166 | 167 | // Search 168 | patterns := cfg.Config.Search.Patterns 169 | require.Len(t, patterns, 1) 170 | require.Equal(t, "a(b+)c", patterns[0].String()) 171 | 172 | require.Equal(t, 100, cfg.Config.Search.MaxResults) 173 | require.Equal(t, 30*time.Second, cfg.Config.Search.Timeout) 174 | 175 | // Security 176 | require.Len(t, cfg.Config.Security.Audience, 1) 177 | require.Equal(t, "platform", cfg.Config.Security.Cluster) 178 | require.Equal(t, "roles", cfg.Config.Security.Service) 179 | } 180 | -------------------------------------------------------------------------------- /config/testdata/with-search-and-security.yml: -------------------------------------------------------------------------------- 1 | Config: 2 | Search: 3 | Patterns: 4 | - "a(b+)c" 5 | MaxResults: 100 6 | Timeout: 30s 7 | Security: 8 | x-audience: 9 | - service:cards 10 | x-cluster: platform 11 | x-service: roles 12 | -------------------------------------------------------------------------------- /config/testdata/with-widget-secrets.yml: -------------------------------------------------------------------------------- 1 | Config: 2 | Widgets: 3 | AAA: 4 | Credentials: 5 | Password: "p2" 6 | Nested: 7 | Nested2: 8 | Nested3: 9 | Value: "v1" 10 | -------------------------------------------------------------------------------- /config/testdata/with-widgets.yml: -------------------------------------------------------------------------------- 1 | Config: 2 | Widgets: 3 | AAA: 4 | Name: "aaa" 5 | Credentials: 6 | Username: "u1" 7 | Password: "p1" 8 | -------------------------------------------------------------------------------- /configs/config.app.yml: -------------------------------------------------------------------------------- 1 | Config: 2 | App: "app" 3 | Values: 4 | - "app" 5 | - "test2" 6 | Zero: "" 7 | -------------------------------------------------------------------------------- /configs/config.default.yml: -------------------------------------------------------------------------------- 1 | Config: 2 | Default: "default" 3 | Values: 4 | - "default" 5 | Zero: "hero" 6 | -------------------------------------------------------------------------------- /configs/config.extra.yml: -------------------------------------------------------------------------------- 1 | Config: 2 | App: "app" 3 | Extra: 4 | Object: "value" 5 | -------------------------------------------------------------------------------- /configs/config.secrets.yml: -------------------------------------------------------------------------------- 1 | Config: 2 | Secret: "keep secret!" 3 | Values: 4 | - "secret" 5 | -------------------------------------------------------------------------------- /database/database.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package database 6 | 7 | import ( 8 | "context" 9 | "database/sql" 10 | "fmt" 11 | 12 | "github.com/moov-io/base/log" 13 | ) 14 | 15 | // New establishes a database connection according to the type and environmental 16 | // variables for that specific database. 17 | func New(ctx context.Context, logger log.Logger, config DatabaseConfig) (*sql.DB, error) { 18 | if config.MySQL != nil { 19 | preppedDb, err := mysqlConnection(logger, config.MySQL, config.DatabaseName) 20 | if err != nil { 21 | return nil, fmt.Errorf("configuring mysql connection: %v", err) 22 | } 23 | 24 | db, err := preppedDb.Connect(ctx) 25 | if err != nil { 26 | return nil, fmt.Errorf("connecting to mysql: %w", err) 27 | } 28 | 29 | return ApplyConnectionsConfig(db, &config.MySQL.Connections, logger), nil 30 | 31 | } else if config.Spanner != nil { 32 | db, err := spannerConnection(logger, *config.Spanner, config.DatabaseName) 33 | if err != nil { 34 | return nil, fmt.Errorf("connecting to spanner: %w", err) 35 | } 36 | return db, nil 37 | } else if config.Postgres != nil { 38 | db, err := postgresConnection(ctx, logger, *config.Postgres, config.DatabaseName) 39 | if err != nil { 40 | return nil, fmt.Errorf("connecting to postgres: %w", err) 41 | } 42 | return ApplyConnectionsConfig(db, &config.Postgres.Connections, logger), nil 43 | } 44 | 45 | return nil, fmt.Errorf("database config not defined") 46 | } 47 | 48 | func NewAndMigrate(ctx context.Context, logger log.Logger, config DatabaseConfig, opts ...MigrateOption) (*sql.DB, error) { 49 | if logger == nil { 50 | logger = log.NewNopLogger() 51 | } 52 | 53 | if ctx == nil { 54 | ctx = context.Background() 55 | } 56 | 57 | // run migrations first 58 | if err := RunMigrations(logger, config, opts...); err != nil { 59 | return nil, err 60 | } 61 | 62 | // create DB connection for our service 63 | db, err := New(ctx, logger, config) 64 | if err != nil { 65 | return nil, err 66 | } 67 | 68 | return db, nil 69 | } 70 | 71 | // UniqueViolation returns true when the provided error matches a database error 72 | // for duplicate entries (violating a unique table constraint). 73 | func UniqueViolation(err error) bool { 74 | return MySQLUniqueViolation(err) || SpannerUniqueViolation(err) || PostgresUniqueViolation(err) 75 | } 76 | 77 | func DataTooLong(err error) bool { 78 | return MySQLDataTooLong(err) 79 | } 80 | 81 | func DeadlockFound(err error) bool { 82 | return MySQLDeadlockFound(err) || PostgresDeadlockFound(err) 83 | } 84 | 85 | func ApplyConnectionsConfig(db *sql.DB, connections *ConnectionsConfig, logger log.Logger) *sql.DB { 86 | if connections.MaxOpen > 0 { 87 | logger.Logf("setting SQL max open connections to %d", connections.MaxOpen) 88 | db.SetMaxOpenConns(connections.MaxOpen) 89 | } 90 | 91 | if connections.MaxIdle > 0 { 92 | logger.Logf("setting SQL max idle connections to %d", connections.MaxIdle) 93 | db.SetMaxIdleConns(connections.MaxIdle) 94 | } 95 | 96 | // Due to a known issue https://github.com/golang/go/issues/45993#issuecomment-1427873850, 97 | // maxIdleTime must be specified before MaxLifetime or else it will not be honored. 98 | if connections.MaxIdleTime > 0 { 99 | logger.Logf("setting SQL max idle time to %v", connections.MaxIdleTime) 100 | db.SetConnMaxIdleTime(connections.MaxIdleTime) 101 | } 102 | 103 | if connections.MaxLifetime > 0 { 104 | logger.Logf("setting SQL max lifetime to %v", connections.MaxLifetime) 105 | db.SetConnMaxLifetime(connections.MaxLifetime) 106 | } 107 | 108 | return db 109 | } 110 | -------------------------------------------------------------------------------- /database/database_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | package database_test 5 | 6 | import ( 7 | "bytes" 8 | "errors" 9 | "os" 10 | "testing" 11 | 12 | gomysql "github.com/go-sql-driver/mysql" 13 | "github.com/jackc/pgx/v5/pgconn" 14 | "github.com/moov-io/base/database" 15 | 16 | "github.com/stretchr/testify/require" 17 | ) 18 | 19 | func TestUniqueViolation(t *testing.T) { 20 | // mysql 21 | mysqlErr := errors.New(`problem upserting depository="282f6ffcd9ba5b029afbf2b739ee826e22d9df3b", userId="f25f48968da47ef1adb5b6531a1c2197295678ce": Error 1062 (23000): Duplicate entry '282f6ffcd9ba5b029afbf2b739ee826e22d9df3b' for key 'PRIMARY'`) 22 | if !database.UniqueViolation(mysqlErr) { 23 | t.Error("should have matched mysql unique violation") 24 | } 25 | gomysqlErr := &gomysql.MySQLError{ 26 | Number: 1062, 27 | } 28 | if !database.UniqueViolation(gomysqlErr) { 29 | t.Error("should have matched go mysql driver unique violation") 30 | } 31 | 32 | // postgres 33 | psqlErr := errors.New(`problem upserting depository="282f6ffcd9ba5b029afbf2b739ee826e22d9df3b", userId="f25f48968da47ef1adb5b6531a1c2197295678ce": ERROR: duplicate key value violates unique constraint "depository" (SQLSTATE 23505)`) 34 | if !database.UniqueViolation(psqlErr) { 35 | t.Error("should have matched postgres unique violation") 36 | } 37 | pgconnErr := &pgconn.PgError{ 38 | Code: "23505", 39 | } 40 | if !database.UniqueViolation(pgconnErr) { 41 | t.Error("should have matched PgError unique violation") 42 | } 43 | 44 | // no violation 45 | noViolationErr := errors.New(`problem upserting depository="282f6ffcd9ba5b029afbf2b739ee826e22d9df3b", userId="f25f48968da47ef1adb5b6531a1c2197295678ce": Error 1061 (23000): Something went wrong`) 46 | if database.UniqueViolation(noViolationErr) { 47 | t.Error("should not have matched unique violation") 48 | } 49 | gomysqlErr.Number = 1061 50 | if database.UniqueViolation(gomysqlErr) { 51 | t.Error("should not have matched go mysql driver unique violation") 52 | } 53 | pgconnErr.Code = "23504" 54 | if database.UniqueViolation(pgconnErr) { 55 | t.Error("should not have matched PgError unique violation") 56 | } 57 | } 58 | 59 | func TestDeadlockFound(t *testing.T) { 60 | // mysql 61 | mysqlErr := errors.New(`problem upserting depository="282f6ffcd9ba5b029afbf2b739ee826e22d9df3b", userId="f25f48968da47ef1adb5b6531a1c2197295678ce": Error 1213 (40001): Deadlock found when trying to get lock; try restarting transaction`) 62 | if !database.DeadlockFound(mysqlErr) { 63 | t.Error("should have matched mysql deadlock found") 64 | } 65 | gomysqlErr := &gomysql.MySQLError{ 66 | Number: 1213, 67 | } 68 | if !database.DeadlockFound(gomysqlErr) { 69 | t.Error("should have matched go mysql driver deadlock found") 70 | } 71 | 72 | // postgres 73 | psqlErr := errors.New(`problem upserting depository="282f6ffcd9ba5b029afbf2b739ee826e22d9df3b", userId="f25f48968da47ef1adb5b6531a1c2197295678ce": ERROR: deadlock detected (SQLSTATE 40P01)`) 74 | if !database.DeadlockFound(psqlErr) { 75 | t.Error("should have matched postgres deadlock found") 76 | } 77 | pgconnErr := &pgconn.PgError{ 78 | Code: "40P01", 79 | } 80 | if !database.DeadlockFound(pgconnErr) { 81 | t.Error("should have matched PgError deadlock found") 82 | } 83 | 84 | // no deadlock found 85 | noDeadlockErr := errors.New(`problem upserting depository="282f6ffcd9ba5b029afbf2b739ee826e22d9df3b", userId="f25f48968da47ef1adb5b6531a1c2197295678ce": Error 1061 (23000): Something went wrong`) 86 | if database.DeadlockFound(noDeadlockErr) { 87 | t.Error("should not have matched deadlock found") 88 | } 89 | 90 | gomysqlErr.Number = 1231 91 | if database.DeadlockFound(gomysqlErr) { 92 | t.Error("should not have matched go mysql driver deadlock found") 93 | } 94 | pgconnErr.Code = "40P02" 95 | if database.DeadlockFound(pgconnErr) { 96 | t.Error("should not have matched PgError deadlock found") 97 | } 98 | } 99 | 100 | func TestDataTooLong(t *testing.T) { 101 | // mysql 102 | mysqlErr := errors.New(`problem upserting depository="282f6ffcd9ba5b029afbf2b739ee826e22d9df3b", userId="f25f48968da47ef1adb5b6531a1c2197295678ce": Error 1406 (22001): Data too long for column 'depository' at row 1`) 103 | if !database.DataTooLong(mysqlErr) { 104 | t.Error("should have matched mysql data too long") 105 | } 106 | gomysqlErr := &gomysql.MySQLError{ 107 | Number: 1406, 108 | } 109 | if !database.DataTooLong(gomysqlErr) { 110 | t.Error("should have matched go mysql driver data too long") 111 | } 112 | 113 | // no data too long 114 | noDataTooLongErr := errors.New(`problem upserting depository="282f6ffcd9ba5b029afbf2b739ee826e22d9df3b", userId="f25f48968da47ef1adb5b6531a1c2197295678ce": Error 1062 (23000): Something went wrong`) 115 | if database.DataTooLong(noDataTooLongErr) { 116 | t.Error("should not have mysql matched data too long") 117 | } 118 | gomysqlErr.Number = 1062 119 | if database.DataTooLong(gomysqlErr) { 120 | t.Error("should not have matched go mysql driver data too long") 121 | } 122 | } 123 | 124 | func TestConnectionsConfigOrder(t *testing.T) { 125 | bs, err := os.ReadFile("database.go") 126 | require.NoError(t, err) 127 | 128 | // SetConnMaxIdleTime must be specified first 129 | // See: https://github.com/golang/go/issues/45993#issuecomment-1427873850 130 | maxIdleTimeIdx := bytes.Index(bs, []byte("db.SetConnMaxIdleTime")) 131 | maxLifetimeIdx := bytes.Index(bs, []byte("db.SetConnMaxLifetime")) 132 | 133 | if maxIdleTimeIdx > maxLifetimeIdx { 134 | t.Error(".SetConnMaxIdleTime must come first") 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /database/error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | package database 5 | 6 | import ( 7 | "fmt" 8 | ) 9 | 10 | // ErrOpenConnections describes the number of open connections that should have been closed by a call to Close(). 11 | // All queries/transactions should call Close() to prevent unused, open connections. 12 | type ErrOpenConnections struct { 13 | Database string 14 | NumConnections int 15 | } 16 | 17 | func (e ErrOpenConnections) Error() string { 18 | return fmt.Sprintf("found %d open connection(s) in %s", e.NumConnections, e.Database) 19 | } 20 | -------------------------------------------------------------------------------- /database/migrator.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | package database 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "fmt" 10 | "io/fs" 11 | "sync" 12 | "time" 13 | 14 | "github.com/moov-io/base/log" 15 | "github.com/moov-io/base/telemetry" 16 | 17 | "github.com/golang-migrate/migrate/v4" 18 | "github.com/golang-migrate/migrate/v4/database" 19 | migmysql "github.com/golang-migrate/migrate/v4/database/mysql" 20 | migpostgres "github.com/golang-migrate/migrate/v4/database/postgres" 21 | "github.com/golang-migrate/migrate/v4/source" 22 | "github.com/golang-migrate/migrate/v4/source/iofs" 23 | 24 | "go.opentelemetry.io/otel/attribute" 25 | "go.opentelemetry.io/otel/trace" 26 | ) 27 | 28 | var migrationMutex sync.Mutex 29 | 30 | func RunMigrations(logger log.Logger, config DatabaseConfig, opts ...MigrateOption) error { 31 | return RunMigrationsContext(context.Background(), logger, config, opts...) 32 | } 33 | 34 | func RunMigrationsContext(ctx context.Context, logger log.Logger, config DatabaseConfig, opts ...MigrateOption) error { 35 | _, span := telemetry.StartSpan(ctx, "run-migrations", trace.WithAttributes( 36 | attribute.String("db.database_name", config.DatabaseName), 37 | )) 38 | defer span.End() 39 | 40 | logger.Info().Log("Running Migrations") 41 | 42 | // apply all of our optional arguments 43 | o := &migrateOptions{} 44 | for _, opt := range opts { 45 | if err := opt(o); err != nil { 46 | return err 47 | } 48 | } 49 | 50 | source, driver, err := getDriver(logger, config, o) 51 | if err != nil { 52 | return err 53 | } 54 | defer driver.Close() 55 | 56 | migrationMutex.Lock() 57 | m, err := migrate.NewWithInstance( 58 | source.name, 59 | source, 60 | config.DatabaseName, 61 | driver, 62 | ) 63 | if err != nil { 64 | return logger.Fatal().LogErrorf("Error running migration: %w", err).Err() 65 | } 66 | 67 | if o.timeout != nil { 68 | m.LockTimeout = *o.timeout 69 | } 70 | 71 | previousVersion, dirty, err := m.Version() 72 | if err != nil { 73 | if err != migrate.ErrNilVersion { 74 | return logger.Fatal().LogErrorf("Error getting current DB version: %w", err).Err() 75 | } 76 | // set sane values 77 | previousVersion = 0 78 | dirty = false 79 | } 80 | span.SetAttributes(attribute.Int64("db.previous_version", int64(previousVersion))) //nolint:gosec 81 | 82 | err = m.Up() 83 | migrationMutex.Unlock() 84 | 85 | switch err { 86 | case nil: 87 | case migrate.ErrNoChange: 88 | logger.Info().Logf("Database already at version %d (dirty: %v)", previousVersion, dirty) 89 | default: 90 | return logger.Fatal().LogErrorf("Error running migrations (current: %d, dirty: %v): %w", previousVersion, dirty, err).Err() 91 | } 92 | 93 | newVersion, newDirty, err := m.Version() 94 | if err != nil { 95 | if err != migrate.ErrNilVersion { 96 | return logger.Fatal().LogErrorf("Error getting new DB version: %w", err).Err() 97 | } 98 | // set sane values 99 | newVersion = 0 100 | newDirty = false 101 | } 102 | span.SetAttributes(attribute.Int64("db.new_version", int64(newVersion))) //nolint:gosec 103 | 104 | logger.Info().Logf("Migrations complete: previous: %d (dirty:%v) -> new: %d (dirty:%v)", previousVersion, dirty, newVersion, newDirty) 105 | 106 | return nil 107 | } 108 | 109 | // Deprecated: Here to not break compatibility since it was once public. 110 | func GetDriver(logger log.Logger, config DatabaseConfig) (source.Driver, database.Driver, error) { 111 | return getDriver(logger, config, &migrateOptions{}) 112 | } 113 | 114 | func getDriver(logger log.Logger, config DatabaseConfig, opts *migrateOptions) (*SourceDriver, database.Driver, error) { 115 | var err error 116 | 117 | if config.MySQL != nil { 118 | if opts.source == nil { 119 | src, err := NewPkgerSource("mysql", true) 120 | if err != nil { 121 | return nil, nil, err 122 | } 123 | opts.source = &SourceDriver{ 124 | name: "pkger-mysql", 125 | Driver: src, 126 | } 127 | } 128 | 129 | if opts.driver == nil { 130 | db, err := New(context.Background(), logger, config) 131 | if err != nil { 132 | return nil, nil, err 133 | } 134 | 135 | opts.driver, err = MySQLDriver(db) 136 | if err != nil { 137 | return nil, nil, err 138 | } 139 | } 140 | 141 | } else if config.Spanner != nil { 142 | if opts.source == nil { 143 | src, err := NewPkgerSource("spanner", false) 144 | if err != nil { 145 | return nil, nil, err 146 | } 147 | opts.source = &SourceDriver{ 148 | name: "pkger-spanner", 149 | Driver: src, 150 | } 151 | } 152 | 153 | if opts.driver == nil { 154 | opts.driver, err = SpannerDriver(config) 155 | if err != nil { 156 | return nil, nil, err 157 | } 158 | } 159 | } else if config.Postgres != nil { 160 | if opts.source == nil { 161 | src, err := NewPkgerSource("postgres", false) 162 | if err != nil { 163 | return nil, nil, err 164 | } 165 | opts.source = &SourceDriver{ 166 | name: "pkger-postgres", 167 | Driver: src, 168 | } 169 | } 170 | 171 | if opts.driver == nil { 172 | db, err := New(context.Background(), logger, config) 173 | if err != nil { 174 | return nil, nil, err 175 | } 176 | 177 | opts.driver, err = PostgresDriver(db) 178 | if err != nil { 179 | return nil, nil, err 180 | } 181 | } 182 | } 183 | 184 | if opts.source == nil || opts.driver == nil { 185 | return nil, nil, fmt.Errorf("database config not defined") 186 | } 187 | 188 | return opts.source, opts.driver, nil 189 | } 190 | 191 | func MySQLDriver(db *sql.DB) (database.Driver, error) { 192 | return migmysql.WithInstance(db, &migmysql.Config{}) 193 | } 194 | 195 | func SpannerDriver(config DatabaseConfig) (database.Driver, error) { 196 | return SpannerMigrationDriver(*config.Spanner, config.DatabaseName) 197 | } 198 | 199 | func PostgresDriver(db *sql.DB) (database.Driver, error) { 200 | return migpostgres.WithInstance(db, &migpostgres.Config{}) 201 | } 202 | 203 | type MigrateOption func(o *migrateOptions) error 204 | 205 | type SourceDriver struct { 206 | name string 207 | source.Driver 208 | } 209 | 210 | type migrateOptions struct { 211 | source *SourceDriver 212 | driver database.Driver 213 | 214 | timeout *time.Duration 215 | } 216 | 217 | func WithEmbeddedMigrations(f fs.FS) MigrateOption { 218 | return func(o *migrateOptions) error { 219 | src, err := iofs.New(f, "migrations") 220 | if err != nil { 221 | return err 222 | } 223 | o.source = &SourceDriver{ 224 | name: "embedded", 225 | Driver: src, 226 | } 227 | return nil 228 | } 229 | } 230 | 231 | func WithTimeout(dur time.Duration) MigrateOption { 232 | return func(o *migrateOptions) error { 233 | o.timeout = &dur 234 | return nil 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /database/model_config.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | package database 5 | 6 | import ( 7 | "encoding/json" 8 | "time" 9 | 10 | "github.com/moov-io/base/mask" 11 | ) 12 | 13 | type DatabaseConfig struct { 14 | MySQL *MySQLConfig 15 | Spanner *SpannerConfig 16 | Postgres *PostgresConfig 17 | DatabaseName string 18 | } 19 | 20 | type SpannerConfig struct { 21 | Project string 22 | Instance string 23 | 24 | DisableCleanStatements bool 25 | } 26 | 27 | type PostgresConfig struct { 28 | Address string 29 | User string 30 | Password string 31 | Connections ConnectionsConfig 32 | TLS *PostgresTLSConfig 33 | Alloy *PostgresAlloyConfig 34 | } 35 | 36 | type PostgresTLSConfig struct { 37 | CACertFile string 38 | ClientKeyFile string 39 | ClientCertFile string 40 | } 41 | 42 | type PostgresAlloyConfig struct { 43 | InstanceURI string 44 | UseIAM bool 45 | UsePSC bool 46 | } 47 | 48 | type MySQLConfig struct { 49 | Address string 50 | User string 51 | Password string 52 | Connections ConnectionsConfig 53 | UseTLS bool 54 | TLSCAFile string 55 | VerifyCAFile bool 56 | TLSClientCerts []TLSClientCertConfig 57 | 58 | // InsecureSkipVerify is a dangerous option which should be used with extreme caution. 59 | // This setting disables multiple security checks performed with TLS connections. 60 | InsecureSkipVerify bool 61 | } 62 | 63 | type TLSClientCertConfig struct { 64 | CertFilePath string 65 | KeyFilePath string 66 | } 67 | 68 | func (m *MySQLConfig) MarshalJSON() ([]byte, error) { 69 | type Aux struct { 70 | Address string 71 | User string 72 | Password string 73 | Connections ConnectionsConfig 74 | UseTLS bool 75 | TLSCAFile string 76 | InsecureSkipVerify bool 77 | VerifyCAFile bool 78 | } 79 | return json.Marshal(Aux{ 80 | Address: m.Address, 81 | User: m.User, 82 | Password: mask.Password(m.Password), 83 | Connections: m.Connections, 84 | UseTLS: m.UseTLS, 85 | TLSCAFile: m.TLSCAFile, 86 | InsecureSkipVerify: m.InsecureSkipVerify, 87 | VerifyCAFile: m.VerifyCAFile, 88 | }) 89 | } 90 | 91 | type ConnectionsConfig struct { 92 | MaxOpen int 93 | MaxIdle int 94 | MaxLifetime time.Duration 95 | MaxIdleTime time.Duration 96 | } 97 | -------------------------------------------------------------------------------- /database/model_config_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | package database_test 5 | 6 | import ( 7 | "bytes" 8 | "encoding/json" 9 | "testing" 10 | 11 | "github.com/moov-io/base/database" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestMySQLConfig(t *testing.T) { 16 | cfg := &database.MySQLConfig{ 17 | Address: "tcp(localhost:3306)", 18 | User: "app", 19 | Password: "secret", 20 | Connections: database.ConnectionsConfig{ 21 | MaxOpen: 100, 22 | }, 23 | } 24 | 25 | var buf bytes.Buffer 26 | err := json.NewEncoder(&buf).Encode(cfg) 27 | require.NoError(t, err) 28 | require.Contains(t, buf.String(), `"Password":"s*****t"`) 29 | } 30 | -------------------------------------------------------------------------------- /database/mysql_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | package database_test 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "errors" 10 | "testing" 11 | "time" 12 | 13 | "github.com/moov-io/base" 14 | "github.com/moov-io/base/database" 15 | "github.com/moov-io/base/database/testdb" 16 | "github.com/moov-io/base/log" 17 | "github.com/stretchr/testify/require" 18 | ) 19 | 20 | func TestMySQL__basic(t *testing.T) { 21 | if testing.Short() { 22 | t.Skip("-short flag enabled") 23 | } 24 | 25 | // create a phony MySQL 26 | mysqlConfig := database.DatabaseConfig{ 27 | DatabaseName: "moov", 28 | MySQL: &database.MySQLConfig{ 29 | User: "moov", 30 | Password: "moov", 31 | Address: "tcp(127.0.0.1:3306)", 32 | }, 33 | } 34 | 35 | ctx, cancelFunc := context.WithCancel(context.Background()) 36 | defer cancelFunc() 37 | 38 | m, err := database.New(ctx, log.NewNopLogger(), mysqlConfig) 39 | require.NoError(t, err) 40 | defer m.Close() 41 | 42 | require.NotNil(t, m) 43 | 44 | // Inspect the global and session SQL modes 45 | // See: https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sql-mode-setting 46 | sqlModes := readSQLModes(t, m, "SELECT @@SESSION.sql_mode;") 47 | require.Contains(t, sqlModes, "ALLOW_INVALID_DATES") 48 | require.Contains(t, sqlModes, "STRICT_ALL_TABLES") 49 | 50 | require.Equal(t, 1, m.Stats().OpenConnections) 51 | } 52 | 53 | func TestMySQLUniqueViolation(t *testing.T) { 54 | err := errors.New(`problem upserting depository="282f6ffcd9ba5b029afbf2b739ee826e22d9df3b", userId="f25f48968da47ef1adb5b6531a1c2197295678ce": Error 1062: Duplicate entry '282f6ffcd9ba5b029afbf2b739ee826e22d9df3b' for key 'PRIMARY'`) 55 | if !database.UniqueViolation(err) { 56 | t.Error("should have matched unique violation") 57 | } 58 | } 59 | 60 | func TestMySQLUniqueViolation_WithStateValue(t *testing.T) { 61 | err := errors.New(`problem upserting depository="282f6ffcd9ba5b029afbf2b739ee826e22d9df3b", userId="f25f48968da47ef1adb5b6531a1c2197295678ce": Error 1062 (23000): Duplicate entry '282f6ffcd9ba5b029afbf2b739ee826e22d9df3b' for key 'PRIMARY'`) 62 | if !database.UniqueViolation(err) { 63 | t.Error("should have matched unique violation") 64 | } 65 | } 66 | 67 | func TestMySQLDataTooLong(t *testing.T) { 68 | err := errors.New("Error 1406: Data too long") 69 | if !database.MySQLDataTooLong(err) { 70 | t.Error("should have matched") 71 | } 72 | } 73 | 74 | func TestMySQLDataTooLong_WithStateValue(t *testing.T) { 75 | err := errors.New("Error 1406 (22001): Data too long") 76 | if !database.MySQLDataTooLong(err) { 77 | t.Error("should have matched") 78 | } 79 | } 80 | 81 | func readSQLModes(t *testing.T, db *sql.DB, query string) string { 82 | stmt, err := db.Prepare(query) 83 | require.NoError(t, err) 84 | defer stmt.Close() 85 | 86 | row := stmt.QueryRow() 87 | require.NoError(t, row.Err()) 88 | 89 | var sqlModes string 90 | require.NoError(t, row.Scan(&sqlModes)) 91 | return sqlModes 92 | } 93 | 94 | func Test_MySQL_Embedded_Migration(t *testing.T) { 95 | if testing.Short() { 96 | t.Skip("-short flag enabled") 97 | } 98 | 99 | // create a phony MySQL 100 | mysqlConfig := database.DatabaseConfig{ 101 | DatabaseName: "moov2" + base.ID(), 102 | MySQL: &database.MySQLConfig{ 103 | User: "root", 104 | Password: "root", 105 | Address: "tcp(127.0.0.1:3306)", 106 | Connections: database.ConnectionsConfig{ 107 | MaxOpen: 1, 108 | MaxIdle: 1, 109 | MaxLifetime: time.Minute, 110 | MaxIdleTime: time.Second, 111 | }, 112 | }, 113 | } 114 | 115 | err := testdb.NewMySQLDatabase(t, mysqlConfig) 116 | require.NoError(t, err) 117 | 118 | db, err := database.NewAndMigrate(context.Background(), log.NewDefaultLogger(), mysqlConfig, database.WithEmbeddedMigrations(base.MySQLMigrations)) 119 | require.NoError(t, err) 120 | defer db.Close() 121 | } 122 | -------------------------------------------------------------------------------- /database/pkger.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "strings" 8 | 9 | "github.com/golang-migrate/migrate/v4/source" 10 | mpkger "github.com/golang-migrate/migrate/v4/source/pkger" 11 | "github.com/markbates/pkger" 12 | "github.com/markbates/pkger/pkging/mem" 13 | ) 14 | 15 | const MIGRATIONS_DIR = "/migrations/" 16 | 17 | func NewPkgerSource(database string, allowGeneric bool) (source.Driver, error) { 18 | database = strings.ToLower(database) 19 | 20 | hereInfo, err := pkger.Current() 21 | if err != nil { 22 | return nil, err 23 | } 24 | 25 | pmem, err := mem.New(hereInfo) 26 | if err != nil { 27 | return nil, err 28 | } 29 | 30 | pmem.MkdirAll(MIGRATIONS_DIR, 0755) 31 | 32 | err = pkger.Walk(MIGRATIONS_DIR, func(path string, info os.FileInfo, err error) error { 33 | if err != nil { 34 | return err 35 | } 36 | 37 | if info.IsDir() { 38 | return nil 39 | } 40 | 41 | splits := strings.Split(info.Name(), ".") 42 | slen := len(splits) 43 | if slen < 3 { 44 | return fmt.Errorf("doesn't follow format of {version}_{title}.up{.db}?.sql - %s", info.Name()) 45 | } 46 | 47 | if splits[slen-1] != "sql" { 48 | return fmt.Errorf("must end in .sql") 49 | } 50 | 51 | var run bool 52 | switch splits[slen-2] { 53 | case "up": 54 | run = true && allowGeneric 55 | case "down": 56 | run = true && allowGeneric 57 | case database: 58 | run = true 59 | default: 60 | run = false 61 | } 62 | 63 | if !run { 64 | return nil 65 | } 66 | 67 | cur, err := pkger.Open(MIGRATIONS_DIR + info.Name()) 68 | if err != nil { 69 | return err 70 | } 71 | 72 | nw, err := pmem.Create(path) 73 | if err != nil { 74 | return err 75 | } 76 | 77 | _, err = io.Copy(nw, cur) 78 | if err != nil { 79 | return err 80 | } 81 | 82 | nw.Close() 83 | 84 | return nil 85 | }) 86 | if err != nil { 87 | return nil, fmt.Errorf("walking the migrations directory: %w", err) 88 | } 89 | 90 | drv, err := mpkger.WithInstance(pmem, MIGRATIONS_DIR) 91 | if err != nil { 92 | return nil, fmt.Errorf("unable to instantiate driver - %w", err) 93 | } 94 | 95 | return drv, nil 96 | } 97 | -------------------------------------------------------------------------------- /database/postgres.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "fmt" 8 | "net" 9 | "strings" 10 | 11 | "cloud.google.com/go/alloydbconn" 12 | "github.com/jackc/pgx/v5" 13 | "github.com/jackc/pgx/v5/pgconn" 14 | "github.com/jackc/pgx/v5/stdlib" 15 | "github.com/moov-io/base/log" 16 | ) 17 | 18 | const ( 19 | // PostgreSQL Error Codes 20 | // https://www.postgresql.org/docs/current/errcodes-appendix.html 21 | postgresErrUniqueViolation = "23505" 22 | postgresErrDeadlockFound = "40P01" 23 | ) 24 | 25 | func postgresConnection(ctx context.Context, logger log.Logger, config PostgresConfig, databaseName string) (*sql.DB, error) { 26 | var connStr string 27 | if config.Alloy != nil { 28 | c, err := getAlloyDBConnectorConnStr(ctx, config, databaseName) 29 | if err != nil { 30 | return nil, logger.LogErrorf("creating alloydb connection: %w", err).Err() 31 | } 32 | connStr = c 33 | } else { 34 | c, err := getPostgresConnStr(config, databaseName) 35 | if err != nil { 36 | return nil, logger.LogErrorf("creating postgres connection: %w", err).Err() 37 | } 38 | connStr = c 39 | } 40 | 41 | db, err := sql.Open("pgx", connStr) 42 | if err != nil { 43 | return nil, logger.LogErrorf("opening database: %w", err).Err() 44 | } 45 | 46 | err = db.Ping() 47 | if err != nil { 48 | _ = db.Close() 49 | return nil, logger.LogErrorf("connecting to database: %w", err).Err() 50 | } 51 | 52 | return db, nil 53 | } 54 | 55 | func getPostgresConnStr(config PostgresConfig, databaseName string) (string, error) { 56 | url := fmt.Sprintf("postgres://%s:%s@%s/%s", config.User, config.Password, config.Address, databaseName) 57 | 58 | params := "" 59 | 60 | if config.TLS != nil { 61 | params += "sslmode=verify-full" 62 | 63 | if config.TLS.CACertFile == "" { 64 | return "", fmt.Errorf("missing TLS CA file") 65 | } 66 | params += "&sslrootcert=" + config.TLS.CACertFile 67 | 68 | if config.TLS.ClientCertFile != "" { 69 | params += "&sslcert=" + config.TLS.ClientCertFile 70 | } 71 | 72 | if config.TLS.ClientKeyFile != "" { 73 | params += "&sslkey=" + config.TLS.ClientKeyFile 74 | } 75 | } 76 | 77 | connStr := fmt.Sprintf("%s?%s", url, params) 78 | return connStr, nil 79 | } 80 | 81 | func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, databaseName string) (string, error) { 82 | if config.Alloy == nil { 83 | return "", fmt.Errorf("missing alloy config") 84 | } 85 | 86 | var dialer *alloydbconn.Dialer 87 | var dsn string 88 | 89 | if config.Alloy.UseIAM { 90 | d, err := alloydbconn.NewDialer(ctx, alloydbconn.WithIAMAuthN()) 91 | if err != nil { 92 | return "", fmt.Errorf("creating alloydb dialer: %v", err) 93 | } 94 | dialer = d 95 | dsn = fmt.Sprintf( 96 | // sslmode is disabled because the alloy db connection dialer will handle it 97 | // no password is used with IAM 98 | "user=%s dbname=%s sslmode=disable", 99 | config.User, databaseName, 100 | ) 101 | } else { 102 | d, err := alloydbconn.NewDialer(ctx) 103 | if err != nil { 104 | return "", fmt.Errorf("creating alloydb dialer: %v", err) 105 | } 106 | dialer = d 107 | dsn = fmt.Sprintf( 108 | // sslmode is disabled because the alloy db connection dialer will handle it 109 | "user=%s password=%s dbname=%s sslmode=disable", 110 | config.User, config.Password, databaseName, 111 | ) 112 | } 113 | 114 | // TODO 115 | //cleanup := func() error { return d.Close() } 116 | 117 | connConfig, err := pgx.ParseConfig(dsn) 118 | if err != nil { 119 | return "", fmt.Errorf("failed to parse pgx config: %v", err) 120 | } 121 | 122 | var connOptions []alloydbconn.DialOption 123 | if config.Alloy.UsePSC { 124 | connOptions = append(connOptions, alloydbconn.WithPSC()) 125 | } 126 | 127 | connConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) { 128 | return dialer.Dial(ctx, config.Alloy.InstanceURI, connOptions...) 129 | } 130 | 131 | connStr := stdlib.RegisterConnConfig(connConfig) 132 | return connStr, nil 133 | } 134 | 135 | // PostgresUniqueViolation returns true when the provided error matches the Postgres code 136 | // for unique violation. 137 | func PostgresUniqueViolation(err error) bool { 138 | if err == nil { 139 | return false 140 | } 141 | 142 | var pgError *pgconn.PgError 143 | if errors.As(err, &pgError) && pgError.Code == postgresErrUniqueViolation { 144 | return true 145 | } 146 | 147 | return strings.Contains(err.Error(), postgresErrUniqueViolation) 148 | } 149 | 150 | // PostgresDeadlockFound returns true when the provided error matches the Postgres code 151 | // for deadlock found. 152 | func PostgresDeadlockFound(err error) bool { 153 | if err == nil { 154 | return false 155 | } 156 | 157 | var pgError *pgconn.PgError 158 | if errors.As(err, &pgError) && pgError.Code == postgresErrDeadlockFound { 159 | return true 160 | } 161 | 162 | return strings.Contains(err.Error(), postgresErrDeadlockFound) 163 | } 164 | -------------------------------------------------------------------------------- /database/postgres_test.go: -------------------------------------------------------------------------------- 1 | package database_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | "time" 9 | 10 | "github.com/moov-io/base" 11 | "github.com/moov-io/base/database" 12 | "github.com/moov-io/base/database/testdb" 13 | "github.com/moov-io/base/log" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | var ( 18 | alloydbInstanceURI = os.Getenv("ALLOYDB_INSTANCE_URI") 19 | alloydbDBName = os.Getenv("ALLOYDB_DBNAME") 20 | alloydbIAMUser = os.Getenv("ALLOYDB_IAM_USER") 21 | alloydbNativeUser = os.Getenv("ALLOYDB_NATIVE_USER") 22 | alloydbNativePassword = os.Getenv("ALLOYDB_NATIVE_PASSWORD") 23 | ) 24 | 25 | func TestPostgres_Basic(t *testing.T) { 26 | if testing.Short() { 27 | t.Skip("-short flag enabled") 28 | } 29 | 30 | config := database.DatabaseConfig{ 31 | DatabaseName: "moov", 32 | Postgres: &database.PostgresConfig{ 33 | Address: "127.0.0.1:5432", 34 | User: "moov", 35 | Password: "moov", 36 | Connections: database.ConnectionsConfig{ 37 | MaxOpen: 4, 38 | MaxIdle: 4, 39 | MaxLifetime: time.Minute * 2, 40 | MaxIdleTime: time.Minute * 2, 41 | }, 42 | }, 43 | } 44 | 45 | db, err := database.New(context.Background(), log.NewTestLogger(), config) 46 | require.NoError(t, err) 47 | require.NotNil(t, db) 48 | defer db.Close() 49 | } 50 | 51 | func TestPostgres_TLS(t *testing.T) { 52 | if testing.Short() { 53 | t.Skip("-short flag enabled") 54 | } 55 | 56 | config := database.DatabaseConfig{ 57 | DatabaseName: "moov", 58 | Postgres: &database.PostgresConfig{ 59 | Address: "127.0.0.1:5432", 60 | User: "moov", 61 | Password: "moov", 62 | TLS: &database.PostgresTLSConfig{ 63 | CACertFile: filepath.Join("..", "testcerts", "root.crt"), 64 | ClientCertFile: filepath.Join("..", "testcerts", "client.crt"), 65 | ClientKeyFile: filepath.Join("..", "testcerts", "client.key"), 66 | }, 67 | }, 68 | } 69 | 70 | db, err := database.New(context.Background(), log.NewTestLogger(), config) 71 | require.NoError(t, err) 72 | require.NotNil(t, db) 73 | defer db.Close() 74 | } 75 | 76 | func TestProstres_Alloy(t *testing.T) { 77 | if testing.Short() { 78 | t.Skip("-short flag enabled") 79 | } 80 | 81 | if alloydbInstanceURI == "" || alloydbDBName == "" || alloydbNativeUser == "" || alloydbNativePassword == "" { 82 | t.Skip("missing required environment variables") 83 | } 84 | 85 | config := database.DatabaseConfig{ 86 | DatabaseName: alloydbDBName, 87 | Postgres: &database.PostgresConfig{ 88 | User: alloydbNativeUser, 89 | Password: alloydbNativePassword, 90 | Alloy: &database.PostgresAlloyConfig{ 91 | InstanceURI: alloydbInstanceURI, 92 | UseIAM: false, 93 | UsePSC: true, 94 | }, 95 | }, 96 | } 97 | 98 | db, err := database.New(context.Background(), log.NewTestLogger(), config) 99 | require.NoError(t, err) 100 | require.NotNil(t, db) 101 | defer db.Close() 102 | } 103 | 104 | func TestProstres_Alloy_IAM(t *testing.T) { 105 | if testing.Short() { 106 | t.Skip("-short flag enabled") 107 | } 108 | 109 | if alloydbInstanceURI == "" || alloydbDBName == "" || alloydbIAMUser == "" { 110 | t.Skip("missing required environment variables") 111 | } 112 | 113 | config := database.DatabaseConfig{ 114 | DatabaseName: alloydbDBName, 115 | Postgres: &database.PostgresConfig{ 116 | User: alloydbIAMUser, 117 | Alloy: &database.PostgresAlloyConfig{ 118 | InstanceURI: alloydbInstanceURI, 119 | UseIAM: true, 120 | UsePSC: true, 121 | }, 122 | }, 123 | } 124 | 125 | db, err := database.New(context.Background(), log.NewTestLogger(), config) 126 | require.NoError(t, err) 127 | require.NotNil(t, db) 128 | defer db.Close() 129 | } 130 | 131 | func Test_Postgres_Embedded_Migration(t *testing.T) { 132 | if testing.Short() { 133 | t.Skip("-short flag enabled") 134 | } 135 | 136 | // create a test postgres db 137 | config := database.DatabaseConfig{ 138 | DatabaseName: "postgres" + base.ID(), 139 | Postgres: &database.PostgresConfig{ 140 | Address: "127.0.0.1:5432", 141 | User: "moov", 142 | Password: "moov", 143 | }, 144 | } 145 | 146 | err := testdb.NewPostgresDatabase(t, config) 147 | require.NoError(t, err) 148 | 149 | db, err := database.NewAndMigrate(context.Background(), log.NewDefaultLogger(), config, database.WithEmbeddedMigrations(base.PostgresMigrations)) 150 | require.NoError(t, err) 151 | defer db.Close() 152 | } 153 | 154 | func Test_Postgres_Alloy_Migrations(t *testing.T) { 155 | if testing.Short() { 156 | t.Skip("-short flag enabled") 157 | } 158 | 159 | if alloydbInstanceURI == "" || alloydbDBName == "" || alloydbNativeUser == "" || alloydbNativePassword == "" { 160 | t.Skip("missing required environment variables") 161 | } 162 | 163 | config := database.DatabaseConfig{ 164 | DatabaseName: alloydbDBName, 165 | Postgres: &database.PostgresConfig{ 166 | User: alloydbNativeUser, 167 | Password: alloydbNativePassword, 168 | Alloy: &database.PostgresAlloyConfig{ 169 | InstanceURI: alloydbInstanceURI, 170 | UseIAM: false, 171 | UsePSC: true, 172 | }, 173 | }, 174 | } 175 | 176 | // migrating database given by ALLOYDB_DBNAME env var 177 | 178 | db, err := database.NewAndMigrate(context.Background(), log.NewDefaultLogger(), config, database.WithEmbeddedMigrations(base.PostgresMigrations)) 179 | require.NoError(t, err) 180 | defer db.Close() 181 | } 182 | 183 | func Test_Postgres_UniqueViolation(t *testing.T) { 184 | if testing.Short() { 185 | t.Skip("-short flag enabled") 186 | } 187 | 188 | // create a test postgres db 189 | config := database.DatabaseConfig{ 190 | DatabaseName: "postgres" + base.ID(), 191 | Postgres: &database.PostgresConfig{ 192 | Address: "127.0.0.1:5432", 193 | User: "moov", 194 | Password: "moov", 195 | }, 196 | } 197 | 198 | err := testdb.NewPostgresDatabase(t, config) 199 | require.NoError(t, err) 200 | 201 | db, err := database.New(context.Background(), log.NewDefaultLogger(), config) 202 | require.NoError(t, err) 203 | 204 | createQry := `CREATE TABLE names (id SERIAL PRIMARY KEY, name VARCHAR(255));` 205 | _, err = db.Exec(createQry) 206 | require.NoError(t, err) 207 | 208 | insertQry := `INSERT INTO names (id, name) VALUES ($1, $2);` 209 | _, err = db.Exec(insertQry, 1, "James") 210 | require.NoError(t, err) 211 | 212 | _, err = db.Exec(insertQry, 1, "James") 213 | require.Error(t, err) 214 | require.True(t, database.UniqueViolation(err)) 215 | } 216 | -------------------------------------------------------------------------------- /database/spanner.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "strings" 7 | 8 | "cloud.google.com/go/spanner" 9 | "github.com/golang-migrate/migrate/v4/database" 10 | migspanner "github.com/golang-migrate/migrate/v4/database/spanner" 11 | _ "github.com/googleapis/go-sql-spanner" 12 | "google.golang.org/grpc/codes" 13 | 14 | "github.com/moov-io/base/log" 15 | ) 16 | 17 | func spannerConnection(_ log.Logger, cfg SpannerConfig, databaseName string) (*sql.DB, error) { 18 | db, err := sql.Open("spanner", fmt.Sprintf("projects/%s/instances/%s/databases/%s", cfg.Project, cfg.Instance, databaseName)) 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | return db, nil 24 | } 25 | 26 | func SpannerMigrationDriver(cfg SpannerConfig, databaseName string) (database.Driver, error) { 27 | clean := !cfg.DisableCleanStatements 28 | 29 | s := migspanner.Spanner{} 30 | return s.Open(fmt.Sprintf("spanner://projects/%s/instances/%s/databases/%s?x-migrations-table=spanner_schema_migrations&x-clean-statements=%t", cfg.Project, cfg.Instance, databaseName, clean)) 31 | } 32 | 33 | // SpannerUniqueViolation returns true when the provided error matches the Spanner code 34 | // for duplicate entries (violating a unique table constraint). 35 | // Refer to https://cloud.google.com/spanner/docs/error-codes for Spanner error definitions, 36 | // and https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto for error codes 37 | func SpannerUniqueViolation(err error) bool { 38 | if err == nil { 39 | return false 40 | } 41 | return spanner.ErrCode(err) == codes.AlreadyExists || 42 | strings.Contains(err.Error(), "AlreadyExists") 43 | } 44 | -------------------------------------------------------------------------------- /database/spanner_test.go: -------------------------------------------------------------------------------- 1 | package database_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "cloud.google.com/go/spanner" 9 | "github.com/googleapis/gax-go/v2/apierror" 10 | "github.com/stretchr/testify/require" 11 | "google.golang.org/grpc/codes" 12 | "google.golang.org/grpc/status" 13 | 14 | "github.com/moov-io/base" 15 | "github.com/moov-io/base/database" 16 | "github.com/moov-io/base/database/testdb" 17 | "github.com/moov-io/base/log" 18 | ) 19 | 20 | func Test_OpenConnection(t *testing.T) { 21 | if testing.Short() { 22 | t.Skip("-short flag enabled") 23 | } 24 | 25 | // Switches the spanner driver into using the emulator and bypassing the auth checks. 26 | testdb.SetSpannerEmulator(nil) 27 | 28 | cfg := database.DatabaseConfig{ 29 | DatabaseName: "my-database", 30 | Spanner: &database.SpannerConfig{ 31 | Project: "my-project", 32 | Instance: "my-instance", 33 | }, 34 | } 35 | 36 | db, err := database.New(context.Background(), log.NewDefaultLogger(), cfg) 37 | require.NoError(t, err) 38 | defer db.Close() 39 | } 40 | 41 | func Test_Migration(t *testing.T) { 42 | if testing.Short() { 43 | t.Skip("-short flag enabled") 44 | } 45 | 46 | // Switches the spanner driver into using the emulator and bypassing the auth checks. 47 | testdb.SetSpannerEmulator(nil) 48 | 49 | cfg, err := testdb.NewSpannerDatabase("mydb", nil) 50 | require.NoError(t, err) 51 | 52 | err = database.RunMigrations(log.NewDefaultLogger(), cfg) 53 | require.NoError(t, err) 54 | } 55 | 56 | func Test_IdempotentCreate(t *testing.T) { 57 | if testing.Short() { 58 | t.Skip("-short flag enabled") 59 | } 60 | 61 | // Switches the spanner driver into using the emulator and bypassing the auth checks. 62 | testdb.SetSpannerEmulator(nil) 63 | 64 | spanner := &database.SpannerConfig{ 65 | Project: "basetest", 66 | Instance: "idempotent", 67 | } 68 | 69 | cfg1, err := testdb.NewSpannerDatabase("mydb", spanner) 70 | require.NoError(t, err) 71 | 72 | cfg2, err := testdb.NewSpannerDatabase("mydb", cfg1.Spanner) 73 | require.NoError(t, err) 74 | 75 | require.Equal(t, cfg1.Spanner, spanner) 76 | require.Equal(t, cfg1, cfg2) 77 | } 78 | 79 | func Test_MigrateAndRun(t *testing.T) { 80 | if testing.Short() { 81 | t.Skip("-short flag enabled") 82 | } 83 | 84 | // Switches the spanner driver into using the emulator and bypassing the auth checks. 85 | testdb.SetSpannerEmulator(nil) 86 | 87 | cfg, err := testdb.NewSpannerDatabase("mydb", nil) 88 | require.NoError(t, err) 89 | 90 | err = database.RunMigrations(log.NewDefaultLogger(), cfg) 91 | require.NoError(t, err) 92 | 93 | db, err := database.New(context.Background(), log.NewDefaultLogger(), cfg) 94 | require.NoError(t, err) 95 | defer db.Close() 96 | 97 | rows, err := db.Query("SELECT * FROM MigrationTest") 98 | require.NoError(t, err) 99 | defer rows.Close() 100 | require.NoError(t, rows.Err()) 101 | } 102 | 103 | func Test_Embedded_Migration(t *testing.T) { 104 | if testing.Short() { 105 | t.Skip("-short flag enabled") 106 | } 107 | 108 | // Switches the spanner driver into using the emulator and bypassing the auth checks. 109 | testdb.SetSpannerEmulator(nil) 110 | 111 | cfg, err := testdb.NewSpannerDatabase("mydb", nil) 112 | require.NoError(t, err) 113 | 114 | db, err := database.NewAndMigrate(context.Background(), log.NewDefaultLogger(), cfg, database.WithEmbeddedMigrations(base.SpannerMigrations)) 115 | require.NoError(t, err) 116 | defer db.Close() 117 | } 118 | 119 | func TestSpannerUniqueViolation(t *testing.T) { 120 | errMsg := "Failed to insert row with primary key ({pk#primary_key:\"282f6ffcd9ba5b029afbf2b739ee826e22d9df3b\"}) due to previously existing row" 121 | // Test backwards-compatible parsing of spanner.Error (soon to be deprecated) from Spanner client 122 | statusErr := status.New(codes.AlreadyExists, errMsg).Err() 123 | oldSpannerErr := spanner.ToSpannerError(statusErr) 124 | if !database.SpannerUniqueViolation(oldSpannerErr) { 125 | t.Error("should have matched unique violation") 126 | } 127 | 128 | // Test new apirerror.APIError response from Spanner client 129 | newSpannerErr, parseErr := apierror.FromError(statusErr) 130 | require.True(t, parseErr) 131 | if !database.SpannerUniqueViolation(newSpannerErr) { 132 | t.Error("should have matched unique violation") 133 | } 134 | 135 | // Test wrapped spanner error 136 | wrappedErr := fmt.Errorf("wrapped err: %w", statusErr) 137 | if !database.SpannerUniqueViolation(wrappedErr) { 138 | t.Error("should have matched unique violation") 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /database/testdata/gencerts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # This script generates TLS certs for local development: 5 | # - Creates a self-signed root CA 6 | # - Generates server and client certs signed by the root CA 7 | # - Sets up certs for localhost use (e.g., local HTTPS and mTLS testing) 8 | # Note: These certs are for development/testing only, not for production use. 9 | mkdir -p testcerts 10 | cd testcerts 11 | 12 | echo "STARTING Generating test certificates" 13 | openssl genrsa -out root.key 2048 14 | openssl req -new -x509 -days 365 -key root.key -subj "/C=CN/ST=GD/L=SZ/O=Moov, Inc./CN=Moov Root CA" -out root.crt 15 | openssl req -newkey rsa:2048 -nodes -keyout server.key -subj "/C=CN/ST=GD/L=SZ/O=Moov, Inc./CN=localhost" -out server.csr 16 | openssl x509 -req -extfile <(printf "subjectAltName=DNS:localhost,IP:127.0.0.1") -days 365 -in server.csr -CA root.crt -CAkey root.key -CAcreateserial -out server.crt 17 | openssl req -newkey rsa:2048 -nodes -keyout client.key -subj "/C=CN/ST=GD/L=SZ/O=Moov, Inc./CN=moov" -out client.csr 18 | openssl x509 -req -extfile <(printf "subjectAltName=DNS:localhost,IP:127.0.0.1") -days 365 -in client.csr -CA root.crt -CAkey root.key -CAcreateserial -out client.crt 19 | 20 | rm -f server.csr client.csr 21 | ls -l 22 | 23 | echo "FINIHSED Generating test certificates" 24 | -------------------------------------------------------------------------------- /database/testdb/mysql.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/moov-io/base/database" 9 | ) 10 | 11 | func NewMySQLDatabase(t *testing.T, cfg database.DatabaseConfig) error { 12 | t.Helper() 13 | if cfg.MySQL == nil { 14 | return fmt.Errorf("mysql config not defined") 15 | } 16 | 17 | rootDb, err := sql.Open("mysql", fmt.Sprintf("%s:%s@%s/", cfg.MySQL.User, cfg.MySQL.Password, cfg.MySQL.Address)) 18 | if err != nil { 19 | return err 20 | } 21 | 22 | if err := rootDb.Ping(); err != nil { 23 | return err 24 | } 25 | 26 | _, err = rootDb.Exec(fmt.Sprintf("CREATE DATABASE %s", cfg.DatabaseName)) 27 | if err != nil { 28 | return err 29 | } 30 | 31 | t.Cleanup(func() { 32 | rootDb.Exec(fmt.Sprintf("DROP DATABASE %s", cfg.DatabaseName)) 33 | rootDb.Close() 34 | }) 35 | 36 | return nil 37 | } 38 | -------------------------------------------------------------------------------- /database/testdb/postgres.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/moov-io/base/database" 9 | ) 10 | 11 | func NewPostgresDatabase(t *testing.T, cfg database.DatabaseConfig) error { 12 | t.Helper() 13 | if cfg.Postgres == nil { 14 | return fmt.Errorf("postgres config not defined") 15 | } 16 | 17 | db, err := sql.Open("pgx", fmt.Sprintf("postgres://%s:%s@%s", cfg.Postgres.User, cfg.Postgres.Password, cfg.Postgres.Address)) 18 | if err != nil { 19 | return err 20 | } 21 | 22 | if err := db.Ping(); err != nil { 23 | return err 24 | } 25 | 26 | _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", cfg.DatabaseName)) 27 | if err != nil { 28 | return err 29 | } 30 | 31 | t.Cleanup(func() { 32 | db.Exec(fmt.Sprintf("DROP DATABASE %s", cfg.DatabaseName)) 33 | db.Close() 34 | }) 35 | 36 | return nil 37 | } 38 | -------------------------------------------------------------------------------- /database/testdb/spanner.go: -------------------------------------------------------------------------------- 1 | package testdb 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | spannerdb "cloud.google.com/go/spanner/admin/database/apiv1" 9 | "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" 10 | instance "cloud.google.com/go/spanner/admin/instance/apiv1" 11 | "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" 12 | "github.com/moov-io/base" 13 | "github.com/moov-io/base/database" 14 | "google.golang.org/grpc/codes" 15 | "google.golang.org/grpc/status" 16 | ) 17 | 18 | // Must be called if using the docker spanner emulator 19 | func SetSpannerEmulator(hostOverride *string) { 20 | host := "localhost:9010" 21 | if hostOverride != nil { 22 | host = *hostOverride 23 | } 24 | 25 | os.Setenv("SPANNER_EMULATOR_HOST", host) 26 | } 27 | 28 | func NewSpannerDatabase(databaseName string, spannerCfg *database.SpannerConfig) (database.DatabaseConfig, error) { 29 | if spannerCfg == nil { 30 | spannerCfg = &database.SpannerConfig{ 31 | Project: "proj" + base.ID()[0:26], 32 | Instance: "test", 33 | } 34 | } 35 | 36 | cfg := database.DatabaseConfig{ 37 | DatabaseName: databaseName, 38 | Spanner: spannerCfg, 39 | } 40 | 41 | if err := createInstance(cfg.Spanner); err != nil { 42 | return cfg, err 43 | } 44 | 45 | if err := createDatabase(cfg); err != nil { 46 | return cfg, err 47 | } 48 | 49 | return cfg, nil 50 | } 51 | 52 | func createInstance(cfg *database.SpannerConfig) error { 53 | ctx := context.Background() 54 | instanceAdmin, err := instance.NewInstanceAdminClient(ctx) 55 | if err != nil { 56 | return err 57 | } 58 | defer instanceAdmin.Close() 59 | 60 | op, err := instanceAdmin.CreateInstance(ctx, &instancepb.CreateInstanceRequest{ 61 | Parent: fmt.Sprintf("projects/%s", cfg.Project), 62 | InstanceId: cfg.Instance, 63 | Instance: &instancepb.Instance{ 64 | Config: fmt.Sprintf("projects/%s/instanceConfigs/%s", cfg.Project, "emulator-config"), 65 | DisplayName: cfg.Instance, 66 | NodeCount: 1, 67 | }, 68 | }) 69 | if err != nil { 70 | if status.Code(err) == codes.AlreadyExists { 71 | return nil 72 | } 73 | return fmt.Errorf("could not create instance %s: %v", fmt.Sprintf("projects/%s/instances/%s", cfg.Project, cfg.Instance), err) 74 | } 75 | 76 | // Wait for the instance creation to finish. 77 | if _, err := op.Wait(ctx); err != nil { 78 | return fmt.Errorf("waiting for instance creation to finish failed: %v", err) 79 | } 80 | 81 | return nil 82 | } 83 | 84 | func createDatabase(cfg database.DatabaseConfig) error { 85 | ctx := context.Background() 86 | databaseAdminClient, err := spannerdb.NewDatabaseAdminClient(ctx) 87 | if err != nil { 88 | return err 89 | } 90 | defer databaseAdminClient.Close() 91 | 92 | opDB, err := databaseAdminClient.CreateDatabase(ctx, &databasepb.CreateDatabaseRequest{ 93 | Parent: fmt.Sprintf("projects/%s/instances/%s", cfg.Spanner.Project, cfg.Spanner.Instance), 94 | CreateStatement: fmt.Sprintf("CREATE DATABASE `%s`", cfg.DatabaseName), 95 | }) 96 | if err != nil { 97 | if status.Code(err) == codes.AlreadyExists { 98 | return nil 99 | } 100 | return err 101 | } 102 | 103 | // Wait for the database creation to finish. 104 | if _, err := opDB.Wait(ctx); err != nil { 105 | return fmt.Errorf("waiting for database creation to finish failed: %v", err) 106 | } 107 | 108 | return nil 109 | } 110 | -------------------------------------------------------------------------------- /database/tls.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "crypto/tls" 5 | 6 | "github.com/moov-io/base/log" 7 | ) 8 | 9 | func LoadTLSClientCertsFromConfig(logger log.Logger, config *MySQLConfig) ([]tls.Certificate, error) { 10 | var clientCerts []tls.Certificate 11 | 12 | for _, clientCert := range config.TLSClientCerts { 13 | cert, err := LoadTLSClientCertFromFile(logger, clientCert.CertFilePath, clientCert.KeyFilePath) 14 | if err != nil { 15 | return []tls.Certificate{}, err 16 | } 17 | clientCerts = append(clientCerts, cert) 18 | } 19 | 20 | return clientCerts, nil 21 | } 22 | 23 | func LoadTLSClientCertFromFile(logger log.Logger, certFile, keyFile string) (tls.Certificate, error) { 24 | if certFile == "" || keyFile == "" { 25 | return tls.Certificate{}, logger.LogErrorf("cert path or key path not provided").Err() 26 | } 27 | 28 | cert, err := tls.LoadX509KeyPair(certFile, keyFile) 29 | if err != nil { 30 | return tls.Certificate{}, logger.LogErrorf("error loading client cert/key from file: %v", err).Err() 31 | } 32 | return cert, nil 33 | } 34 | -------------------------------------------------------------------------------- /database/tls_test.go: -------------------------------------------------------------------------------- 1 | package database_test 2 | 3 | import ( 4 | "path/filepath" 5 | "testing" 6 | 7 | "github.com/madflojo/testcerts" 8 | "github.com/moov-io/base/database" 9 | "github.com/moov-io/base/log" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func Test_LoadClientCertsFromConfig(t *testing.T) { 14 | dir := t.TempDir() 15 | 16 | certFilepath := filepath.Join(dir, "client_cert.pem") 17 | keyFilepath := filepath.Join(dir, "client_cert_private_key.pem") 18 | 19 | err := testcerts.GenerateCertsToFile(certFilepath, keyFilepath) 20 | require.Nil(t, err) 21 | 22 | config := &database.MySQLConfig{ 23 | TLSClientCerts: []database.TLSClientCertConfig{ 24 | { 25 | CertFilePath: certFilepath, 26 | KeyFilePath: keyFilepath, 27 | }, 28 | }, 29 | } 30 | 31 | clientCerts, err := database.LoadTLSClientCertsFromConfig(log.NewNopLogger(), config) 32 | require.Nil(t, err) 33 | 34 | require.Len(t, clientCerts, 1) 35 | require.Len(t, clientCerts[0].Certificate, 1) 36 | } 37 | -------------------------------------------------------------------------------- /database/tx.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | package database 5 | 6 | type RunInTx func() error 7 | 8 | func NopInTx() error { 9 | return nil 10 | } 11 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package base implements core libraries used in multiple Moov projects. Refer to each projects documentation for more details. 6 | package base 7 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | mysql: 4 | image: mysql:9-oracle 5 | restart: always 6 | ports: 7 | - "3306:3306" 8 | environment: 9 | - MYSQL_DATABASE=moov 10 | - MYSQL_USER=moov 11 | - MYSQL_PASSWORD=moov 12 | - MYSQL_ROOT_PASSWORD=root 13 | networks: 14 | - intranet 15 | healthcheck: 16 | test: ["CMD", "mysqladmin" ,"ping", "-h", "localhost"] 17 | timeout: 20s 18 | retries: 10 19 | tmpfs: # Run this mysql in memory as its used for testing 20 | - /var/lib/mysql 21 | 22 | spanner: 23 | image: gcr.io/cloud-spanner-emulator/emulator 24 | restart: always 25 | ports: 26 | - "9010:9010" 27 | - "9020:9020" 28 | networks: 29 | - intranet 30 | 31 | postgres: 32 | image: postgres:17.4 33 | restart: always 34 | ports: 35 | - "5432:5432" 36 | # https://github.com/docker-library/postgres/issues/1059#issuecomment-1467077098 37 | command: | 38 | sh -c 'chown postgres:postgres /opt/moov/certs/*.key && chmod 0644 /opt/moov/certs/*.crt && ls -l /opt/moov/certs/ && exec docker-entrypoint.sh -c ssl=on -c ssl_cert_file=/opt/moov/certs/server.crt -c ssl_key_file=/opt/moov/certs/server.key -c ssl_ca_file=/opt/moov/certs/root.crt' 39 | healthcheck: 40 | test: ["CMD-SHELL", "pg_isready -U moov"] 41 | interval: 5s 42 | timeout: 5s 43 | retries: 5 44 | environment: 45 | - POSTGRES_DB=moov 46 | - POSTGRES_USER=moov 47 | - POSTGRES_PASSWORD=moov 48 | networks: 49 | - intranet 50 | volumes: 51 | - ./testcerts/root.crt:/opt/moov/certs/root.crt 52 | - ./testcerts/server.crt:/opt/moov/certs/server.crt 53 | - ./testcerts/server.key:/opt/moov/certs/server.key 54 | 55 | networks: 56 | intranet: 57 | -------------------------------------------------------------------------------- /docker/docker.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package docker 6 | 7 | import ( 8 | "os/exec" 9 | ) 10 | 11 | // Enabled returns true if Docker is available when called. 12 | func Enabled() bool { 13 | bin, err := exec.LookPath("docker") 14 | return bin != "" && err == nil // 'docker' was found on PATH 15 | } 16 | -------------------------------------------------------------------------------- /docker/docker_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package docker 6 | 7 | import ( 8 | "os" 9 | "runtime" 10 | "testing" 11 | ) 12 | 13 | func TestDocker(t *testing.T) { 14 | osname := os.Getenv("TRAVIS_OS_NAME") 15 | if osname == "" { 16 | t.Skip("docker: only testing in CI") 17 | } 18 | 19 | if runtime.GOOS == "darwin" { 20 | if Enabled() { 21 | t.Error("docker on travis-ci osx/macOS available now?") 22 | } 23 | } else { 24 | if !Enabled() { 25 | t.Errorf("expected Docker to be enabled in %s CI", runtime.GOOS) 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package base 6 | 7 | import ( 8 | "bytes" 9 | "encoding/json" 10 | "errors" 11 | "fmt" 12 | "io" 13 | "reflect" 14 | ) 15 | 16 | // UnwrappableError is an interface for errors that wrap another error with some extra context 17 | // The interface allows these errors to get automatically unwrapped by the Match function 18 | type UnwrappableError interface { 19 | Error() string 20 | Unwrap() error 21 | } 22 | 23 | // ParseError is returned for parsing reader errors. 24 | // The first line is 1. 25 | type ParseError struct { 26 | Line int // Line number where the error occurred 27 | Record string // Name of the record type being parsed 28 | Err error // The actual error 29 | } 30 | 31 | func (e ParseError) Error() string { 32 | if e.Record == "" { 33 | return fmt.Sprintf("line:%d %T %s", e.Line, e.Err, e.Err) 34 | } 35 | return fmt.Sprintf("line:%d record:%s %T %s", e.Line, e.Record, e.Err, e.Err) 36 | } 37 | 38 | // Unwrap implements the UnwrappableError interface for ParseError 39 | func (e ParseError) Unwrap() error { 40 | return e.Err 41 | } 42 | 43 | // ErrorList represents an array of errors which is also an error itself. 44 | type ErrorList []error 45 | 46 | // Add appends err onto the ErrorList. Errors are kept in append order. 47 | func (e *ErrorList) Add(err error) { 48 | *e = append(*e, err) 49 | } 50 | 51 | // Err returns the first error (or nil). 52 | func (e ErrorList) Err() error { 53 | if len(e) == 0 { 54 | return nil 55 | } 56 | return e[0] 57 | } 58 | 59 | // Error implements the error interface 60 | func (e ErrorList) Error() string { 61 | if len(e) == 0 { 62 | return "" 63 | } 64 | var buf bytes.Buffer 65 | e.Print(&buf) 66 | return buf.String() 67 | } 68 | 69 | // Print formats the ErrorList into a string written to w. 70 | // If ErrorList contains multiple errors those after the first 71 | // are indented. 72 | func (e ErrorList) Print(w io.Writer) { 73 | if w == nil || len(e) == 0 { 74 | fmt.Fprintf(w, "") 75 | return 76 | } 77 | 78 | fmt.Fprintf(w, "%s", e[0]) 79 | if len(e) > 1 { 80 | fmt.Fprintf(w, "\n") 81 | } 82 | 83 | for i := 1; i < len(e); i++ { 84 | fmt.Fprintf(w, " %s", e[i]) 85 | if i < len(e)-1 { // don't add \n to last error 86 | fmt.Fprintf(w, "\n") 87 | } 88 | } 89 | } 90 | 91 | // Empty no errors to return 92 | func (e ErrorList) Empty() bool { 93 | return len(e) == 0 94 | } 95 | 96 | // MarshalJSON marshals error list 97 | func (e ErrorList) MarshalJSON() ([]byte, error) { 98 | return json.Marshal(e.Error()) 99 | } 100 | 101 | // Match takes in two errors and compares them, returning true if they match and false if they don't 102 | // The matching is done by basic equality for simple errors (i.e. defined by errors.New) and by type 103 | // for other errors. If errA is wrapped with an error supporting the UnwrappableError interface it 104 | // will also unwrap it and then recursively compare the unwrapped error with errB. 105 | func Match(errA, errB error) bool { 106 | if errA == nil { 107 | return errB == nil 108 | } 109 | 110 | // typed errors can be compared by type 111 | if reflect.TypeOf(errA) == reflect.TypeOf(errB) { 112 | simpleError := errors.New("simple error") 113 | if reflect.TypeOf(errB) == reflect.TypeOf(simpleError) { 114 | // simple errors all have the same type, so we need to compare them directly 115 | return errA == errB 116 | } 117 | return true 118 | } 119 | 120 | // match wrapped errors 121 | uwErr, ok := errA.(UnwrappableError) 122 | if ok { 123 | return Match(uwErr.Unwrap(), errB) 124 | } 125 | 126 | return false 127 | } 128 | 129 | // Has takes in a (potential) list of errors, and an error to check for. If any of the errors 130 | // in the list have the same type as the error to check, it returns true. If the "list" isn't 131 | // actually a list (typically because it is nil), or no errors in the list match the other error 132 | // it returns false. So it can be used as an easy way to check for a particular kind of error. 133 | func Has(list error, err error) bool { 134 | el, ok := list.(ErrorList) 135 | if !ok { 136 | return false 137 | } 138 | for i := 0; i < len(el); i++ { 139 | if Match(el[i], err) { 140 | return true 141 | } 142 | } 143 | return false 144 | } 145 | -------------------------------------------------------------------------------- /error_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package base 6 | 7 | import ( 8 | "bytes" 9 | "errors" 10 | "fmt" 11 | "strings" 12 | "testing" 13 | 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func TestParseError_Error(t *testing.T) { 18 | errorList := ErrorList{} 19 | errorList.Add(errors.New("testing")) 20 | 21 | pse := ParseError{ 22 | Err: errorList, 23 | Line: 5, 24 | Record: "ABC", 25 | } 26 | 27 | if !strings.Contains(pse.Error(), "testing") { 28 | t.Errorf("got %s", errorList.Error()) 29 | } 30 | 31 | if pse.Record != "ABC" { 32 | t.Errorf("got %s", pse.Record) 33 | } 34 | 35 | if pse.Line != 5 { 36 | t.Errorf("got %v", pse.Line) 37 | } 38 | 39 | } 40 | 41 | func TestParseErrorRecordNull_Error(t *testing.T) { 42 | errorList := ErrorList{} 43 | errorList.Add(errors.New("testing")) 44 | 45 | pse := ParseError{ 46 | Err: errorList, 47 | Line: 5, 48 | Record: "", 49 | } 50 | 51 | e1 := pse.Error() 52 | 53 | if e1 != "line:5 base.ErrorList testing" { 54 | t.Errorf("got %s", e1) 55 | } 56 | } 57 | 58 | func TestErrorList_Add(t *testing.T) { 59 | errorList := ErrorList{} 60 | errorList.Add(errors.New("testing")) 61 | 62 | es := errorList.Error() 63 | 64 | if es != "testing" { 65 | t.Errorf("got %s", errorList.Error()) 66 | } 67 | 68 | if errorList.Empty() { 69 | t.Errorf("ErrorList is empty: %v", errorList) 70 | } 71 | 72 | errorList.Add(errors.New("continued testing")) 73 | 74 | if errorList.Empty() { 75 | t.Errorf("ErrorList is empty: %v", errorList) 76 | } 77 | } 78 | 79 | func TestErrorList_Err(t *testing.T) { 80 | errorList := ErrorList{} 81 | errorList.Add(errors.New("testing")) 82 | 83 | e1 := errorList.Err() 84 | 85 | if e1.Error() != "testing" { 86 | t.Errorf("got %q", e1) 87 | } 88 | 89 | } 90 | 91 | func TestErrorList_Print(t *testing.T) { 92 | errorList := ErrorList{} 93 | errorList.Add(errors.New("testing")) 94 | errorList.Add(errors.New("continued testing")) 95 | 96 | var buf bytes.Buffer 97 | errorList.Print(&buf) 98 | 99 | if v := errorList.Error(); v == "" { 100 | t.Errorf("got %q", v) 101 | } 102 | buf.Reset() 103 | 104 | } 105 | 106 | func TestErrorList_Empty(t *testing.T) { 107 | errorList := ErrorList{} 108 | 109 | e1 := errorList.Err() 110 | 111 | if e1 != nil { 112 | t.Errorf("got %q", e1) 113 | } 114 | if errorList.Error() != "" { 115 | t.Errorf("got %s", errorList.Error()) 116 | } 117 | 118 | var buf bytes.Buffer 119 | errorList.Print(&buf) 120 | buf.Reset() 121 | } 122 | 123 | func TestErrorList__EmptyThenNot(t *testing.T) { 124 | var el ErrorList 125 | require.NoError(t, el.Err()) 126 | require.Equal(t, "", el.Error()) 127 | require.True(t, el.Empty()) 128 | 129 | el.Add(errors.New("bad thing")) 130 | require.Error(t, el.Err()) 131 | require.Equal(t, "bad thing", el.Error()) 132 | require.False(t, el.Empty()) 133 | } 134 | 135 | func TestErrorList_MarshalJSON(t *testing.T) { 136 | errorList := ErrorList{} 137 | errorList.Add(errors.New("testing")) 138 | errorList.Add(errors.New("continued testing")) 139 | errorList.Add(errors.New("testing again")) 140 | errorList.Add(errors.New("continued testing again")) 141 | 142 | b, err := errorList.MarshalJSON() 143 | 144 | if len(b) == 0 { 145 | t.Errorf("got %s", errorList.Error()) 146 | } 147 | if err != nil { 148 | t.Errorf("got %s", errorList.Error()) 149 | } 150 | } 151 | 152 | // testMatch validates the Match error function 153 | func TestMatch(t *testing.T) { 154 | testError := errors.New("Test error") 155 | 156 | if !Match(nil, nil) { 157 | t.Error("Match should be reflexive on nil") 158 | } 159 | 160 | if !Match(testError, testError) { 161 | t.Error("Match should be reflexive") 162 | } 163 | 164 | p := ParseError{Err: testError} 165 | if !Match(p, testError) { 166 | t.Error("Match should match wrapped errors implementing the UnwrappableError interface") 167 | } 168 | 169 | differentError := errors.New("Different error") 170 | if Match(testError, differentError) { 171 | t.Error("Match should return false for different simple errors") 172 | } 173 | 174 | q := ParseError{Err: differentError} 175 | if !Match(p, q) { 176 | t.Error("Match should match two different ParseErrors to each other since they have the same type") 177 | } 178 | 179 | errorList := ErrorList{} 180 | if Match(errorList, p) { 181 | t.Error("Match should return false for errors with different types") 182 | } 183 | } 184 | 185 | // testHas validates the Has error function 186 | func TestHas(t *testing.T) { 187 | err := errors.New("Non list error") 188 | 189 | if Has(err, err) { 190 | t.Error("Has should return false when given a non-list error as the first arg") 191 | } 192 | 193 | if Has(nil, err) { 194 | t.Error("Has should not return true if there are no errors") 195 | } 196 | 197 | if Has(ErrorList([]error{}), err) { 198 | t.Error("Has should not return true if there are no errors") 199 | } 200 | 201 | if !Has(ErrorList([]error{err}), err) { 202 | t.Error("Has should return true if the error list has the test error") 203 | } 204 | } 205 | 206 | func TestErrorList_Panic(t *testing.T) { 207 | var el ErrorList 208 | require.Equal(t, "", fmt.Sprintf("%v", el)) 209 | require.Equal(t, "", fmt.Errorf("%w", el).Error()) 210 | } 211 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/moov-io/base 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | cloud.google.com/go/alloydbconn v1.15.1 7 | cloud.google.com/go/spanner v1.79.0 8 | github.com/go-kit/kit v0.13.0 9 | github.com/go-kit/log v0.2.1 10 | github.com/go-sql-driver/mysql v1.9.2 11 | github.com/go-viper/mapstructure/v2 v2.2.1 12 | github.com/golang-migrate/migrate/v4 v4.18.3 13 | github.com/google/uuid v1.6.0 14 | github.com/googleapis/gax-go/v2 v2.14.1 15 | github.com/googleapis/go-sql-spanner v1.13.0 16 | github.com/gorilla/mux v1.8.1 17 | github.com/jackc/pgx/v5 v5.7.4 18 | github.com/madflojo/testcerts v1.4.0 19 | github.com/markbates/pkger v0.17.1 20 | github.com/prometheus/client_golang v1.22.0 21 | github.com/rickar/cal/v2 v2.1.23 22 | github.com/spf13/viper v1.20.1 23 | github.com/stretchr/testify v1.10.0 24 | go.opentelemetry.io/otel v1.36.0 25 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0 26 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.36.0 27 | go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.36.0 28 | go.opentelemetry.io/otel/sdk v1.36.0 29 | go.opentelemetry.io/otel/trace v1.36.0 30 | google.golang.org/grpc v1.72.1 31 | ) 32 | 33 | require ( 34 | cel.dev/expr v0.20.0 // indirect 35 | cloud.google.com/go v0.120.0 // indirect 36 | cloud.google.com/go/alloydb v1.15.0 // indirect 37 | cloud.google.com/go/auth v0.15.0 // indirect 38 | cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect 39 | cloud.google.com/go/compute/metadata v0.6.0 // indirect 40 | cloud.google.com/go/iam v1.4.2 // indirect 41 | cloud.google.com/go/longrunning v0.6.6 // indirect 42 | cloud.google.com/go/monitoring v1.24.1 // indirect 43 | filippo.io/edwards25519 v1.1.0 // indirect 44 | github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.2 // indirect 45 | github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0 // indirect 46 | github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.51.0 // indirect 47 | github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0 // indirect 48 | github.com/beorn7/perks v1.0.1 // indirect 49 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect 50 | github.com/cenkalti/backoff/v5 v5.0.2 // indirect 51 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 52 | github.com/cncf/xds/go v0.0.0-20250121191232-2f005788dc42 // indirect 53 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 54 | github.com/envoyproxy/go-control-plane/envoy v1.32.4 // indirect 55 | github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect 56 | github.com/felixge/httpsnoop v1.0.4 // indirect 57 | github.com/fsnotify/fsnotify v1.8.0 // indirect 58 | github.com/go-jose/go-jose/v4 v4.0.5 // indirect 59 | github.com/go-logfmt/logfmt v0.6.0 // indirect 60 | github.com/go-logr/logr v1.4.2 // indirect 61 | github.com/go-logr/stdr v1.2.2 // indirect 62 | github.com/gobuffalo/here v0.6.7 // indirect 63 | github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect 64 | github.com/google/s2a-go v0.1.9 // indirect 65 | github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect 66 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect 67 | github.com/hashicorp/errwrap v1.1.0 // indirect 68 | github.com/hashicorp/go-multierror v1.1.1 // indirect 69 | github.com/jackc/pgpassfile v1.0.0 // indirect 70 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect 71 | github.com/jackc/puddle/v2 v2.2.2 // indirect 72 | github.com/lib/pq v1.10.9 // indirect 73 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 74 | github.com/pelletier/go-toml/v2 v2.2.3 // indirect 75 | github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect 76 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 77 | github.com/prometheus/client_model v0.6.1 // indirect 78 | github.com/prometheus/common v0.62.0 // indirect 79 | github.com/prometheus/procfs v0.15.1 // indirect 80 | github.com/sagikazarmark/locafero v0.7.0 // indirect 81 | github.com/sourcegraph/conc v0.3.0 // indirect 82 | github.com/spf13/afero v1.12.0 // indirect 83 | github.com/spf13/cast v1.7.1 // indirect 84 | github.com/spf13/pflag v1.0.6 // indirect 85 | github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect 86 | github.com/subosito/gotenv v1.6.0 // indirect 87 | github.com/zeebo/errs v1.4.0 // indirect 88 | go.opencensus.io v0.24.0 // indirect 89 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 90 | go.opentelemetry.io/contrib/detectors/gcp v1.35.0 // indirect 91 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 // indirect 92 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect 93 | go.opentelemetry.io/otel/metric v1.36.0 // indirect 94 | go.opentelemetry.io/otel/sdk/metric v1.35.0 // indirect 95 | go.opentelemetry.io/proto/otlp v1.6.0 // indirect 96 | go.uber.org/atomic v1.11.0 // indirect 97 | go.uber.org/multierr v1.11.0 // indirect 98 | golang.org/x/crypto v0.38.0 // indirect 99 | golang.org/x/net v0.40.0 // indirect 100 | golang.org/x/oauth2 v0.29.0 // indirect 101 | golang.org/x/sync v0.14.0 // indirect 102 | golang.org/x/sys v0.33.0 // indirect 103 | golang.org/x/text v0.25.0 // indirect 104 | golang.org/x/time v0.11.0 // indirect 105 | google.golang.org/api v0.228.0 // indirect 106 | google.golang.org/genproto v0.0.0-20250303144028-a0af3efb3deb // indirect 107 | google.golang.org/genproto/googleapis/api v0.0.0-20250519155744-55703ea1f237 // indirect 108 | google.golang.org/genproto/googleapis/rpc v0.0.0-20250519155744-55703ea1f237 // indirect 109 | google.golang.org/protobuf v1.36.6 // indirect 110 | gopkg.in/yaml.v3 v3.0.1 // indirect 111 | ) 112 | -------------------------------------------------------------------------------- /http/bind/bind.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package bind returns well known HTTP local bind addresses for Moov services. 6 | // The package is intended for services to use for discovery during local development. 7 | // 8 | // This package also returns admin ports, useable with the github.com/moov-io/base/admin 9 | // package. 10 | package bind 11 | 12 | import ( 13 | "fmt" 14 | "strconv" 15 | "strings" 16 | ) 17 | 18 | // serviceBinds is a map between a service name and its local bind address. 19 | // The returned values will always be of the form ":XXXX" where XXXX is a 20 | // valid port above 1024. 21 | var serviceBinds = map[string]string{ 22 | // Never change existing records, just add new records. 23 | "ach": ":8080", 24 | "auth": ":8081", 25 | "paygate": ":8082", 26 | "x9": ":8083", // x9 was renamed to icl 27 | "icl": ":8083", 28 | "ofac": ":8084", // ofac was renamed to watchman 29 | "watchman": ":8084", 30 | "gl": ":8085", // GL was renamed to accounts 31 | "accounts": ":8085", 32 | "fed": ":8086", 33 | "customers": ":8087", 34 | "wire": ":8088", 35 | "apitest": ":8089", 36 | "console": ":8100", 37 | } 38 | 39 | // HTTP returns the local bind address for a Moov service. 40 | func HTTP(serviceName string) string { 41 | v, ok := serviceBinds[strings.ToLower(serviceName)] 42 | if !ok { 43 | return "" 44 | } 45 | return v 46 | } 47 | 48 | // Admin returns the local bind address for a Moov service's admin server. 49 | // This server typically serves metrics and debugging endpoints. 50 | func Admin(serviceName string) string { 51 | http := HTTP(serviceName) 52 | if http == "" { 53 | return "" 54 | } 55 | http = strings.TrimPrefix(http, ":") 56 | n, err := strconv.Atoi(http) 57 | if err != nil { 58 | return "" 59 | } 60 | n += 1000 // 90XX 61 | n += 10 // 909X 62 | return fmt.Sprintf(":%d", n) 63 | } 64 | -------------------------------------------------------------------------------- /http/bind/bind_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package bind 6 | 7 | import ( 8 | "testing" 9 | ) 10 | 11 | func TestBind(t *testing.T) { 12 | // valid 13 | http := HTTP("auth") 14 | if http != ":8081" { 15 | t.Errorf("got %s", http) 16 | } 17 | admin := Admin("auth") 18 | if admin != ":9091" { 19 | t.Errorf("got %s", admin) 20 | } 21 | if port := HTTP("console"); port != ":8100" { 22 | t.Errorf("got %s", port) 23 | } 24 | if port := Admin("console"); port != ":9110" { 25 | t.Errorf("got %s", port) 26 | } 27 | 28 | // invalid 29 | if v := HTTP("other"); v != "" { 30 | t.Errorf("got %s", v) 31 | } 32 | if v := Admin("other"); v != "" { 33 | t.Errorf("got %s", v) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /http/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package http implements a core suite of HTTP functions for use inside Moov. These packages are designed to 6 | // be used in production to provide insight without an excessive performance tradeoff. 7 | // 8 | // This package implements several opininated response functions (See Problem, InternalError) and stateless CORS 9 | // handling under our load balancing setup. They may not work for you. 10 | // 11 | // This package also implements a wrapper around http.ResponseWriter to log X-Request-ID, timing and the resulting status code. 12 | package http 13 | -------------------------------------------------------------------------------- /http/response.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package http 6 | 7 | import ( 8 | "net/http" 9 | "time" 10 | 11 | "github.com/moov-io/base/log" 12 | 13 | "github.com/go-kit/kit/metrics" 14 | ) 15 | 16 | // ResponseWriter implements Go's standard library http.ResponseWriter to complete HTTP requests 17 | type ResponseWriter struct { 18 | http.ResponseWriter 19 | 20 | start time.Time 21 | request *http.Request 22 | metric metrics.Histogram 23 | 24 | headersWritten bool // set on WriteHeader 25 | 26 | log log.Logger 27 | } 28 | 29 | // WriteHeader sends an HTTP response header with the provided status code, records response duration, 30 | // and optionally records the HTTP metadata in a go-kit log.Logger 31 | func (w *ResponseWriter) WriteHeader(code int) { 32 | if w == nil || w.headersWritten { 33 | return 34 | } 35 | w.headersWritten = true 36 | 37 | // Headers 38 | SetAccessControlAllowHeaders(w, w.request.Header.Get("Origin")) 39 | defer w.ResponseWriter.WriteHeader(code) 40 | 41 | // Record route timing 42 | diff := time.Since(w.start) 43 | if w.metric != nil { 44 | w.metric.Observe(diff.Seconds()) 45 | } 46 | 47 | // Skip Go's content sniff here to speed up response timing for client 48 | if w.ResponseWriter.Header().Get("Content-Type") == "" { 49 | w.ResponseWriter.Header().Set("Content-Type", "text/plain") 50 | w.ResponseWriter.Header().Set("X-Content-Type-Options", "nosniff") 51 | } 52 | 53 | if requestID := GetRequestID(w.request); requestID != "" && w.log != nil { 54 | w.log.With(log.Fields{ 55 | "method": log.String(w.request.Method), 56 | "path": log.String(w.request.URL.Path), 57 | "status": log.Int(code), 58 | "duration": log.TimeDuration(diff), 59 | "requestID": log.String(requestID), 60 | }).Send() 61 | } 62 | } 63 | 64 | // Wrap returns a ResponseWriter usable by applications. No parts of the Request are inspected or ResponseWriter modified. 65 | func Wrap(logger log.Logger, m metrics.Histogram, w http.ResponseWriter, r *http.Request) *ResponseWriter { 66 | now := time.Now() 67 | return &ResponseWriter{ 68 | ResponseWriter: w, 69 | start: now, 70 | request: r, 71 | metric: m, 72 | log: logger, 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /http/response_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package http 6 | 7 | import ( 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | 12 | "github.com/go-kit/kit/metrics/prometheus" 13 | stdprometheus "github.com/prometheus/client_golang/prometheus" 14 | ) 15 | 16 | var ( 17 | routeHistogram = prometheus.NewHistogramFrom(stdprometheus.HistogramOpts{ 18 | Name: "http_response_duration_seconds", 19 | Help: "Histogram representing the http response durations", 20 | }, nil) 21 | ) 22 | 23 | func TestResponse__Wrap(t *testing.T) { 24 | req := httptest.NewRequest("GET", "https://api.moov.io/v1/ach/ping", nil) 25 | req.Header.Set("Origin", "https://moov.io/demo") 26 | 27 | w := httptest.NewRecorder() 28 | 29 | ww := Wrap(nil, routeHistogram, w, req) 30 | ww.WriteHeader(http.StatusTeapot) 31 | w.Flush() 32 | 33 | if w.Code != http.StatusTeapot { 34 | t.Errorf("got HTTP code: %d", w.Code) 35 | } 36 | if v := w.Header().Get("Access-Control-Allow-Origin"); v == "" { 37 | t.Error("expected CORS heders") 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /http/server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package http 6 | 7 | import ( 8 | "encoding/json" 9 | "errors" 10 | "fmt" 11 | "math" 12 | "net/http" 13 | "path/filepath" 14 | "runtime" 15 | "strconv" 16 | "strings" 17 | 18 | "github.com/gorilla/mux" 19 | "github.com/moov-io/base/strx" 20 | ) 21 | 22 | const ( 23 | maxHeaderLength = 36 24 | ) 25 | 26 | // Problem writes err to w while also setting the HTTP status code, content-type and marshaling 27 | // err as the response body. 28 | func Problem(w http.ResponseWriter, err error) { 29 | if err == nil { 30 | return 31 | } 32 | w.WriteHeader(http.StatusBadRequest) 33 | w.Header().Set("Content-Type", "application/json; charset=utf-8") 34 | json.NewEncoder(w).Encode(map[string]interface{}{ 35 | "error": err.Error(), 36 | }) 37 | } 38 | 39 | // InternalError writes err to w while also setting the HTTP status code, content-type and marshaling 40 | // err as the response body. 41 | // 42 | // Returned is the calling file and line number: server.go:33 43 | func InternalError(w http.ResponseWriter, err error) string { 44 | w.WriteHeader(http.StatusInternalServerError) 45 | 46 | pcs := make([]uintptr, 5) // some limit 47 | _ = runtime.Callers(1, pcs) 48 | 49 | file, line := "", 0 50 | 51 | // Sometimes InternalError will be wrapped by helper methods inside an application. 52 | // We should linear search our callers until we find one outside github.com/moov-io 53 | // because that likely represents the stdlib. 54 | // 55 | // Note: This might not work for code already outside github.com/moov-io, please report 56 | // feedback if this works or not. 57 | i, frames := 0, runtime.CallersFrames(pcs) 58 | for { 59 | f, more := frames.Next() 60 | if !more { 61 | break 62 | } 63 | 64 | // f.Function can either be an absolute path (/Users/...) or a package 65 | // (i.e. github.com/moov-io/...) so check for either. 66 | if strings.Contains(f.Function, "github.com/moov-io") || strings.HasPrefix(f.Function, "main.") { 67 | _, file, line, _ = runtime.Caller(i) // next caller 68 | } 69 | i++ 70 | } 71 | 72 | // Get the filename, file was a full path 73 | _, file = filepath.Split(file) 74 | return fmt.Sprintf("%s:%d", file, line) 75 | } 76 | 77 | // AddCORSHandler captures Corss Origin Resource Sharing (CORS) requests 78 | // by looking at all OPTIONS requests for the Origin header, parsing that 79 | // and responding back with the other Access-Control-Allow-* headers. 80 | // 81 | // Docs: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS 82 | func AddCORSHandler(r *mux.Router) { 83 | r.Methods("OPTIONS").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 84 | origin := r.Header.Get("Origin") 85 | if origin == "" { 86 | w.WriteHeader(http.StatusBadRequest) 87 | return 88 | } 89 | SetAccessControlAllowHeaders(w, r.Header.Get("Origin")) 90 | w.WriteHeader(http.StatusOK) 91 | }) 92 | } 93 | 94 | // SetAccessControlAllowHeaders writes Access-Control-Allow-* headers to a response to allow 95 | // for further CORS-allowed requests. 96 | func SetAccessControlAllowHeaders(w http.ResponseWriter, origin string) { 97 | // Access-Control-Allow-Origin can't be '*' with requests that send credentials. 98 | // Instead, we need to explicitly set the domain (from request's Origin header) 99 | // 100 | // Allow requests from anyone's localhost and only from secure pages. 101 | if strings.HasPrefix(origin, "http://localhost:") || strings.HasPrefix(origin, "https://") { 102 | w.Header().Set("Access-Control-Allow-Origin", origin) 103 | w.Header().Set("Access-Control-Allow-Methods", "GET,POST,PATCH,DELETE,OPTIONS") 104 | w.Header().Set("Access-Control-Allow-Headers", "Cookie,X-User-Id,X-Request-Id,Content-Type") 105 | w.Header().Set("Access-Control-Allow-Credentials", "true") 106 | } 107 | } 108 | 109 | // GetRequestID returns the Moov header value for request IDs 110 | func GetRequestID(r *http.Request) string { 111 | return r.Header.Get("X-Request-Id") 112 | } 113 | 114 | // GetUserID returns the Moov userId from HTTP headers 115 | func GetUserID(r *http.Request) string { 116 | return strx.Or(r.Header.Get("X-User"), r.Header.Get("X-User-Id")) 117 | } 118 | 119 | // GetSkipAndCount returns the skip and count pagination values from the query parameters 120 | // - skip is the number of records to pass over before starting a search (max math.MaxInt32) 121 | // - count is the number of records to retrieve in the search (max 10,000) 122 | // - exists indicates if skip or count was passed into the request URL 123 | func GetSkipAndCount(r *http.Request) (skip int, count int, exists bool, err error) { 124 | return readSkipCount(r, math.MaxInt32, 10000) 125 | } 126 | 127 | // LimitedSkipCount returns the skip and count pagination values from the request's query parameters 128 | // See GetSkipAndCount for descriptions of each parameter 129 | func LimitedSkipCount(r *http.Request, skipLimit, countLimit int) (skip int, count int, exists bool, err error) { 130 | return readSkipCount(r, skipLimit, countLimit) 131 | } 132 | 133 | func readSkipCount(r *http.Request, skipMax, countMax int) (skip int, count int, exists bool, err error) { 134 | skipVal := r.URL.Query().Get("skip") 135 | countVal := r.URL.Query().Get("count") 136 | exists = len(skipVal) > 0 || len(countVal) > 0 137 | 138 | // Parse skip 139 | skip, err = strconv.Atoi(skipVal) 140 | if err != nil && len(skipVal) > 0 { 141 | skip = 0 142 | return skip, count, exists, err 143 | } 144 | // Limit skip 145 | skip = int(math.Min(float64(skip), float64(skipMax))) 146 | skip = int(math.Max(0, float64(skip))) 147 | 148 | // Parse count 149 | count, err = strconv.Atoi(countVal) 150 | if err != nil && len(countVal) > 0 { 151 | count = 0 152 | return skip, count, exists, err 153 | } 154 | 155 | // Limit count 156 | count = int(math.Min(float64(count), float64(countMax))) 157 | count = int(math.Max(0, float64(count))) 158 | if count == 0 { 159 | count = 200 160 | } 161 | 162 | return skip, count, exists, nil 163 | } 164 | 165 | type Direction string 166 | 167 | const ( 168 | Ascending Direction = "ASC" 169 | Descending Direction = "DESC" 170 | ) 171 | 172 | type OrderBy struct { 173 | Name string 174 | Direction Direction 175 | } 176 | 177 | // GetOrderBy returns the field names and direction to order the response by 178 | func GetOrderBy(r *http.Request) ([]OrderBy, error) { 179 | orderByParam := r.URL.Query().Get("orderBy") 180 | if orderByParam == "" { 181 | return []OrderBy{}, nil 182 | } 183 | 184 | paramSplit := strings.Split(orderByParam, ",") 185 | var orderBys []OrderBy 186 | for _, split := range paramSplit { 187 | orderBy := strings.Split(split, ":") 188 | if len(orderBy) != 2 { 189 | return nil, fmt.Errorf("invalid orderBy: %s", orderBy) 190 | } 191 | 192 | name := strings.TrimSpace(orderBy[0]) 193 | if name == "" { 194 | return nil, errors.New("missing orderBy name") 195 | } 196 | 197 | directionStr := strings.TrimSpace(orderBy[1]) 198 | if directionStr == "" { 199 | return nil, errors.New("missing orderBy direction") 200 | } 201 | directionStr = strings.ToLower(directionStr) 202 | 203 | var direction Direction 204 | if strings.HasPrefix(directionStr, "asc") { 205 | direction = Ascending 206 | } else if strings.HasPrefix(directionStr, "desc") { 207 | direction = Descending 208 | } else { 209 | return nil, fmt.Errorf("invalid orderBy direction: %s", direction) 210 | } 211 | 212 | orderBys = append(orderBys, OrderBy{ 213 | Name: name, 214 | Direction: direction, 215 | }) 216 | } 217 | return orderBys, nil 218 | } 219 | -------------------------------------------------------------------------------- /id.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package base 6 | 7 | import ( 8 | "crypto/rand" 9 | "encoding/hex" 10 | "strings" 11 | ) 12 | 13 | // ID creates a new random string for Moov systems. 14 | // Do not assume anything about these ID's other than they are non-empty strings. 15 | func ID() string { 16 | // NOTE(adam): Moov's apps depend on the length and hex encoding of these ID's to cleanup HTTP Prometheus metrics. 17 | bs := make([]byte, 20) 18 | n, err := rand.Read(bs) 19 | if err != nil || n == 0 { 20 | return "" 21 | } 22 | return strings.ToLower(hex.EncodeToString(bs)) 23 | } 24 | -------------------------------------------------------------------------------- /id_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package base 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestID(t *testing.T) { 14 | for i := 0; i < 1000; i++ { 15 | id := ID() 16 | require.NotEmpty(t, id) 17 | require.Len(t, id, 40) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /k8s/kubernetes.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package k8s 6 | 7 | import ( 8 | "os" 9 | ) 10 | 11 | var serviceAccountFilepaths = []string{ 12 | // https://stackoverflow.com/a/49045575 13 | "/var/run/secrets/kubernetes.io", 14 | 15 | // https://github.com/hashicorp/vault/blob/master/command/agent/auth/kubernetes/kubernetes.go#L20 16 | "/var/run/secrets/kubernetes.io/serviceaccount/token", 17 | } 18 | 19 | // Inside returns true if ran from inside a Kubernetes cluster. 20 | func Inside() bool { 21 | // Allow a user override path 22 | paths := append(serviceAccountFilepaths, os.Getenv("KUBERNETES_SERVICE_ACCOUNT_FILEPATH")) 23 | 24 | for i := range paths { 25 | if _, err := os.Stat(paths[i]); err == nil { 26 | return true 27 | } 28 | } 29 | 30 | return false 31 | } 32 | -------------------------------------------------------------------------------- /k8s/kubernetes_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package k8s 6 | 7 | import ( 8 | "os" 9 | "testing" 10 | ) 11 | 12 | func TestK8SInside(t *testing.T) { 13 | if Inside() { 14 | t.Errorf("not inside k8s") 15 | } 16 | 17 | // Create a file and pretend it's the Kubernetes service account filepath 18 | fd, err := os.Create("k8s-service-account") 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | defer os.Remove(fd.Name()) 23 | if err := fd.Sync(); err != nil { 24 | t.Fatal(err) 25 | } 26 | 27 | // Pretend 28 | t.Setenv("KUBERNETES_SERVICE_ACCOUNT_FILEPATH", fd.Name()) 29 | 30 | if !Inside() { 31 | t.Error("we should be pretending to be in a Kubernetes cluster") 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /log/README.md: -------------------------------------------------------------------------------- 1 | # Log Package 2 | 3 | The log package provides structured logging capabilities for Moov applications. 4 | 5 | ## Usage 6 | 7 | ### Basic Logging 8 | 9 | ```go 10 | import "github.com/moov-io/base/log" 11 | 12 | // Create a new logger 13 | logger := log.NewDefaultLogger() 14 | 15 | // Log a message with different levels 16 | logger.Info().Log("Application started") 17 | logger.Debug().Log("Debug information") 18 | logger.Warn().Log("Warning message") 19 | logger.Error().Log("Error occurred") 20 | 21 | // Log with key-value pairs 22 | logger.Info().Set("request_id", log.String("12345")).Log("Processing request") 23 | 24 | // Log formatted messages 25 | logger.Infof("Processing request %s", "12345") 26 | 27 | // Log errors 28 | err := someFunction() 29 | if err != nil { 30 | logger.LogError(err) 31 | } 32 | ``` 33 | 34 | ### Using Fields 35 | 36 | ```go 37 | import "github.com/moov-io/base/log" 38 | 39 | // Create a map of fields 40 | fields := log.Fields{ 41 | "request_id": log.String("12345"), 42 | "user_id": log.Int(42), 43 | "timestamp": log.Time(time.Now()), 44 | } 45 | 46 | // Log with fields 47 | logger.With(fields).Info().Log("Request processed") 48 | ``` 49 | 50 | ### Using StructContext 51 | 52 | The `StructContext` function allows you to log struct fields automatically by using tags. 53 | 54 | ```go 55 | import "github.com/moov-io/base/log" 56 | 57 | // Define a struct with log tags 58 | type User struct { 59 | ID int `log:"id"` 60 | Username string `log:"username"` 61 | Email string `log:"email,omitempty"` // won't be logged if empty 62 | Address Address `log:"address"` // nested struct must have log tag 63 | Hidden string // no log tag, won't be logged 64 | } 65 | 66 | type Address struct { 67 | Street string `log:"street"` 68 | City string `log:"city"` 69 | Country string `log:"country"` 70 | } 71 | 72 | // Create a user 73 | user := User{ 74 | ID: 1, 75 | Username: "johndoe", 76 | Email: "john@example.com", 77 | Address: Address{ 78 | Street: "123 Main St", 79 | City: "New York", 80 | Country: "USA", 81 | }, 82 | Hidden: "secret", 83 | } 84 | 85 | // Log with struct context 86 | logger.With(log.StructContext(user)).Info().Log("User logged in") 87 | 88 | // Log with struct context and prefix 89 | logger.With(log.StructContext(user, log.WithPrefix("user"))).Info().Log("User details") 90 | 91 | // Using custom tag other than "log" 92 | type Product struct { 93 | ID int `otel:"product_id"` 94 | Name string `otel:"product_name"` 95 | Price float64 `otel:"price,omitempty"` 96 | } 97 | 98 | product := Product{ 99 | ID: 42, 100 | Name: "Widget", 101 | Price: 19.99, 102 | } 103 | 104 | // Use otel tags instead of log tags 105 | logger.With(log.StructContext(product, log.WithTag("otel"))).Info().Log("Product details") 106 | ``` 107 | 108 | The above will produce log entries with the following fields: 109 | - `id=1` 110 | - `username=johndoe` 111 | - `email=john@example.com` 112 | - `address.street=123 Main St` 113 | - `address.city=New York` 114 | - `address.country=USA` 115 | 116 | With the prefix option, the fields will be: 117 | - `user.id=1` 118 | - `user.username=johndoe` 119 | - `user.email=john@example.com` 120 | - `user.address.street=123 Main St` 121 | - `user.address.city=New York` 122 | - `user.address.country=USA` 123 | 124 | With the custom tag option, the fields will be extracted from the tag you specify (such as `otel`): 125 | - `product_id=42` 126 | - `product_name=Widget` 127 | - `price=19.99` 128 | 129 | Note that nested structs or pointers to structs must have the specified tag to be included in the context. 130 | 131 | ## Features 132 | 133 | - Structured logging with key-value pairs 134 | - Multiple log levels (Debug, Info, Warn, Error, Fatal) 135 | - JSON and LogFmt output formats 136 | - Context-based logging 137 | - Automatic struct field logging with StructContext 138 | - Support for various value types (string, int, float, bool, time, etc.) 139 | 140 | ## Configuration 141 | 142 | The default logger format is determined by the `MOOV_LOG_FORMAT` environment variable: 143 | - `json`: JSON format 144 | - `logfmt`: LogFmt format (default) 145 | - `nop` or `noop`: No-op logger that discards all logs 146 | -------------------------------------------------------------------------------- /log/logger.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | type Logger interface { 4 | Set(key string, value Valuer) Logger 5 | With(ctxs ...Context) Logger 6 | Details() map[string]interface{} 7 | 8 | Debug() Logger 9 | Info() Logger 10 | Warn() Logger 11 | Error() Logger 12 | Fatal() Logger 13 | 14 | Log(message string) 15 | Logf(format string, args ...interface{}) 16 | Send() 17 | 18 | LogError(error error) LoggedError 19 | LogErrorf(format string, args ...interface{}) LoggedError 20 | } 21 | 22 | type Context interface { 23 | Context() map[string]Valuer 24 | } 25 | -------------------------------------------------------------------------------- /log/logger_impl.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "runtime" 7 | "sort" 8 | "strings" 9 | "sync" 10 | "testing" 11 | "time" 12 | 13 | "github.com/go-kit/log" 14 | ) 15 | 16 | func NewDefaultLogger() Logger { 17 | switch strings.ToLower(os.Getenv("MOOV_LOG_FORMAT")) { 18 | case "json": 19 | return NewJSONLogger() 20 | case "nop", "noop": 21 | return NewNopLogger() 22 | case "logfmt": 23 | return NewLogFmtLogger() 24 | default: 25 | return NewLogFmtLogger() 26 | } 27 | } 28 | 29 | func NewNopLogger() Logger { 30 | return NewLogger(log.NewNopLogger()) 31 | } 32 | 33 | func NewLogFmtLogger() Logger { 34 | return NewLogger(log.NewLogfmtLogger(log.NewSyncWriter(os.Stderr))) 35 | } 36 | 37 | func NewJSONLogger() Logger { 38 | return NewLogger(log.NewJSONLogger(log.NewSyncWriter(os.Stderr))) 39 | } 40 | 41 | func NewTestLogger() Logger { 42 | if testing.Verbose() { 43 | return NewDefaultLogger() 44 | } 45 | return NewNopLogger() 46 | } 47 | 48 | func NewBufferLogger() (*BufferedLogger, Logger) { 49 | buffer := &BufferedLogger{ 50 | buf: &strings.Builder{}, 51 | } 52 | writer := log.NewLogfmtLogger(log.NewSyncWriter(buffer)) 53 | log := NewLogger(writer) 54 | return buffer, log 55 | } 56 | 57 | type BufferedLogger struct { 58 | mu sync.RWMutex 59 | buf *strings.Builder 60 | } 61 | 62 | func (bl *BufferedLogger) Write(p []byte) (n int, err error) { 63 | bl.mu.Lock() 64 | defer bl.mu.Unlock() 65 | 66 | return bl.buf.Write(p) 67 | } 68 | 69 | func (bl *BufferedLogger) Reset() { 70 | bl.mu.Lock() 71 | defer bl.mu.Unlock() 72 | 73 | bl.buf.Reset() 74 | } 75 | 76 | func (bl *BufferedLogger) String() string { 77 | bl.mu.RLock() 78 | defer bl.mu.RUnlock() 79 | 80 | return bl.buf.String() 81 | } 82 | 83 | func NewLogger(writer log.Logger) Logger { 84 | l := &logger{ 85 | writer: writer, 86 | ctx: make(map[string]Valuer), 87 | } 88 | 89 | // Default logs to be info until changed 90 | return l.Info() 91 | } 92 | 93 | var _ Logger = (*logger)(nil) 94 | 95 | type logger struct { 96 | writer log.Logger 97 | ctx map[string]Valuer 98 | } 99 | 100 | func (l *logger) Set(key string, value Valuer) Logger { 101 | return l.With(Fields{ 102 | key: value, 103 | }) 104 | } 105 | 106 | // With returns a new Logger with the contexts added to its own. 107 | func (l *logger) With(ctxs ...Context) (out Logger) { 108 | defer func() { 109 | if r := recover(); r != nil { 110 | var file string 111 | var line int 112 | var ok bool 113 | 114 | // Search the call stack for the first non-Go file 115 | for i := 1; i < 10; i++ { 116 | _, file, line, ok = runtime.Caller(i) 117 | if !ok || !strings.Contains(file, "/src/runtime/") { 118 | break 119 | } 120 | } 121 | 122 | if ok { 123 | l.writer.Log( 124 | "level", "error", 125 | "file", file, 126 | "line", fmt.Sprintf("%d", line), 127 | "msg", fmt.Sprintf("recovered from %T - %v", r, r), 128 | ) 129 | } 130 | 131 | out = l // make the caller whole 132 | return 133 | } 134 | }() 135 | 136 | // Estimation assuming that for each ctxs has at least 1 value. 137 | combined := make(map[string]Valuer, len(l.ctx)+len(ctxs)) 138 | 139 | for k, v := range l.ctx { 140 | combined[k] = v 141 | } 142 | 143 | for _, c := range ctxs { 144 | if c == nil { 145 | continue 146 | } 147 | 148 | itemCtx := c.Context() 149 | for k, v := range itemCtx { 150 | combined[k] = v 151 | } 152 | } 153 | 154 | return &logger{ 155 | writer: l.writer, 156 | ctx: combined, 157 | } 158 | } 159 | 160 | func (l *logger) Details() map[string]interface{} { 161 | m := make(map[string]interface{}, len(l.ctx)) 162 | for k, v := range l.ctx { 163 | m[k] = v.getValue() 164 | } 165 | return m 166 | } 167 | 168 | func (l *logger) Debug() Logger { 169 | return l.With(Debug) 170 | } 171 | 172 | func (l *logger) Info() Logger { 173 | return l.With(Info) 174 | } 175 | 176 | func (l *logger) Warn() Logger { 177 | return l.With(Warn) 178 | } 179 | 180 | func (l *logger) Error() Logger { 181 | return l.With(Error) 182 | } 183 | 184 | func (l *logger) Fatal() Logger { 185 | return l.With(Fatal) 186 | } 187 | 188 | func (l *logger) Log(msg string) { 189 | // Frontload the timestamp and msg 190 | keyvals := []interface{}{ 191 | "ts", time.Now().UTC().Format(time.RFC3339), 192 | } 193 | if msg != "" { 194 | keyvals = append(keyvals, "msg", msg) 195 | } 196 | 197 | // Sort the rest of the list so the log lines look similar 198 | details := l.Details() 199 | keys := make([]string, 0, len(l.ctx)) 200 | for k := range details { 201 | keys = append(keys, k) 202 | } 203 | sort.Strings(keys) 204 | 205 | // Lets add them into the arguments 206 | for _, k := range keys { 207 | keyvals = append(keyvals, k, details[k]) 208 | } 209 | 210 | _ = l.writer.Log(keyvals...) 211 | } 212 | 213 | func (l *logger) Logf(format string, args ...interface{}) { 214 | msg := fmt.Sprintf(format, args...) 215 | l.Log(msg) 216 | } 217 | 218 | // Send is equivalent to calling Msg("") 219 | func (l *logger) Send() { 220 | l.Log("") 221 | } 222 | 223 | func (l *logger) LogError(err error) LoggedError { 224 | l.Set("errored", Bool(true)).Log(err.Error()) 225 | return LoggedError{err} 226 | } 227 | 228 | // LogError logs the error or creates a new one using the msg if `err` is nil and returns it. 229 | func (l *logger) LogErrorf(format string, args ...interface{}) LoggedError { 230 | err := fmt.Errorf(format, args...) 231 | return l.LogError(err) 232 | } 233 | 234 | type LoggedError struct { 235 | err error 236 | } 237 | 238 | func (l LoggedError) Err() error { 239 | return l.err 240 | } 241 | 242 | func (l LoggedError) Nil() error { 243 | return nil 244 | } 245 | -------------------------------------------------------------------------------- /log/model_fields.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | type Fields map[string]Valuer 4 | 5 | func (f Fields) Context() map[string]Valuer { 6 | return f 7 | } 8 | -------------------------------------------------------------------------------- /log/model_levels.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | // Level just wraps a string to be able to add Context specific to log levels 4 | type Level string 5 | 6 | // Info is sets level=info in the log output 7 | const Debug = Level("debug") 8 | 9 | // Info is sets level=info in the log output 10 | const Info = Level("info") 11 | 12 | // Info is sets level=warn in the log output 13 | const Warn = Level("warn") 14 | 15 | // Error sets level=error in the log output 16 | const Error = Level("error") 17 | 18 | // Fatal sets level=fatal in the log output 19 | const Fatal = Level("fatal") 20 | 21 | // Context returns the map that states that key value of `level={{l}}` 22 | func (l Level) Context() map[string]Valuer { 23 | return map[string]Valuer{ 24 | "level": String(string(l)), 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /log/model_stacktrace.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | "strings" 7 | ) 8 | 9 | type st string 10 | 11 | const StackTrace = st("stacktrace") 12 | 13 | // Context returns the map that states that key value of `level={{l}}` 14 | func (s st) Context() map[string]Valuer { 15 | kv := map[string]Valuer{} 16 | 17 | i := 0 18 | c := 0 19 | _, file, line, ok := runtime.Caller(i) 20 | for ; ok; i++ { 21 | if c > 0 || (!strings.HasSuffix(file, "model_stacktrace.go") && !strings.HasSuffix(file, "logger_impl.go")) { 22 | key := fmt.Sprintf("caller_%d", c) 23 | value := fmt.Sprintf("%s:%d", file, line) 24 | kv[key] = String(value) 25 | c++ 26 | } 27 | _, file, line, ok = runtime.Caller(i + 1) 28 | } 29 | 30 | return kv 31 | } 32 | -------------------------------------------------------------------------------- /log/model_valuer.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "encoding/base64" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | // Valuer is an interface to deal with typing problems of just having an interface{} as the acceptable parameters 12 | // Go-kit logging has a failure case if you attempt to throw any values into it. 13 | // This is a way to guard our developers from having to worry about error cases of the lower logging framework. 14 | type Valuer interface { 15 | getValue() interface{} 16 | } 17 | 18 | type any struct { 19 | value interface{} 20 | } 21 | 22 | func (a *any) getValue() interface{} { 23 | return a.value 24 | } 25 | 26 | func String(s string) Valuer { 27 | return &any{s} 28 | } 29 | 30 | func StringOrNil(s *string) Valuer { 31 | if s == nil { 32 | return &any{nil} 33 | } 34 | return String(*s) 35 | } 36 | 37 | func Int(i int) Valuer { 38 | return &any{i} 39 | } 40 | 41 | func Int64(i int64) Valuer { 42 | return &any{i} 43 | } 44 | 45 | func Int64OrNil(i *int64) Valuer { 46 | return &any{i} 47 | } 48 | 49 | func Uint32(i uint32) Valuer { 50 | return &any{i} 51 | } 52 | 53 | func Uint64(i uint64) Valuer { 54 | return &any{i} 55 | } 56 | 57 | func Float32(f float32) Valuer { 58 | return &any{f} 59 | } 60 | 61 | func Float64(f float64) Valuer { 62 | return &any{f} 63 | } 64 | 65 | func Bool(b bool) Valuer { 66 | return &any{b} 67 | } 68 | 69 | func TimeDuration(d time.Duration) Valuer { 70 | return &any{d.String()} 71 | } 72 | 73 | func Time(t time.Time) Valuer { 74 | return TimeFormatted(t, time.RFC3339Nano) 75 | } 76 | 77 | func TimeOrNil(t *time.Time) Valuer { 78 | if t == nil { 79 | return &any{nil} 80 | } 81 | return Time(*t) 82 | } 83 | 84 | func TimeFormatted(t time.Time, format string) Valuer { 85 | return String(t.Format(format)) 86 | } 87 | 88 | func ByteString(b []byte) Valuer { 89 | return String(string(b)) 90 | } 91 | 92 | func ByteBase64(b []byte) Valuer { 93 | return String(base64.RawURLEncoding.EncodeToString(b)) 94 | } 95 | 96 | func Stringer(s fmt.Stringer) Valuer { 97 | if v := reflect.ValueOf(s); v.Kind() == reflect.Pointer && v.IsNil() { 98 | return &any{nil} 99 | } 100 | 101 | return &any{s.String()} 102 | } 103 | 104 | func Strings(vals []string) Valuer { 105 | out := fmt.Sprintf("[%s]", strings.Join(vals, ", ")) 106 | return String(out) 107 | } 108 | -------------------------------------------------------------------------------- /log/model_valuer_test.go: -------------------------------------------------------------------------------- 1 | package log_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/moov-io/base/log" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | type Item struct { 11 | Value string 12 | } 13 | 14 | type Foo struct { 15 | Name *Item 16 | } 17 | 18 | func (f Foo) Context() map[string]log.Valuer { 19 | return log.Fields{ 20 | "name": log.String( 21 | f.Name.Value, 22 | ), 23 | } 24 | } 25 | 26 | func TestValuer__String(t *testing.T) { 27 | logger := log.NewTestLogger() 28 | 29 | foo := Foo{} 30 | logger.With(foo).Log("shouldn't panic") 31 | } 32 | 33 | type Mode int 34 | 35 | func (m Mode) String() string { 36 | switch m { 37 | case 1: 38 | return "SANDBOX" 39 | case 2: 40 | return "PRODUCTION" 41 | } 42 | return "UNSPECIFIED" 43 | } 44 | 45 | func TestValuer_Stringer(t *testing.T) { 46 | out, logger := log.NewBufferLogger() 47 | 48 | m := Mode(2) 49 | 50 | logger.With(log.Fields{ 51 | "mode": log.Stringer(m), 52 | }).Log("log with .String() key/value pair") 53 | 54 | require.Contains(t, out.String(), `mode=PRODUCTION`) 55 | 56 | out, logger = log.NewBufferLogger() 57 | 58 | var n *Mode 59 | 60 | logger.With(log.Fields{ 61 | "mode": log.Stringer(n), 62 | }).Log("log with nil .String() key/value pair") 63 | 64 | require.Contains(t, out.String(), `mode=null`) 65 | } 66 | -------------------------------------------------------------------------------- /log/struct_context.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "slices" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | // StructContextOption defines options for StructContext 12 | type StructContextOption func(*structContext) 13 | 14 | // WithPrefix adds a prefix to all struct field names 15 | func WithPrefix(prefix string) StructContextOption { 16 | return func(sc *structContext) { 17 | sc.prefix = prefix 18 | } 19 | } 20 | 21 | // WithTag adds a custom tag to look for in struct fields 22 | func WithTag(tag string) StructContextOption { 23 | return func(sc *structContext) { 24 | sc.tag = tag 25 | } 26 | } 27 | 28 | // structContext implements the Context interface for struct fields 29 | type structContext struct { 30 | fields map[string]Valuer 31 | prefix string 32 | tag string 33 | } 34 | 35 | // Context returns a map of field names to Valuer implementations 36 | func (sc *structContext) Context() map[string]Valuer { 37 | return sc.fields 38 | } 39 | 40 | // StructContext creates a Context from a struct, extracting fields tagged with `log` 41 | // It supports nested structs and respects omitempty directive 42 | func StructContext(v interface{}, opts ...StructContextOption) Context { 43 | sc := &structContext{ 44 | fields: make(map[string]Valuer), 45 | prefix: "", 46 | tag: "log", 47 | } 48 | 49 | // Apply options 50 | for _, opt := range opts { 51 | opt(sc) 52 | } 53 | 54 | if v == nil { 55 | return sc 56 | } 57 | 58 | value := reflect.ValueOf(v) 59 | extractFields(value, sc, "") 60 | 61 | return sc 62 | } 63 | 64 | // extractFields recursively extracts fields from a struct value 65 | func extractFields(value reflect.Value, sc *structContext, path string) { 66 | // If it's a pointer, dereference it 67 | if value.Kind() == reflect.Ptr { 68 | if value.IsNil() { 69 | return 70 | } 71 | value = value.Elem() 72 | } 73 | 74 | // Only process structs 75 | if value.Kind() != reflect.Struct { 76 | return 77 | } 78 | 79 | typ := value.Type() 80 | for i := range typ.NumField() { 81 | field := typ.Field(i) 82 | fieldValue := value.Field(i) 83 | 84 | // Skip unexported fields 85 | if !field.IsExported() { 86 | continue 87 | } 88 | 89 | // Get the log tag 90 | tag := field.Tag.Get(sc.tag) 91 | if tag == "" { 92 | // Skip fields without log tag 93 | continue 94 | } 95 | 96 | // Parse the tag 97 | tagParts := strings.Split(tag, ",") 98 | fieldName := tagParts[0] 99 | if fieldName == "" { 100 | fieldName = field.Name 101 | } 102 | 103 | // Handle omitempty 104 | omitEmpty := slices.Contains(tagParts, "omitempty") 105 | 106 | // Build the full field name with path and prefix 107 | fullName := fieldName 108 | if path != "" { 109 | fullName = path + "." + fieldName 110 | } 111 | 112 | // we add prefis only once, for the field on the first level 113 | if path == "" && sc.prefix != "" { 114 | fullName = sc.prefix + "." + fullName 115 | } 116 | 117 | // Check if field should be omitted due to empty value 118 | if omitEmpty && fieldValue.IsZero() { 119 | continue 120 | } 121 | 122 | // Store the field value 123 | valuer := valueToValuer(fieldValue) 124 | if valuer != nil { 125 | sc.fields[fullName] = valuer 126 | } 127 | 128 | // If it's a struct, recursively extract its fields only if it has a log tag 129 | if fieldValue.Kind() == reflect.Struct || 130 | (fieldValue.Kind() == reflect.Ptr && !fieldValue.IsNil() && fieldValue.Elem().Kind() == reflect.Struct) { 131 | extractFields(fieldValue, sc, fullName) 132 | } 133 | } 134 | } 135 | 136 | // valueToValuer converts a reflect.Value to a Valuer 137 | func valueToValuer(v reflect.Value) Valuer { 138 | if !v.IsValid() { 139 | return nil 140 | } 141 | 142 | //nolint:exhaustive 143 | switch v.Kind() { 144 | case reflect.Bool: 145 | return Bool(v.Bool()) 146 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 147 | return Int64(v.Int()) 148 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 149 | return Uint64(v.Uint()) 150 | case reflect.Float32: 151 | return Float32(float32(v.Float())) 152 | case reflect.Float64: 153 | return Float64(v.Float()) 154 | case reflect.String: 155 | return String(v.String()) 156 | case reflect.Ptr: 157 | if v.IsNil() { 158 | return &any{nil} 159 | } 160 | return valueToValuer(v.Elem()) 161 | case reflect.Struct: 162 | // Check if it's a time.Time 163 | if v.Type().String() == "time.Time" { 164 | if v.CanInterface() { 165 | t, ok := v.Interface().(time.Time) 166 | if ok { 167 | return Time(t) 168 | } 169 | } 170 | } 171 | } 172 | 173 | // Try to use Stringer for complex types 174 | if v.CanInterface() { 175 | if stringer, ok := v.Interface().(fmt.Stringer); ok { 176 | return Stringer(stringer) 177 | } 178 | } 179 | 180 | // Return as string representation for other types 181 | return String(fmt.Sprintf("%v", v.Interface())) 182 | } 183 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | .PHONY: setup 2 | setup: gen-certs 3 | docker compose up -d --force-recreate --remove-orphans 4 | 5 | .PHONY: check 6 | check: 7 | ifeq ($(OS),Windows_NT) 8 | go test ./... -short 9 | else 10 | @wget -O lint-project.sh https://raw.githubusercontent.com/moov-io/infra/master/go/lint-project.sh 11 | @chmod +x ./lint-project.sh 12 | GOCYCLO_LIMIT=26 COVER_THRESHOLD=50.0 GOLANGCI_LINTERS=gosec GITLEAKS_EXCLUDE=testcerts ./lint-project.sh 13 | endif 14 | 15 | .PHONY: clean 16 | clean: 17 | @rm -rf ./bin/ ./tmp/ coverage.txt misspell* staticcheck lint-project.sh 18 | 19 | .PHONY: cover-test cover-web 20 | cover-test: 21 | go test -coverprofile=cover.out ./... 22 | cover-web: 23 | go tool cover -html=cover.out 24 | 25 | .PHONY: teardown 26 | teardown: 27 | -docker compose down --remove-orphans 28 | -docker compose rm -f -v 29 | 30 | .PHONY: gen-certs 31 | gen-certs: 32 | ./database/testdata/gencerts.sh 33 | -------------------------------------------------------------------------------- /mask/password.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package mask 6 | 7 | import ( 8 | "fmt" 9 | "unicode/utf8" 10 | ) 11 | 12 | func Password(s string) string { 13 | if utf8.RuneCountInString(s) < 5 { 14 | return "*****" // too short, we can't mask anything 15 | } else { 16 | // turn 'password' into 'p*****d' 17 | first, last := s[0:1], s[len(s)-1:] 18 | return fmt.Sprintf("%s*****%s", first, last) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /mask/password_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package mask 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestMaskPassword(t *testing.T) { 14 | cases := []struct { 15 | input, expected string 16 | }{ 17 | {"", "*****"}, 18 | {"ab", "*****"}, 19 | {"abcde", "a*****e"}, 20 | {"123456", "1*****6"}, 21 | {"password", "p*****d"}, 22 | } 23 | for i := range cases { 24 | output := Password(cases[i].input) 25 | require.Equal(t, cases[i].expected, output) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /migrations/001_create_tests.up.sql: -------------------------------------------------------------------------------- 1 | create table tests (id varchar(10)) 2 | -------------------------------------------------------------------------------- /migrations/002_create_tests.up.mysql.sql: -------------------------------------------------------------------------------- 1 | /* nada */ -------------------------------------------------------------------------------- /migrations/002_create_tests.up.postgres.sql: -------------------------------------------------------------------------------- 1 | create table mig_test(name varchar(10)); -------------------------------------------------------------------------------- /migrations/002_create_tests.up.spanner.sql: -------------------------------------------------------------------------------- 1 | -- comment1 2 | /* comment2 */ 3 | CREATE TABLE MigrationTest ( 4 | Id STRING(36) 5 | ) PRIMARY KEY(Id) 6 | -------------------------------------------------------------------------------- /migrations/002_create_tests.up.sqlite.sql: -------------------------------------------------------------------------------- 1 | /* nada */ -------------------------------------------------------------------------------- /mysql/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mysql:9.3-oracle 2 | 3 | RUN mkdir /var/lib/mysql-volume 4 | CMD ["--datadir", "/var/lib/mysql-volume"] 5 | -------------------------------------------------------------------------------- /mysql/README.md: -------------------------------------------------------------------------------- 1 | # MySQL docker image for testing 2 | 3 | In order to speedup testing with MySQL we created custom docker image for MySQL that stores 4 | data inside container (not outside in volume). 5 | 6 | # Update Image 7 | 8 | In order to update image please follow this steps: 9 | 10 | Updated `Dockerfile` if needed and then build the image: 11 | 12 | ``` 13 | make build 14 | ``` 15 | 16 | Run the container from this image to initialize database and create all necessary files inside container 17 | 18 | ``` 19 | make run 20 | ``` 21 | 22 | After container is ready (you will see that it's "...ready for connections..."), please, stop the container with 23 | 24 | ``` 25 | make stop 26 | ``` 27 | 28 | at this point we have a container with initialized MySQL database 29 | 30 | Now it's time to make image based on this container: 31 | 32 | ``` 33 | make image 34 | ``` 35 | 36 | Final step is to push image to docker hub: 37 | 38 | ``` 39 | make push 40 | ``` 41 | 42 | That's it! 43 | -------------------------------------------------------------------------------- /mysql/makefile: -------------------------------------------------------------------------------- 1 | DOCKER_CONTAINER_LIST := $(shell docker ps -aq -f status=exited -f name=mysql-volumeless) 2 | 3 | build: clean 4 | docker build -t mysql-volumeless:8.0 . 5 | docker run -d \ 6 | --name mysql-volumeless \ 7 | -e MYSQL_USER=moov \ 8 | -e MYSQL_PASSWORD=secret \ 9 | -e MYSQL_ROOT_PASSWORD=secret \ 10 | -e MYSQL_DATABASE=test \ 11 | -p 3306:3306 \ 12 | mysql-volumeless:8.0 13 | ./test-mysql-is-ready.sh 14 | docker stop mysql-volumeless 15 | docker commit mysql-volumeless mysql-volumeless:8.0 16 | docker tag mysql-volumeless:8.0 moov/mysql-volumeless:8.0 17 | push: 18 | docker push moov/mysql-volumeless:8.0 19 | clean: 20 | @if [ -n "$(DOCKER_CONTAINER_LIST)" ]; then docker rm "$(DOCKER_CONTAINER_LIST)"; fi; 21 | -------------------------------------------------------------------------------- /mysql/test-mysql-is-ready.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | ATTEMPTS=0 5 | 6 | until docker exec mysql-volumeless mysql -h localhost -u moov -psecret --protocol=TCP -e "SELECT VERSION();SELECT NOW()" test || [ $ATTEMPTS -ge 10 ] 7 | do 8 | ((ATTEMPTS+=1)) 9 | echo "Waiting for database connection... ($ATTEMPTS)" 10 | # wait for 5 seconds before check again 11 | sleep 3 12 | done 13 | -------------------------------------------------------------------------------- /package.go: -------------------------------------------------------------------------------- 1 | package base 2 | 3 | import "embed" 4 | 5 | //go:embed configs/config.default.yml 6 | var ConfigDefaults embed.FS 7 | 8 | //go:embed migrations/*.up.sql migrations/*.up.mysql.sql 9 | var MySQLMigrations embed.FS 10 | 11 | //go:embed migrations/*.up.spanner.sql 12 | var SpannerMigrations embed.FS 13 | 14 | //go:embed migrations/*.up.postgres.sql 15 | var PostgresMigrations embed.FS 16 | -------------------------------------------------------------------------------- /randx/randx.go: -------------------------------------------------------------------------------- 1 | package randx 2 | 3 | import ( 4 | "crypto/rand" 5 | "math/big" 6 | ) 7 | 8 | // Between will return a randomly generated integer within the lower and upper bounds provided. 9 | func Between(lower, upper int) (int64, error) { 10 | n, err := rand.Int(rand.Reader, big.NewInt(int64(lower))) 11 | if err != nil { 12 | return 0, err 13 | } 14 | return n.Int64() + int64(lower), nil 15 | } 16 | 17 | // Must is a helper that wraps a call to Between and panics if the error is non-nil. 18 | func Must(n int64, err error) int64 { 19 | if err != nil { 20 | panic(err) //nolint:forbidigo 21 | } 22 | return n 23 | } 24 | -------------------------------------------------------------------------------- /randx/randx_test.go: -------------------------------------------------------------------------------- 1 | package randx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestBetween(t *testing.T) { 10 | lower, upper := 100, 250 11 | 12 | n, err := Between(lower, upper) 13 | require.NoError(t, err) 14 | 15 | if n < int64(lower) || n > int64(upper) { 16 | t.Fatalf("%d falls outside of %d and %d", n, lower, upper) 17 | } 18 | } 19 | 20 | func TestMust(t *testing.T) { 21 | lower, upper := 1000, 25000 22 | 23 | n := Must(Between(lower, upper)) 24 | 25 | if n < int64(lower) || n > int64(upper) { 26 | t.Fatalf("%d falls outside of %d and %d", n, lower, upper) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": [ 3 | "config:base" 4 | ], 5 | "groupName": "all", 6 | "packageRules": [ 7 | { 8 | "matchUpdateTypes": ["minor", "patch", "pin", "digest"], 9 | "automerge": true 10 | } 11 | ] 12 | } 13 | -------------------------------------------------------------------------------- /sql/db.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "context" 5 | gosql "database/sql" 6 | "time" 7 | 8 | "github.com/moov-io/base/log" 9 | ) 10 | 11 | type DB struct { 12 | *gosql.DB 13 | 14 | logger log.Logger 15 | slowQueryThresholdMs int64 16 | 17 | id string 18 | stopTimer context.CancelFunc 19 | } 20 | 21 | func ObserveDB(innerDB *gosql.DB, logger log.Logger, id string) (*DB, error) { 22 | cancel := MonitorSQLDriver(innerDB, id) 23 | 24 | return &DB{ 25 | DB: innerDB, 26 | id: id, 27 | stopTimer: cancel, 28 | logger: logger, 29 | 30 | slowQueryThresholdMs: (time.Second * 2).Milliseconds(), 31 | }, nil 32 | } 33 | 34 | func (w *DB) lazyLogger() log.Logger { 35 | return w.logger 36 | } 37 | 38 | func (w *DB) start(op string, qry string, args int) func() int64 { 39 | return MeasureQuery(w.lazyLogger, w.slowQueryThresholdMs, w.id, op, qry, args) 40 | } 41 | 42 | func (w *DB) error(err error) error { 43 | return MeasureError(w.id, err) 44 | } 45 | 46 | func (w *DB) Close() error { 47 | return w.DB.Close() 48 | } 49 | 50 | func (w *DB) SetSlowQueryThreshold(d time.Duration) { 51 | w.slowQueryThresholdMs = d.Milliseconds() 52 | } 53 | 54 | func (w *DB) Prepare(query string) (*Stmt, error) { 55 | done := w.start("prepare", query, 0) 56 | defer done() 57 | 58 | return newStmt(context.Background(), w.logger, w.DB, query, w.id, w.slowQueryThresholdMs) 59 | } 60 | 61 | func (w *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { 62 | done := w.start("prepare", query, 0) 63 | defer done() 64 | 65 | return newStmt(ctx, w.logger, w.DB, query, w.id, w.slowQueryThresholdMs) 66 | } 67 | 68 | func (w *DB) Exec(query string, args ...interface{}) (gosql.Result, error) { 69 | done := w.start("exec", query, len(args)) 70 | defer done() 71 | 72 | r, err := w.DB.Exec(query, args...) 73 | return r, w.error(err) 74 | } 75 | 76 | func (w *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (gosql.Result, error) { 77 | done := w.start("exec", query, len(args)) 78 | ctx, end := span(ctx, w.id, "exec", query, len(args)) 79 | defer func() { 80 | end() 81 | done() 82 | }() 83 | 84 | r, err := w.DB.ExecContext(ctx, query, args...) 85 | return r, w.error(err) 86 | } 87 | 88 | func (w *DB) Query(query string, args ...interface{}) (*gosql.Rows, error) { 89 | done := w.start("query", query, len(args)) 90 | defer done() 91 | 92 | //nolint:sqlclosecheck 93 | r, err := w.DB.Query(query, args...) 94 | return r, w.error(err) 95 | } 96 | 97 | func (w *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*gosql.Rows, error) { 98 | done := w.start("query", query, len(args)) 99 | ctx, end := span(ctx, w.id, "query", query, len(args)) 100 | defer func() { 101 | end() 102 | done() 103 | }() 104 | 105 | r, err := w.DB.QueryContext(ctx, query, args...) //nolint:sqlclosecheck 106 | return r, w.error(err) 107 | } 108 | 109 | func (w *DB) QueryRow(query string, args ...interface{}) *gosql.Row { 110 | done := w.start("query-row", query, len(args)) 111 | defer done() 112 | 113 | r := w.DB.QueryRow(query, args...) 114 | w.error(r.Err()) 115 | 116 | return r 117 | } 118 | 119 | func (w *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *gosql.Row { 120 | done := w.start("query-row", query, len(args)) 121 | ctx, end := span(ctx, w.id, "query-row", query, len(args)) 122 | defer func() { 123 | end() 124 | done() 125 | }() 126 | 127 | r := w.DB.QueryRowContext(ctx, query, args...) 128 | w.error(r.Err()) 129 | 130 | return r 131 | } 132 | 133 | func (w *DB) Begin() (*Tx, error) { 134 | t, err := w.DB.Begin() 135 | if err != nil { 136 | return nil, w.error(err) 137 | } 138 | 139 | tx := &Tx{ 140 | Tx: t, 141 | logger: w.logger, 142 | id: w.id, 143 | ctx: context.Background(), 144 | slowQueryThresholdMs: w.slowQueryThresholdMs, 145 | } 146 | 147 | tx.done = MeasureQuery(tx.lazyLogger, w.slowQueryThresholdMs, tx.id, "tx", "Transaction", 0) 148 | 149 | return tx, nil 150 | } 151 | 152 | type TxOptions = gosql.TxOptions 153 | 154 | func (w *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { 155 | ctx, end := span(ctx, w.id, "tx", "BEGIN TRANSACTION", 0) 156 | 157 | t, err := w.DB.BeginTx(ctx, opts) 158 | if err != nil { 159 | return nil, w.error(err) 160 | } 161 | 162 | tx := &Tx{ 163 | Tx: t, 164 | logger: w.logger, 165 | id: w.id, 166 | ctx: ctx, 167 | slowQueryThresholdMs: w.slowQueryThresholdMs, 168 | } 169 | 170 | done := MeasureQuery(tx.lazyLogger, w.slowQueryThresholdMs, tx.id, "tx", "Transaction", 0) 171 | 172 | tx.done = func() int64 { 173 | end() 174 | return done() 175 | } 176 | 177 | return tx, nil 178 | } 179 | -------------------------------------------------------------------------------- /sql/sql.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "strings" 7 | "sync" 8 | "time" 9 | 10 | "github.com/moov-io/base/log" 11 | "github.com/moov-io/base/telemetry" 12 | 13 | "github.com/prometheus/client_golang/prometheus" 14 | "github.com/prometheus/client_golang/prometheus/promauto" 15 | "go.opentelemetry.io/otel/attribute" 16 | "go.opentelemetry.io/otel/trace" 17 | ) 18 | 19 | var ( 20 | statusLock = &sync.Mutex{} 21 | 22 | sqlConnections = promauto.NewGaugeVec(prometheus.GaugeOpts{ 23 | Name: "sql_connections", 24 | Help: "How many MySQL connections and what status they're in.", 25 | }, []string{"state", "id"}) 26 | 27 | sqlConnectionsCounters = promauto.NewGaugeVec(prometheus.GaugeOpts{ 28 | Name: "sql_connections_counters", 29 | Help: `Counters specific to the sql connections. 30 | wait_count: The total number of connections waited for. 31 | wait_duration: The total time blocked waiting for a new connection. 32 | max_idle_closed: The total number of connections closed due to SetMaxIdleConns. 33 | max_idle_time_closed: The total number of connections closed due to SetConnMaxIdleTime. 34 | max_lifetime_closed: The total number of connections closed due to SetConnMaxLifetime. 35 | `, 36 | }, []string{"counter", "id"}) 37 | 38 | sqlQueries = promauto.NewHistogramVec(prometheus.HistogramOpts{ 39 | Name: "sql_queries", 40 | Help: `Histogram that measures the time in milliseconds queries take`, 41 | Buckets: []float64{10, 25, 50, 100, 250, 500, 1000, 2500, 5000}, 42 | }, []string{"operation", "id"}) 43 | 44 | sqlErrors = promauto.NewCounterVec(prometheus.CounterOpts{ 45 | Name: "sql_errors", 46 | Help: `Histogram that measures the time in milliseconds queries take`, 47 | }, []string{"id"}) 48 | 49 | // Adding in aliases for the usual error cases 50 | ErrNoRows = sql.ErrNoRows 51 | ErrConnDone = sql.ErrConnDone 52 | ErrTxDone = sql.ErrTxDone 53 | ) 54 | 55 | func MonitorSQLDriver(db *sql.DB, id string) context.CancelFunc { 56 | ctx, cancel := context.WithCancel(context.Background()) 57 | 58 | // Setup metrics after the database is setup 59 | go func(db *sql.DB, id string) { 60 | t := time.NewTicker(60 * time.Second) 61 | for { 62 | select { 63 | case <-ctx.Done(): 64 | return 65 | case <-t.C: 66 | MeasureStats(db, id) 67 | } 68 | } 69 | }(db, id) 70 | 71 | return cancel 72 | } 73 | 74 | func MeasureStats(db *sql.DB, id string) error { 75 | statusLock.Lock() 76 | defer statusLock.Unlock() 77 | 78 | stats := db.Stats() 79 | 80 | sqlConnections.With(prometheus.Labels{"state": "idle", "id": id}).Set(float64(stats.Idle)) 81 | sqlConnections.With(prometheus.Labels{"state": "inuse", "id": id}).Set(float64(stats.InUse)) 82 | sqlConnections.With(prometheus.Labels{"state": "open", "id": id}).Set(float64(stats.OpenConnections)) 83 | 84 | sqlConnectionsCounters.With(prometheus.Labels{"counter": "wait_count", "id": id}).Set(float64(stats.WaitCount)) 85 | sqlConnectionsCounters.With(prometheus.Labels{"counter": "wait_ms", "id": id}).Set(float64(stats.WaitDuration.Milliseconds())) 86 | sqlConnectionsCounters.With(prometheus.Labels{"counter": "max_idle_closed", "id": id}).Set(float64(stats.MaxIdleClosed)) 87 | sqlConnectionsCounters.With(prometheus.Labels{"counter": "max_idle_time_closed", "id": id}).Set(float64(stats.MaxIdleTimeClosed)) 88 | sqlConnectionsCounters.With(prometheus.Labels{"counter": "max_lifetime_closed", "id": id}).Set(float64(stats.MaxLifetimeClosed)) 89 | 90 | return nil 91 | } 92 | 93 | type LazyLogger func() log.Logger 94 | 95 | func MeasureQuery(logger LazyLogger, slowQueryThresholdMs int64, id string, op string, qry string, args int) func() int64 { 96 | s := time.Now().UnixMilli() 97 | 98 | once := sync.Once{} 99 | 100 | return func() int64 { 101 | d := int64(-1) 102 | 103 | once.Do(func() { 104 | d = time.Now().UnixMilli() - s 105 | 106 | sqlQueries.With(prometheus.Labels{"id": id, "operation": op}).Observe(float64(d)) 107 | 108 | if d >= slowQueryThresholdMs { 109 | logger().Warn().With(log.Fields{ 110 | "query": log.String(CleanQuery(qry)), 111 | "query_id": log.String(id), 112 | "query_op": log.String(op), 113 | "query_time_ms": log.Int64(d), 114 | "query_args": log.Int(args), 115 | }).Log("slow query detected") 116 | } 117 | 118 | // Lazy loggers could self reference, so lets nil it out. 119 | logger = nil 120 | }) 121 | 122 | return d 123 | } 124 | } 125 | 126 | func MeasureError(id string, err error) error { 127 | if err != nil && err != ErrNoRows { 128 | sqlErrors.With(prometheus.Labels{"id": id}).Inc() 129 | } 130 | return err 131 | } 132 | 133 | func CleanQuery(s string) string { 134 | cleaner := strings.ReplaceAll(s, "\n", " ") 135 | cleaner = strings.ReplaceAll(cleaner, "\t", " ") 136 | cleaner = strings.Trim(cleaner, "\n\t ") 137 | 138 | for { 139 | spaces := strings.ReplaceAll(cleaner, " ", " ") 140 | 141 | // Check if it didn't change after the last replace 142 | if spaces == cleaner { 143 | break 144 | } 145 | 146 | cleaner = spaces 147 | } 148 | 149 | return cleaner 150 | } 151 | 152 | func span(ctx context.Context, id string, op string, query string, args int) (context.Context, func()) { 153 | start := time.Now() 154 | 155 | ctx, span := telemetry.StartSpan(ctx, "sql "+op, 156 | trace.WithSpanKind(trace.SpanKindInternal), 157 | trace.WithAttributes( 158 | attribute.String("sql.query", CleanQuery(query)), 159 | attribute.String("sql.query_id", id), 160 | attribute.String("sql.query_op", op), 161 | attribute.Int("sql.query_args", args), 162 | ), 163 | ) 164 | 165 | return ctx, func() { 166 | took := time.Since(start) 167 | span.SetAttributes(attribute.Int64("sql.query_time_ms", took.Milliseconds())) 168 | 169 | span.End() 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /sql/stmt.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "context" 5 | gosql "database/sql" 6 | 7 | "github.com/moov-io/base/log" 8 | ) 9 | 10 | type Stmt struct { 11 | logger log.Logger 12 | 13 | id string 14 | 15 | slowQueryThresholdMs int64 16 | 17 | query string 18 | ss *gosql.Stmt 19 | } 20 | 21 | func newStmt(ctx context.Context, logger log.Logger, db *gosql.DB, query, id string, slowQueryThresholdMs int64) (*Stmt, error) { 22 | // This statement is closed by (*Stmt).Close() and the responsibility of callers. 23 | // We want to keep the *gosql.Stmt alive 24 | ss, err := db.PrepareContext(ctx, query) 25 | if err != nil { 26 | return nil, err 27 | } 28 | return newWrappedStmt(logger, ss, query, id, slowQueryThresholdMs) 29 | } 30 | 31 | func newTxStmt(ctx context.Context, logger log.Logger, tx *gosql.Tx, query, id string, slowQueryThresholdMs int64) (*Stmt, error) { 32 | // This statement is closed by (*Stmt).Close() and the responsibility of callers. 33 | // We want to keep the *gosql.Stmt alive 34 | ss, err := tx.PrepareContext(ctx, query) 35 | if err != nil { 36 | return nil, err 37 | } 38 | return newWrappedStmt(logger, ss, query, id, slowQueryThresholdMs) 39 | } 40 | 41 | func newWrappedStmt(logger log.Logger, ss *gosql.Stmt, query, id string, slowQueryThresholdMs int64) (*Stmt, error) { 42 | return &Stmt{ 43 | logger: logger, 44 | id: id, 45 | query: query, 46 | ss: ss, 47 | 48 | slowQueryThresholdMs: slowQueryThresholdMs, 49 | }, nil 50 | } 51 | 52 | func (s *Stmt) lazyLogger() log.Logger { 53 | return s.logger 54 | } 55 | 56 | func (s *Stmt) start(op string, qry string, args int) func() int64 { 57 | return MeasureQuery(s.lazyLogger, s.slowQueryThresholdMs, s.id, op, qry, args) 58 | } 59 | 60 | func (s *Stmt) error(err error) error { 61 | return MeasureError(s.id, err) 62 | } 63 | 64 | func (s *Stmt) Close() error { 65 | if s != nil && s.ss != nil { 66 | return s.ss.Close() 67 | } 68 | return nil 69 | } 70 | 71 | func (s *Stmt) Exec(args ...any) (gosql.Result, error) { 72 | done := s.start("exec", s.query, len(args)) 73 | defer done() 74 | 75 | r, err := s.ss.Exec(args...) 76 | return r, s.error(err) 77 | } 78 | 79 | func (s *Stmt) ExecContext(ctx context.Context, args ...any) (gosql.Result, error) { 80 | done := s.start("exec", s.query, len(args)) 81 | ctx, end := span(ctx, s.id, "exec", s.query, len(args)) 82 | defer func() { 83 | end() 84 | done() 85 | }() 86 | 87 | r, err := s.ss.ExecContext(ctx, args...) 88 | return r, s.error(err) 89 | } 90 | 91 | func (s *Stmt) Query(args ...any) (*gosql.Rows, error) { 92 | done := s.start("query", s.query, len(args)) 93 | defer done() 94 | 95 | r, err := s.ss.Query(args...) //nolint:sqlclosecheck 96 | return r, s.error(err) 97 | } 98 | 99 | func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*gosql.Rows, error) { 100 | done := s.start("query", s.query, len(args)) 101 | ctx, end := span(ctx, s.id, "query", s.query, len(args)) 102 | defer func() { 103 | end() 104 | done() 105 | }() 106 | 107 | r, err := s.ss.QueryContext(ctx, args...) //nolint:sqlclosecheck 108 | return r, s.error(err) 109 | } 110 | 111 | func (s *Stmt) QueryRow(args ...any) *gosql.Row { 112 | done := s.start("query-row", s.query, len(args)) 113 | defer done() 114 | 115 | r := s.ss.QueryRow(args...) 116 | s.error(r.Err()) 117 | 118 | return r 119 | } 120 | 121 | func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *gosql.Row { 122 | done := s.start("query-row", s.query, len(args)) 123 | ctx, end := span(ctx, s.id, "query-row", s.query, len(args)) 124 | defer func() { 125 | end() 126 | done() 127 | }() 128 | 129 | r := s.ss.QueryRowContext(ctx, args...) 130 | s.error(r.Err()) 131 | 132 | return r 133 | } 134 | -------------------------------------------------------------------------------- /sql/tx.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "context" 5 | gosql "database/sql" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/moov-io/base/log" 10 | ) 11 | 12 | type Tx struct { 13 | *gosql.Tx 14 | 15 | logger log.Logger 16 | 17 | id string 18 | done func() int64 19 | ctx context.Context 20 | 21 | slowQueryThresholdMs int64 22 | 23 | queries []ranQuery 24 | } 25 | 26 | type ranQuery struct { 27 | op string 28 | qry string 29 | dur int64 30 | args int 31 | } 32 | 33 | func (w *Tx) lazyLogger() log.Logger { 34 | return w.logger 35 | } 36 | 37 | func (w *Tx) Context() context.Context { 38 | return w.ctx 39 | } 40 | 41 | func (w *Tx) start(op string, query string, args int) func() int64 { 42 | _, end := span(w.ctx, w.id, op, query, args) 43 | 44 | s := time.Now().UnixMilli() 45 | return func() int64 { 46 | end() 47 | d := time.Now().UnixMilli() - s 48 | 49 | w.queries = append(w.queries, ranQuery{ 50 | op: op, 51 | qry: query, 52 | dur: d, 53 | args: args, 54 | }) 55 | 56 | return d 57 | } 58 | } 59 | 60 | func (w *Tx) error(err error) error { 61 | return MeasureError(w.id, err) 62 | } 63 | 64 | func (w *Tx) Commit() error { 65 | defer w.finished() 66 | return w.error(w.Tx.Commit()) 67 | } 68 | 69 | func (w *Tx) Rollback() error { 70 | defer w.finished() 71 | return w.error(w.Tx.Rollback()) 72 | } 73 | 74 | func (w *Tx) finished() { 75 | w.logger = w.logger.With(log.Fields{ 76 | "query_id": log.String(w.id), 77 | "query_cnt": log.Int(len(w.queries)), 78 | }) 79 | 80 | for i, q := range w.queries { 81 | if i < 7 { 82 | pre := fmt.Sprintf("%d_", i) 83 | w.logger = w.logger.With(log.Fields{ 84 | pre + "query": log.String(CleanQuery(q.qry)), 85 | pre + "query_op": log.String(q.op), 86 | pre + "query_time_ms": log.Int64(q.dur), 87 | pre + "query_args": log.Int(q.args), 88 | }) 89 | } 90 | } 91 | 92 | w.done() 93 | } 94 | 95 | func (w *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (gosql.Result, error) { 96 | done := w.start("exec", query, len(args)) 97 | defer done() 98 | 99 | r, err := w.Tx.ExecContext(ctx, query, args...) 100 | return r, w.error(err) 101 | } 102 | 103 | func (w *Tx) Exec(query string, args ...interface{}) (gosql.Result, error) { 104 | done := w.start("exec", query, len(args)) 105 | defer done() 106 | 107 | r, err := w.Tx.Exec(query, args...) 108 | return r, w.error(err) 109 | } 110 | 111 | func (w *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*gosql.Rows, error) { 112 | done := w.start("query", query, len(args)) 113 | defer done() 114 | 115 | r, err := w.Tx.QueryContext(ctx, query, args...) //nolint:sqlclosecheck 116 | return r, w.error(err) 117 | } 118 | 119 | func (w *Tx) Query(query string, args ...interface{}) (*gosql.Rows, error) { 120 | done := w.start("query", query, len(args)) 121 | defer done() 122 | 123 | r, err := w.Tx.Query(query, args...) //nolint:sqlclosecheck 124 | return r, w.error(err) 125 | } 126 | 127 | func (w *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *gosql.Row { 128 | done := w.start("query-row", query, len(args)) 129 | defer done() 130 | 131 | r := w.Tx.QueryRowContext(ctx, query, args...) 132 | w.error(r.Err()) 133 | 134 | return r 135 | } 136 | 137 | func (w *Tx) QueryRow(query string, args ...interface{}) *gosql.Row { 138 | done := w.start("query-row", query, len(args)) 139 | defer done() 140 | 141 | r := w.Tx.QueryRow(query, args...) 142 | w.error(r.Err()) 143 | 144 | return r 145 | } 146 | 147 | func (w *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { 148 | done := w.start("prepare", query, 0) 149 | defer done() 150 | 151 | return newTxStmt(ctx, w.logger, w.Tx, query, w.id, w.slowQueryThresholdMs) 152 | } 153 | -------------------------------------------------------------------------------- /stime/static_time.go: -------------------------------------------------------------------------------- 1 | package stime 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type StaticTimeService interface { 8 | Change(update time.Time) time.Time 9 | Add(d time.Duration) time.Time 10 | TimeService 11 | } 12 | 13 | type staticTimeService struct { 14 | time time.Time 15 | } 16 | 17 | func NewStaticTimeService() StaticTimeService { 18 | return &staticTimeService{ 19 | time: time.Now().In(time.UTC).Round(time.Second), 20 | } 21 | } 22 | 23 | func (s *staticTimeService) Now() time.Time { 24 | return s.time 25 | } 26 | 27 | func (s *staticTimeService) Change(update time.Time) time.Time { 28 | s.time = update 29 | return s.time 30 | } 31 | 32 | func (s *staticTimeService) Add(d time.Duration) time.Time { 33 | s.time = s.time.Add(d) 34 | return s.time 35 | } 36 | -------------------------------------------------------------------------------- /stime/system_time.go: -------------------------------------------------------------------------------- 1 | package stime 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type TimeService interface { 8 | Now() time.Time 9 | } 10 | 11 | type timeService struct{} 12 | 13 | func NewSystemTimeService() TimeService { 14 | return &timeService{} 15 | } 16 | 17 | func (s *timeService) Now() time.Time { 18 | return time.Now().In(time.UTC) 19 | } 20 | -------------------------------------------------------------------------------- /strx/strx.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package strx 6 | 7 | import ( 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | // Or returns the first non-empty string 13 | func Or(options ...string) string { 14 | for i := range options { 15 | if v := strings.TrimSpace(options[i]); v != "" { 16 | return v 17 | } 18 | } 19 | return "" 20 | } 21 | 22 | // Yes returns true if the provided case-insensitive string matches 'yes' and is used to parse config values. 23 | func Yes(in string) bool { 24 | in = strings.TrimSpace(in) 25 | if strings.EqualFold(in, "yes") { 26 | return true 27 | } 28 | v, _ := strconv.ParseBool(in) 29 | return v 30 | } 31 | -------------------------------------------------------------------------------- /strx/strx_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package strx 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestOr(t *testing.T) { 14 | if v := Or(); v != "" { 15 | t.Errorf("got %q", v) 16 | } 17 | if v := Or("", "backup"); v != "backup" { 18 | t.Errorf("got %q", v) 19 | } 20 | if v := Or("primary", ""); v != "primary" { 21 | t.Errorf("got %q", v) 22 | } 23 | if v := Or("primary", "backup"); v != "primary" { 24 | t.Errorf("got %q", v) 25 | } 26 | } 27 | 28 | func TestYes(t *testing.T) { 29 | // accepted values 30 | require.True(t, Yes("yes")) 31 | require.True(t, Yes(" true ")) 32 | 33 | // common, but unsupported 34 | require.False(t, Yes("on")) 35 | require.False(t, Yes("y")) 36 | require.False(t, Yes("no")) 37 | 38 | // explicit no values 39 | require.False(t, Yes("no")) 40 | require.False(t, Yes("false")) 41 | require.False(t, Yes("")) 42 | } 43 | -------------------------------------------------------------------------------- /telemetry/collector.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import ( 4 | "context" 5 | 6 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace" 7 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" 8 | 9 | // Add in gzip 10 | "google.golang.org/grpc/credentials" 11 | _ "google.golang.org/grpc/encoding/gzip" 12 | ) 13 | 14 | type OtelConfig struct { 15 | Host string 16 | TLS bool 17 | } 18 | 19 | func newOpenTelementyCollectorExporter(ctx context.Context, config OtelConfig) (*otlptrace.Exporter, error) { 20 | opts := []otlptracegrpc.Option{ 21 | otlptracegrpc.WithEndpoint(config.Host), 22 | } 23 | 24 | if config.TLS { 25 | opts = append(opts, otlptracegrpc.WithTLSCredentials(credentials.NewClientTLSFromCert(nil, ""))) 26 | } else { 27 | opts = append(opts, otlptracegrpc.WithInsecure()) 28 | } 29 | 30 | client := otlptracegrpc.NewClient(opts...) 31 | return otlptrace.New(ctx, client) 32 | } 33 | -------------------------------------------------------------------------------- /telemetry/config.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "os" 7 | "time" 8 | 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/propagation" 11 | "go.opentelemetry.io/otel/sdk/resource" 12 | tracesdk "go.opentelemetry.io/otel/sdk/trace" 13 | semconv "go.opentelemetry.io/otel/semconv/v1.7.0" 14 | "go.opentelemetry.io/otel/trace" 15 | ) 16 | 17 | type Config struct { 18 | ServiceName string 19 | ServiceNamespace *string 20 | Stdout bool 21 | OpenTelemetryCollector *OtelConfig 22 | Honeycomb *HoneycombConfig 23 | 24 | // Allows for testing of the output of telemetry without affecting use with config files 25 | testWriter io.Writer 26 | } 27 | 28 | // Allows for testing where the output of the traces are sent to a io.Writer instance. 29 | func TestConfig(w io.Writer) Config { 30 | return Config{ 31 | ServiceName: "test-service", 32 | testWriter: w, 33 | } 34 | } 35 | 36 | type ShutdownFunc func() error 37 | 38 | var NoopShutdown ShutdownFunc = func() error { 39 | return nil 40 | } 41 | 42 | func SetupTelemetry(ctx context.Context, config Config, version string) (ShutdownFunc, error) { 43 | var ( 44 | err error 45 | exp tracesdk.SpanExporter 46 | ) 47 | 48 | if config.testWriter != nil { 49 | exp, err = newJsonExporter(config.testWriter) 50 | if err != nil { 51 | return NoopShutdown, err 52 | } 53 | 54 | } else if isOtelEnvironmentSet() { 55 | exp, err = newOtelExporterFromEnvironment(ctx) 56 | if err != nil { 57 | return NoopShutdown, err 58 | } 59 | 60 | } else if isHoneycombEnvironmentSet() { 61 | exp, err = newHoneycombExporterFromEnvironment(ctx) 62 | if err != nil { 63 | return NoopShutdown, err 64 | } 65 | 66 | } else if config.Stdout { 67 | exp, err = newStdoutExporter() 68 | if err != nil { 69 | return NoopShutdown, err 70 | } 71 | 72 | } else if config.OpenTelemetryCollector != nil { 73 | exp, err = newOpenTelementyCollectorExporter(ctx, *config.OpenTelemetryCollector) 74 | if err != nil { 75 | return NoopShutdown, err 76 | } 77 | 78 | } else if config.Honeycomb != nil { 79 | exp, err = newHoneycombExporterFromConfig(ctx, *config.Honeycomb) 80 | if err != nil { 81 | return NoopShutdown, err 82 | } 83 | } 84 | 85 | // Make sure something is set for the exporter 86 | if exp == nil { 87 | exp, err = newDiscardExporter() 88 | if err != nil { 89 | return NoopShutdown, err 90 | } 91 | } 92 | 93 | tp := newTraceProvider(exp, config, version) 94 | 95 | otel.SetTracerProvider(tp) 96 | 97 | otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator( 98 | propagation.TraceContext{}, 99 | propagation.Baggage{}, 100 | )) 101 | 102 | return func() error { 103 | ctx := context.Background() 104 | tp.ForceFlush(ctx) 105 | return tp.Shutdown(ctx) 106 | }, nil 107 | } 108 | 109 | func newTraceProvider(exp tracesdk.SpanExporter, config Config, version string) TracerProvider { 110 | if config.ServiceName == "" { 111 | config.ServiceName = os.Getenv("MOOV_SERVICE_NAME") 112 | } 113 | 114 | if config.ServiceNamespace == nil || *config.ServiceNamespace == "" { 115 | ns := os.Getenv("MOOV_SERVICE_NAMESPACE") 116 | config.ServiceNamespace = &ns 117 | } 118 | 119 | // Wrap it so we can filter out useless traces from consuming 120 | exp = NewFilteredExporter(exp) 121 | 122 | batcher := tracesdk.WithBatcher(exp, 123 | tracesdk.WithMaxQueueSize(3*tracesdk.DefaultMaxQueueSize), 124 | tracesdk.WithMaxExportBatchSize(3*tracesdk.DefaultMaxExportBatchSize), 125 | tracesdk.WithBatchTimeout(5*time.Second), 126 | ) 127 | 128 | // If we're using the testWriter we want to make sure its not buffering anything in the background 129 | if config.testWriter != nil { 130 | batcher = tracesdk.WithSyncer(exp) 131 | } 132 | 133 | resource := resource.NewWithAttributes( 134 | semconv.SchemaURL, 135 | semconv.ServiceNameKey.String(config.ServiceName), 136 | semconv.ServiceVersionKey.String(version), 137 | semconv.ServiceNamespaceKey.String(*config.ServiceNamespace), 138 | ) 139 | 140 | tp := tracesdk.NewTracerProvider( 141 | tracesdk.WithSampler(tracesdk.ParentBased(tracesdk.AlwaysSample())), 142 | batcher, 143 | tracesdk.WithResource(resource), 144 | ) 145 | 146 | return &tracerProvider{ 147 | TracerProvider: tp, 148 | } 149 | } 150 | 151 | type TracerProvider interface { 152 | trace.TracerProvider 153 | 154 | ForceFlush(ctx context.Context) error 155 | Shutdown(ctx context.Context) error 156 | } 157 | 158 | type tracerProvider struct { 159 | *tracesdk.TracerProvider 160 | } 161 | -------------------------------------------------------------------------------- /telemetry/config_test.go: -------------------------------------------------------------------------------- 1 | package telemetry_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "testing" 7 | "time" 8 | 9 | "github.com/moov-io/base/telemetry" 10 | 11 | "github.com/stretchr/testify/require" 12 | "go.opentelemetry.io/otel/trace" 13 | ) 14 | 15 | func Test_Setup_Honey(t *testing.T) { 16 | shutdown, err := telemetry.SetupTelemetry(context.Background(), telemetry.Config{ 17 | ServiceName: "test", 18 | Honeycomb: &telemetry.HoneycombConfig{ 19 | URL: "api.honeycomb.io:443", 20 | Team: "HoneycombAPIKey", 21 | }, 22 | }, "v0.0.1") 23 | 24 | require.NoError(t, err) 25 | 26 | err = shutdown() 27 | require.NoError(t, err) 28 | } 29 | 30 | func Test_Setup_Otel(t *testing.T) { 31 | shutdown, err := telemetry.SetupTelemetry(context.Background(), telemetry.Config{ 32 | ServiceName: "test", 33 | OpenTelemetryCollector: &telemetry.OtelConfig{ 34 | Host: "collector", 35 | }, 36 | }, "v0.0.1") 37 | 38 | require.NoError(t, err) 39 | 40 | err = shutdown() 41 | require.NoError(t, err) 42 | } 43 | 44 | func Test_Setup_Stdout(t *testing.T) { 45 | shutdown, err := telemetry.SetupTelemetry(context.Background(), telemetry.Config{ 46 | ServiceName: "test", 47 | Stdout: true, 48 | }, "v0.0.1") 49 | 50 | require.NoError(t, err) 51 | 52 | _, spn := telemetry.StartSpan(context.Background(), "test") 53 | spn.AddEvent("added an event!") 54 | spn.End() 55 | 56 | err = shutdown() 57 | require.NoError(t, err) 58 | } 59 | 60 | func Test_Keeping_Consumers(t *testing.T) { 61 | buf := setupTelemetry(t) 62 | 63 | _, spn := telemetry.StartSpan(context.Background(), "consuming", trace.WithSpanKind(trace.SpanKindConsumer)) 64 | time.Sleep(5 * time.Millisecond) 65 | spn.End() 66 | 67 | require.Contains(t, buf.String(), `"Name":"consuming"`) 68 | } 69 | 70 | func Test_Dropping_Empty_Consumers(t *testing.T) { 71 | buf := setupTelemetry(t) 72 | 73 | _, spn := telemetry.StartSpan(context.Background(), "consuming", trace.WithSpanKind(trace.SpanKindConsumer)) 74 | // instantaneously returns 75 | spn.End() 76 | 77 | require.NotContains(t, buf.String(), "consuming") 78 | } 79 | 80 | func setupTelemetry(t *testing.T) *bytes.Buffer { 81 | t.Helper() 82 | buf := &bytes.Buffer{} 83 | config := telemetry.TestConfig(buf) 84 | 85 | shutdown, err := telemetry.SetupTelemetry(context.Background(), config, "v0.0.1") 86 | t.Cleanup(func() { 87 | shutdown() 88 | }) 89 | 90 | require.NoError(t, err) 91 | 92 | return buf 93 | } 94 | -------------------------------------------------------------------------------- /telemetry/env.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import ( 4 | "context" 5 | "os" 6 | 7 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace" 8 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" 9 | ) 10 | 11 | func isOtelEnvironmentSet() bool { 12 | return os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT") != "" 13 | } 14 | 15 | // Creates a exporter thats completely built by environment flags. 16 | // References: 17 | // - https://opentelemetry.io/docs/specs/otel/protocol/exporter/ 18 | // - https://opentelemetry.io/docs/specs/otel/configuration/sdk-environment-variables/ 19 | func newOtelExporterFromEnvironment(ctx context.Context) (*otlptrace.Exporter, error) { 20 | client := otlptracegrpc.NewClient() 21 | return otlptrace.New(ctx, client) 22 | } 23 | -------------------------------------------------------------------------------- /telemetry/exporter.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "go.opentelemetry.io/otel/codes" 8 | tracesdk "go.opentelemetry.io/otel/sdk/trace" 9 | "go.opentelemetry.io/otel/trace" 10 | ) 11 | 12 | var _ tracesdk.SpanExporter = &filteredExporter{} 13 | 14 | func NewFilteredExporter(inner tracesdk.SpanExporter) tracesdk.SpanExporter { 15 | return &filteredExporter{inner: inner} 16 | } 17 | 18 | type filteredExporter struct { 19 | inner tracesdk.SpanExporter 20 | } 21 | 22 | func (fe *filteredExporter) Shutdown(ctx context.Context) error { 23 | return fe.inner.Shutdown(ctx) 24 | } 25 | 26 | func (fe *filteredExporter) ExportSpans(ctx context.Context, spans []tracesdk.ReadOnlySpan) error { 27 | in := []tracesdk.ReadOnlySpan{} 28 | 29 | for _, span := range spans { 30 | if fe.AlwaysInclude(span) { 31 | in = append(in, span) 32 | continue 33 | } 34 | 35 | if HasSpanDrop(span) { 36 | continue 37 | } 38 | 39 | if IsEmptyConsume(span) { 40 | continue 41 | } 42 | 43 | in = append(in, span) 44 | } 45 | 46 | return fe.inner.ExportSpans(ctx, in) 47 | } 48 | 49 | func (fe *filteredExporter) AlwaysInclude(s tracesdk.ReadOnlySpan) bool { 50 | return len(s.Links()) > 0 || 51 | len(s.Events()) > 0 || 52 | s.ChildSpanCount() > 0 || 53 | s.Status().Code == codes.Error 54 | } 55 | 56 | // Allows for services to just flag a span to be dropped. 57 | func HasSpanDrop(s tracesdk.ReadOnlySpan) bool { 58 | for _, attr := range s.Attributes() { 59 | if attr.Key == DropSpanKey && attr.Value.AsBool() { 60 | return true 61 | } 62 | } 63 | 64 | return false 65 | } 66 | 67 | // Detects if its an event that was consumed but ignored. 68 | // These can cause a lot of cluttering in the traces and we want to filter them out. 69 | func IsEmptyConsume(s tracesdk.ReadOnlySpan) bool { 70 | if s.SpanKind() == trace.SpanKindConsumer { 71 | 72 | // If it took less than a millisecond and has no child spans, the event was most likely ignored... 73 | if s.EndTime().Sub(s.StartTime()) < time.Millisecond { 74 | return true 75 | } 76 | } 77 | 78 | return false 79 | } 80 | -------------------------------------------------------------------------------- /telemetry/honey.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import ( 4 | "context" 5 | "os" 6 | 7 | "google.golang.org/grpc/credentials" 8 | 9 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace" 10 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" 11 | 12 | // Add in gzip 13 | _ "google.golang.org/grpc/encoding/gzip" 14 | ) 15 | 16 | type HoneycombConfig struct { 17 | URL string 18 | Team string 19 | } 20 | 21 | func newHoneycombExporterFromConfig(ctx context.Context, config HoneycombConfig) (*otlptrace.Exporter, error) { 22 | return newHoneycombExporter(ctx, config.URL, config.Team) 23 | } 24 | 25 | func isHoneycombEnvironmentSet() bool { 26 | return os.Getenv("HONEYCOMB_API_KEY") != "" 27 | } 28 | 29 | func newHoneycombExporterFromEnvironment(ctx context.Context) (*otlptrace.Exporter, error) { 30 | return newHoneycombExporter(ctx, "api.honeycomb.io:443", os.Getenv("HONEYCOMB_API_KEY")) 31 | } 32 | 33 | func newHoneycombExporter(ctx context.Context, endpoint string, team string) (*otlptrace.Exporter, error) { 34 | // Configuration to export data to Honeycomb: 35 | // 36 | // 1. The Honeycomb endpoint 37 | // 2. Your API key, set as the x-honeycomb-team header 38 | opts := []otlptracegrpc.Option{ 39 | otlptracegrpc.WithCompressor("gzip"), 40 | otlptracegrpc.WithEndpoint(endpoint), 41 | otlptracegrpc.WithHeaders(map[string]string{ 42 | "x-honeycomb-team": team, 43 | }), 44 | otlptracegrpc.WithTLSCredentials(credentials.NewClientTLSFromCert(nil, "")), 45 | } 46 | 47 | client := otlptracegrpc.NewClient(opts...) 48 | return otlptrace.New(ctx, client) 49 | } 50 | -------------------------------------------------------------------------------- /telemetry/linked.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import ( 4 | "context" 5 | 6 | "go.opentelemetry.io/otel/attribute" 7 | "go.opentelemetry.io/otel/codes" 8 | "go.opentelemetry.io/otel/trace" 9 | ) 10 | 11 | // StartLinkedRootSpan starts a new root span where the parent and child spans share links to each other. This 12 | // is particularly useful in batch processing applications where separate spans are wanted for each subprocess 13 | // in the batch, but without cluttering the parent span. 14 | func StartLinkedRootSpan(ctx context.Context, name string, options ...trace.SpanStartOption) *LinkedSpan { 15 | 16 | // new root for the children 17 | childOpts := append([]trace.SpanStartOption{ 18 | trace.WithNewRoot(), 19 | trace.WithLinks(trace.LinkFromContext(ctx, attribute.String("link.name", "parent"))), // link to parent from child 20 | }, options...) 21 | 22 | childCtx, childSpan := StartSpan(ctx, name, childOpts...) 23 | 24 | // start a new span on the parent and link to the child span from the parent one. 25 | parentOpts := append([]trace.SpanStartOption{ 26 | trace.WithLinks(trace.LinkFromContext(childCtx, attribute.String("link.name", "child"))), // link to parent from child 27 | }, options...) 28 | 29 | parentCtx, parentSpan := StartSpan(ctx, name, parentOpts...) 30 | 31 | return &LinkedSpan{ 32 | childCtx: childCtx, 33 | childSpan: childSpan, 34 | parentCtx: parentCtx, 35 | parentSpan: parentSpan, 36 | } 37 | } 38 | 39 | type LinkedSpan struct { 40 | childCtx context.Context 41 | childSpan trace.Span 42 | 43 | parentCtx context.Context 44 | parentSpan trace.Span 45 | } 46 | 47 | func (l *LinkedSpan) End(options ...trace.SpanEndOption) { 48 | l.childSpan.End(options...) 49 | l.parentSpan.End(options...) 50 | } 51 | 52 | func (l *LinkedSpan) AddEvent(name string, options ...trace.EventOption) { 53 | l.childSpan.AddEvent(name, options...) 54 | l.parentSpan.AddEvent(name, options...) 55 | } 56 | 57 | func (l *LinkedSpan) RecordError(err error, options ...trace.EventOption) { 58 | l.childSpan.RecordError(err, options...) 59 | l.parentSpan.RecordError(err, options...) 60 | } 61 | 62 | func (l *LinkedSpan) SetStatus(code codes.Code, description string) { 63 | l.childSpan.SetStatus(code, description) 64 | l.parentSpan.SetStatus(code, description) 65 | } 66 | 67 | func (l *LinkedSpan) SetAttributes(kv ...attribute.KeyValue) { 68 | l.childSpan.SetAttributes(kv...) 69 | l.parentSpan.SetAttributes(kv...) 70 | } 71 | 72 | func (l *LinkedSpan) SetName(name string) { 73 | l.childSpan.SetName(name) 74 | l.parentSpan.SetName(name) 75 | } 76 | 77 | func (l *LinkedSpan) ChildSpan() trace.Span { 78 | return l.childSpan 79 | } 80 | 81 | func (l *LinkedSpan) ChildContext() context.Context { 82 | return l.childCtx 83 | } 84 | 85 | func (l *LinkedSpan) ParentSpan() trace.Span { 86 | return l.parentSpan 87 | } 88 | 89 | func (l *LinkedSpan) ParentContext() context.Context { 90 | return l.parentCtx 91 | } 92 | -------------------------------------------------------------------------------- /telemetry/stdout.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import ( 4 | "io" 5 | "os" 6 | 7 | "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" 8 | "go.opentelemetry.io/otel/sdk/trace" 9 | ) 10 | 11 | func newDiscardExporter() (trace.SpanExporter, error) { 12 | return newJsonExporter(io.Discard) 13 | } 14 | 15 | func newStdoutExporter() (trace.SpanExporter, error) { 16 | return newJsonExporter(os.Stdout) 17 | } 18 | 19 | // newExporter returns a console exporter. 20 | func newJsonExporter(w io.Writer) (trace.SpanExporter, error) { 21 | return stdouttrace.New( 22 | stdouttrace.WithWriter(w), 23 | ) 24 | } 25 | -------------------------------------------------------------------------------- /telemetry/tracer.go: -------------------------------------------------------------------------------- 1 | package telemetry 2 | 3 | import ( 4 | "context" 5 | 6 | "go.opentelemetry.io/otel" 7 | "go.opentelemetry.io/otel/attribute" 8 | "go.opentelemetry.io/otel/trace" 9 | ) 10 | 11 | const ( 12 | InstrumentationName = "moov.io" 13 | AttributeTag = "otel" 14 | MaxArrayAttributes = 10 15 | ) 16 | 17 | // GetTracer returns a unique Tracer scoped to be used by instrumentation code 18 | // to trace computational workflows. 19 | func GetTracer(opts ...trace.TracerOption) trace.Tracer { 20 | return otel.GetTracerProvider().Tracer(InstrumentationName, opts...) 21 | } 22 | 23 | // StartSpan will create a Span and a context containing the newly created Span. 24 | // 25 | // If the context.Context provided contains a Span then the new span will be a child span, 26 | // otherwise the new span will be a root span. 27 | // 28 | // OTEL recommends creating all attributes via `WithAttributes()` SpanOption when the span is created. 29 | // 30 | // Created spans MUST be ended with `.End()` and is the responsibility of callers. 31 | func StartSpan(ctx context.Context, spanName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { 32 | if ctx == nil { 33 | ctx = context.Background() 34 | } 35 | return GetTracer().Start(ctx, spanName, opts...) 36 | } 37 | 38 | // SpanFromContext returns the current Span from ctx. 39 | // 40 | // If no Span is currently set in ctx an implementation of a Span that performs no operations is returned. 41 | func SpanFromContext(ctx context.Context) trace.Span { 42 | return trace.SpanFromContext(ctx) 43 | } 44 | 45 | // AddEvent adds an event the Span in `ctx` with the provided name and options. 46 | func AddEvent(ctx context.Context, name string, options ...trace.EventOption) { 47 | SpanFromContext(ctx).AddEvent(name, options...) 48 | } 49 | 50 | // RecordError will record err as an exception span event for this span. It will also return the err passed in. 51 | func RecordError(ctx context.Context, err error, options ...trace.EventOption) error { 52 | options = append(options, trace.WithStackTrace(true)) 53 | SpanFromContext(ctx).RecordError(err, options...) 54 | return err 55 | } 56 | 57 | // SetAttributes sets kv as attributes of the Span. If a key from kv already exists for an 58 | // attribute of the Span it will be overwritten with the value contained in kv. 59 | func SetAttributes(ctx context.Context, kv ...attribute.KeyValue) { 60 | SpanFromContext(ctx).SetAttributes(kv...) 61 | } 62 | -------------------------------------------------------------------------------- /telemetry/tracer_test.go: -------------------------------------------------------------------------------- 1 | package telemetry_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/moov-io/base/telemetry" 9 | 10 | "github.com/stretchr/testify/require" 11 | "go.opentelemetry.io/otel/attribute" 12 | "go.opentelemetry.io/otel/sdk/trace" 13 | ) 14 | 15 | func TestStartSpan__NoPanic(t *testing.T) { 16 | ctx, span := telemetry.StartSpan(nil, "no-panics") //nolint:staticcheck 17 | require.NotNil(t, ctx) 18 | require.NotNil(t, span) 19 | } 20 | 21 | func TestSpan_SetAttributes(t *testing.T) { 22 | var conf telemetry.Config 23 | shutdown, err := telemetry.SetupTelemetry(context.Background(), conf, "v0.0.0") 24 | require.NoError(t, err) 25 | t.Cleanup(func() { shutdown() }) 26 | 27 | ctx, span := telemetry.StartSpan(context.Background(), "set-attributes") 28 | defer span.End() 29 | 30 | // First Set 31 | span.SetAttributes(attribute.String("kafka.topic", "test.cmd.v1")) 32 | 33 | // Second Set 34 | span.SetAttributes(attribute.String("event.type", "my-favorite-event")) 35 | 36 | // Verify the attributes which are set 37 | ss := telemetry.SpanFromContext(ctx) 38 | require.Equal(t, "*trace.recordingSpan", fmt.Sprintf("%T", ss)) 39 | ro, ok := ss.(trace.ReadOnlySpan) 40 | require.True(t, ok) 41 | 42 | attrs := ro.Attributes() 43 | for i := range attrs { 44 | switch attrs[i].Key { 45 | case "kafka.topic", "event.type": 46 | // do nothing 47 | default: 48 | t.Errorf("attribute[%d]=%#v\n", i, attrs[i]) 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /time.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Moov Authors 2 | // Use of this source code is governed by an Apache License 3 | // license that can be found in the LICENSE file. 4 | 5 | package base 6 | 7 | import ( 8 | "time" 9 | 10 | "github.com/rickar/cal/v2" 11 | "github.com/rickar/cal/v2/us" 12 | ) 13 | 14 | const ( 15 | // ISO8601Format represents an ISO 8601 format with timezone 16 | ISO8601Format = "2006-01-02T15:04:05Z07:00" 17 | ) 18 | 19 | // Time is an time.Time struct that encodes and decodes in ISO 8601. 20 | // 21 | // ISO 8601 is usable by a large array of libraries whereas RFC 3339 support 22 | // isn't often part of language standard libraries. 23 | // 24 | // Time also assists in calculating processing days that meet the US Federal Reserve Banks processing days. 25 | // 26 | // For holidays falling on Saturday, Federal Reserve Banks and Branches will be open the preceding Friday. 27 | // For holidays falling on Sunday, all Federal Reserve Banks and Branches will be closed the following Monday. 28 | // ACH and FedWire payments are not processed on weekends or the following US holidays. 29 | // 30 | // Holiday Schedule: https://www.frbservices.org/about/holiday-schedules 31 | // 32 | // All logic is based on ET(Eastern) time as defined by the Federal Reserve 33 | // https://www.frbservices.org/resources/resource-centers/same-day-ach/fedach-processing-schedule.html 34 | type Time struct { 35 | time.Time 36 | 37 | cal *cal.Calendar 38 | } 39 | 40 | // Now returns a Time object with the current clock time set. 41 | func Now(location *time.Location) Time { 42 | // Create our calendar to attach on Time 43 | calendar := &cal.Calendar{ 44 | Name: "moov-io/base", 45 | } 46 | calendar.AddHoliday(us.Holidays...) // TODO(adam): check for more? 47 | // calendar.Observed = cal.ObservedMonday // TODO(adam): 48 | return Time{ 49 | cal: calendar, 50 | Time: time.Now().In(location).Truncate(1 * time.Second), 51 | } 52 | } 53 | 54 | // NewTime wraps a time.Time value in Moov's base.Time struct. 55 | // If you need the underlying time.Time value call .Time: 56 | // 57 | // The time zone will be changed to UTC. 58 | func NewTime(t time.Time) Time { 59 | tt := Now(time.UTC) 60 | tt.Time = t // overwrite underlying Time 61 | return tt 62 | } 63 | 64 | // MarshalJSON returns JSON for the given Time 65 | func (t Time) MarshalJSON() ([]byte, error) { 66 | var bs []byte 67 | bs = append(bs, '"') 68 | 69 | t.Time = t.Time.Truncate(1 * time.Second) // drop milliseconds 70 | bs = t.AppendFormat(bs, ISO8601Format) 71 | 72 | bs = append(bs, '"') 73 | return bs, nil 74 | } 75 | 76 | // UnmarshalJSON unpacks a JSON string to populate a Time instance 77 | func (t *Time) UnmarshalJSON(data []byte) error { 78 | // Ignore null, like in the main JSON package. 79 | if string(data) == "null" { 80 | return nil 81 | } 82 | tt, err := time.Parse(`"`+ISO8601Format+`"`, string(data)) 83 | if err != nil || tt.IsZero() { 84 | // Try in RFC3339 format (default Go time) 85 | tt, _ = time.Parse(time.RFC3339, string(data)) 86 | *t = NewTime(tt) 87 | } 88 | 89 | t.Time = tt.UTC().Truncate(1 * time.Second) // convert to UTC and drop millis 90 | 91 | return nil 92 | } 93 | 94 | // Equal compares two Time values. Time values are considered equal if they both truncate 95 | // to the same year/month/day and hour/minute/second. 96 | func (t Time) Equal(other Time) bool { 97 | t1 := t.Time.Truncate(1 * time.Second) 98 | t2 := other.Time.Truncate(1 * time.Second) 99 | return t1.Equal(t2) 100 | } 101 | 102 | func (t Time) IsHoliday() bool { 103 | actual, observed, _ := t.cal.IsHoliday(t.Time) 104 | 105 | // The Federal Reserve does not observe the following holidays on the preceding Friday 106 | if (!actual && observed) && t.Time.Weekday() == time.Friday { 107 | return false 108 | } 109 | 110 | return actual || observed 111 | } 112 | 113 | func (t Time) GetHoliday() *cal.Holiday { 114 | _, _, holiday := t.cal.IsHoliday(t.Time) 115 | return holiday 116 | } 117 | 118 | // IsBusinessDay is defined as Mondays through Fridays except federal holidays. 119 | // Source: https://www.federalreserve.gov/Pubs/regcc/regcc.htm 120 | func (t Time) IsBusinessDay() bool { 121 | actual, _, _ := t.cal.IsHoliday(t.Time) 122 | return !t.IsWeekend() && !actual 123 | } 124 | 125 | // IsBankingDay checks the rules around holidays (i.e. weekends) to determine if the given day is a banking day. 126 | func (t Time) IsBankingDay() bool { 127 | // if date is not a weekend and not a holiday it is banking day. 128 | if t.IsWeekend() { 129 | return false 130 | } 131 | // and not a holiday 132 | if t.IsHoliday() { 133 | return false 134 | } 135 | // and not a monday after a holiday 136 | if t.Time.Weekday() == time.Monday { 137 | sun := t.Time.AddDate(0, 0, -1) 138 | 139 | actual, observed, _ := t.cal.IsHoliday(sun) 140 | return !actual && !observed 141 | } 142 | return true 143 | } 144 | 145 | // AddBusinessDay takes an integer for the number of valid business days to add and returns a Time. 146 | // Negative values and large values (over 500 days) will not modify the Time. 147 | func (t Time) AddBusinessDay(d int) Time { 148 | if d < 1 || d > 500 { 149 | return t 150 | } 151 | 152 | t.Time = t.Time.AddDate(0, 0, 1) 153 | if t.IsBusinessDay() { 154 | return t.AddBusinessDay(d - 1) 155 | } 156 | 157 | return t.AddBusinessDay(d) 158 | } 159 | 160 | // AddBankingDay takes an integer for the number of valid banking days to add and returns a Time. 161 | // Negative values and large values (over 500 days) will not modify the Time. 162 | func (t Time) AddBankingDay(d int) Time { 163 | if d < 1 || d > 500 { 164 | return t 165 | } 166 | 167 | t.Time = t.Time.AddDate(0, 0, 1) 168 | if t.IsBankingDay() { 169 | return t.AddBankingDay(d - 1) 170 | } 171 | 172 | return t.AddBankingDay(d) 173 | } 174 | 175 | // IsWeekend reports whether the given date falls on a weekend. 176 | func (t Time) IsWeekend() bool { 177 | day := t.Time.Weekday() 178 | return day == time.Saturday || day == time.Sunday 179 | } 180 | 181 | // AddBankingTime increments t by the hours, minutes, and seconds provided 182 | // but keeps the final time within 9am to 5pm in t's Location. 183 | func (t Time) AddBankingTime(hours, minutes, seconds int) Time { 184 | duration := time.Duration(hours) * time.Hour 185 | duration += time.Duration(minutes) * time.Minute 186 | duration += time.Duration(seconds) * time.Second 187 | 188 | return addBankingDuration(t, duration) 189 | } 190 | 191 | func addBankingDuration(start Time, duration time.Duration) Time { 192 | // If we're past the current day's banking hours advance forward one day 193 | if start.Hour() >= 17 { 194 | start = start.AddBankingDay(1) 195 | } 196 | 197 | // Start the day at 9am or later, but not past 5pm 198 | if start.Hour() < 9 || start.Hour() >= 17 { 199 | start.Time = time.Date(start.Year(), start.Month(), start.Day(), 9, start.Minute(), start.Second(), 0, start.Location()) 200 | } 201 | 202 | // Add banking hours as we can 203 | for duration > 0 { 204 | if start.IsBankingDay() { 205 | // Calculate the time remaining in the banking day 206 | endOfDay := time.Date(start.Year(), start.Month(), start.Day(), 17, 0, 0, 0, start.Location()) 207 | remainingToday := endOfDay.Sub(start.Time) 208 | if duration < remainingToday { 209 | start.Time = start.Time.Add(duration) 210 | return start 211 | } 212 | duration -= remainingToday 213 | } 214 | // Move to the next banking day starting at 9 AM 215 | start = start.AddBankingDay(1) 216 | start.Time = time.Date(start.Year(), start.Month(), start.Day(), 9, 0, 0, 0, start.Location()) 217 | } 218 | return start 219 | } 220 | -------------------------------------------------------------------------------- /time.rb: -------------------------------------------------------------------------------- 1 | require 'date' 2 | 3 | if ARGV.empty? 4 | puts "No time provided" 5 | exit 1 6 | end 7 | 8 | datetime = DateTime.iso8601(ARGV[0]) 9 | puts "Date: %s" % datetime.strftime('%Y-%m-%d') 10 | puts "Time: %s" % datetime.strftime('%H:%M:%S') 11 | --------------------------------------------------------------------------------