├── .gitattributes ├── .github └── workflows │ └── go.yml ├── .gitignore ├── Makefile ├── blockchain └── binance │ ├── api │ ├── client.go │ └── model.go │ ├── client.go │ ├── explorer │ ├── client.go │ └── model.go │ └── model.go ├── cache └── redis │ ├── client_interface.go │ ├── redis.go │ └── redis_test.go ├── client ├── api │ └── backend │ │ ├── client.go │ │ └── model.go ├── client.go ├── client_execute.go ├── client_metrics.go ├── client_metrics_test.go ├── client_test.go ├── client_wrapper.go ├── client_wrapper_test.go ├── clientcache.go ├── clientcache_test.go ├── jsonrpc.go ├── jsonrpc_batch.go ├── jsonrpc_batch_test.go ├── jsonrpc_test.go ├── path.go ├── path_test.go ├── request.go └── request_test.go ├── config └── viper │ └── viper.go ├── crypto ├── aes.go ├── aes_test.go ├── sign.go └── sign_test.go ├── ctask ├── do_all.go ├── do_all_test.go ├── doer.go └── doer_test.go ├── database ├── config.go ├── db.go ├── migrate.go ├── migration_runner_env.go └── mock_db.go ├── eventer ├── client.go └── log.go ├── gin ├── hmac.go ├── hmac_test.go └── setup.go ├── go.mod ├── go.sum ├── health ├── http.go └── http_test.go ├── httplib ├── downloader.go └── server.go ├── logging ├── README.md ├── formatter_strict_text.go ├── logger.go └── logger_test.go ├── metrics ├── README.md ├── handler.go ├── http_metrics.go ├── metrics.go ├── pusher.go └── register.go ├── middleware ├── cache.go ├── cache_control.go ├── cache_control_test.go ├── cache_test.go ├── logger.go ├── metrics.go ├── metrics_test.go ├── sentry.go ├── sentry_test.go └── shutdown.go ├── mock ├── mock.go ├── mock_test.go └── test.json ├── mq ├── consumer.go ├── exchange.go ├── mq.go ├── options.go └── queue.go ├── pkg └── nullable │ ├── primitives.go │ └── time.go ├── set ├── ordered.go ├── ordered_test.go ├── set.go └── set_test.go ├── slice ├── filter.go ├── filter_test.go ├── partition.go ├── partition_test.go ├── search.go └── search_test.go ├── testy ├── integration_test_suite.go ├── tagged.go └── tagged_test.go └── worker ├── metrics └── metricspusherworker.go ├── options.go ├── worker.go └── worker_test.go /.gitattributes: -------------------------------------------------------------------------------- 1 | *.json linguist-vendored 2 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | test: 12 | name: Test 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Setup go 16 | uses: actions/setup-go@v2 17 | with: 18 | go-version: ^1.18 19 | id: go 20 | 21 | - name: Checkout 22 | uses: actions/checkout@v2 23 | 24 | - name: Test 25 | run: go test -v ./... 26 | 27 | lint: 28 | name: Lint 29 | runs-on: ubuntu-latest 30 | steps: 31 | - uses: actions/setup-go@v4 32 | with: 33 | go-version: '1.19' 34 | - uses: actions/checkout@v3 35 | - uses: golangci/golangci-lint-action@v3 36 | with: 37 | version: v1.50.1 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | vendor/ 16 | 17 | .idea/ 18 | .vscode/ 19 | 20 | bin/* 21 | .DS_Store 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: test-network 2 | go test -v ./... 3 | 4 | test-network: 5 | cd ./network; \ 6 | go test -v ./...; \ 7 | 8 | ## golint: Run linter. 9 | lint: go-lint-install go-lint 10 | 11 | go-lint-install: 12 | ifeq (,$(shell which golangci-lint)) 13 | @echo " > Installing golint" 14 | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- v1.50.1 15 | endif 16 | 17 | go-lint: 18 | @echo " > Running golint" 19 | golangci-lint run ./... 20 | -------------------------------------------------------------------------------- /blockchain/binance/api/client.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "net/url" 7 | "strconv" 8 | "time" 9 | 10 | "github.com/trustwallet/go-libs/client" 11 | ) 12 | 13 | // Client is a binance API client 14 | type Client struct { 15 | req client.Request 16 | } 17 | 18 | func InitClient(url string, errorHandler client.HttpErrorHandler) Client { 19 | request := client.InitJSONClient(url, errorHandler) 20 | 21 | return Client{ 22 | req: request, 23 | } 24 | } 25 | 26 | func (c *Client) GetTransactionsByAddress(address string, limit int) ([]Tx, error) { 27 | startTime := strconv.Itoa(int(time.Now().AddDate(0, 0, -7).Unix() * 1000)) 28 | endTime := strconv.Itoa(int(time.Now().Unix() * 1000)) 29 | params := url.Values{ 30 | "address": {address}, 31 | "startTime": {startTime}, 32 | "endTime": {endTime}, 33 | "limit": {strconv.Itoa(limit)}, 34 | } 35 | 36 | var result TransactionsResponse 37 | 38 | _, err := c.req.Execute(context.TODO(), client.NewReqBuilder(). 39 | Method(http.MethodGet). 40 | PathStatic("bc/api/v1/txs"). 41 | Query(params). 42 | WriteTo(&result). 43 | Build()) 44 | return result.Tx, err 45 | } 46 | -------------------------------------------------------------------------------- /blockchain/binance/api/model.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | type ( 4 | TransactionsResponse struct { 5 | Total int `json:"total"` 6 | Tx []Tx `json:"txs"` 7 | } 8 | 9 | Type string 10 | 11 | Tx struct { 12 | Hash string `json:"hash"` 13 | BlockHeight int `json:"blockHeight"` 14 | BlockTime int64 `json:"blockTime"` 15 | Type Type `json:"type"` 16 | Fee int `json:"fee"` 17 | Code int `json:"code"` 18 | Source int `json:"source"` 19 | Sequence int `json:"sequence"` 20 | Memo string `json:"memo"` 21 | Log string `json:"log"` 22 | Data string `json:"data"` 23 | Asset string `json:"asset"` 24 | Amount float64 `json:"amount"` 25 | FromAddr string `json:"fromAddr"` 26 | ToAddr string `json:"toAddr"` 27 | } 28 | ) 29 | -------------------------------------------------------------------------------- /blockchain/binance/client.go: -------------------------------------------------------------------------------- 1 | package binance 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "net/url" 7 | "strconv" 8 | "time" 9 | 10 | "github.com/trustwallet/go-libs/client" 11 | ) 12 | 13 | // Client is a binance dex API client 14 | type Client struct { 15 | req client.Request 16 | } 17 | 18 | func InitClient(url, apiKey string, errorHandler client.HttpErrorHandler) Client { 19 | request := client.InitJSONClient(url, errorHandler, client.WithExtraHeader("apikey", apiKey)) 20 | return Client{ 21 | req: request, 22 | } 23 | } 24 | 25 | func (c Client) FetchNodeInfo() (result NodeInfoResponse, err error) { 26 | _, err = c.req.Execute(context.TODO(), client.NewReqBuilder(). 27 | Method(http.MethodGet). 28 | PathStatic("/api/v1/node-info"). 29 | WriteTo(&result). 30 | Build()) 31 | return result, err 32 | } 33 | 34 | func (c Client) FetchTransactionsInBlock(blockNumber int64) (result TransactionsInBlockResponse, err error) { 35 | _, err = c.req.Execute(context.TODO(), client.NewReqBuilder(). 36 | Method(http.MethodGet). 37 | Pathf("api/v2/transactions-in-block/%d", blockNumber). 38 | WriteTo(&result). 39 | Build()) 40 | return result, err 41 | } 42 | 43 | func (c Client) FetchTransactionsByAddressAndTokenID(address, tokenID string, limit int) ([]Tx, error) { 44 | startTime := strconv.Itoa(int(time.Now().AddDate(0, -3, 0).Unix() * 1000)) 45 | params := url.Values{ 46 | "address": {address}, 47 | "txAsset": {tokenID}, 48 | "startTime": {startTime}, 49 | "limit": {strconv.Itoa(limit)}, 50 | } 51 | var result TransactionsInBlockResponse 52 | _, err := c.req.Execute(context.TODO(), client.NewReqBuilder(). 53 | Method(http.MethodGet). 54 | PathStatic("/api/v1/transactions"). 55 | Query(params). 56 | WriteTo(&result). 57 | Build()) 58 | return result.Tx, err 59 | } 60 | 61 | func (c Client) FetchAccountMeta(address string) (result AccountMeta, err error) { 62 | _, err = c.req.Execute(context.TODO(), client.NewReqBuilder(). 63 | Method(http.MethodGet). 64 | Pathf("/api/v1/account/%s", address). 65 | WriteTo(&result). 66 | Build()) 67 | return result, err 68 | } 69 | 70 | func (c Client) FetchTokens(limit int) (result Tokens, err error) { 71 | params := url.Values{"limit": {strconv.Itoa(limit)}} 72 | _, err = c.req.Execute(context.TODO(), client.NewReqBuilder(). 73 | Method(http.MethodGet). 74 | PathStatic("/api/v1/tokens"). 75 | Query(params). 76 | WriteTo(&result). 77 | Build()) 78 | return result, err 79 | } 80 | 81 | func (c Client) FetchMarketPairs(limit int) (pairs []MarketPair, err error) { 82 | params := url.Values{"limit": {strconv.Itoa(limit)}} 83 | _, err = c.req.Execute(context.TODO(), client.NewReqBuilder(). 84 | Method(http.MethodGet). 85 | PathStatic("/api/v1/markets"). 86 | Query(params). 87 | WriteTo(&pairs). 88 | Build()) 89 | return pairs, err 90 | } 91 | -------------------------------------------------------------------------------- /blockchain/binance/explorer/client.go: -------------------------------------------------------------------------------- 1 | package explorer 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "net/url" 7 | "strconv" 8 | 9 | "github.com/trustwallet/go-libs/client" 10 | ) 11 | 12 | // Client is a binance explorer API client 13 | type Client struct { 14 | req client.Request 15 | } 16 | 17 | func InitClient(url string, errorHandler client.HttpErrorHandler) Client { 18 | request := client.InitJSONClient(url, errorHandler) 19 | 20 | return Client{ 21 | req: request, 22 | } 23 | } 24 | 25 | func (c Client) FetchBep2Assets(page, rows int) (assets Bep2Assets, err error) { 26 | params := url.Values{ 27 | "page": {strconv.Itoa(page)}, 28 | "rows": {strconv.Itoa(rows)}, 29 | } 30 | _, err = c.req.Execute(context.TODO(), client.NewReqBuilder(). 31 | Method(http.MethodGet). 32 | PathStatic("/api/v1/assets"). 33 | Query(params). 34 | WriteTo(&assets). 35 | Build()) 36 | 37 | return assets, err 38 | } 39 | -------------------------------------------------------------------------------- /blockchain/binance/explorer/model.go: -------------------------------------------------------------------------------- 1 | package explorer 2 | 3 | type ( 4 | Bep2Asset struct { 5 | Asset string `json:"asset"` 6 | Name string `json:"name"` 7 | AssetImg string `json:"assetImg"` 8 | MappedAsset string `json:"mappedAsset"` 9 | Decimals int `json:"decimals"` 10 | } 11 | 12 | Bep2Assets struct { 13 | AssetInfoList []Bep2Asset `json:"assetInfoList"` 14 | } 15 | ) 16 | -------------------------------------------------------------------------------- /blockchain/binance/model.go: -------------------------------------------------------------------------------- 1 | package binance 2 | 3 | import "time" 4 | 5 | type ( 6 | NodeInfoResponse struct { 7 | SyncInfo struct { 8 | LatestBlockHeight int `json:"latest_block_height"` 9 | } `json:"sync_info"` 10 | } 11 | 12 | TransactionsInBlockResponse struct { 13 | BlockHeight int `json:"blockHeight"` 14 | Tx []Tx `json:"tx"` 15 | } 16 | 17 | TxType string 18 | 19 | Tx struct { 20 | TxHash string `json:"txHash"` 21 | BlockHeight int `json:"blockHeight"` 22 | TxType TxType `json:"txType"` 23 | TimeStamp time.Time `json:"timeStamp"` 24 | FromAddr interface{} `json:"fromAddr"` 25 | ToAddr interface{} `json:"toAddr"` 26 | Value string `json:"value"` 27 | TxAsset string `json:"txAsset"` 28 | TxFee string `json:"txFee"` 29 | OrderID string `json:"orderId,omitempty"` 30 | Code int `json:"code"` 31 | Data string `json:"data"` 32 | Memo string `json:"memo"` 33 | Source int `json:"source"` 34 | SubTransactions []SubTransactions `json:"subTransactions,omitempty"` 35 | Sequence int `json:"sequence"` 36 | } 37 | 38 | TransactionData struct { 39 | OrderData struct { 40 | Symbol string `json:"symbol"` 41 | OrderType string `json:"orderType"` 42 | Side string `json:"side"` 43 | Price string `json:"price"` 44 | Quantity string `json:"quantity"` 45 | TimeInForce string `json:"timeInForce"` 46 | OrderID string `json:"orderId"` 47 | } `json:"orderData"` 48 | } 49 | 50 | SubTransactions struct { 51 | TxHash string `json:"txHash"` 52 | BlockHeight int `json:"blockHeight"` 53 | TxType string `json:"txType"` 54 | FromAddr string `json:"fromAddr"` 55 | ToAddr string `json:"toAddr"` 56 | TxAsset string `json:"txAsset"` 57 | TxFee string `json:"txFee"` 58 | Value string `json:"value"` 59 | } 60 | 61 | AccountMeta struct { 62 | Balances []TokenBalance `json:"balances"` 63 | } 64 | 65 | TokenBalance struct { 66 | Free string `json:"free"` 67 | Frozen string `json:"frozen"` 68 | Locked string `json:"locked"` 69 | Symbol string `json:"symbol"` 70 | } 71 | 72 | Tokens []Token 73 | 74 | Token struct { 75 | ContractAddress string `json:"contract_address"` 76 | Name string `json:"name"` 77 | OriginalSymbol string `json:"original_symbol"` 78 | Owner string `json:"owner"` 79 | Symbol string `json:"symbol"` 80 | TotalSupply string `json:"total_supply"` 81 | } 82 | 83 | MarketPair struct { 84 | BaseAssetSymbol string `json:"base_asset_symbol"` 85 | LotSize string `json:"lot_size"` 86 | QuoteAssetSymbol string `json:"quote_asset_symbol"` 87 | TickSize string `json:"tick_size"` 88 | } 89 | ) 90 | -------------------------------------------------------------------------------- /cache/redis/client_interface.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/redis/go-redis/v9" 8 | ) 9 | 10 | // redisClient is the underlying redis client interface, go-redis library 11 | // we need this interface mainly to unify implementation of cluster and single instance mode 12 | type redisClient interface { 13 | Get(ctx context.Context, key string) *redis.StringCmd 14 | MGet(ctx context.Context, keys ...string) *redis.SliceCmd 15 | Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd 16 | Del(ctx context.Context, keys ...string) *redis.IntCmd 17 | Pipeline() redis.Pipeliner 18 | Watch(ctx context.Context, fn func(tx *redis.Tx) error, keys ...string) error 19 | SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.BoolCmd 20 | SetXX(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.BoolCmd 21 | 22 | Ping(ctx context.Context) *redis.StatusCmd 23 | Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd 24 | Close() error 25 | 26 | redisClientForTest 27 | } 28 | 29 | // redisClientForTest lists functions used only in unit tests 30 | type redisClientForTest interface { 31 | TTL(ctx context.Context, key string) *redis.DurationCmd 32 | } 33 | -------------------------------------------------------------------------------- /client/api/backend/client.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/trustwallet/go-libs/client" 8 | ) 9 | 10 | type Client struct { 11 | req client.Request 12 | } 13 | 14 | func InitClient(url string, errorHandler client.HttpErrorHandler) Client { 15 | return Client{ 16 | req: client.InitJSONClient(url, errorHandler), 17 | } 18 | } 19 | 20 | func (c *Client) GetAssetInfo(assetID string) (result AssetInfoResp, err error) { 21 | _, err = c.req.Execute(context.TODO(), client.NewReqBuilder(). 22 | Method(http.MethodGet). 23 | Pathf("/v1/assets/%s", assetID). 24 | WriteTo(&result). 25 | Build()) 26 | return result, err 27 | } 28 | -------------------------------------------------------------------------------- /client/api/backend/model.go: -------------------------------------------------------------------------------- 1 | package backend 2 | 3 | type ( 4 | AssetInfoResp struct { 5 | Name string `json:"name"` 6 | Symbol string `json:"symbol"` 7 | Type string `json:"type"` 8 | Decimals int `json:"decimals"` 9 | AssetID string `json:"asset_id"` 10 | } 11 | ) 12 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | "time" 9 | 10 | "github.com/prometheus/client_golang/prometheus" 11 | 12 | log "github.com/sirupsen/logrus" 13 | ) 14 | 15 | const defaultTimeout = 5 * time.Second 16 | 17 | type Request struct { 18 | BaseURL string 19 | Headers map[string]string 20 | Host string 21 | HttpClient HTTPClient 22 | HttpErrorHandler HttpErrorHandler 23 | 24 | // Monitoring 25 | metricRegisterer prometheus.Registerer 26 | httpMetrics *httpClientMetrics 27 | } 28 | 29 | type HTTPClient interface { 30 | Do(req *http.Request) (*http.Response, error) 31 | } 32 | 33 | type HttpError struct { 34 | StatusCode int 35 | URL url.URL 36 | Body []byte 37 | } 38 | 39 | func (e *HttpError) Error() string { 40 | return fmt.Sprintf( 41 | "Failed request status %d for url: (%s), body: (%s)", 42 | e.StatusCode, 43 | e.URL.RequestURI(), 44 | string(e.Body), 45 | ) 46 | } 47 | 48 | type HttpErrorHandler func(res *http.Response, uri string) error 49 | 50 | type Option func(request *Request) error 51 | 52 | func InitClient(baseURL string, errorHandler HttpErrorHandler, options ...Option) Request { 53 | if errorHandler == nil { 54 | errorHandler = DefaultErrorHandler 55 | } 56 | 57 | client := Request{ 58 | Headers: make(map[string]string), 59 | HttpClient: &http.Client{ 60 | Timeout: defaultTimeout, 61 | }, 62 | HttpErrorHandler: errorHandler, 63 | BaseURL: baseURL, 64 | } 65 | 66 | for _, option := range options { 67 | err := option(&client) 68 | if err != nil { 69 | log.Fatal("Could not initialize http client", err) 70 | } 71 | } 72 | 73 | if client.metricsEnabled() { 74 | err := client.metricRegisterer.Register(client.httpMetrics) 75 | if err != nil { 76 | if _, ok := err.(*prometheus.AlreadyRegisteredError); ok { 77 | log.WithError(err).Warn("metric already registered") 78 | } else { 79 | log.WithError(err).Error("could not initialize http client metrics") 80 | } 81 | } 82 | } 83 | 84 | return client 85 | } 86 | 87 | func InitJSONClient(baseUrl string, errorHandler HttpErrorHandler, options ...Option) Request { 88 | jsonHeaders := map[string]string{ 89 | "Content-Type": "application/json", 90 | "Accept": "application/json", 91 | } 92 | 93 | client := InitClient( 94 | baseUrl, 95 | errorHandler, 96 | append(options, WithExtraHeaders(jsonHeaders))...) 97 | return client 98 | } 99 | 100 | var DefaultErrorHandler = func(res *http.Response, uri string) error { 101 | return nil 102 | } 103 | 104 | // TimeoutOption is an option to set timeout for the http client calls 105 | // value unit is nanoseconds 106 | func TimeoutOption(timeout time.Duration) Option { 107 | return func(request *Request) error { 108 | httpClient, ok := request.HttpClient.(*http.Client) 109 | if !ok { 110 | return errors.New("unable to set timeout: httpclient is not *http.Client") 111 | } 112 | 113 | httpClient.Timeout = timeout 114 | return nil 115 | } 116 | } 117 | 118 | func ProxyOption(proxyURL string) Option { 119 | return func(request *Request) error { 120 | if proxyURL == "" { 121 | return nil 122 | } 123 | 124 | httpClient, ok := request.HttpClient.(*http.Client) 125 | if !ok { 126 | return errors.New("unable to set proxy: httpclient is not *http.Client") 127 | } 128 | 129 | return setHttpClientTransportProxy(httpClient, proxyURL) 130 | } 131 | } 132 | 133 | func WithHttpClient(httpClient HTTPClient) Option { 134 | return func(request *Request) error { 135 | request.HttpClient = httpClient 136 | return nil 137 | } 138 | } 139 | 140 | func WithExtraHeader(key, value string) Option { 141 | return func(request *Request) error { 142 | request.Headers[key] = value 143 | return nil 144 | } 145 | } 146 | 147 | func WithHost(host string) Option { 148 | return func(request *Request) error { 149 | request.Host = host 150 | return nil 151 | } 152 | } 153 | 154 | func WithExtraHeaders(headers map[string]string) Option { 155 | return func(request *Request) error { 156 | for k, v := range headers { 157 | request.Headers[k] = v 158 | } 159 | return nil 160 | } 161 | } 162 | 163 | func WithMetricsEnabled(reg prometheus.Registerer, constLabels prometheus.Labels) Option { 164 | return func(request *Request) error { 165 | request.httpMetrics = newHttpClientMetrics(constLabels) 166 | request.metricRegisterer = reg 167 | return nil 168 | } 169 | } 170 | 171 | // Deprecated: Internal http.Client shouldn't be modified after construction. Use WithHttpClient instead 172 | func (r *Request) SetTimeout(timeout time.Duration) { 173 | r.HttpClient.(*http.Client).Timeout = timeout 174 | } 175 | 176 | // Deprecated: Internal http.Client shouldn't be modified after construction. Use WithHttpClient instead 177 | func (r *Request) SetProxy(proxyUrl string) error { 178 | if proxyUrl == "" { 179 | return errors.New("empty proxy url") 180 | } 181 | url, err := url.Parse(proxyUrl) 182 | if err != nil { 183 | return err 184 | } 185 | r.HttpClient.(*http.Client).Transport = &http.Transport{Proxy: http.ProxyURL(url)} 186 | return nil 187 | } 188 | 189 | // Deprecated: Headers shouldn't be modified after construction. Use WithExtraHeaders instead 190 | func (r *Request) AddHeader(key, value string) { 191 | r.Headers[key] = value 192 | } 193 | 194 | func setHttpClientTransportProxy(client *http.Client, proxyUrl string) error { 195 | if proxyUrl == "" { 196 | return errors.New("empty proxy url") 197 | } 198 | url, err := url.Parse(proxyUrl) 199 | if err != nil { 200 | return err 201 | } 202 | 203 | if client.Transport == nil { 204 | client.Transport = &http.Transport{Proxy: http.ProxyURL(url)} 205 | return nil 206 | } 207 | 208 | transport, ok := client.Transport.(*http.Transport) 209 | if !ok { 210 | return errors.New("http client transport is not *http.Transport") 211 | } 212 | transport.Proxy = http.ProxyURL(url) 213 | return nil 214 | } 215 | -------------------------------------------------------------------------------- /client/client_execute.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "net/url" 11 | "strings" 12 | "time" 13 | ) 14 | 15 | // Execute executes http request as described in Req. 16 | // 17 | // If Req.WriteTo is specified, it will also populate the resultContainer 18 | func (r *Request) Execute(ctx context.Context, req *Req) ([]byte, error) { 19 | request, err := r.constructHttpRequest(ctx, req) 20 | if err != nil { 21 | return nil, err 22 | } 23 | 24 | startTime := time.Now() 25 | res, err := r.HttpClient.Do(request) 26 | r.reportMonitoringMetricsIfEnabled(startTime, request, req, res, err) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | if req.rawResponseContainer != nil && res != nil { 32 | *req.rawResponseContainer = *res 33 | } 34 | 35 | err = r.HttpErrorHandler(res, request.URL.String()) 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | defer res.Body.Close() 41 | b, err := io.ReadAll(res.Body) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { 47 | return nil, &HttpError{ 48 | StatusCode: res.StatusCode, 49 | URL: *request.URL, 50 | Body: b, 51 | } 52 | } 53 | 54 | err = populateResultContainer(b, req.resultContainer) 55 | if err != nil { 56 | return b, err 57 | } 58 | 59 | return b, nil 60 | } 61 | 62 | // constructHttpRequest constructs a http.Request object from description in Req and common headers in r. 63 | func (r *Request) constructHttpRequest(ctx context.Context, req *Req) (*http.Request, error) { 64 | body, err := GetBody(req.body) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | request, err := http.NewRequestWithContext(ctx, req.method, r.GetURL(req.path.String(), req.query), body) 70 | if err != nil { 71 | return nil, err 72 | } 73 | 74 | r.setRequestHeaders(request, req) 75 | 76 | if r.Host != "" { 77 | request.Host = r.Host 78 | } 79 | return request, nil 80 | } 81 | 82 | func (r *Request) reportMonitoringMetricsIfEnabled( 83 | startTime time.Time, request *http.Request, 84 | req *Req, res *http.Response, resErr error, 85 | ) { 86 | if r.metricsEnabled() { 87 | url := r.GetURL(getMonitoredPathTemplateIfEnabled(req), nil) 88 | method := request.Method 89 | name := req.metricName 90 | status := getHttpRespMetricStatus(res, resErr) 91 | 92 | r.httpMetrics.observeDuration(url, method, name, startTime) 93 | r.httpMetrics.observeResult(url, method, name, status) 94 | } 95 | } 96 | 97 | // setRequestHeaders sets the given httpRequest with the common headers from the client, and headers specified in Req. 98 | // If there are duplicated headers, the headers specified in Req takes precedence. 99 | func (r *Request) setRequestHeaders(httpRequest *http.Request, req *Req) { 100 | headersSlice := []map[string]string{r.Headers, req.headers} 101 | for _, headers := range headersSlice { 102 | for key, value := range headers { 103 | httpRequest.Header.Set(key, value) 104 | } 105 | } 106 | } 107 | 108 | // populateResultContainer populates the given resultContainer if it's not nil 109 | func populateResultContainer(b []byte, resultContainer any) error { 110 | if resultContainer != nil { 111 | err := json.Unmarshal(b, resultContainer) 112 | if err != nil { 113 | return err 114 | } 115 | } 116 | return nil 117 | } 118 | 119 | func getMonitoredPathTemplateIfEnabled(req *Req) string { 120 | if !req.pathMetricEnabled { 121 | return "" 122 | } 123 | return req.path.template 124 | } 125 | 126 | func (r *Request) GetBase(path string) string { 127 | baseURL := strings.TrimRight(r.BaseURL, "/") 128 | if path == "" { 129 | return baseURL 130 | } 131 | path = strings.TrimLeft(path, "/") 132 | return fmt.Sprintf("%s/%s", baseURL, path) 133 | } 134 | 135 | func (r *Request) GetURL(path string, query url.Values) string { 136 | baseURL := r.GetBase(path) 137 | if query == nil { 138 | return baseURL 139 | } 140 | queryStr := query.Encode() 141 | return fmt.Sprintf("%s?%s", baseURL, queryStr) 142 | } 143 | 144 | func (r *Request) metricsEnabled() bool { 145 | return r.httpMetrics != nil 146 | } 147 | 148 | func GetBody(body interface{}) (buf io.ReadWriter, err error) { 149 | if body != nil { 150 | buf = new(bytes.Buffer) 151 | err = json.NewEncoder(buf).Encode(body) 152 | } 153 | return 154 | } 155 | -------------------------------------------------------------------------------- /client/client_metrics.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/prometheus/client_golang/prometheus" 9 | ) 10 | 11 | const ( 12 | namespaceHttpClient = "httpclient" 13 | 14 | metricNameRequestDurationSeconds = "request_duration_seconds" 15 | metricNameRequestTotal = "request_total" 16 | 17 | labelUrl = "url" 18 | labelMethod = "method" 19 | labelStatus = "status" 20 | labelName = "name" 21 | 22 | labelValueErr = "error" 23 | ) 24 | 25 | type httpClientMetrics struct { 26 | durationSeconds *prometheus.HistogramVec 27 | requestTotal *prometheus.CounterVec 28 | } 29 | 30 | func newHttpClientMetrics(constLabels prometheus.Labels) *httpClientMetrics { 31 | m := &httpClientMetrics{ 32 | durationSeconds: prometheus.NewHistogramVec(prometheus.HistogramOpts{ 33 | Namespace: namespaceHttpClient, 34 | Name: metricNameRequestDurationSeconds, 35 | Help: "Histogram of duration of outgoing http requests", 36 | ConstLabels: constLabels, 37 | }, []string{labelUrl, labelMethod, labelName}), 38 | requestTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ 39 | Namespace: namespaceHttpClient, 40 | Name: metricNameRequestTotal, 41 | Help: "Count of total outgoing http requests, with its result status in labels", 42 | ConstLabels: constLabels, 43 | }, []string{labelUrl, labelMethod, labelName, labelStatus}), 44 | } 45 | 46 | return m 47 | } 48 | 49 | func (metric *httpClientMetrics) observeDuration(url, method, name string, startTime time.Time) { 50 | metric.durationSeconds.WithLabelValues(url, method, name).Observe(time.Since(startTime).Seconds()) 51 | } 52 | 53 | func (metric *httpClientMetrics) observeResult(url, method, name, status string) { 54 | metric.requestTotal.WithLabelValues(url, method, name, status).Inc() 55 | } 56 | 57 | // Describe implements prometheus.Collector interface 58 | func (metric *httpClientMetrics) Describe(descs chan<- *prometheus.Desc) { 59 | metric.durationSeconds.Describe(descs) 60 | metric.requestTotal.Describe(descs) 61 | } 62 | 63 | // Collect implements prometheus.Collector interface 64 | func (metric *httpClientMetrics) Collect(metrics chan<- prometheus.Metric) { 65 | metric.durationSeconds.Collect(metrics) 66 | metric.requestTotal.Collect(metrics) 67 | } 68 | 69 | func getHttpRespMetricStatus(resp *http.Response, err error) string { 70 | if err != nil { 71 | return labelValueErr 72 | } 73 | firstDigit := resp.StatusCode / 100 74 | return fmt.Sprintf("%dxx", firstDigit) 75 | } 76 | -------------------------------------------------------------------------------- /client/client_wrapper.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "net/url" 7 | ) 8 | 9 | // Deprecated: Left as backwards-compatibility. Use Execute(NewReqBuilder()) for better APIs and monitoring 10 | func (r *Request) GetWithContext(ctx context.Context, result interface{}, path string, query url.Values) error { 11 | _, err := r.Execute(ctx, NewReqBuilder(). 12 | Method(http.MethodGet). 13 | PathStatic(path). 14 | Query(query). 15 | WriteTo(result). 16 | pathMetricEnabled(false). 17 | Build()) 18 | return err 19 | } 20 | 21 | // Deprecated: Left as backwards-compatibility. Use Execute(NewReqBuilder()) for better APIs and monitoring 22 | func (r *Request) Get(result interface{}, path string, query url.Values) error { 23 | return r.GetWithContext(context.Background(), result, path, query) 24 | } 25 | 26 | // Deprecated: Left as backwards-compatibility. Use Execute(NewReqBuilder()) for better APIs and monitoring 27 | func (r *Request) Post(result interface{}, path string, body interface{}) error { 28 | return r.PostWithContext(context.Background(), result, path, body) 29 | } 30 | 31 | // Deprecated: Left as backwards-compatibility. Use Execute(NewReqBuilder()) for better APIs and monitoring 32 | func (r *Request) GetRaw(path string, query url.Values) ([]byte, error) { 33 | return r.Execute(context.Background(), NewReqBuilder(). 34 | Method(http.MethodGet). 35 | PathStatic(path). 36 | Query(query). 37 | pathMetricEnabled(false). 38 | Build()) 39 | } 40 | 41 | // Deprecated: Left as backwards-compatibility. Use Execute(NewReqBuilder()) for better APIs and monitoring 42 | func (r *Request) PostRaw(path string, body interface{}) ([]byte, error) { 43 | return r.Execute(context.Background(), NewReqBuilder(). 44 | Method(http.MethodPost). 45 | PathStatic(path). 46 | Body(body). 47 | pathMetricEnabled(false). 48 | Build()) 49 | } 50 | 51 | // Deprecated: Left as backwards-compatibility. Use Execute(NewReqBuilder()) for better APIs and monitoring 52 | func (r *Request) PostWithContext(ctx context.Context, result interface{}, path string, body interface{}) error { 53 | _, err := r.Execute(ctx, NewReqBuilder(). 54 | Method(http.MethodPost). 55 | PathStatic(path). 56 | Body(body). 57 | WriteTo(result). 58 | pathMetricEnabled(false). 59 | Build()) 60 | return err 61 | } 62 | -------------------------------------------------------------------------------- /client/client_wrapper_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "testing" 12 | 13 | "github.com/gin-gonic/gin" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func TestRequest_Get(t *testing.T) { 18 | const aBaseURL = "http://www.example.com" 19 | 20 | var responses = []string{ 21 | `{"status": "success"}`, 22 | `{"status": "success with data"}`, 23 | } 24 | 25 | router := gin.New() 26 | router.GET("/test", func(c *gin.Context) { 27 | c.Data(http.StatusOK, gin.MIMEJSON, []byte(responses[0])) 28 | }) 29 | 30 | router.GET("/path/with/query", func(c *gin.Context) { 31 | queryData := c.Query("data") 32 | if queryData != "testdata" { 33 | _ = c.AbortWithError(http.StatusBadRequest, errors.New("ooops")) 34 | return 35 | } 36 | c.Data(http.StatusOK, gin.MIMEJSON, []byte(responses[1])) 37 | }) 38 | 39 | httpClient := httpClientFromGinEngine(t, router, aBaseURL) 40 | c := InitClient(aBaseURL, nil, WithHttpClient(httpClient)) 41 | 42 | tests := []struct { 43 | name string 44 | path string 45 | query url.Values 46 | expectedResp string 47 | assertError require.ErrorAssertionFunc 48 | }{ 49 | { 50 | name: "happy path simple", 51 | path: "/test", 52 | query: nil, 53 | expectedResp: responses[0], 54 | assertError: require.NoError, 55 | }, 56 | { 57 | name: "happy path with query string", 58 | path: "/path/with/query", 59 | query: func() url.Values { 60 | v := url.Values{} 61 | v.Set("data", "testdata") 62 | return v 63 | }(), 64 | expectedResp: responses[1], 65 | assertError: require.NoError, 66 | }, 67 | { 68 | name: "error path", 69 | path: "/path/with/query", 70 | query: func() url.Values { 71 | v := url.Values{} 72 | v.Set("data", "wrong_value") 73 | return v 74 | }(), 75 | expectedResp: "{}", 76 | assertError: require.Error, 77 | }, 78 | } 79 | 80 | for _, test := range tests { 81 | t.Run(test.name, func(t *testing.T) { 82 | t.Run("GetWithContext", func(t *testing.T) { 83 | resObj := map[string]string{} 84 | err := c.GetWithContext(context.Background(), &resObj, test.path, test.query) 85 | test.assertError(t, err) 86 | 87 | actualRespStr, err := json.Marshal(resObj) 88 | require.NoError(t, err) 89 | require.JSONEq(t, test.expectedResp, string(actualRespStr)) 90 | }) 91 | 92 | t.Run("GetRaw", func(t *testing.T) { 93 | bytes, err := c.GetRaw(test.path, test.query) 94 | test.assertError(t, err) 95 | 96 | if string(bytes) == "" { 97 | bytes = []byte("{}") 98 | } 99 | require.JSONEq(t, test.expectedResp, string(bytes)) 100 | }) 101 | }) 102 | } 103 | } 104 | 105 | func TestRequest_Post(t *testing.T) { 106 | const aBaseURL = "http://www.example.com" 107 | 108 | type reqStruct struct { 109 | Data string `json:"data"` 110 | } 111 | var responses = []string{ 112 | `{"status": "success"}`, 113 | `{"status": "success with request"}`, 114 | } 115 | 116 | router := gin.New() 117 | router.POST("/test", func(c *gin.Context) { 118 | c.Data(http.StatusOK, gin.MIMEJSON, []byte(responses[0])) 119 | }) 120 | 121 | router.POST("/a/very/long/path", func(c *gin.Context) { 122 | var req reqStruct 123 | _ = c.Bind(&req) 124 | if req.Data != "testdata" { 125 | _ = c.AbortWithError(http.StatusBadRequest, errors.New("ooops")) 126 | return 127 | } 128 | c.Data(http.StatusOK, gin.MIMEJSON, []byte(responses[1])) 129 | }) 130 | 131 | httpClient := httpClientFromGinEngine(t, router, aBaseURL) 132 | c := InitJSONClient(aBaseURL, nil, WithHttpClient(httpClient)) 133 | 134 | tests := []struct { 135 | name string 136 | path string 137 | body any 138 | expectedResp string 139 | assertError require.ErrorAssertionFunc 140 | }{ 141 | { 142 | name: "happy path no request", 143 | path: "/test", 144 | expectedResp: responses[0], 145 | body: nil, 146 | assertError: require.NoError, 147 | }, 148 | { 149 | name: "happy path - long path, with request", 150 | path: "/a/very/long/path", 151 | expectedResp: responses[1], 152 | body: reqStruct{Data: "testdata"}, 153 | assertError: require.NoError, 154 | }, 155 | { 156 | name: "error path", 157 | path: "/path/with/query", 158 | body: reqStruct{Data: "wrong_data"}, 159 | expectedResp: "{}", 160 | assertError: require.Error, 161 | }, 162 | } 163 | 164 | for _, test := range tests { 165 | t.Run(test.name, func(t *testing.T) { 166 | t.Run("PostWithContext", func(t *testing.T) { 167 | resObj := map[string]string{} 168 | err := c.PostWithContext(context.Background(), &resObj, test.path, test.body) 169 | test.assertError(t, err) 170 | 171 | actualRespStr, err := json.Marshal(resObj) 172 | require.NoError(t, err) 173 | require.JSONEq(t, test.expectedResp, string(actualRespStr)) 174 | }) 175 | 176 | t.Run("PostRaw", func(t *testing.T) { 177 | bytes, err := c.PostRaw(test.path, test.body) 178 | test.assertError(t, err) 179 | 180 | if string(bytes) == "" { 181 | bytes = []byte("{}") 182 | } 183 | require.JSONEq(t, test.expectedResp, string(bytes)) 184 | }) 185 | }) 186 | } 187 | } 188 | 189 | func httpClientFromGinEngine(t *testing.T, engine *gin.Engine, baseURL string) *http.Client { 190 | return &http.Client{ 191 | Transport: RoundTripperFunc(func(request *http.Request) (*http.Response, error) { 192 | require.Equal(t, baseURL, fmt.Sprintf("%s://%s", request.URL.Scheme, request.URL.Host)) 193 | 194 | w := httptest.NewRecorder() 195 | engine.ServeHTTP(w, request) 196 | res := w.Result() 197 | res.Request = request 198 | return res, nil 199 | }), 200 | } 201 | } 202 | -------------------------------------------------------------------------------- /client/clientcache.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "crypto/sha1" 6 | "encoding/base64" 7 | "encoding/json" 8 | "errors" 9 | "net/url" 10 | "strings" 11 | "time" 12 | 13 | "github.com/patrickmn/go-cache" 14 | ) 15 | 16 | var memoryCache *memCache 17 | 18 | func init() { 19 | memoryCache = &memCache{cache: cache.New(5*time.Minute, 5*time.Minute)} 20 | } 21 | 22 | type memCache struct { 23 | cache *cache.Cache 24 | } 25 | 26 | func (r *Request) PostWithCache(result interface{}, path string, body interface{}, cache time.Duration) error { 27 | return r.PostWithCacheAndContext(context.Background(), result, path, body, cache) 28 | } 29 | 30 | func (r *Request) PostWithCacheAndContext(ctx context.Context, result interface{}, path string, body interface{}, cache time.Duration) error { 31 | key := r.generateKey(path, nil, body) 32 | err := memoryCache.getCache(key, result) 33 | if err == nil { 34 | return nil 35 | } 36 | 37 | err = r.PostWithContext(ctx, result, path, body) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | return memoryCache.setCache(key, result, cache) 43 | } 44 | 45 | func (r *Request) GetWithCache(result interface{}, path string, query url.Values, cache time.Duration) error { 46 | return r.GetWithCacheAndContext(context.Background(), result, path, query, cache) 47 | } 48 | 49 | func (r *Request) GetWithCacheAndContext(ctx context.Context, result interface{}, path string, query url.Values, cache time.Duration) error { 50 | key := r.generateKey(path, query, nil) 51 | err := memoryCache.getCache(key, result) 52 | if err == nil { 53 | return nil 54 | } 55 | 56 | err = r.GetWithContext(ctx, result, path, query) 57 | if err != nil { 58 | return err 59 | } 60 | 61 | return memoryCache.setCache(key, result, cache) 62 | } 63 | 64 | func (mc *memCache) setCache(key string, value interface{}, duration time.Duration) error { 65 | b, err := json.Marshal(value) 66 | if err != nil { 67 | return errors.New(err.Error() + " client cache cannot marshal cache object") 68 | } 69 | memoryCache.cache.Set(key, b, duration) 70 | return nil 71 | } 72 | 73 | func (mc *memCache) getCache(key string, value interface{}) error { 74 | c, ok := mc.cache.Get(key) 75 | if !ok { 76 | return errors.New("validator cache: invalid cache key") 77 | } 78 | r, ok := c.([]byte) 79 | if !ok { 80 | return errors.New("validator cache: failed to cast cache to bytes") 81 | } 82 | err := json.Unmarshal(r, value) 83 | if err != nil { 84 | return errors.New(err.Error() + " not found") 85 | } 86 | return nil 87 | } 88 | 89 | func (r *Request) generateKey(path string, query url.Values, body interface{}) string { 90 | var queryStr = "" 91 | if query != nil { 92 | queryStr = query.Encode() 93 | } 94 | requestUrl := strings.Join([]string{r.GetBase(path), queryStr}, "?") 95 | var b []byte 96 | if body != nil { 97 | b, _ = json.Marshal(body) 98 | } 99 | hash := sha1.Sum(append([]byte(requestUrl), b...)) 100 | return base64.URLEncoding.EncodeToString(hash[:]) 101 | } 102 | -------------------------------------------------------------------------------- /client/clientcache_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "net/url" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestRequest_generateKey(t *testing.T) { 10 | type args struct { 11 | baseURL string 12 | path string 13 | query url.Values 14 | body interface{} 15 | } 16 | tests := []struct { 17 | name string 18 | args args 19 | want string 20 | }{ 21 | { 22 | name: "test cosmos key without params", 23 | args: args{ 24 | baseURL: "https://raw.githubusercontent.com/trustwallet/assets/master/blockchains/cosmos", 25 | path: "validators/list.json", 26 | }, 27 | want: "ukpgy7t9m_vLHvyQL82smBoTov4=", 28 | }, 29 | { 30 | name: "test cosmos key with params", 31 | args: args{ 32 | baseURL: "https://raw.githubusercontent.com/trustwallet/assets/master/blockchains/cosmos", 33 | path: "validators/list.json", 34 | query: url.Values{"address": {"TQZskDJJRGAHifeKoQ7wLey42iGvwp3"}, "visible": {"false"}}, 35 | }, 36 | want: "jkkaXhzkelj5l3WE_B57Q1IY0Qo=", 37 | }, 38 | {name: "test tron key without params ", 39 | args: args{ 40 | baseURL: "https://api.trongrid.io", 41 | path: "wallet/getaccount", 42 | }, 43 | want: "PIoOx2azFYta4KMAtt0lttrqquM=", 44 | }, 45 | {name: "test tron key with params 1", 46 | args: args{ 47 | baseURL: "https://api.trongrid.io", 48 | path: "wallet/getaccount", 49 | body: struct { 50 | Address string `json:"address"` 51 | Visible bool `json:"visible"` 52 | }{Address: "TQZskDJJRGAHifeKoQ7wLC4QDyB2iGvwp2", Visible: true}, 53 | }, 54 | want: "h0noiR5a4M_RGQBH7805sgGl_HE=", 55 | }, 56 | {name: "test tron key with params 2", 57 | args: args{ 58 | baseURL: "https://api.trongrid.io", 59 | path: "wallet/getaccount", 60 | body: struct { 61 | Address string `json:"address"` 62 | Visible bool `json:"visible"` 63 | }{Address: "TQZskDJJRGAHifeKoQ7wLey42iGvwp3", Visible: false}, 64 | }, 65 | want: "Admv3wAXHkirPi4SaIXimDgLbow=", 66 | }, 67 | } 68 | for _, tt := range tests { 69 | t.Run(tt.name, func(t *testing.T) { 70 | r := &Request{BaseURL: tt.args.baseURL} 71 | if got := r.generateKey(tt.args.path, tt.args.query, tt.args.body); got != tt.want { 72 | t.Errorf("generateKey() = %v, want %v", got, tt.want) 73 | } 74 | }) 75 | } 76 | } 77 | 78 | type ( 79 | args struct { 80 | baseURL string 81 | path string 82 | query url.Values 83 | result interface{} 84 | } 85 | response struct { 86 | ID string `json:"id"` 87 | Name string `json:"name"` 88 | Description string `json:"description"` 89 | Website string `json:"website"` 90 | } 91 | test struct { 92 | name string 93 | args args 94 | } 95 | ) 96 | 97 | func testCollection() []test { 98 | return []test{ 99 | { 100 | name: "test cosmos key without params", 101 | args: args{ 102 | baseURL: "https://raw.githubusercontent.com/trustwallet/assets/master/blockchains/cosmos/", 103 | path: "validators/list.json", 104 | result: new([]response), 105 | }, 106 | }, 107 | { 108 | name: "test cosmos key with params", 109 | args: args{ 110 | baseURL: "https://raw.githubusercontent.com/trustwallet/assets/master/blockchains/cosmos/", 111 | path: "validators/list.json", 112 | query: url.Values{"address": {"TQZskDJJRGAHifeKoQ7wLey42iGvwp3"}, "visible": {"false"}}, 113 | result: new([]response), 114 | }, 115 | }, 116 | } 117 | } 118 | 119 | func TestRequest_GetWithCache(t *testing.T) { 120 | for _, tt := range testCollection() { 121 | t.Run(tt.name, func(t *testing.T) { 122 | r := InitClient(tt.args.baseURL, nil) 123 | if err := r.GetWithCache(tt.args.result, tt.args.path, tt.args.query, time.Duration(1*time.Second)); err != nil { 124 | t.Errorf("GetWithCache was failed for %v, error %v", tt.name, err) 125 | } 126 | 127 | key := r.generateKey(tt.args.path, tt.args.query, nil) 128 | 129 | _, ok := memoryCache.cache.Get(key) 130 | 131 | if !ok { 132 | t.Errorf("GetWithCache could not find cache for %v", tt.name) 133 | } 134 | }) 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /client/jsonrpc.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | ) 9 | 10 | var requestID = int64(0) 11 | 12 | const JsonRpcVersion = "2.0" 13 | 14 | type ( 15 | RpcRequests []*RpcRequest 16 | 17 | RpcRequest struct { 18 | JsonRpc string `json:"jsonrpc"` 19 | Method string `json:"method"` 20 | Params interface{} `json:"params,omitempty"` 21 | Id int64 `json:"id,omitempty"` 22 | } 23 | 24 | RpcResponse struct { 25 | JsonRpc string `json:"jsonrpc"` 26 | Error *RpcError `json:"error,omitempty"` 27 | Result interface{} `json:"result,omitempty"` 28 | Id int64 `json:"id,omitempty"` 29 | } 30 | 31 | RpcResponseRaw struct { 32 | JsonRpc string `json:"jsonrpc"` 33 | Error *RpcError `json:"error,omitempty"` 34 | Result json.RawMessage `json:"result,omitempty"` 35 | Id int64 `json:"id,omitempty"` 36 | } 37 | 38 | RpcError struct { 39 | Code int `json:"code"` 40 | Message string `json:"message"` 41 | Data string `json:"data"` 42 | } 43 | ) 44 | 45 | func (r *Request) RpcCall(result interface{}, method string, params interface{}) error { 46 | req := &RpcRequest{JsonRpc: JsonRpcVersion, Method: method, Params: params, Id: genID()} 47 | var resp *RpcResponse 48 | _, err := r.Execute(context.Background(), NewReqBuilder(). 49 | Method(http.MethodPost). 50 | WriteTo(&resp). 51 | Body(req). 52 | MetricName(method). 53 | Build()) 54 | if err != nil { 55 | return err 56 | } 57 | if resp.Error != nil { 58 | return resp.Error 59 | } 60 | return resp.GetObject(result) 61 | } 62 | 63 | func (r *Request) RpcCallRaw(method string, params interface{}) ([]byte, error) { 64 | req := &RpcRequest{JsonRpc: JsonRpcVersion, Method: method, Params: params, Id: genID()} 65 | var resp *RpcResponseRaw 66 | _, err := r.Execute(context.Background(), NewReqBuilder(). 67 | Method(http.MethodPost). 68 | WriteTo(&resp). 69 | Body(req). 70 | MetricName(method). 71 | Build()) 72 | if err != nil { 73 | return nil, err 74 | } 75 | if resp.Error != nil { 76 | return nil, resp.Error 77 | } 78 | return []byte(resp.Result), nil 79 | } 80 | 81 | func (r *Request) RpcBatchCall(requests RpcRequests) ([]RpcResponse, error) { 82 | var resp []RpcResponse 83 | _, err := r.Execute(context.Background(), NewReqBuilder(). 84 | Method(http.MethodPost). 85 | WriteTo(&resp). 86 | Body(requests.fillDefaultValues()). 87 | Build()) 88 | if err != nil { 89 | return nil, err 90 | } 91 | return resp, nil 92 | } 93 | 94 | func (e *RpcError) Error() string { 95 | return fmt.Sprintf("%s (%d)", e.Message, e.Code) 96 | } 97 | 98 | func (r *RpcResponse) GetObject(toType interface{}) error { 99 | js, err := json.Marshal(r.Result) 100 | if err != nil { 101 | return err 102 | } 103 | 104 | err = json.Unmarshal(js, toType) 105 | if err != nil { 106 | return err 107 | } 108 | return nil 109 | } 110 | 111 | func (rs RpcRequests) fillDefaultValues() RpcRequests { 112 | for _, r := range rs { 113 | r.JsonRpc = JsonRpcVersion 114 | r.Id = genID() 115 | } 116 | return rs 117 | } 118 | 119 | func genID() int64 { 120 | requestID++ 121 | return requestID 122 | } 123 | -------------------------------------------------------------------------------- /client/jsonrpc_batch.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | type RpcRequestMapper func(interface{}) RpcRequest 4 | 5 | func MakeBatchRequests(elements []interface{}, batchSize int, mapper RpcRequestMapper) (requests []RpcRequests) { 6 | batches := MakeBatches(elements, batchSize) 7 | for _, batch := range batches { 8 | var reqs RpcRequests 9 | for _, ele := range batch { 10 | mapped := mapper(ele) 11 | reqs = append(reqs, &mapped) 12 | } 13 | requests = append(requests, reqs) 14 | } 15 | return 16 | } 17 | 18 | func MakeBatches(elements []interface{}, batchSize int) (batches [][]interface{}) { 19 | batch := make([]interface{}, 0) 20 | size := 0 21 | for _, ele := range elements { 22 | if size >= batchSize { 23 | batches = append(batches, batch) 24 | size = 0 25 | batch = make([]interface{}, 0) 26 | } 27 | size++ 28 | batch = append(batch, ele) 29 | } 30 | batches = append(batches, batch) 31 | return 32 | } 33 | -------------------------------------------------------------------------------- /client/jsonrpc_batch_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func mapHash(hash interface{}) RpcRequest { 9 | array := []interface{}{hash} 10 | return RpcRequest{ 11 | Method: "GetTransaction", 12 | Params: array, 13 | } 14 | } 15 | 16 | func Test_makeRequests(t *testing.T) { 17 | type args struct { 18 | hashes []interface{} 19 | perGroup int 20 | } 21 | tests := []struct { 22 | name string 23 | args args 24 | want []RpcRequests 25 | }{ 26 | { 27 | name: "test group size 1", 28 | args: args{ 29 | hashes: []interface{}{ 30 | "0x1", "0x2", "0x3", 31 | }, 32 | perGroup: 1, 33 | }, 34 | want: []RpcRequests{ 35 | { 36 | &RpcRequest{ 37 | Method: "GetTransaction", 38 | Params: []interface{}{"0x1"}, 39 | }, 40 | }, 41 | { 42 | &RpcRequest{ 43 | Method: "GetTransaction", 44 | Params: []interface{}{"0x2"}, 45 | }, 46 | }, 47 | { 48 | &RpcRequest{ 49 | Method: "GetTransaction", 50 | Params: []interface{}{"0x3"}, 51 | }, 52 | }, 53 | }, 54 | }, 55 | } 56 | for _, tt := range tests { 57 | t.Run(tt.name, func(t *testing.T) { 58 | if got := MakeBatchRequests(tt.args.hashes, tt.args.perGroup, mapHash); !reflect.DeepEqual(got, tt.want) { 59 | t.Errorf("makeBatchRequests() = %v, want %v", got, tt.want) 60 | } 61 | }) 62 | } 63 | } 64 | 65 | func Test_makeBatches(t *testing.T) { 66 | type args struct { 67 | hashes []interface{} 68 | batchSize int 69 | } 70 | tests := []struct { 71 | name string 72 | args args 73 | wantBatches [][]interface{} 74 | }{ 75 | { 76 | name: "Test batch size 4", 77 | args: args{ 78 | hashes: []interface{}{ 79 | "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", 80 | }, 81 | batchSize: 4, 82 | }, 83 | wantBatches: [][]interface{}{ 84 | {"1", "2", "3", "4"}, 85 | {"5", "6", "7", "8"}, 86 | {"9", "10", "11"}, 87 | }, 88 | }, 89 | { 90 | name: "Test batch size 10", 91 | args: args{ 92 | hashes: []interface{}{ 93 | "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", 94 | }, 95 | batchSize: 10, 96 | }, 97 | wantBatches: [][]interface{}{ 98 | {"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}, 99 | {"11"}, 100 | }, 101 | }, 102 | { 103 | name: "Test batch size 11", 104 | args: args{ 105 | hashes: []interface{}{ 106 | "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", 107 | }, 108 | batchSize: 11, 109 | }, 110 | wantBatches: [][]interface{}{ 111 | {"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"}, 112 | }, 113 | }, 114 | } 115 | for _, tt := range tests { 116 | t.Run(tt.name, func(t *testing.T) { 117 | if gotBatches := MakeBatches(tt.args.hashes, tt.args.batchSize); !reflect.DeepEqual(gotBatches, tt.wantBatches) { 118 | t.Errorf("makeBatches() = %v, want %v", gotBatches, tt.wantBatches) 119 | } 120 | }) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /client/jsonrpc_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestRpcRequests_fillDefaultValues(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | rs RpcRequests 13 | want RpcRequests 14 | }{ 15 | { 16 | "test 1", 17 | RpcRequests{{Method: "method1", Params: "params1"}}, 18 | RpcRequests{{Method: "method1", Params: "params1", JsonRpc: JsonRpcVersion, Id: 1}}, 19 | }, { 20 | "test 2", 21 | RpcRequests{ 22 | {Method: "method1", Params: "params1"}, {Method: "method2", Params: "params2"}}, 23 | RpcRequests{ 24 | {Method: "method1", Params: "params1", JsonRpc: JsonRpcVersion, Id: 2}, 25 | {Method: "method2", Params: "params2", JsonRpc: JsonRpcVersion, Id: 3}, 26 | }, 27 | }, 28 | } 29 | for _, tt := range tests { 30 | t.Run(tt.name, func(t *testing.T) { 31 | got := tt.rs.fillDefaultValues() 32 | assert.Equal(t, tt.want, got) 33 | }) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /client/path.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import "fmt" 4 | 5 | type Path struct { 6 | template string 7 | values []any 8 | } 9 | 10 | func NewStaticPath(path string) Path { 11 | return Path{template: path} 12 | } 13 | 14 | func NewEmptyPath() Path { 15 | return Path{} 16 | } 17 | 18 | func NewPath(template string, values []any) Path { 19 | return Path{template: template, values: values} 20 | } 21 | 22 | func (p Path) String() string { 23 | return fmt.Sprintf(p.template, p.values...) 24 | } 25 | -------------------------------------------------------------------------------- /client/path_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestPath_String(t *testing.T) { 10 | type fields struct { 11 | template string 12 | values []any 13 | } 14 | tests := []struct { 15 | name string 16 | fields fields 17 | want string 18 | }{ 19 | { 20 | name: "empty template, empty values", 21 | fields: fields{ 22 | template: "", 23 | values: nil, 24 | }, 25 | want: "", 26 | }, 27 | { 28 | name: "empty template only", 29 | fields: fields{ 30 | template: "", 31 | values: []any{1, 2, 3}, 32 | }, 33 | want: "%!(EXTRA int=1, int=2, int=3)", 34 | }, 35 | { 36 | name: "empty values only", 37 | fields: fields{ 38 | template: "/api/v1/blocks", 39 | values: nil, 40 | }, 41 | want: "/api/v1/blocks", 42 | }, 43 | { 44 | name: "both exist", 45 | fields: fields{ 46 | template: "/nft/collections/%s/tokens", 47 | values: []any{"123"}, 48 | }, 49 | want: "/nft/collections/123/tokens", 50 | }, 51 | { 52 | name: "missing values", 53 | fields: fields{ 54 | template: "/nft/collections/%s/tokens/%d", 55 | values: []any{"123"}, 56 | }, 57 | want: "/nft/collections/123/tokens/%!d(MISSING)", 58 | }, 59 | { 60 | name: "multiple values", 61 | fields: fields{ 62 | template: "/nft/collections/%s/tokens/%s", 63 | values: []any{"123", "bnb"}, 64 | }, 65 | want: "/nft/collections/123/tokens/bnb", 66 | }, 67 | } 68 | for _, tt := range tests { 69 | t.Run(tt.name, func(t *testing.T) { 70 | p := Path{ 71 | template: tt.fields.template, 72 | values: tt.fields.values, 73 | } 74 | assert.Equalf(t, tt.want, p.String(), "String()") 75 | }) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /client/request.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | ) 7 | 8 | // Req defines a http request. It is named `Req` instead of `Request` because the http client is named `Request` 9 | // Consider renaming the client to other name. 10 | // 11 | // To build this struct, use NewReqBuilder. 12 | type Req struct { 13 | headers map[string]string 14 | resultContainer any 15 | method string 16 | path Path 17 | query url.Values 18 | body any 19 | rawResponseContainer *http.Response 20 | 21 | metricName string 22 | pathMetricEnabled bool 23 | } 24 | 25 | type ReqBuilder struct { 26 | req *Req 27 | } 28 | 29 | func NewReqBuilder() *ReqBuilder { 30 | return &ReqBuilder{ 31 | req: &Req{ 32 | headers: map[string]string{}, 33 | pathMetricEnabled: true, 34 | }, 35 | } 36 | } 37 | 38 | // Headers sets the headers of the http request. Headers will be overwritten in case of duplicates 39 | func (builder *ReqBuilder) Headers(headers map[string]string) *ReqBuilder { 40 | for k, v := range headers { 41 | builder.req.headers[k] = v 42 | } 43 | return builder 44 | } 45 | 46 | func (builder *ReqBuilder) WriteTo(resultContainer any) *ReqBuilder { 47 | builder.req.resultContainer = resultContainer 48 | return builder 49 | } 50 | 51 | func (builder *ReqBuilder) WriteRawResponseTo(resp *http.Response) *ReqBuilder { 52 | builder.req.rawResponseContainer = resp 53 | return builder 54 | } 55 | 56 | func (builder *ReqBuilder) Method(method string) *ReqBuilder { 57 | builder.req.method = method 58 | return builder 59 | } 60 | 61 | // PathStatic sets the path for the request. 62 | // Use PathStatic ONLY if your path doesn't contain any parameters. Otherwise, use Pathf instead 63 | func (builder *ReqBuilder) PathStatic(path string) *ReqBuilder { 64 | builder.req.path = NewStaticPath(path) 65 | return builder 66 | } 67 | 68 | func (builder *ReqBuilder) Pathf(pathTemplate string, values ...any) *ReqBuilder { 69 | builder.req.path = NewPath(pathTemplate, values) 70 | return builder 71 | } 72 | 73 | func (builder *ReqBuilder) Query(query url.Values) *ReqBuilder { 74 | builder.req.query = query 75 | return builder 76 | } 77 | 78 | func (builder *ReqBuilder) Body(body any) *ReqBuilder { 79 | builder.req.body = body 80 | return builder 81 | } 82 | 83 | func (builder *ReqBuilder) MetricName(name string) *ReqBuilder { 84 | builder.req.metricName = name 85 | return builder 86 | } 87 | 88 | // pathMetricEnabled is only for internal use, where it is set to false 89 | // in deprecated wrapper functions such as Get, GetWithContext, Post, PostRaw 90 | func (builder *ReqBuilder) pathMetricEnabled(enabled bool) *ReqBuilder { 91 | builder.req.pathMetricEnabled = enabled 92 | return builder 93 | } 94 | 95 | func (builder *ReqBuilder) Build() *Req { 96 | copiedReq := *builder.req 97 | return &copiedReq 98 | } 99 | -------------------------------------------------------------------------------- /client/request_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestRequest_WriteRawResponseTo(t *testing.T) { 14 | const ( 15 | pathOk = "/ok" 16 | path5xx = "/5xx" 17 | ) 18 | 19 | tests := []struct { 20 | name string 21 | path string 22 | statusCode int 23 | headers http.Header 24 | }{ 25 | { 26 | name: "Test write raw response with statusOK", 27 | path: pathOk, 28 | statusCode: http.StatusOK, 29 | headers: http.Header{ 30 | "Content-Type": []string{"application/json"}, 31 | "x-aptos-block-height": []string{"73287085"}, 32 | }, 33 | }, 34 | { 35 | name: "Test write raw response with status5xx", 36 | path: path5xx, 37 | statusCode: http.StatusInternalServerError, 38 | headers: http.Header{ 39 | "Content-Type": []string{"application/json"}, 40 | }, 41 | }, 42 | } 43 | for _, tt := range tests { 44 | t.Run(tt.name, func(t *testing.T) { 45 | client := InitClient("http://www.example.com", nil, 46 | WithHttpClient(&http.Client{ 47 | Transport: RoundTripperFunc(func(request *http.Request) (*http.Response, error) { 48 | switch request.URL.Path { 49 | case pathOk: 50 | return &http.Response{ 51 | StatusCode: http.StatusOK, 52 | Body: io.NopCloser(strings.NewReader(`{"Data": "ok"}`)), 53 | Header: tt.headers, 54 | }, nil 55 | case path5xx: 56 | return &http.Response{ 57 | StatusCode: http.StatusInternalServerError, 58 | Request: request, 59 | Body: io.NopCloser(strings.NewReader(`{"Data": "5xx"}`)), 60 | Header: tt.headers, 61 | }, nil 62 | default: 63 | return nil, nil 64 | } 65 | }), 66 | }), 67 | ) 68 | var resp http.Response 69 | _, _ = client.Execute(context.Background(), NewReqBuilder().Method(http.MethodGet).PathStatic(tt.path).WriteRawResponseTo(&resp).Build()) 70 | require.Equal(t, tt.headers, resp.Header) 71 | require.Equal(t, tt.statusCode, resp.StatusCode) 72 | }) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /config/viper/viper.go: -------------------------------------------------------------------------------- 1 | package viper 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | 7 | log "github.com/sirupsen/logrus" 8 | "github.com/spf13/viper" 9 | ) 10 | 11 | func Load(confPath string, receiver interface{}) { 12 | viper.AutomaticEnv() 13 | viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) 14 | 15 | viper.AddConfigPath(".") 16 | viper.SetConfigName("config") 17 | viper.SetConfigType("yml") 18 | 19 | configType := "default" 20 | if confPath != "" { 21 | viper.SetConfigFile(confPath) 22 | configType = "supplied" 23 | } 24 | 25 | err := viper.ReadInConfig() 26 | if err != nil { 27 | log.WithError(err).Fatalf("Read %s config", configType) 28 | } 29 | 30 | log.WithFields(log.Fields{"config": viper.ConfigFileUsed()}).Infof("Viper using %s config", configType) 31 | 32 | bindEnvs(reflect.ValueOf(receiver)) 33 | if err := viper.Unmarshal(receiver); err != nil { 34 | log.Panic(err, "Error Unmarshal Viper Config File") 35 | } 36 | } 37 | 38 | func bindEnvs(v reflect.Value, parts ...string) { 39 | if v.Kind() == reflect.Ptr { 40 | if v.IsNil() { 41 | return 42 | } 43 | 44 | bindEnvs(v.Elem(), parts...) 45 | return 46 | } 47 | 48 | ift := v.Type() 49 | for i := 0; i < ift.NumField(); i++ { 50 | val := v.Field(i) 51 | t := ift.Field(i) 52 | tv, ok := t.Tag.Lookup("mapstructure") 53 | if !ok { 54 | continue 55 | } 56 | switch val.Kind() { 57 | case reflect.Struct: 58 | bindEnvs(val, append(parts, tv)...) 59 | default: 60 | if err := viper.BindEnv(strings.Join(append(parts, tv), ".")); err != nil { 61 | log.Fatal(err) 62 | } 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /crypto/aes.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "crypto/rand" 7 | "encoding/base64" 8 | "io" 9 | ) 10 | 11 | func AESEncrypt(key []byte, message string) (string, error) { 12 | c, err := aes.NewCipher(key) 13 | if err != nil { 14 | return "", err 15 | } 16 | 17 | gcm, err := cipher.NewGCM(c) 18 | if err != nil { 19 | return "", err 20 | } 21 | 22 | nonce := make([]byte, gcm.NonceSize()) 23 | if _, err = io.ReadFull(rand.Reader, nonce); err != nil { 24 | return "", err 25 | } 26 | 27 | plaintext := []byte(message) 28 | ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) 29 | return base64.StdEncoding.EncodeToString(ciphertext), nil 30 | } 31 | 32 | func AESDecrypt(key []byte, secure string) (string, error) { 33 | ciphertext, err := base64.StdEncoding.DecodeString(secure) 34 | if err != nil { 35 | return "", err 36 | } 37 | 38 | c, err := aes.NewCipher(key) 39 | if err != nil { 40 | return "", err 41 | } 42 | gcm, err := cipher.NewGCM(c) 43 | if err != nil { 44 | return "", err 45 | } 46 | 47 | nonceSize := gcm.NonceSize() 48 | if len(ciphertext) < nonceSize { 49 | return "", err 50 | } 51 | 52 | nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] 53 | plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) 54 | if err != nil { 55 | return "", err 56 | } 57 | return string(plaintext), err 58 | } 59 | -------------------------------------------------------------------------------- /crypto/aes_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestAES(t *testing.T) { 9 | secret, message := []byte("RfUjXnZr4u7x!A%D*G-KaPdSgVkYp3s5"), "a plain text" 10 | 11 | _, err := AESEncrypt([]byte("not_a_valid_length"), message) 12 | assert.Error(t, err) 13 | 14 | encryptedMessage, err := AESEncrypt(secret, message) 15 | assert.NoError(t, err) 16 | assert.NotEmpty(t, encryptedMessage) 17 | assert.NotEqual(t, encryptedMessage, message) 18 | 19 | decryptedMessage, err := AESDecrypt(secret, encryptedMessage) 20 | assert.NoError(t, err) 21 | assert.Equal(t, decryptedMessage, message) 22 | 23 | failedMessage, err := AESDecrypt([]byte("this_is_an_invalid_secret_key___"), decryptedMessage) 24 | assert.Error(t, err) 25 | assert.Empty(t, failedMessage) 26 | } 27 | -------------------------------------------------------------------------------- /crypto/sign.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "crypto" 5 | "crypto/hmac" 6 | "crypto/rand" 7 | "crypto/rsa" 8 | "crypto/sha256" 9 | "crypto/x509" 10 | "encoding/pem" 11 | "errors" 12 | "fmt" 13 | "io" 14 | "os" 15 | "strings" 16 | ) 17 | 18 | // Signer is an interface that can be used to sign messages 19 | type Signer interface { 20 | Sign(msg []byte) ([]byte, error) 21 | } 22 | 23 | // SignFunc is a wrapper of sign functions to implement Signer interface 24 | type SignFunc func(msg []byte) ([]byte, error) 25 | 26 | func (sf SignFunc) Sign(msg []byte) ([]byte, error) { 27 | return sf(msg) 28 | } 29 | 30 | // NewSHA256WithRSASigner returns Signer instance which signs msg using SHA256WithRSA and private key 31 | func NewSHA256WithRSASigner(privateKey *rsa.PrivateKey) SignFunc { 32 | return func(msg []byte) ([]byte, error) { 33 | return SHA256WithRSA(msg, privateKey) 34 | } 35 | } 36 | 37 | // NewHMACSHA256Signer returns Signer instance which signs msg using HMACSHA256S and key 38 | func NewHMACSHA256Signer(key string) SignFunc { 39 | return func(msg []byte) ([]byte, error) { 40 | return HMACSHA256(msg, key) 41 | } 42 | } 43 | 44 | // SHA256WithRSA signs SHA256 hash of the message with RSA privateKey 45 | func SHA256WithRSA(msg []byte, privateKey *rsa.PrivateKey) ([]byte, error) { 46 | if privateKey == nil { 47 | return nil, errors.New("private key is empty") 48 | } 49 | 50 | h := sha256.New() 51 | _, err := h.Write(msg) 52 | if err != nil { 53 | return nil, fmt.Errorf("write bytes: %v", err) 54 | } 55 | 56 | res, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, h.Sum(nil)) 57 | if err != nil { 58 | return nil, fmt.Errorf("SignPKCS1v15: %v", err) 59 | } 60 | 61 | return res, nil 62 | } 63 | 64 | // HMACSHA256 signs message with HMAC SHA256 using key 65 | func HMACSHA256(msg []byte, key string) ([]byte, error) { 66 | h := hmac.New(sha256.New, []byte(key)) 67 | _, err := h.Write(msg) 68 | if err != nil { 69 | return nil, fmt.Errorf("hmac write: %v", err) 70 | } 71 | return h.Sum(nil), nil 72 | } 73 | 74 | // GetRSAPrivateKey reads RSA private key from the reader 75 | func GetRSAPrivateKey(reader io.Reader) (*rsa.PrivateKey, error) { 76 | bs, err := io.ReadAll(reader) 77 | if err != nil { 78 | return nil, fmt.Errorf("read bytes: %v", err) 79 | } 80 | 81 | privPem, _ := pem.Decode(bs) 82 | if privPem == nil { 83 | return nil, errors.New("decoded key is empty") 84 | } 85 | 86 | if privPem.Type != "RSA PRIVATE KEY" { 87 | return nil, errors.New("key type is not RSA private key") 88 | } 89 | 90 | var parsedKey interface{} 91 | parsedKey, err = x509.ParsePKCS1PrivateKey(privPem.Bytes) 92 | if err != nil { 93 | return nil, fmt.Errorf("parse PKCS1 private key: %v", err) 94 | } 95 | 96 | privateKey, ok := parsedKey.(*rsa.PrivateKey) 97 | if !ok { 98 | return nil, errors.New("parsed key is not RSA private key") 99 | } 100 | 101 | return privateKey, nil 102 | } 103 | 104 | // GetRSAPrivateKeyFromFile reads RSA private key from file 105 | func GetRSAPrivateKeyFromFile(fileName string) (*rsa.PrivateKey, error) { 106 | file, err := os.Open(fileName) 107 | if err != nil { 108 | return nil, fmt.Errorf("open file: %v", err) 109 | } 110 | 111 | return GetRSAPrivateKey(file) 112 | } 113 | 114 | // GetRSAPrivateKeyFromString reads RSA private key from string 115 | func GetRSAPrivateKeyFromString(s string) (*rsa.PrivateKey, error) { 116 | return GetRSAPrivateKey(strings.NewReader(s)) 117 | } 118 | -------------------------------------------------------------------------------- /crypto/sign_test.go: -------------------------------------------------------------------------------- 1 | package crypto 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rsa" 6 | "crypto/x509" 7 | "encoding/base64" 8 | "encoding/pem" 9 | "fmt" 10 | "os" 11 | "strings" 12 | "testing" 13 | 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | var privateKey = `-----BEGIN RSA PRIVATE KEY----- 18 | MIIBOgIBAAJBAK1ASa283Iotdl+Sbp5IRNjumvuTs/r0ZSt1S/8dqe08WN2GiDXn 19 | f+U1UOJPDp5qN7d+AoQSMUg2bHXeLjrxxCUCAwEAAQJAcYfJQGKcmqfEBEju2CY/ 20 | h3CEewuFS5RPn7TTwi/sJJrtEkeha4CYgGJJusAr8K3J0O8EBnMtEz+KltYDWd6i 21 | AQIhANSWLwXtb0lUqemqoslj3RKirsHac30IyyiJ45NQWp5BAiEA0KGuouUQdNbL 22 | vso31iilbUnJJ54k1C8hREoEAqx9NOUCIQC5INByaQKw6XnOczqwBrdOsz1cs9A+ 23 | 4pmJBAubDi7cAQIgOIFx4SCVQm/iovv1/4TmuSDg4GAOrYFOS0aYq3i4OJkCIAQw 24 | PklhQYvKRwjm1jiktUyTyRHIDSVSmveZ/8N6zJSW 25 | -----END RSA PRIVATE KEY----- 26 | ` 27 | 28 | func testCompareKeys(t *testing.T, exp string, act *rsa.PrivateKey) { 29 | var buf bytes.Buffer 30 | assert.NoError(t, pem.Encode(&buf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(act)})) 31 | 32 | assert.Equal(t, exp, buf.String()) 33 | } 34 | 35 | func TestGetRSAPrivateKey(t *testing.T) { 36 | key, err := GetRSAPrivateKey(strings.NewReader(privateKey)) 37 | assert.NoError(t, err) 38 | 39 | testCompareKeys(t, privateKey, key) 40 | } 41 | 42 | func TestHMACSHA256(t *testing.T) { 43 | res, err := HMACSHA256([]byte("test"), "e9a9b09e-6dfb-455e-8c27-7b206bec08a1") 44 | assert.NoError(t, err) 45 | assert.Equal( 46 | t, 47 | "9e99537c0a09c501bb348bc12743707beee35eba0b1bd885de15f91bc9311047", 48 | fmt.Sprintf("%x", string(res)), 49 | ) 50 | } 51 | 52 | func TestSHA256WithRSA(t *testing.T) { 53 | key, err := GetRSAPrivateKey(strings.NewReader(privateKey)) 54 | assert.NoError(t, err) 55 | 56 | res, err := SHA256WithRSA([]byte("test"), key) 57 | assert.NoError(t, err) 58 | assert.Equal( 59 | t, 60 | "dUiTbTPRbMhL0GyTuAE+BAbSxfEwdbWdzQuF2r3esVKg0CMtEa2btCN7O0eQezQFDRIQVXmhKRccqWPQw/Zjbw==", 61 | base64.StdEncoding.EncodeToString(res), 62 | ) 63 | } 64 | 65 | func TestGetRSAPrivateKeyFromFile(t *testing.T) { 66 | f, err := os.CreateTemp("", "test") 67 | assert.NoError(t, err) 68 | defer os.Remove(f.Name()) 69 | 70 | _, err = f.Write([]byte(privateKey)) 71 | assert.NoError(t, err) 72 | assert.NoError(t, f.Sync()) 73 | 74 | key, err := GetRSAPrivateKeyFromFile(f.Name()) 75 | assert.NoError(t, err) 76 | 77 | testCompareKeys(t, privateKey, key) 78 | } 79 | 80 | func TestGetRSAPrivateKeyFromString(t *testing.T) { 81 | key, err := GetRSAPrivateKeyFromString(privateKey) 82 | assert.NoError(t, err) 83 | 84 | testCompareKeys(t, privateKey, key) 85 | } 86 | -------------------------------------------------------------------------------- /ctask/do_all.go: -------------------------------------------------------------------------------- 1 | package ctask 2 | 3 | import ( 4 | "context" 5 | "runtime" 6 | 7 | "golang.org/x/sync/errgroup" 8 | ) 9 | 10 | type DoAllOpt func(cfg *DoAllConfig) 11 | type DoAllConfig struct { 12 | WorkerNum int 13 | } 14 | 15 | type DoAllResp[R any] struct { 16 | Result R 17 | Error error 18 | } 19 | 20 | // DoAll execute tasks using the given executor function for all the given tasks. 21 | // It waits until all tasks are finished. 22 | // The return value is a slice of Result or Error 23 | // 24 | // The max number of goroutines can optionally be specified using the option WithWorkerNum. 25 | // By default, it is set to runtime.NumCPU() 26 | func DoAll[Task any, Result any]( 27 | ctx context.Context, 28 | tasks []Task, 29 | executor func(ctx context.Context, t Task) (Result, error), 30 | opts ...DoAllOpt, 31 | ) []DoAllResp[Result] { 32 | cfg := getDoAllConfigWithOptions(opts...) 33 | 34 | g, ctx := errgroup.WithContext(ctx) 35 | g.SetLimit(cfg.WorkerNum) 36 | results := make([]DoAllResp[Result], len(tasks)) 37 | for idx, task := range tasks { 38 | idx, task := idx, task // retain current loop values to be used in goroutine 39 | g.Go(func() error { 40 | select { 41 | case <-ctx.Done(): 42 | results[idx] = DoAllResp[Result]{Error: ctx.Err()} 43 | return nil 44 | default: 45 | res, err := executor(ctx, task) 46 | results[idx] = DoAllResp[Result]{ 47 | Result: res, 48 | Error: err, 49 | } 50 | return nil 51 | } 52 | }) 53 | } 54 | _ = g.Wait() // impossible to have error here 55 | return results 56 | } 57 | 58 | func getDoAllConfigWithOptions(opts ...DoAllOpt) DoAllConfig { 59 | cfg := DoAllConfig{ 60 | WorkerNum: runtime.NumCPU(), 61 | } 62 | for _, opt := range opts { 63 | opt(&cfg) 64 | } 65 | return cfg 66 | } 67 | 68 | func WithDoAllWorkerNum(num int) DoAllOpt { 69 | return func(cfg *DoAllConfig) { 70 | cfg.WorkerNum = num 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /ctask/do_all_test.go: -------------------------------------------------------------------------------- 1 | package ctask 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestDoAll(t *testing.T) { 13 | type T = int // task type 14 | type R = int // result type 15 | 16 | type args struct { 17 | ctx context.Context 18 | ctxTimeout time.Duration 19 | tasks []T 20 | executor func(ctx context.Context, t T) (R, error) 21 | opts []DoAllOpt 22 | } 23 | tests := []struct { 24 | name string 25 | args args 26 | want []DoAllResp[R] 27 | }{ 28 | { 29 | name: "happy path", 30 | args: args{ 31 | ctx: context.Background(), 32 | tasks: []T{0, 1, 2, 3, 4, 5, 6}, 33 | executor: fibonacci, 34 | opts: nil, 35 | }, 36 | want: []DoAllResp[R]{ 37 | {Result: 1}, 38 | {Result: 1}, 39 | {Result: 2}, 40 | {Result: 3}, 41 | {Result: 5}, 42 | {Result: 8}, 43 | {Result: 13}, 44 | }, 45 | }, 46 | { 47 | name: "empty slice", 48 | args: args{ 49 | ctx: context.Background(), 50 | tasks: nil, 51 | executor: fibonacci, 52 | opts: nil, 53 | }, 54 | want: []DoAllResp[R]{}, 55 | }, 56 | { 57 | name: "error path", 58 | args: args{ 59 | ctx: context.Background(), 60 | tasks: []T{0, 1, 2, 1, -1, 5}, 61 | executor: fibonacci, 62 | opts: []DoAllOpt{WithDoAllWorkerNum(1)}, 63 | }, 64 | want: []DoAllResp[R]{ 65 | {Result: 1}, 66 | {Result: 1}, 67 | {Result: 2}, 68 | {Result: 1}, 69 | {Error: errors.New("negative")}, 70 | {Result: 8}, 71 | }, 72 | }, 73 | { 74 | name: "slow functions should return context deadline exceeded error", 75 | args: args{ 76 | ctx: context.Background(), 77 | ctxTimeout: 20 * time.Millisecond, 78 | tasks: []T{ 79 | 10, 1000, 10, 5000, 10, 80 | }, 81 | executor: func(ctx context.Context, t T) (R, error) { 82 | select { 83 | case <-ctx.Done(): 84 | return 0, ctx.Err() 85 | case <-time.After(time.Duration(t) * time.Millisecond): 86 | return 1, nil 87 | } 88 | }, 89 | opts: []DoAllOpt{WithDoAllWorkerNum(5)}, 90 | }, 91 | want: []DoAllResp[R]{ 92 | {Result: 1}, {Error: context.DeadlineExceeded}, {Result: 1}, {Error: context.DeadlineExceeded}, {Result: 1}, 93 | }, 94 | }, 95 | { 96 | name: "slow function with sleeps should run concurrently without context deadline error", 97 | args: args{ 98 | ctx: context.Background(), 99 | ctxTimeout: 50 * time.Millisecond, 100 | tasks: []T{ 101 | 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 102 | 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 103 | 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 104 | 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 105 | 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 106 | }, 107 | executor: func(ctx context.Context, t T) (R, error) { 108 | time.Sleep(time.Duration(t) * time.Millisecond) 109 | return 1, nil 110 | }, 111 | opts: []DoAllOpt{WithDoAllWorkerNum(20)}, 112 | }, 113 | want: []DoAllResp[R]{ 114 | {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, 115 | {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, 116 | {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, 117 | {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, 118 | {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, 119 | {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, 120 | {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, 121 | {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, 122 | {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, 123 | {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, 124 | }, 125 | }, 126 | } 127 | for _, tt := range tests { 128 | t.Run(tt.name, func(t *testing.T) { 129 | ctx := tt.args.ctx 130 | if tt.args.ctxTimeout > 0 { 131 | var cancel context.CancelFunc 132 | ctx, cancel = context.WithTimeout(ctx, tt.args.ctxTimeout) 133 | defer cancel() 134 | } 135 | got := DoAll(ctx, tt.args.tasks, tt.args.executor, tt.args.opts...) 136 | require.Equal(t, tt.want, got) 137 | }) 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /ctask/doer.go: -------------------------------------------------------------------------------- 1 | package ctask 2 | 3 | import ( 4 | "context" 5 | "runtime" 6 | 7 | "golang.org/x/sync/errgroup" 8 | ) 9 | 10 | type DoOpt func(cfg *DoConfig) 11 | type DoConfig struct { 12 | WorkerNum int 13 | } 14 | 15 | // Do execute tasks using the given executor function, 16 | // and return the results in the same order as the given tasks respectively. 17 | // It stops executing remaining tasks after any first error is encountered. 18 | // 19 | // The max number of goroutines can optionally be specified using the option WithWorkerNum. 20 | // By default, it is set to runtime.NumCPU() 21 | func Do[Task any, Result any]( 22 | ctx context.Context, 23 | tasks []Task, 24 | executor func(ctx context.Context, t Task) (Result, error), 25 | opts ...DoOpt, 26 | ) ([]Result, error) { 27 | cfg := getConfigWithOptions(opts...) 28 | 29 | g, ctx := errgroup.WithContext(ctx) 30 | g.SetLimit(cfg.WorkerNum) 31 | results := make([]Result, len(tasks)) 32 | for idx, task := range tasks { 33 | idx, task := idx, task // retain current loop values to be used in goroutine 34 | g.Go(func() error { 35 | select { 36 | case <-ctx.Done(): 37 | return ctx.Err() 38 | default: 39 | res, err := executor(ctx, task) 40 | if err != nil { 41 | return err 42 | } 43 | results[idx] = res 44 | return nil 45 | } 46 | }) 47 | } 48 | if err := g.Wait(); err != nil { 49 | return nil, err 50 | } 51 | return results, nil 52 | } 53 | 54 | func getConfigWithOptions(opts ...DoOpt) DoConfig { 55 | cfg := DoConfig{ 56 | WorkerNum: runtime.NumCPU(), 57 | } 58 | for _, opt := range opts { 59 | opt(&cfg) 60 | } 61 | return cfg 62 | } 63 | 64 | func WithWorkerNum(num int) DoOpt { 65 | return func(cfg *DoConfig) { 66 | cfg.WorkerNum = num 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /ctask/doer_test.go: -------------------------------------------------------------------------------- 1 | package ctask 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestDo(t *testing.T) { 13 | type T = int // task type 14 | type R = int // result type 15 | 16 | type args struct { 17 | ctx context.Context 18 | ctxTimeout time.Duration 19 | tasks []T 20 | executor func(ctx context.Context, t T) (R, error) 21 | opts []DoOpt 22 | } 23 | tests := []struct { 24 | name string 25 | args args 26 | want []R 27 | requireErr require.ErrorAssertionFunc 28 | }{ 29 | { 30 | name: "happy path", 31 | args: args{ 32 | ctx: context.Background(), 33 | tasks: []T{0, 1, 2, 3, 4, 5, 6}, 34 | executor: fibonacci, 35 | opts: nil, 36 | }, 37 | want: []R{1, 1, 2, 3, 5, 8, 13}, 38 | requireErr: require.NoError, 39 | }, 40 | { 41 | name: "empty slice", 42 | args: args{ 43 | ctx: context.Background(), 44 | tasks: []T{0, 1, 2, 3, 4, 5, 6}, 45 | executor: fibonacci, 46 | opts: nil, 47 | }, 48 | want: []R{1, 1, 2, 3, 5, 8, 13}, 49 | requireErr: require.NoError, 50 | }, 51 | { 52 | name: "error path & ensure tasks after error aren't executed (1000th fibonacci is too slow to be computed)", 53 | args: args{ 54 | ctx: context.Background(), 55 | tasks: []T{0, 1, 2, 1, -1, 1000}, 56 | executor: fibonacci, 57 | opts: []DoOpt{WithWorkerNum(1)}, 58 | }, 59 | want: nil, 60 | requireErr: func(t require.TestingT, err error, i ...interface{}) { 61 | require.Equal(t, errors.New("negative"), err) 62 | }, 63 | }, 64 | { 65 | name: "slow function with sleeps should run concurrently without context deadline error", 66 | args: args{ 67 | ctx: context.Background(), 68 | ctxTimeout: 50 * time.Millisecond, 69 | tasks: []T{ 70 | 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 71 | 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 72 | 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 73 | 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 74 | 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 75 | }, 76 | executor: func(ctx context.Context, t T) (R, error) { 77 | time.Sleep(time.Duration(t) * time.Millisecond) 78 | return 1, nil 79 | }, 80 | opts: []DoOpt{WithWorkerNum(20)}, 81 | }, 82 | want: []R{ 83 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 84 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 85 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 86 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 87 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 88 | }, 89 | requireErr: require.NoError, 90 | }, 91 | } 92 | for _, tt := range tests { 93 | t.Run(tt.name, func(t *testing.T) { 94 | ctx := tt.args.ctx 95 | if tt.args.ctxTimeout > 0 { 96 | var cancel context.CancelFunc 97 | ctx, cancel = context.WithTimeout(ctx, tt.args.ctxTimeout) 98 | defer cancel() 99 | } 100 | got, err := Do(ctx, tt.args.tasks, tt.args.executor, tt.args.opts...) 101 | tt.requireErr(t, err) 102 | require.Equal(t, tt.want, got) 103 | }) 104 | } 105 | } 106 | 107 | func fibonacci(ctx context.Context, n int) (int, error) { 108 | if n < 0 { 109 | return 0, errors.New("negative") 110 | } 111 | if n < 2 { 112 | return 1, nil 113 | } 114 | r1, err := fibonacci(ctx, n-1) 115 | if err != nil { 116 | return 0, err 117 | } 118 | 119 | r2, err := fibonacci(ctx, n-2) 120 | if err != nil { 121 | return 0, err 122 | } 123 | 124 | return r1 + r2, nil 125 | } 126 | -------------------------------------------------------------------------------- /database/config.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | gormLogger "gorm.io/gorm/logger" 8 | ) 9 | 10 | type LogLevel string 11 | 12 | const ( 13 | LogLevelSilent LogLevel = "silent" 14 | LogLevelError LogLevel = "error" 15 | LogLevelWarn LogLevel = "warn" 16 | LogLevelInfo LogLevel = "info" 17 | ) 18 | 19 | func newLogLevelFromString(logLevel LogLevel) (gormLogger.LogLevel, error) { 20 | switch logLevel { 21 | case LogLevelSilent: 22 | return gormLogger.Silent, nil 23 | case LogLevelError: 24 | return gormLogger.Error, nil 25 | case LogLevelWarn: 26 | return gormLogger.Warn, nil 27 | case LogLevelInfo: 28 | return gormLogger.Info, nil 29 | default: 30 | return 0, fmt.Errorf("invalid log level") 31 | } 32 | } 33 | 34 | type DBConnPool struct { 35 | MaxIdleConns int `mapstructure:"max_idle_conns"` 36 | ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time"` 37 | MaxOpenConns int `mapstructure:"max_open_conns"` 38 | ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"` 39 | } 40 | 41 | // DBConfig represents the configuration for a database connection. 42 | type DBConfig struct { 43 | // Url is the URL of the read-write database instance to connect to. 44 | Url string `mapstructure:"url"` 45 | 46 | // ReadonlyUrl is the URL of the read-only database instances. 47 | // This is optional and can be set to nil if read-write splitting is not required. 48 | ReadonlyUrl *string `mapstructure:"readonly_url"` 49 | 50 | // LogLevel is the logging level for the database connection. 51 | // Possible values are "silent", "error", "warn", and "info". 52 | // This is optional and the default value is "error". 53 | LogLevel LogLevel `mapstructure:"log_level"` 54 | 55 | // ConnPool is the connection pool settings for the database connection. 56 | // This is optional and can be set to nil if the default connection pool settings are sufficient. 57 | ConnPool *DBConnPool `mapstructure:"conn_pool"` 58 | } 59 | 60 | var ( 61 | defaultMaxIdleConns = 2 62 | defaultMaxOpenConns = 0 63 | defaultConnMaxIdleTime = time.Duration(0) 64 | defaultConnMaxLifetime = time.Duration(0) 65 | ) 66 | 67 | func (cfg *DBConfig) applyDefaultValue() { 68 | if cfg.LogLevel == "" { 69 | cfg.LogLevel = LogLevelError 70 | } 71 | if cfg.ConnPool == nil { 72 | // match the default configuration in database/sql 73 | // https: //github.com/golang/go/blob/198074abd7ec36ee71198a109d98f1ccdb7c5533/src/database/sql/sql.go#L912 74 | cfg.ConnPool = &DBConnPool{ 75 | MaxIdleConns: defaultMaxIdleConns, 76 | ConnMaxIdleTime: defaultConnMaxIdleTime, 77 | MaxOpenConns: defaultMaxOpenConns, 78 | ConnMaxLifetime: defaultConnMaxLifetime, 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /database/db.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | "time" 8 | 9 | "gorm.io/driver/postgres" 10 | gormLogger "gorm.io/gorm/logger" 11 | "gorm.io/plugin/dbresolver" 12 | 13 | "gorm.io/gorm" 14 | ) 15 | 16 | //go:generate mockgen -destination=./mock_db.go -package=database . DBContextGetter,TrxContextGetter 17 | 18 | type DBContextGetter interface { 19 | DBFrom(ctx context.Context) *gorm.DB 20 | } 21 | 22 | type TrxContextGetter interface { 23 | Transaction(ctx context.Context, fc func(ctx context.Context) error) error 24 | } 25 | 26 | type transactionKey struct{} 27 | 28 | var trxKey = &transactionKey{} 29 | 30 | // DBGetter implements the DBContextGetter interface, allowing retrieval of a read/write database connection. 31 | type DBGetter struct { 32 | db *gorm.DB 33 | } 34 | 35 | // NewDBGetter creates a new DBGetter instance with the specified database configuration. 36 | // If you are using a read-write splitting database connection, the `dbresolver` will automatically select 37 | // the appropriate connection based on the SQL to be executed. 38 | // When using database transactions, the read-write connection is used by default. 39 | // If you want to force the use of write or read connection, you can use the following method: 40 | // ``` 41 | // 42 | // getter.DBFrom(ctx).Clauses(dbresolver.Write/dbresolver.Read) 43 | // getter.GetSourceDB().Clauses(dbresolver.Write/dbresolver.Read) 44 | // 45 | // ``` 46 | // For more information, read https://gorm.io/docs/dbresolver.html#Read-x2F-Write-Splitting 47 | // *Note* that when using read-write splitting, there is a potential issue where a read operation immediately 48 | // following a write operation may not see the updated data if it is executed on a different read-only replica 49 | // that has not yet been updated with the new data. 50 | func NewDBGetter(cfg DBConfig) (*DBGetter, error) { 51 | cfg.applyDefaultValue() 52 | 53 | logLevel, err := newLogLevelFromString(cfg.LogLevel) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | db, err := gorm.Open(postgres.Open(cfg.Url), &gorm.Config{ 59 | SkipDefaultTransaction: true, 60 | Logger: gormLogger.New( 61 | log.New(os.Stdout, "\r\n", log.LstdFlags), 62 | gormLogger.Config{ 63 | SlowThreshold: time.Second, 64 | LogLevel: logLevel, 65 | IgnoreRecordNotFoundError: true, 66 | Colorful: true, 67 | }), 68 | }) 69 | if err != nil { 70 | return nil, err 71 | } 72 | 73 | var replicas []gorm.Dialector 74 | if cfg.ReadonlyUrl != nil { 75 | replicas = append(replicas, postgres.Open(*cfg.ReadonlyUrl)) 76 | } 77 | 78 | resolver := dbresolver.Register( 79 | dbresolver.Config{ 80 | Sources: []gorm.Dialector{postgres.Open(cfg.Url)}, 81 | Replicas: replicas, 82 | TraceResolverMode: logLevel == gormLogger.Info, 83 | }).SetConnMaxIdleTime(cfg.ConnPool.ConnMaxIdleTime). 84 | SetConnMaxLifetime(cfg.ConnPool.ConnMaxLifetime). 85 | SetMaxIdleConns(cfg.ConnPool.MaxIdleConns). 86 | SetMaxOpenConns(cfg.ConnPool.MaxOpenConns) 87 | if err := db.Use(resolver); err != nil { 88 | return nil, err 89 | } 90 | return &DBGetter{db: db}, nil 91 | } 92 | 93 | func NewDBGetterFromGormInstance(db *gorm.DB) *DBGetter { 94 | return &DBGetter{db: db} 95 | } 96 | 97 | func (getter *DBGetter) GetSourceDB() *gorm.DB { 98 | return getter.db 99 | } 100 | 101 | func (getter *DBGetter) HealthCheck() error { 102 | // gorm dbresolver doesn't support getting replica connection 103 | // https://github.com/go-gorm/dbresolver/issues/45 104 | sqlDB, err := getter.db.DB() 105 | if err != nil { 106 | return err 107 | } 108 | return sqlDB.Ping() 109 | } 110 | 111 | func (getter *DBGetter) DBFrom(ctx context.Context) *gorm.DB { 112 | if db, ok := ctx.Value(trxKey).(*gorm.DB); ok { 113 | return db 114 | } 115 | return getter.db 116 | } 117 | 118 | func (getter *DBGetter) Transaction(ctx context.Context, fc func(ctx context.Context) error) error { 119 | return getter.DBFrom(ctx).Transaction(func(tx *gorm.DB) error { 120 | return fc(context.WithValue(ctx, trxKey, tx)) 121 | }) 122 | } 123 | 124 | func (getter *DBGetter) Close() error { 125 | // gorm dbresolver doesn't support getting replica connection 126 | // https://github.com/go-gorm/dbresolver/issues/45 127 | sqlDB, err := getter.db.DB() 128 | if err != nil { 129 | return err 130 | } 131 | return sqlDB.Close() 132 | } 133 | -------------------------------------------------------------------------------- /database/migrate.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/golang-migrate/migrate/v4" 8 | // required by migrate.New... to parse migration files directory 9 | _ "github.com/golang-migrate/migrate/v4/source/file" 10 | 11 | _ "github.com/golang-migrate/migrate/v4/database/postgres" 12 | ) 13 | 14 | // supported migrate operations 15 | const ( 16 | defaultFilesDir = "dbmigrations" 17 | operationUp = "up" 18 | operationDown = "down" 19 | operationForce = "force" 20 | ) 21 | 22 | type operationFn func(*MigrationRunner, OperationData) error 23 | 24 | var supportedOperations = map[string]operationFn{ 25 | operationUp: runUp, 26 | operationDown: runDown, 27 | operationForce: runForce, 28 | } 29 | 30 | // OperationData contains information about the migration operation to be performed. 31 | type OperationData struct { 32 | ID string 33 | ForceVersion int 34 | } 35 | 36 | // MigrationRunner is responsible for managing and running database migrations. 37 | type MigrationRunner struct { 38 | mgr *migrate.Migrate 39 | filesDir string 40 | logger logger 41 | } 42 | 43 | // Option represents a function that configures a MigrationRunner. 44 | type Option func(runner *MigrationRunner) 45 | 46 | // WithLogger sets a custom logger for the MigrationRunner. 47 | // If not provided, a noopLogger will be used by default. 48 | func WithLogger(logger logger) Option { 49 | return func(runner *MigrationRunner) { 50 | runner.logger = logger 51 | } 52 | } 53 | 54 | // WithFilesDir sets a custom directory containing migration files for the MigrationRunner. 55 | // If not provided, the default directory "dbmigrations" will be used. 56 | func WithFilesDir(filesDir string) Option { 57 | return func(runner *MigrationRunner) { 58 | runner.filesDir = filesDir 59 | } 60 | } 61 | 62 | // NewMigrationRunner creates a new MigrationRunner with the given database connection string (dsn) and options. 63 | func NewMigrationRunner(dsn string, opts ...Option) (*MigrationRunner, error) { 64 | runner := &MigrationRunner{ 65 | filesDir: defaultFilesDir, 66 | logger: &noopLogger{}, 67 | } 68 | 69 | for _, opt := range opts { 70 | opt(runner) 71 | } 72 | 73 | mgr, err := migrate.New("file://"+runner.filesDir, dsn) 74 | if err != nil { 75 | return nil, fmt.Errorf("creating Migrate object: %w", err) 76 | } 77 | 78 | mgr.Log = toMigrationsLogger(runner.logger) 79 | runner.mgr = mgr 80 | 81 | return runner, nil 82 | } 83 | 84 | // Run executes the migration operation specified by the OperationData. 85 | func (m *MigrationRunner) Run(operation OperationData) error { 86 | operationName := operation.ID 87 | operationFn, found := supportedOperations[operationName] 88 | if !found { 89 | return fmt.Errorf("unsupported migration operation: %s", operationName) 90 | } 91 | 92 | if err := operationFn(m, operation); err != nil { 93 | return fmt.Errorf("operation %s failed: %v", operationName, err) 94 | } 95 | 96 | return nil 97 | } 98 | 99 | // Version returns the current migration version, a dirty flag, and an error if any. 100 | func (m *MigrationRunner) Version() (version uint, dirty bool, err error) { 101 | return m.mgr.Version() 102 | } 103 | 104 | // runUp runs the "up" migration operation, applying new migrations to the database. 105 | func runUp(m *MigrationRunner, _ OperationData) error { 106 | m.logger.Info("running migrate UP") 107 | 108 | err := m.mgr.Up() 109 | if errors.Is(err, migrate.ErrNoChange) { 110 | m.logger.Info(fmt.Sprintf("no new migrations found in: %s", m.filesDir)) 111 | return nil 112 | } 113 | if err != nil { 114 | return fmt.Errorf("running migrations UP failed: %w", err) 115 | } 116 | return nil 117 | } 118 | 119 | // runDown runs the "down" migration operation, rolling back the latest applied migration. 120 | func runDown(m *MigrationRunner, _ OperationData) error { 121 | m.logger.Info(fmt.Sprintf("running migrate DOWN with STEPS=%d", 1)) 122 | 123 | // always rollback the latest applied migration only 124 | err := m.mgr.Steps(-1) 125 | if err != nil { 126 | return fmt.Errorf("running migrations DOWN failed: %w", err) 127 | } 128 | return nil 129 | } 130 | 131 | // runForce runs the "force" migration operation, forcibly setting the migration version without running the actual migrations. 132 | func runForce(m *MigrationRunner, op OperationData) error { 133 | m.logger.Info(fmt.Sprintf("running FORCE with VERSION %d", op.ForceVersion)) 134 | 135 | err := m.mgr.Force(op.ForceVersion) 136 | if err != nil { 137 | return fmt.Errorf("running migrations FORCE with VERSION %d failed: %w", op.ForceVersion, err) 138 | } 139 | return nil 140 | } 141 | 142 | type logger interface { 143 | Info(args ...interface{}) 144 | Error(args ...interface{}) 145 | Printf(format string, v ...interface{}) 146 | } 147 | 148 | type noopLogger struct{} 149 | 150 | func (l *noopLogger) Info(...interface{}) {} 151 | func (l *noopLogger) Error(...interface{}) {} 152 | func (l *noopLogger) Printf(string, ...interface{}) {} 153 | func (l *noopLogger) Verbose() bool { return false } 154 | 155 | // adapter to use logger like logrus 156 | func toMigrationsLogger(logger logger) *migrationsLogger { 157 | return &migrationsLogger{logger: logger} 158 | } 159 | 160 | // to be able to log not only errors, but also Info level logs from golang-migrate, 161 | // we have to implement migrate.Logger interface 162 | type migrationsLogger struct { 163 | logger logger 164 | } 165 | 166 | // Printf is like fmt.Printf 167 | func (m *migrationsLogger) Printf(format string, v ...interface{}) { 168 | m.logger.Printf(format, v) 169 | } 170 | 171 | // Verbose should return true when verbose logging output is wanted 172 | func (m *migrationsLogger) Verbose() bool { 173 | return true 174 | } 175 | 176 | func (m *migrationsLogger) Info(args ...interface{}) { 177 | m.logger.Info(args) 178 | } 179 | 180 | func (m *migrationsLogger) Error(args ...interface{}) { 181 | m.logger.Error(args) 182 | } 183 | -------------------------------------------------------------------------------- /database/migration_runner_env.go: -------------------------------------------------------------------------------- 1 | package database 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strconv" 7 | ) 8 | 9 | // Constants representing environment variable keys for migration configuration. 10 | const ( 11 | envKeyDsn = "MIGRATION_DSN" 12 | envKeyOp = "MIGRATION_OPERATION" 13 | envKeyForceVersion = "MIGRATION_FORCE_VERSION" 14 | envKeyFilesDir = "MIGRATION_FILES_DIR" 15 | ) 16 | 17 | func readForceVersion() (int, error) { 18 | forceVersionRaw, ok := os.LookupEnv(envKeyForceVersion) 19 | if !ok { 20 | return 0, nil 21 | } 22 | 23 | forceVersion, err := strconv.Atoi(forceVersionRaw) 24 | if err != nil { 25 | return 0, fmt.Errorf("convert forceVersion: %v", err) 26 | } 27 | 28 | return forceVersion, nil 29 | } 30 | 31 | // RunMigrationsFromEnv reads migration configuration from environment variables, 32 | // creates a MigrationRunner, and runs the specified migration operation. 33 | func RunMigrationsFromEnv(logger logger) error { 34 | dsn, ok := os.LookupEnv(envKeyDsn) 35 | if !ok { 36 | return fmt.Errorf("missing env: %s", envKeyDsn) 37 | } 38 | 39 | operation, ok := os.LookupEnv(envKeyOp) 40 | if !ok { 41 | return fmt.Errorf("missing env: %s", envKeyOp) 42 | } 43 | 44 | forceVersion, err := readForceVersion() 45 | if err != nil { 46 | return fmt.Errorf("read forceVersion: %v", err) 47 | } 48 | 49 | opts := []Option{WithLogger(logger)} 50 | if filesDir, ok := os.LookupEnv(envKeyFilesDir); ok { 51 | opts = append(opts, WithFilesDir(filesDir)) 52 | } 53 | 54 | runner, err := NewMigrationRunner(dsn, opts...) 55 | if err != nil { 56 | return fmt.Errorf("new migrations runner: %v", err) 57 | } 58 | 59 | // Get the current migration version and log it. 60 | version, dirty, err := runner.Version() 61 | if err != nil { 62 | logger.Error(fmt.Sprintf("getting current migration version: %v", err)) 63 | } else { 64 | logger.Info(fmt.Sprintf("migration version before operation: %d, dirty: %v", version, dirty)) 65 | } 66 | 67 | if err := runner.Run(OperationData{ 68 | ID: operation, 69 | ForceVersion: forceVersion, 70 | }); err != nil { 71 | return fmt.Errorf("run operation %s: %v", operation, err) 72 | } 73 | 74 | logger.Info("successfully finished migration") 75 | 76 | // Get the migration version after the operation and log it. 77 | version, dirty, err = runner.Version() 78 | if err != nil { 79 | return fmt.Errorf("getting migration version after operation: %v", err) 80 | } 81 | 82 | logger.Info(fmt.Sprintf("migration version after operation: %d, dirty: %v", version, dirty)) 83 | 84 | return nil 85 | } 86 | -------------------------------------------------------------------------------- /database/mock_db.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: github.com/trustwallet/go-libs/database (interfaces: DBContextGetter,TrxContextGetter) 3 | 4 | // Package database is a generated GoMock package. 5 | package database 6 | 7 | import ( 8 | context "context" 9 | reflect "reflect" 10 | 11 | gomock "github.com/golang/mock/gomock" 12 | gorm "gorm.io/gorm" 13 | ) 14 | 15 | // MockDBContextGetter is a mock of DBContextGetter interface. 16 | type MockDBContextGetter struct { 17 | ctrl *gomock.Controller 18 | recorder *MockDBContextGetterMockRecorder 19 | } 20 | 21 | // MockDBContextGetterMockRecorder is the mock recorder for MockDBContextGetter. 22 | type MockDBContextGetterMockRecorder struct { 23 | mock *MockDBContextGetter 24 | } 25 | 26 | // NewMockDBContextGetter creates a new mock instance. 27 | func NewMockDBContextGetter(ctrl *gomock.Controller) *MockDBContextGetter { 28 | mock := &MockDBContextGetter{ctrl: ctrl} 29 | mock.recorder = &MockDBContextGetterMockRecorder{mock} 30 | return mock 31 | } 32 | 33 | // EXPECT returns an object that allows the caller to indicate expected use. 34 | func (m *MockDBContextGetter) EXPECT() *MockDBContextGetterMockRecorder { 35 | return m.recorder 36 | } 37 | 38 | // DBFrom mocks base method. 39 | func (m *MockDBContextGetter) DBFrom(arg0 context.Context) *gorm.DB { 40 | m.ctrl.T.Helper() 41 | ret := m.ctrl.Call(m, "DBFrom", arg0) 42 | ret0, _ := ret[0].(*gorm.DB) 43 | return ret0 44 | } 45 | 46 | // DBFrom indicates an expected call of DBFrom. 47 | func (mr *MockDBContextGetterMockRecorder) DBFrom(arg0 interface{}) *gomock.Call { 48 | mr.mock.ctrl.T.Helper() 49 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DBFrom", reflect.TypeOf((*MockDBContextGetter)(nil).DBFrom), arg0) 50 | } 51 | 52 | // MockTrxContextGetter is a mock of TrxContextGetter interface. 53 | type MockTrxContextGetter struct { 54 | ctrl *gomock.Controller 55 | recorder *MockTrxContextGetterMockRecorder 56 | } 57 | 58 | // MockTrxContextGetterMockRecorder is the mock recorder for MockTrxContextGetter. 59 | type MockTrxContextGetterMockRecorder struct { 60 | mock *MockTrxContextGetter 61 | } 62 | 63 | // NewMockTrxContextGetter creates a new mock instance. 64 | func NewMockTrxContextGetter(ctrl *gomock.Controller) *MockTrxContextGetter { 65 | mock := &MockTrxContextGetter{ctrl: ctrl} 66 | mock.recorder = &MockTrxContextGetterMockRecorder{mock} 67 | return mock 68 | } 69 | 70 | // EXPECT returns an object that allows the caller to indicate expected use. 71 | func (m *MockTrxContextGetter) EXPECT() *MockTrxContextGetterMockRecorder { 72 | return m.recorder 73 | } 74 | 75 | // Transaction mocks base method. 76 | func (m *MockTrxContextGetter) Transaction(arg0 context.Context, arg1 func(context.Context) error) error { 77 | m.ctrl.T.Helper() 78 | ret := m.ctrl.Call(m, "Transaction", arg0, arg1) 79 | ret0, _ := ret[0].(error) 80 | return ret0 81 | } 82 | 83 | // Transaction indicates an expected call of Transaction. 84 | func (mr *MockTrxContextGetterMockRecorder) Transaction(arg0, arg1 interface{}) *gomock.Call { 85 | mr.mock.ctrl.T.Helper() 86 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockTrxContextGetter)(nil).Transaction), arg0, arg1) 87 | } 88 | -------------------------------------------------------------------------------- /eventer/client.go: -------------------------------------------------------------------------------- 1 | package eventer 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/trustwallet/go-libs/client" 8 | "github.com/trustwallet/go-libs/middleware" 9 | ) 10 | 11 | type Client struct { 12 | client.Request 13 | } 14 | 15 | type Status struct { 16 | Status bool `json:"status"` 17 | } 18 | 19 | type Event struct { 20 | Name string `json:"name"` 21 | CreatedAt int64 `json:"created_at"` 22 | Params map[string]string `json:"params"` 23 | } 24 | 25 | var senderClient *Client 26 | var batchLimit = 100 27 | 28 | func Init(url string, limit int) { 29 | senderClient = &Client{client.InitJSONClient(url, middleware.SentryErrorHandler)} 30 | batchLimit = limit 31 | } 32 | 33 | func (c Client) SendBatch(events []Event) (status Status, err error) { 34 | _, err = senderClient.Execute(context.TODO(), client.NewReqBuilder(). 35 | Method(http.MethodPost). 36 | PathStatic(""). 37 | Body(events). 38 | WriteTo(&status). 39 | Build()) 40 | return 41 | } 42 | -------------------------------------------------------------------------------- /eventer/log.go: -------------------------------------------------------------------------------- 1 | package eventer 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | log "github.com/sirupsen/logrus" 8 | ) 9 | 10 | var events []Event 11 | var eventsMux = sync.RWMutex{} 12 | 13 | func Log(event Event) { 14 | eventsMux.Lock() 15 | defer func() { 16 | eventsMux.Unlock() 17 | }() 18 | if event.CreatedAt == 0 { 19 | event.CreatedAt = time.Now().Unix() 20 | } 21 | events = append(events, event) 22 | 23 | if len(events) >= batchLimit { 24 | go sendEvents(events) 25 | events = nil 26 | } 27 | } 28 | 29 | func sendEvents(events []Event) { 30 | if senderClient == nil { 31 | return 32 | } 33 | _, err := senderClient.SendBatch(events) 34 | if err != nil { 35 | log.Error(err) 36 | return 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /gin/hmac.go: -------------------------------------------------------------------------------- 1 | package gin 2 | 3 | import ( 4 | "crypto/hmac" 5 | "crypto/sha256" 6 | "encoding/base64" 7 | "errors" 8 | "net/http" 9 | 10 | "github.com/gin-gonic/gin" 11 | ) 12 | 13 | var ErrInvalidSignature = errors.New("invalid signature") 14 | 15 | type StrFromCtx func(c *gin.Context) (string, error) 16 | 17 | // HmacDefaultSignatureHeader defines the default header name where clients should place the signature. 18 | const HmacDefaultSignatureHeader = "X-REQ-SIG" 19 | 20 | type HmacVerifier struct { 21 | keys [][]byte 22 | sigFN StrFromCtx 23 | sigEncoder func([]byte) string 24 | } 25 | 26 | func NewHmacVerifier(options ...func(verifier *HmacVerifier)) *HmacVerifier { 27 | verifier := &HmacVerifier{ 28 | sigFN: func(c *gin.Context) (string, error) { 29 | return c.GetHeader(HmacDefaultSignatureHeader), nil 30 | }, 31 | sigEncoder: func(b []byte) string { 32 | return base64.StdEncoding.EncodeToString(b) 33 | }, 34 | } 35 | 36 | for _, o := range options { 37 | o(verifier) 38 | } 39 | return verifier 40 | } 41 | 42 | // WithHmacVerifierSigKeys is used to set the valid signature keys. 43 | func WithHmacVerifierSigKeys(keys ...string) func(*HmacVerifier) { 44 | return func(v *HmacVerifier) { 45 | keysB := make([][]byte, len(keys)) 46 | for i := range keys { 47 | keysB[i] = []byte(keys[i]) 48 | } 49 | v.keys = keysB 50 | } 51 | } 52 | 53 | // WithHmacVerifierSigFunction can be used to override signature location. 54 | // As a query string param for example: 55 | // 56 | // sigFn := func(c *gin.Context) (string, error) { 57 | // return c.Query("sig"), nil 58 | // } 59 | func WithHmacVerifierSigFunction(sigFN StrFromCtx) func(*HmacVerifier) { 60 | return func(v *HmacVerifier) { 61 | v.sigFN = sigFN 62 | } 63 | } 64 | 65 | // WithHmacVerifierSigEncoder can be used to override default signature encoder (base64). 66 | func WithHmacVerifierSigEncoder(e func(b []byte) string) func(*HmacVerifier) { 67 | return func(v *HmacVerifier) { 68 | v.sigEncoder = e 69 | } 70 | } 71 | 72 | // SignedHandler can be used to construct signed handlers. 73 | // plaintextFN defines the message that should be signed by clients. 74 | // For example: 75 | // 76 | // func(c *gin.Context) (string, error) { 77 | // return c.Query("asset") + ":" + c.Query("coin"), nil 78 | // } 79 | func (v *HmacVerifier) SignedHandler(h gin.HandlerFunc, plaintextFN StrFromCtx) gin.HandlerFunc { 80 | return func(c *gin.Context) { 81 | plaintext, err := plaintextFN(c) 82 | if err != nil { 83 | _ = c.AbortWithError(http.StatusBadRequest, errors.New("cannot extract plaintext")) 84 | return 85 | } 86 | 87 | sig, err := v.sigFN(c) 88 | if err != nil { 89 | _ = c.AbortWithError(http.StatusBadRequest, errors.New("cannot extract signature")) 90 | return 91 | } 92 | 93 | if err := v.verifySignature([]byte(plaintext), sig); err != nil { 94 | _ = c.AbortWithError(http.StatusUnauthorized, errors.New("cannot verify signature")) 95 | return 96 | } 97 | 98 | h(c) 99 | } 100 | } 101 | 102 | func (v *HmacVerifier) verifySignature(msg []byte, sig string) error { 103 | for _, signatureKey := range v.keys { 104 | h := hmac.New(sha256.New, signatureKey) 105 | h.Write(msg) 106 | sum := h.Sum(nil) 107 | encoded := v.sigEncoder(sum) 108 | if sig == encoded { 109 | return nil 110 | } 111 | } 112 | 113 | return ErrInvalidSignature 114 | } 115 | -------------------------------------------------------------------------------- /gin/hmac_test.go: -------------------------------------------------------------------------------- 1 | package gin 2 | 3 | import ( 4 | "crypto/hmac" 5 | "crypto/sha256" 6 | "encoding/base64" 7 | "errors" 8 | "fmt" 9 | "math/rand" 10 | "net/http" 11 | "net/http/httptest" 12 | "net/url" 13 | "testing" 14 | 15 | "github.com/gin-gonic/gin" 16 | "gotest.tools/assert" 17 | ) 18 | 19 | func createTestContext(t *testing.T, w *httptest.ResponseRecorder, rawURL string, headers map[string]string) *gin.Context { 20 | c, _ := gin.CreateTestContext(w) 21 | 22 | u, err := url.Parse(rawURL) 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | 27 | c.Request = &http.Request{ 28 | URL: u, 29 | Header: make(http.Header), 30 | } 31 | for k, v := range headers { 32 | c.Request.Header.Add(k, v) 33 | } 34 | 35 | return c 36 | } 37 | 38 | func TestHmacVerifier(t *testing.T) { 39 | verifier := NewHmacVerifier( 40 | WithHmacVerifierSigKeys("some-key", "some-other-key", "k"), 41 | ) 42 | 43 | someRouteHandler := func(c *gin.Context) { c.Status(http.StatusOK) } 44 | signedRouteHandler := verifier.SignedHandler(someRouteHandler, func(c *gin.Context) (string, error) { 45 | pt := c.Query("asset") + ":" + c.Query("coin") 46 | return pt, nil 47 | }) 48 | 49 | t.Run("unauthorized when signature not provided", func(t *testing.T) { 50 | w := httptest.NewRecorder() 51 | signedRouteHandler(createTestContext(t, w, "", nil)) 52 | assert.Equal(t, http.StatusUnauthorized, w.Code) 53 | }) 54 | 55 | t.Run("signature verification success for all keys", func(t *testing.T) { 56 | asset := "123" 57 | coin := "60" 58 | plainText := asset + ":" + coin 59 | rawURL := fmt.Sprintf("http://does.not.matter?asset=%s&coin=%s", asset, coin) 60 | 61 | for _, k := range verifier.keys { 62 | h := hmac.New(sha256.New, k) 63 | h.Write([]byte(plainText)) 64 | sig := base64.StdEncoding.EncodeToString(h.Sum(nil)) 65 | 66 | w := httptest.NewRecorder() 67 | signedRouteHandler(createTestContext(t, w, rawURL, map[string]string{ 68 | "X-REQ-SIG": sig, 69 | })) 70 | assert.Equal(t, http.StatusOK, w.Code) 71 | } 72 | }) 73 | 74 | t.Run("signature verification failure", func(t *testing.T) { 75 | asset := "123" 76 | coin := "60" 77 | rawURL := fmt.Sprintf("http://does.not.matter?asset=%s&coin=%s", asset, coin) 78 | 79 | w := httptest.NewRecorder() 80 | signedRouteHandler(createTestContext(t, w, rawURL, map[string]string{ 81 | "X-REQ-SIG": "JBdWTO5yR2GB0TOT8YcM7AjWaJaMtVrAFOYUlZRNlYg=", 82 | })) 83 | assert.Equal(t, http.StatusUnauthorized, w.Code) 84 | }) 85 | 86 | t.Run("bad request when signature cannot be extracted", func(t *testing.T) { 87 | h := NewHmacVerifier( 88 | WithHmacVerifierSigKeys("some-key", "some-other-key", "k"), 89 | WithHmacVerifierSigFunction(func(c *gin.Context) (string, error) { 90 | return "", errors.New("some error") 91 | }), 92 | ).SignedHandler(someRouteHandler, func(c *gin.Context) (string, error) { 93 | return "whatever", nil 94 | }) 95 | 96 | rawURL := "http://does.not.matter" 97 | w := httptest.NewRecorder() 98 | h(createTestContext(t, w, rawURL, map[string]string{})) 99 | assert.Equal(t, http.StatusBadRequest, w.Code) 100 | }) 101 | 102 | t.Run("bad request when plaintext cannot be extracted", func(t *testing.T) { 103 | h := NewHmacVerifier( 104 | WithHmacVerifierSigKeys("some-key", "some-other-key", "k"), 105 | ).SignedHandler(someRouteHandler, func(c *gin.Context) (string, error) { 106 | return "", errors.New("plaintext cannot be extracted") 107 | }) 108 | 109 | rawURL := "http://does.not.matter" 110 | w := httptest.NewRecorder() 111 | h(createTestContext(t, w, rawURL, map[string]string{})) 112 | assert.Equal(t, http.StatusBadRequest, w.Code) 113 | }) 114 | 115 | t.Run("override signature encoder", func(t *testing.T) { 116 | h := NewHmacVerifier( 117 | WithHmacVerifierSigKeys("some-key", "some-other-key", "k"), 118 | WithHmacVerifierSigEncoder(func(b []byte) string { 119 | return "some-static-sig" 120 | }), 121 | ).SignedHandler(someRouteHandler, func(c *gin.Context) (string, error) { 122 | return "whatever", nil 123 | }) 124 | 125 | rawURL := "http://does.not.matter" 126 | w := httptest.NewRecorder() 127 | h(createTestContext(t, w, rawURL, map[string]string{ 128 | HmacDefaultSignatureHeader: "some-static-sig", 129 | })) 130 | assert.Equal(t, http.StatusOK, w.Code) 131 | }) 132 | } 133 | 134 | func BenchmarkHmacVerifier_verifySignature(b *testing.B) { 135 | const msgByteSize = 512 136 | const numValidKeys = 100 137 | validKeys := make([]string, numValidKeys) 138 | for i := range validKeys { 139 | validKeys[i] = fmt.Sprintf("key-%d", i) 140 | } 141 | verifier := NewHmacVerifier(WithHmacVerifierSigKeys(validKeys...)) 142 | 143 | for i := 0; i < b.N; i++ { 144 | msg := randBytes(msgByteSize) 145 | sig := string(randBytes(64)) 146 | err := verifier.verifySignature(msg, sig) 147 | if err == nil { 148 | b.Errorf("expected error") 149 | } 150 | } 151 | } 152 | 153 | func randBytes(n int) []byte { 154 | var charset = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 155 | b := make([]byte, n) 156 | for i := range b { 157 | b[i] = charset[rand.Intn(len(charset))] 158 | } 159 | return b 160 | } 161 | -------------------------------------------------------------------------------- /gin/setup.go: -------------------------------------------------------------------------------- 1 | package gin 2 | 3 | import ( 4 | "context" 5 | "net" 6 | "net/http" 7 | "os" 8 | "os/signal" 9 | "syscall" 10 | 11 | "github.com/gin-gonic/gin" 12 | log "github.com/sirupsen/logrus" 13 | ) 14 | 15 | // Deprecated 16 | // SetupGracefulShutdown blocks execution until interruption command sent 17 | // Use httplib.Server instead 18 | func SetupGracefulShutdown(ctx context.Context, port string, engine *gin.Engine) { 19 | server := &http.Server{ 20 | Addr: ":" + port, 21 | Handler: engine, 22 | } 23 | 24 | defer func() { 25 | if err := server.Shutdown(ctx); err != nil { 26 | log.Info("Server Shutdown: ", err) 27 | } 28 | }() 29 | 30 | signalForExit := make(chan os.Signal, 1) 31 | signal.Notify(signalForExit, 32 | syscall.SIGHUP, 33 | syscall.SIGINT, 34 | syscall.SIGTERM, 35 | syscall.SIGQUIT) 36 | 37 | go func() { 38 | switch err := server.ListenAndServe(); err { 39 | case http.ErrServerClosed: 40 | log.Info("server closed") 41 | default: 42 | log.Error("Application failed ", err) 43 | } 44 | }() 45 | log.WithFields(log.Fields{"bind": port}).Info("Running application") 46 | 47 | stop := <-signalForExit 48 | log.Info("Stop signal Received ", stop) 49 | log.Info("Waiting for all jobs to stop") 50 | } 51 | 52 | // SetupGracefulServeWithUnixFile blocks execution until interruption command sent 53 | func SetupGracefulServeWithUnixFile(ctx context.Context, engine *gin.Engine, unixFile string) { 54 | _, err := os.Create("/tmp/app-initialized") 55 | if err != nil { 56 | log.WithError(err).Error("failed to create file /tmp/app-initialized") 57 | return 58 | } 59 | 60 | defer func() { 61 | if err != nil { 62 | log.Error(err) 63 | } 64 | }() 65 | 66 | listener, err := net.Listen("unix", unixFile) 67 | if err != nil { 68 | return 69 | } 70 | 71 | defer func() { _ = listener.Close() }() 72 | defer func() { _ = os.Remove(unixFile) }() 73 | 74 | server := &http.Server{ 75 | Handler: engine, 76 | } 77 | 78 | defer func() { 79 | if err := server.Shutdown(ctx); err != nil { 80 | log.Info("Server Shutdown: ", err) 81 | } 82 | }() 83 | 84 | signalForExit := make(chan os.Signal, 1) 85 | signal.Notify(signalForExit, 86 | syscall.SIGHUP, 87 | syscall.SIGINT, 88 | syscall.SIGTERM, 89 | syscall.SIGQUIT) 90 | 91 | go func() { 92 | log.Debugf("Listening and serving HTTP on unix:/%s", unixFile) 93 | err = server.Serve(listener) 94 | }() 95 | 96 | stop := <-signalForExit 97 | log.Info("Stop signal Received ", stop) 98 | } 99 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/trustwallet/go-libs 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/alicebob/miniredis/v2 v2.14.3 7 | github.com/evalphobia/logrus_sentry v0.8.2 8 | github.com/getsentry/raven-go v0.2.0 9 | github.com/gin-gonic/gin v1.9.1 10 | github.com/go-redis/redis/v8 v8.8.2 11 | github.com/golang-migrate/migrate/v4 v4.15.2 12 | github.com/golang/mock v1.6.0 13 | github.com/heirko/go-contrib v0.0.0-20200825160048-11fc5e2235fa 14 | github.com/heralight/logrus_mate v1.0.1-0.20170807195635-969b6efb860e 15 | github.com/patrickmn/go-cache v2.1.0+incompatible 16 | github.com/prometheus/client_golang v1.12.0 17 | github.com/sirupsen/logrus v1.8.1 18 | github.com/spf13/viper v1.8.1 19 | github.com/streadway/amqp v1.0.0 20 | github.com/stretchr/testify v1.8.3 21 | golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 22 | golang.org/x/net v0.10.0 23 | golang.org/x/sync v0.1.0 24 | gorm.io/driver/postgres v1.4.7 25 | gorm.io/gorm v1.24.3 26 | gorm.io/plugin/dbresolver v1.4.1 27 | gotest.tools v2.2.0+incompatible 28 | ) 29 | 30 | require ( 31 | github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect 32 | github.com/beorn7/perks v1.0.1 // indirect 33 | github.com/bytedance/sonic v1.9.1 // indirect 34 | github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d // indirect 35 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 36 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect 37 | github.com/davecgh/go-spew v1.1.1 // indirect 38 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 39 | github.com/fsnotify/fsnotify v1.4.9 // indirect 40 | github.com/gabriel-vasile/mimetype v1.4.2 // indirect 41 | github.com/gin-contrib/sse v0.1.0 // indirect 42 | github.com/go-logr/logr v1.2.2 // indirect 43 | github.com/go-logr/stdr v1.2.2 // indirect 44 | github.com/go-playground/locales v0.14.1 // indirect 45 | github.com/go-playground/universal-translator v0.18.1 // indirect 46 | github.com/go-playground/validator/v10 v10.14.0 // indirect 47 | github.com/goccy/go-json v0.10.2 // indirect 48 | github.com/gogap/env_json v0.0.0-20150503135429-86150085ddbe // indirect 49 | github.com/gogap/env_strings v0.0.1 // indirect 50 | github.com/golang/protobuf v1.5.2 // indirect 51 | github.com/google/go-cmp v0.5.8 // indirect 52 | github.com/hashicorp/errwrap v1.1.0 // indirect 53 | github.com/hashicorp/go-multierror v1.1.1 // indirect 54 | github.com/hashicorp/hcl v1.0.0 // indirect 55 | github.com/hoisie/redis v0.0.0-20160730154456-b5c6e81454e0 // indirect 56 | github.com/jackc/pgpassfile v1.0.0 // indirect 57 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 58 | github.com/jackc/pgx/v5 v5.2.0 // indirect 59 | github.com/jinzhu/inflection v1.0.0 // indirect 60 | github.com/jinzhu/now v1.1.5 // indirect 61 | github.com/json-iterator/go v1.1.12 // indirect 62 | github.com/klauspost/cpuid/v2 v2.2.4 // indirect 63 | github.com/leodido/go-urn v1.2.4 // indirect 64 | github.com/lib/pq v1.10.0 // indirect 65 | github.com/magiconair/properties v1.8.5 // indirect 66 | github.com/mattn/go-isatty v0.0.19 // indirect 67 | github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect 68 | github.com/mitchellh/mapstructure v1.4.1 // indirect 69 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 70 | github.com/modern-go/reflect2 v1.0.2 // indirect 71 | github.com/pelletier/go-toml v1.9.3 // indirect 72 | github.com/pelletier/go-toml/v2 v2.0.8 // indirect 73 | github.com/pkg/errors v0.9.1 // indirect 74 | github.com/pmezard/go-difflib v1.0.0 // indirect 75 | github.com/prometheus/client_model v0.2.0 // indirect 76 | github.com/prometheus/common v0.32.1 // indirect 77 | github.com/prometheus/procfs v0.7.3 // indirect 78 | github.com/redis/go-redis/v9 v9.1.0 // indirect 79 | github.com/spf13/afero v1.6.0 // indirect 80 | github.com/spf13/cast v1.3.1 // indirect 81 | github.com/spf13/jwalterweatherman v1.1.0 // indirect 82 | github.com/spf13/pflag v1.0.5 // indirect 83 | github.com/subosito/gotenv v1.2.0 // indirect 84 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 85 | github.com/ugorji/go/codec v1.2.11 // indirect 86 | github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da // indirect 87 | go.opentelemetry.io/otel v1.4.0 // indirect 88 | go.opentelemetry.io/otel/internal/metric v0.27.0 // indirect 89 | go.opentelemetry.io/otel/metric v0.27.0 // indirect 90 | go.opentelemetry.io/otel/trace v1.4.0 // indirect 91 | go.uber.org/atomic v1.10.0 // indirect 92 | golang.org/x/arch v0.3.0 // indirect 93 | golang.org/x/crypto v0.9.0 // indirect 94 | golang.org/x/sys v0.8.0 // indirect 95 | golang.org/x/text v0.9.0 // indirect 96 | google.golang.org/protobuf v1.30.0 // indirect 97 | gopkg.in/ini.v1 v1.62.0 // indirect 98 | gopkg.in/yaml.v2 v2.4.0 // indirect 99 | gopkg.in/yaml.v3 v3.0.1 // indirect 100 | ) 101 | -------------------------------------------------------------------------------- /health/http.go: -------------------------------------------------------------------------------- 1 | package health 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | 8 | log "github.com/sirupsen/logrus" 9 | ) 10 | 11 | const ( 12 | defaultHealthCheckRoute = "/health" 13 | defaultReadinessCheckRoute = "/ready" 14 | defaultPort = 4444 15 | ) 16 | 17 | type CheckFunc func() error 18 | 19 | type Option func(*server) 20 | 21 | type server struct { 22 | healthCheckRoute string 23 | readinessCheckRoute string 24 | port int 25 | healthChecks []CheckFunc 26 | readinessChecks []CheckFunc 27 | } 28 | 29 | func WithHealthCheckRoute(route string) Option { 30 | return func(s *server) { 31 | s.healthCheckRoute = route 32 | } 33 | } 34 | 35 | func WithReadinessCheckRoute(route string) Option { 36 | return func(s *server) { 37 | s.readinessCheckRoute = route 38 | } 39 | } 40 | 41 | func WithPort(port int) Option { 42 | return func(s *server) { 43 | s.port = port 44 | } 45 | } 46 | 47 | func WithHealthChecks(healthChecks ...CheckFunc) Option { 48 | return func(s *server) { 49 | s.healthChecks = healthChecks 50 | } 51 | } 52 | 53 | func WithReadinessChecks(readinessChecks ...CheckFunc) Option { 54 | return func(s *server) { 55 | s.readinessChecks = readinessChecks 56 | } 57 | } 58 | 59 | func handle(handler *http.ServeMux, route string, handleFuncs []CheckFunc) { 60 | handler.HandleFunc(route, func(w http.ResponseWriter, r *http.Request) { 61 | for _, handleFunc := range handleFuncs { 62 | if err := handleFunc(); err != nil { 63 | http.Error(w, err.Error(), http.StatusInternalServerError) 64 | return 65 | } 66 | } 67 | 68 | w.WriteHeader(http.StatusOK) 69 | }) 70 | } 71 | 72 | // StartHealthCheckServer starts a HTTP server to handle health check and readiness check requests. 73 | func StartHealthCheckServer(ctx context.Context, opts ...Option) error { 74 | hcServer := &server{ 75 | healthCheckRoute: defaultHealthCheckRoute, 76 | readinessCheckRoute: defaultReadinessCheckRoute, 77 | port: defaultPort, 78 | } 79 | 80 | for _, opt := range opts { 81 | opt(hcServer) 82 | } 83 | 84 | handler := http.NewServeMux() 85 | handle(handler, hcServer.healthCheckRoute, hcServer.healthChecks) 86 | handle(handler, hcServer.readinessCheckRoute, hcServer.readinessChecks) 87 | 88 | srv := &http.Server{ 89 | Addr: fmt.Sprintf(":%d", hcServer.port), 90 | Handler: handler, 91 | } 92 | 93 | go func() { 94 | <-ctx.Done() 95 | if err := srv.Shutdown(ctx); err != nil { 96 | log.Info("server shutdown: ", err) 97 | } 98 | }() 99 | 100 | if err := srv.ListenAndServe(); err != http.ErrServerClosed { 101 | return err 102 | } 103 | 104 | return nil 105 | } 106 | -------------------------------------------------------------------------------- /health/http_test.go: -------------------------------------------------------------------------------- 1 | package health_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | 13 | . "github.com/trustwallet/go-libs/health" 14 | ) 15 | 16 | func TestStartHealthCheckServer(t *testing.T) { 17 | tests := []struct { 18 | name string 19 | healthChecks []CheckFunc 20 | readinessChecks []CheckFunc 21 | healthCheckRoute string 22 | readinessCheckRoute string 23 | port int 24 | expHealthy bool 25 | expReady bool 26 | }{ 27 | { 28 | name: "default case", 29 | expHealthy: true, 30 | expReady: true, 31 | }, 32 | { 33 | name: "not healthy", 34 | healthChecks: []CheckFunc{func() error { return errors.New("health check") }}, 35 | readinessChecks: []CheckFunc{func() error { return nil }}, 36 | expHealthy: false, 37 | expReady: true, 38 | }, 39 | { 40 | name: "multiple functions", 41 | healthChecks: []CheckFunc{func() error { return errors.New("health check") }, func() error { return nil }}, 42 | readinessChecks: []CheckFunc{func() error { return nil }, func() error { return nil }}, 43 | expHealthy: false, 44 | expReady: true, 45 | }, 46 | { 47 | name: "not ready", 48 | healthChecks: []CheckFunc{func() error { return nil }}, 49 | readinessChecks: []CheckFunc{func() error { return errors.New("health check") }}, 50 | expHealthy: true, 51 | expReady: false, 52 | }, 53 | { 54 | name: "custom routes and port", 55 | healthChecks: []CheckFunc{func() error { return nil }}, 56 | readinessChecks: []CheckFunc{func() error { return nil }}, 57 | healthCheckRoute: "/custom-health", 58 | readinessCheckRoute: "/custom-ready", 59 | port: 1111, 60 | expHealthy: true, 61 | expReady: true, 62 | }, 63 | } 64 | 65 | for _, test := range tests { 66 | t.Run(test.name, func(t *testing.T) { 67 | ctx, cancel := context.WithCancel(context.Background()) 68 | defer cancel() 69 | 70 | var opts []Option 71 | if test.healthChecks != nil { 72 | opts = append(opts, WithHealthChecks(test.healthChecks...)) 73 | } 74 | 75 | if test.readinessChecks != nil { 76 | opts = append(opts, WithReadinessChecks(test.readinessChecks...)) 77 | } 78 | 79 | if test.healthCheckRoute != "" { 80 | opts = append(opts, WithHealthCheckRoute(test.healthCheckRoute)) 81 | } 82 | 83 | if test.readinessCheckRoute != "" { 84 | opts = append(opts, WithReadinessCheckRoute(test.readinessCheckRoute)) 85 | } 86 | 87 | if test.port != 0 { 88 | opts = append(opts, WithPort(test.port)) 89 | } 90 | 91 | port := 4444 92 | if test.port != 0 { 93 | port = test.port 94 | } 95 | 96 | healthRoute := "/health" 97 | if test.healthCheckRoute != "" { 98 | healthRoute = test.healthCheckRoute 99 | } 100 | 101 | healthURL := fmt.Sprintf("http://:%d/%s", port, healthRoute) 102 | 103 | readinessRoute := "/ready" 104 | if test.readinessCheckRoute != "" { 105 | readinessRoute = test.readinessCheckRoute 106 | } 107 | 108 | readinessURL := fmt.Sprintf("http://:%d/%s", port, readinessRoute) 109 | 110 | go func() { 111 | assert.NoError(t, StartHealthCheckServer(ctx, opts...)) 112 | }() 113 | waitForServerToStart(t, healthURL, 20*time.Millisecond, 1*time.Second) 114 | 115 | resp, err := http.Get(healthURL) 116 | assert.NoError(t, err) 117 | assert.True(t, (test.expHealthy && resp.StatusCode == http.StatusOK) || (!test.expHealthy && resp.StatusCode != http.StatusOK)) 118 | 119 | resp, err = http.Get(readinessURL) 120 | assert.NoError(t, err) 121 | assert.True(t, (test.expReady && resp.StatusCode == http.StatusOK) || (!test.expReady && resp.StatusCode != http.StatusOK)) 122 | 123 | cancel() 124 | 125 | waitForServerToStop(t, healthURL, 20*time.Millisecond, 2*time.Second) 126 | }) 127 | } 128 | } 129 | 130 | func waitForServerToStart(t *testing.T, url string, interval time.Duration, timeout time.Duration) { 131 | waitForServer(t, func() bool { 132 | _, err := http.Get(url) 133 | return err == nil 134 | }, interval, timeout) 135 | } 136 | 137 | func waitForServerToStop(t *testing.T, url string, interval time.Duration, timeout time.Duration) { 138 | waitForServer(t, func() bool { 139 | _, err := http.Get(url) 140 | return err != nil 141 | }, interval, timeout) 142 | } 143 | 144 | func waitForServer(t *testing.T, checkFn func() bool, interval time.Duration, timeout time.Duration) { 145 | tick := time.NewTicker(interval) 146 | defer tick.Stop() 147 | now := time.Now() 148 | for { 149 | if time.Since(now) > timeout { 150 | t.Fatal("timeout to connect to server") 151 | return 152 | } 153 | 154 | <-tick.C 155 | if checkFn() { 156 | return 157 | } 158 | } 159 | } 160 | 161 | func TestServerClosedOnContextCancellation(t *testing.T) { 162 | ctx, cancel := context.WithCancel(context.Background()) 163 | defer cancel() 164 | 165 | go func() { 166 | assert.NoError(t, StartHealthCheckServer(ctx)) 167 | }() 168 | waitForServerToStart(t, "http://:4444/health", 20*time.Millisecond, 1*time.Second) 169 | 170 | cancel() 171 | time.Sleep(time.Millisecond * 100) 172 | _, err := http.Get("http://:4444/health") 173 | assert.Error(t, err) // server was shut down 174 | } 175 | -------------------------------------------------------------------------------- /httplib/downloader.go: -------------------------------------------------------------------------------- 1 | package httplib 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | 8 | log "github.com/sirupsen/logrus" 9 | ) 10 | 11 | type Downloader interface { 12 | Download(url string) ([]byte, error) 13 | } 14 | 15 | type downloader struct { 16 | client http.Client 17 | bytesSizeLimit int64 18 | } 19 | 20 | func (d *downloader) Download(url string) ([]byte, error) { 21 | resp, err := d.client.Get(url) 22 | if err != nil { 23 | return nil, err 24 | } 25 | if !(resp.StatusCode >= 200 && resp.StatusCode <= 299) { 26 | return nil, fmt.Errorf("response status code: %d", resp.StatusCode) 27 | } 28 | defer func() { 29 | if err := resp.Body.Close(); err != nil { 30 | log.WithField("error", err).Error("cannot close request body") 31 | } 32 | }() 33 | 34 | reader, _ := resp.Body.(io.Reader) 35 | if d.bytesSizeLimit > 0 { 36 | reader = &io.LimitedReader{R: resp.Body, N: d.bytesSizeLimit} 37 | } 38 | 39 | b, err := io.ReadAll(reader) 40 | if err != nil { 41 | return nil, err 42 | } 43 | 44 | return b, nil 45 | } 46 | 47 | type DownloaderOption func(d *downloader) error 48 | 49 | func NewDownloader(opts ...DownloaderOption) (Downloader, error) { 50 | d := &downloader{ 51 | client: http.Client{}, 52 | } 53 | 54 | for _, opt := range opts { 55 | if err := opt(d); err != nil { 56 | return nil, err 57 | } 58 | } 59 | 60 | return d, nil 61 | } 62 | 63 | // DownloaderOptionBytesSizeLimit limits the downloaded file size to the provided number of bytes. 64 | func DownloaderOptionBytesSizeLimit(n int64) DownloaderOption { 65 | return func(d *downloader) error { 66 | d.bytesSizeLimit = n 67 | return nil 68 | } 69 | } 70 | 71 | // DownloaderOptionHttpClient sets a custom http client to perform the request. 72 | func DownloaderOptionHttpClient(client http.Client) DownloaderOption { 73 | return func(d *downloader) error { 74 | d.client = client 75 | return nil 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /httplib/server.go: -------------------------------------------------------------------------------- 1 | package httplib 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "sync" 8 | 9 | "golang.org/x/net/http2" 10 | "golang.org/x/net/http2/h2c" 11 | 12 | log "github.com/sirupsen/logrus" 13 | ) 14 | 15 | type Server interface { 16 | Run(ctx context.Context, wg *sync.WaitGroup) 17 | } 18 | 19 | type api struct { 20 | router http.Handler 21 | port string 22 | h2c bool 23 | } 24 | 25 | func NewHTTPServer(router http.Handler, port string) Server { 26 | return &api{ 27 | router: router, 28 | port: port, 29 | h2c: false, 30 | } 31 | } 32 | 33 | func NewH2CServer(router http.Handler, port string) Server { 34 | return &api{ 35 | router: router, 36 | port: port, 37 | h2c: true, 38 | } 39 | } 40 | 41 | func (a *api) Run(ctx context.Context, wg *sync.WaitGroup) { 42 | a.serve(ctx, wg) 43 | } 44 | 45 | func (a *api) serve(ctx context.Context, wg *sync.WaitGroup) { 46 | wg.Add(1) 47 | 48 | h2s := &http2.Server{} 49 | h1d, h2d := a.router, h2c.NewHandler(a.router, h2s) 50 | 51 | server := &http.Server{ 52 | Addr: ":" + a.port, 53 | Handler: h1d, 54 | } 55 | if a.h2c { 56 | server.Handler = h2d 57 | } 58 | 59 | serverStopped := make(chan struct{}) 60 | 61 | go func() { 62 | if err := server.ListenAndServe(); err != nil && errors.Is(err, http.ErrServerClosed) { 63 | log.WithError(err).Debug("Server ListenAndServe") 64 | serverStopped <- struct{}{} 65 | } 66 | }() 67 | 68 | log.WithFields(log.Fields{"bind": a.port}).Info("Starting the API server") 69 | 70 | go func() { 71 | defer func() { wg.Done() }() 72 | 73 | select { 74 | case <-ctx.Done(): 75 | log.Info("Shutting down the server") 76 | 77 | if err := server.Shutdown(context.Background()); err != nil { 78 | log.Info("Server Shutdown: ", err) 79 | } 80 | 81 | return 82 | case <-serverStopped: 83 | return 84 | } 85 | }() 86 | } 87 | -------------------------------------------------------------------------------- /logging/README.md: -------------------------------------------------------------------------------- 1 | # logging package 2 | 3 | Add dependency to the project 4 | 5 | ```sh 6 | go get github.com/trustwallet/go-libs/logging 7 | ``` 8 | 9 | ## Features 10 | 11 | * [Logrus Wrapper](#logrus-wrapper) allows an easy acces to common logger instance as well as override it in tests 12 | * [Logging Configuration](#logging-configuration) allows for the logging configuration to be loaded with viper 13 | * [Strict Text Formatter](#strict-text-formatter) allows to unmarshall boolean `logrus` formatter options as **strings** 14 | 15 | ### Logrus Wrapper 16 | 17 | By default `logrus` operates agains global instance (common for `go` ), but such approach doesn't allow to replace the `logger` instance during testing. 18 | The `logging` package allows to get the current instance of the `logger`: 19 | 20 | ```go 21 | log := logging.GetLogger() 22 | ``` 23 | 24 | Also there is a helper method to get log Entry with `component` filed preset: 25 | 26 | ```go 27 | log := logging.GetLogger().WithField("module", "market") 28 | log.Info("some log entry") 29 | // time="2021-08-19T12:33:21Z" level=info msg="some log entry" module="market" 30 | ``` 31 | 32 | For testing purposes the `logger` instance can be replaced: 33 | 34 | ```go 35 | func TestMyService (t *testing.T) { 36 | testLogger, hook := test.NewNullLogger() 37 | testLogger.SetLevel(logrus.WarnLevel) 38 | logging.SetLogger(testLogger) 39 | 40 | // create instance of service which 41 | // utilises logging.GetLogger() inside 42 | s := service.NewService() 43 | s.DoSomWork() 44 | 45 | // all logged messages are available here 46 | for _, e := range hook.Entries { 47 | t.Log(e) 48 | } 49 | } 50 | ``` 51 | 52 | ### Logging Configuration 53 | 54 | Utilizes [Logrus Mate](https://github.com/gogap/logrus_mate) and 55 | [Logrus Helper](https://github.com/heirko/go-contrib/tree/master/logrusHelper) to load configuration with [Viper](https://github.com/spf13/viper) 🐍 56 | 57 | Which means it can be easily specified via config file per environment (e.g. disable timestamps when deployed to Heroku) 58 | 59 | Assuming the `config.yml` 60 | 61 | ```yaml 62 | market: 63 | foo: bar 64 | 65 | logging: 66 | level: debug 67 | formatter: 68 | name: text 69 | options: 70 | disable_timestamp: true 71 | ``` 72 | 73 | And the corresponding go `struct`: 74 | 75 | ```go 76 | type Configuration struct { 77 | Market struct { 78 | Foo string `mapstructure:"foo"` 79 | } `mapstructure:"market"` 80 | Logging logging.Config `mapstructure:"logging"` 81 | } 82 | ``` 83 | 84 | Once viper has unmarshalled the configuration taken from all sources: 85 | 86 | ```go 87 | err = logging.SetLoggerConfig(config.Logging) 88 | if err != nil { 89 | // ... 90 | } 91 | 92 | log := logging.GetLogger() 93 | ``` 94 | 95 | ✨ It's fully backward compatible with code which uses `logrus` directly. 96 | 97 | ```go 98 | 99 | import log "github.com/sirupsen/logrus" 100 | 101 | func LogSomething() { 102 | // respects logging configuration set with 103 | // logging.SetLoggerConfig(...) 104 | log.Info("some log message") 105 | } 106 | ``` 107 | 108 | ### Strict Text Formatter 109 | 110 | This package contains a `strict_text` formatter which replicates 111 | [Logrus Mate](https://github.com/gogap/logrus_mate) `text` formatter behaviour 112 | with a small difference that every boolean Option **should** be passed as a string. 113 | This allows to correctly override logging configuration from environment variables. 114 | 115 | To demonstrate the issue assuming the `config.yml`: 116 | 117 | ```yaml 118 | logging: 119 | level: debug 120 | formatter: 121 | name: text 122 | ``` 123 | 124 | When application executed with the Environment variable override `LOGGING_FORMATTER_OPTIONS_DISABLE_TIMESTAMP=true` the 125 | config will be equally represented as: 126 | 127 | ```yaml 128 | logging: 129 | level: debug 130 | formatter: 131 | name: text 132 | options: 133 | disable_timestamp: "true" 134 | ``` 135 | 136 | Notice, the `disable_timestamp` option which will be of type `interface {} | string` 137 | when unmarshalled by `viper`. 138 | The [Logrus Mate](https://github.com/gogap/logrus_mate) `text` formatter cannot 139 | handle it and throws an error: 140 | 141 | ```txt 142 | json: cannot unmarshal string into Go struct field TextFormatterConfig.disable_timestamp of type bool 143 | ``` 144 | -------------------------------------------------------------------------------- /logging/formatter_strict_text.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | mate "github.com/heralight/logrus_mate" 5 | "github.com/sirupsen/logrus" 6 | ) 7 | 8 | type TextFormatterConfig struct { 9 | ForceColors bool `json:"force_colors,string"` 10 | DisableColors bool `json:"disable_colors,string"` 11 | DisableTimestamp bool `json:"disable_timestamp,string"` 12 | FullTimestamp bool `json:"full_timestamp,string"` 13 | TimestampFormat string `json:"timestamp_format"` 14 | DisableSorting bool `json:"disable_sorting,string"` 15 | } 16 | 17 | func init() { 18 | mate.RegisterFormatter("strict_text", NewTextFormatter) 19 | } 20 | 21 | func NewTextFormatter(options mate.Options) (formatter logrus.Formatter, err error) { 22 | conf := TextFormatterConfig{} 23 | 24 | if err = options.ToObject(&conf); err != nil { 25 | return 26 | } 27 | 28 | formatter = &logrus.TextFormatter{ 29 | ForceColors: conf.ForceColors, 30 | DisableColors: conf.DisableColors, 31 | DisableTimestamp: conf.DisableTimestamp, 32 | FullTimestamp: conf.FullTimestamp, 33 | TimestampFormat: conf.TimestampFormat, 34 | DisableSorting: conf.DisableSorting, 35 | } 36 | return 37 | } 38 | -------------------------------------------------------------------------------- /logging/logger.go: -------------------------------------------------------------------------------- 1 | package logging 2 | 3 | import ( 4 | "github.com/heirko/go-contrib/logrusHelper" 5 | mate "github.com/heralight/logrus_mate" 6 | "github.com/sirupsen/logrus" 7 | ) 8 | 9 | var logger *logrus.Logger 10 | 11 | const FieldKeyComponent = "component" 12 | 13 | type Config mate.LoggerConfig 14 | 15 | func init() { 16 | logger = logrus.New() 17 | } 18 | 19 | func SetLoggerConfig(config Config) error { 20 | err := logrusHelper.SetConfig(logrus.StandardLogger(), mate.LoggerConfig(config)) 21 | if err != nil { 22 | return err 23 | } 24 | return logrusHelper.SetConfig(logger, mate.LoggerConfig(config)) 25 | } 26 | 27 | // GetLogger returns the logger instance. 28 | func GetLogger() *logrus.Logger { 29 | return logger 30 | } 31 | 32 | // GetLoggerForComponent returns the logger instance with component field set 33 | func GetLoggerForComponent(component string) *logrus.Entry { 34 | return GetLogger().WithField(FieldKeyComponent, component) 35 | } 36 | 37 | // SetLogger sets the logger instance 38 | // This is useful in testing as the logger can be overridden 39 | // with a test logger 40 | func SetLogger(l *logrus.Logger) { 41 | logger = l 42 | } 43 | -------------------------------------------------------------------------------- /logging/logger_test.go: -------------------------------------------------------------------------------- 1 | package logging_test 2 | 3 | import ( 4 | "bytes" 5 | "strconv" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/heirko/go-contrib/logrusHelper" 10 | "github.com/sirupsen/logrus" 11 | "github.com/sirupsen/logrus/hooks/test" 12 | "github.com/spf13/viper" 13 | "github.com/trustwallet/go-libs/logging" 14 | "gotest.tools/assert" 15 | ) 16 | 17 | func TestGetLogger(t *testing.T) { 18 | logger := logging.GetLogger() 19 | 20 | assert.Equal(t, logger.Level, logrus.InfoLevel, "default logger minimum level is Info") 21 | } 22 | 23 | func TestGetLoggerForComponent(t *testing.T) { 24 | logger1 := logging.GetLoggerForComponent("logger1") 25 | logger1 = logger1.WithField("custom", "logger1_only") 26 | 27 | logger2 := logging.GetLoggerForComponent("logger2") 28 | 29 | logAndAssertText(t, logger1, func(fields map[string]string) { 30 | assert.Equal(t, "logger1", fields[logging.FieldKeyComponent]) 31 | }) 32 | logAndAssertText(t, logger2, func(fields map[string]string) { 33 | assert.Equal(t, "logger2", fields[logging.FieldKeyComponent]) 34 | 35 | _, ok := fields["custom"] 36 | assert.Assert(t, !ok, "custom field should exist on logger1 only") 37 | }) 38 | } 39 | 40 | func TestParseConfigWithViper(t *testing.T) { 41 | yamlConfig := []byte(` 42 | logging: 43 | out: 44 | name: stdout 45 | level: debug 46 | formatter: 47 | name: text 48 | options: 49 | disable_colors: true 50 | full_timestamp: false 51 | hooks: 52 | - name: file 53 | options: 54 | filename: debug.log, 55 | maxsize: 5000, 56 | maxdays: 1, 57 | rotate: true, 58 | priority: LOG_INFO, 59 | tag: "" 60 | `) 61 | 62 | viper.SetConfigType("yaml") 63 | err := viper.ReadConfig(bytes.NewBuffer(yamlConfig)) 64 | assert.NilError(t, err) 65 | t.Logf("All keys: %#v", viper.AllSettings()) 66 | 67 | logger := logging.GetLogger() 68 | // Unmarshal configuration from Viper 69 | var c = logrusHelper.UnmarshalConfiguration(viper.Sub("logging")) 70 | err = logrusHelper.SetConfig(logger, c) 71 | assert.NilError(t, err) 72 | 73 | assert.Equal(t, logger.Level, logrus.DebugLevel, "logging level set to debug via config") 74 | } 75 | 76 | func TestSetLoggerConfig(t *testing.T) { 77 | yamlConfig := []byte(` 78 | logging: 79 | level: debug 80 | formatter: 81 | name: text 82 | options: 83 | disable_timestamp: true 84 | `) 85 | 86 | viper.SetConfigType("yaml") 87 | err := viper.ReadConfig(bytes.NewBuffer(yamlConfig)) 88 | assert.NilError(t, err) 89 | 90 | var config logging.Config 91 | err = viper.UnmarshalKey("logging", &config) 92 | assert.NilError(t, err) 93 | 94 | err = logging.SetLoggerConfig(config) 95 | assert.NilError(t, err) 96 | 97 | logger := logging.GetLogger() 98 | assert.Equal(t, logger.Level, logrus.DebugLevel, "logging level set to debug via config") 99 | assert.Equal(t, logger.Formatter.(*logrus.TextFormatter).DisableTimestamp, true) 100 | } 101 | 102 | func TestOverrideBoolOptionAsString(t *testing.T) { 103 | yamlConfig := []byte(` 104 | logging: 105 | level: debug 106 | formatter: 107 | name: strict_text 108 | options: 109 | disable_timestamp: "true" 110 | `) 111 | 112 | viper.SetConfigType("yaml") 113 | err := viper.ReadConfig(bytes.NewBuffer(yamlConfig)) 114 | assert.NilError(t, err) 115 | 116 | var config logging.Config 117 | err = viper.UnmarshalKey("logging", &config) 118 | assert.NilError(t, err) 119 | 120 | err = logging.SetLoggerConfig(config) 121 | assert.NilError(t, err) 122 | 123 | logger := logging.GetLogger() 124 | assert.Equal(t, logger.Level, logrus.DebugLevel, "logging level set to debug via config") 125 | assert.Equal(t, logger.Formatter.(*logrus.TextFormatter).DisableTimestamp, true) 126 | } 127 | 128 | func TestSetLoggerConfigForStandardLogger(t *testing.T) { 129 | // Not every component would be able to use logging.GetLogger() 130 | // This test makes sure the config loaded with viper is also 131 | // applied globally to standard logger 132 | 133 | yamlConfig := []byte(` 134 | logging: 135 | level: debug 136 | formatter: 137 | name: text 138 | options: 139 | disable_timestamp: true 140 | `) 141 | 142 | viper.SetConfigType("yaml") 143 | err := viper.ReadConfig(bytes.NewBuffer(yamlConfig)) 144 | assert.NilError(t, err) 145 | 146 | var config logging.Config 147 | err = viper.UnmarshalKey("logging", &config) 148 | assert.NilError(t, err) 149 | 150 | err = logging.SetLoggerConfig(config) 151 | assert.NilError(t, err) 152 | 153 | std := logrus.StandardLogger() 154 | assert.Equal(t, std.Formatter.(*logrus.TextFormatter).DisableTimestamp, true) 155 | } 156 | 157 | func TestSetLogger(t *testing.T) { 158 | testLogger, hook := test.NewNullLogger() 159 | testLogger.SetLevel(logrus.WarnLevel) 160 | logging.SetLogger(testLogger) 161 | 162 | logger1 := logging.GetLogger() 163 | logger1.Info("you should not see me printed") 164 | logger1.Warn("you should see this printed") 165 | 166 | logger2 := logging.GetLoggerForComponent("testing") 167 | logger2.Debug("you should not see me too") 168 | logger1.Error("you should see this printed") 169 | 170 | for _, e := range hook.Entries { 171 | t.Log(e) 172 | } 173 | assert.Equal(t, len(hook.Entries), 2) 174 | } 175 | 176 | func logAndAssertText(t *testing.T, entry *logrus.Entry, assertions func(fields map[string]string)) { 177 | var buffer bytes.Buffer 178 | entry.Logger.Out = &buffer 179 | entry.Logger.Formatter.(*logrus.TextFormatter).DisableColors = true 180 | entry.Info() 181 | 182 | fields := make(map[string]string) 183 | for _, kv := range strings.Split(strings.TrimRight(buffer.String(), "\n"), " ") { 184 | if !strings.Contains(kv, "=") { 185 | continue 186 | } 187 | kvArr := strings.Split(kv, "=") 188 | key := strings.TrimSpace(kvArr[0]) 189 | val := kvArr[1] 190 | if kvArr[1][0] == '"' { 191 | var err error 192 | val, err = strconv.Unquote(val) 193 | assert.NilError(t, err) 194 | } 195 | fields[key] = val 196 | } 197 | assertions(fields) 198 | } 199 | -------------------------------------------------------------------------------- /metrics/handler.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | "github.com/prometheus/client_golang/prometheus" 6 | "github.com/prometheus/client_golang/prometheus/collectors" 7 | "github.com/prometheus/client_golang/prometheus/promhttp" 8 | 9 | "github.com/trustwallet/go-libs/httplib" 10 | ) 11 | 12 | func InitHandler(engine *gin.Engine, path string) { 13 | engine.GET(path, gin.WrapH(promhttp.Handler())) 14 | } 15 | 16 | func NewMetricsServer(appName string, port string, path string) httplib.Server { 17 | router := gin.Default() 18 | 19 | prometheus.DefaultRegisterer.Unregister(collectors.NewGoCollector()) 20 | prometheus.DefaultRegisterer.Unregister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) 21 | InitHandler(router, path) 22 | 23 | return httplib.NewHTTPServer(router, port) 24 | } 25 | -------------------------------------------------------------------------------- /metrics/http_metrics.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | "time" 6 | ) 7 | 8 | const ( 9 | requestStartedKey = "request_started" 10 | requestDurationSecondsKey = "request_duration_seconds" 11 | requestSucceededTotalKey = "request_succeeded_total" 12 | requestClientErrTotalKey = "request_client_error_total" 13 | requestServerErrTotalKey = "request_server_error_total" 14 | ) 15 | 16 | type HttpServerMetric interface { 17 | Start(labelValues ...string) time.Time 18 | Duration(start time.Time, labelValues ...string) 19 | Success(labelValues ...string) 20 | ServerError(labelValues ...string) 21 | ClientError(labelValues ...string) 22 | } 23 | 24 | type httpServerMetric struct { 25 | requestStarted *prometheus.GaugeVec 26 | requestDurationSeconds *prometheus.HistogramVec 27 | requestSucceededTotal *prometheus.CounterVec 28 | requestClientErrTotal *prometheus.CounterVec 29 | requestServerErrTotal *prometheus.CounterVec 30 | } 31 | 32 | func NewHttpServerMetric( 33 | namespace string, 34 | labelNames []string, 35 | staticLabels prometheus.Labels, 36 | reg prometheus.Registerer, 37 | ) HttpServerMetric { 38 | requestStarted := prometheus.NewGaugeVec(prometheus.GaugeOpts{ 39 | Namespace: namespace, 40 | Name: requestStartedKey, 41 | Help: "Last Unix time when request started.", 42 | }, labelNames) 43 | 44 | requestDurationSeconds := prometheus.NewHistogramVec(prometheus.HistogramOpts{ 45 | Namespace: namespace, 46 | Name: requestDurationSecondsKey, 47 | Help: "Duration of the executions.", 48 | }, labelNames) 49 | 50 | requestSucceededTotal := prometheus.NewCounterVec(prometheus.CounterOpts{ 51 | Namespace: namespace, 52 | Name: requestSucceededTotalKey, 53 | Help: "Total number of the 2xx requests which succeeded.", 54 | }, labelNames) 55 | 56 | requestClientErrTotal := prometheus.NewCounterVec(prometheus.CounterOpts{ 57 | Namespace: namespace, 58 | Name: requestClientErrTotalKey, 59 | Help: "Total number of the 4xx requests.", 60 | }, labelNames) 61 | 62 | requestServerErrTotal := prometheus.NewCounterVec(prometheus.CounterOpts{ 63 | Namespace: namespace, 64 | Name: requestServerErrTotalKey, 65 | Help: "Total number of the 5xx requests.", 66 | }, labelNames) 67 | 68 | Register(staticLabels, reg, requestStarted, requestDurationSeconds, requestSucceededTotal, requestClientErrTotal, requestServerErrTotal) 69 | 70 | return &httpServerMetric{ 71 | requestStarted: requestStarted, 72 | requestDurationSeconds: requestDurationSeconds, 73 | requestSucceededTotal: requestSucceededTotal, 74 | requestClientErrTotal: requestClientErrTotal, 75 | requestServerErrTotal: requestServerErrTotal, 76 | } 77 | } 78 | 79 | func (m *httpServerMetric) Start(labelValues ...string) time.Time { 80 | start := time.Now() 81 | m.requestStarted.WithLabelValues(labelValues...).SetToCurrentTime() 82 | return start 83 | } 84 | 85 | func (m *httpServerMetric) Duration(start time.Time, labelValues ...string) { 86 | duration := time.Since(start) 87 | m.requestDurationSeconds.WithLabelValues(labelValues...).Observe(duration.Seconds()) 88 | } 89 | 90 | func (m *httpServerMetric) Success(labelValues ...string) { 91 | m.requestSucceededTotal.WithLabelValues(labelValues...).Inc() 92 | m.requestServerErrTotal.WithLabelValues(labelValues...).Add(0) 93 | m.requestClientErrTotal.WithLabelValues(labelValues...).Add(0) 94 | } 95 | 96 | func (m *httpServerMetric) ServerError(labelValues ...string) { 97 | m.requestSucceededTotal.WithLabelValues(labelValues...).Add(0) 98 | m.requestServerErrTotal.WithLabelValues(labelValues...).Inc() 99 | m.requestClientErrTotal.WithLabelValues(labelValues...).Add(0) 100 | } 101 | 102 | func (m *httpServerMetric) ClientError(labelValues ...string) { 103 | m.requestSucceededTotal.WithLabelValues(labelValues...).Add(0) 104 | m.requestServerErrTotal.WithLabelValues(labelValues...).Add(0) 105 | m.requestClientErrTotal.WithLabelValues(labelValues...).Inc() 106 | } 107 | -------------------------------------------------------------------------------- /metrics/metrics.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/prometheus/client_golang/prometheus" 7 | ) 8 | 9 | const ( 10 | executionStartedKey = "execution_started" 11 | executionDurationSecondsKey = "execution_duration_seconds" 12 | executionSucceededTotalKey = "execution_succeeded_total" 13 | executionFailedTotalKey = "execution_failed_total" 14 | ) 15 | 16 | type Collectors map[string]prometheus.Collector 17 | 18 | type PerformanceMetric interface { 19 | Start(labelValues ...string) time.Time 20 | Duration(start time.Time, labelValues ...string) 21 | Success(labelValues ...string) 22 | Failure(labelValues ...string) 23 | } 24 | 25 | type performanceMetric struct { 26 | executionStarted *prometheus.GaugeVec 27 | executionDurationSeconds *prometheus.HistogramVec 28 | executionSucceededTotal *prometheus.CounterVec 29 | executionFailedTotal *prometheus.CounterVec 30 | } 31 | 32 | func NewPerformanceMetric( 33 | namespace string, 34 | labelNames []string, 35 | staticLabels prometheus.Labels, 36 | reg prometheus.Registerer, 37 | ) PerformanceMetric { 38 | executionStarted := prometheus.NewGaugeVec(prometheus.GaugeOpts{ 39 | Namespace: namespace, 40 | Name: executionStartedKey, 41 | Help: "Last Unix time when execution started.", 42 | }, labelNames) 43 | 44 | executionDurationSeconds := prometheus.NewHistogramVec(prometheus.HistogramOpts{ 45 | Namespace: namespace, 46 | Name: executionDurationSecondsKey, 47 | Help: "Duration of the executions.", 48 | }, labelNames) 49 | 50 | executionSucceededTotal := prometheus.NewCounterVec(prometheus.CounterOpts{ 51 | Namespace: namespace, 52 | Name: executionSucceededTotalKey, 53 | Help: "Total number of the executions which succeeded.", 54 | }, labelNames) 55 | 56 | executionFailedTotal := prometheus.NewCounterVec(prometheus.CounterOpts{ 57 | Namespace: namespace, 58 | Name: executionFailedTotalKey, 59 | Help: "Total number of the executions which failed.", 60 | }, labelNames) 61 | 62 | Register(staticLabels, reg, executionStarted, executionDurationSeconds, executionSucceededTotal, executionFailedTotal) 63 | 64 | return &performanceMetric{ 65 | executionStarted: executionStarted, 66 | executionDurationSeconds: executionDurationSeconds, 67 | executionSucceededTotal: executionSucceededTotal, 68 | executionFailedTotal: executionFailedTotal, 69 | } 70 | } 71 | 72 | func (m *performanceMetric) Start(labelValues ...string) time.Time { 73 | start := time.Now() 74 | m.executionStarted.WithLabelValues(labelValues...).SetToCurrentTime() 75 | return start 76 | } 77 | 78 | func (m *performanceMetric) Duration(start time.Time, labelValues ...string) { 79 | duration := time.Since(start) 80 | m.executionDurationSeconds.WithLabelValues(labelValues...).Observe(duration.Seconds()) 81 | } 82 | 83 | func (m *performanceMetric) Success(labelValues ...string) { 84 | m.executionSucceededTotal.WithLabelValues(labelValues...).Inc() 85 | m.executionFailedTotal.WithLabelValues(labelValues...).Add(0) 86 | } 87 | 88 | func (m *performanceMetric) Failure(labelValues ...string) { 89 | m.executionFailedTotal.WithLabelValues(labelValues...).Inc() 90 | m.executionSucceededTotal.WithLabelValues(labelValues...).Add(0) 91 | } 92 | 93 | type NullablePerformanceMetric struct{} 94 | 95 | func (NullablePerformanceMetric) Start(_ ...string) time.Time { 96 | // NullablePerformanceMetric is a no-op, so returning empty value 97 | return time.Time{} 98 | } 99 | func (NullablePerformanceMetric) Duration(_ time.Time, _ ...string) {} 100 | func (NullablePerformanceMetric) Success(_ ...string) {} 101 | func (NullablePerformanceMetric) Failure(_ ...string) {} 102 | -------------------------------------------------------------------------------- /metrics/pusher.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "net/http" 5 | "os" 6 | 7 | "github.com/prometheus/client_golang/prometheus" 8 | "github.com/prometheus/client_golang/prometheus/push" 9 | 10 | "github.com/trustwallet/go-libs/client" 11 | ) 12 | 13 | type MetricsPusherClient struct { 14 | client client.Request 15 | } 16 | 17 | func NewMetricsPusherClient(pushURL, key string, errorHandler client.HttpErrorHandler) *MetricsPusherClient { 18 | client := client.InitClient(pushURL, errorHandler, client.WithExtraHeader("X-API-Key", key)) 19 | 20 | return &MetricsPusherClient{ 21 | client: client, 22 | } 23 | } 24 | 25 | func (c *MetricsPusherClient) Do(req *http.Request) (*http.Response, error) { 26 | for key, value := range c.client.Headers { 27 | req.Header.Set(key, value) 28 | } 29 | return c.client.HttpClient.Do(req) 30 | } 31 | 32 | type Pusher interface { 33 | Push() error 34 | Close() error 35 | } 36 | 37 | type pusher struct { 38 | pusher *push.Pusher 39 | } 40 | 41 | func NewPusher(pushgatewayURL, jobName string) Pusher { 42 | return &pusher{ 43 | pusher: push.New(pushgatewayURL, jobName). 44 | Grouping("instance", instanceID()). 45 | Gatherer(prometheus.DefaultGatherer), 46 | } 47 | } 48 | 49 | func NewPusherWithCustomClient(pushgatewayURL, jobName string, client client.HTTPClient) Pusher { 50 | return &pusher{ 51 | pusher: push.New(pushgatewayURL, string(jobName)). 52 | Grouping("instance", instanceID()). 53 | Gatherer(prometheus.DefaultGatherer). 54 | Client(client), 55 | } 56 | } 57 | 58 | func (p *pusher) Push() error { 59 | return p.pusher.Push() 60 | } 61 | 62 | func (p *pusher) Close() error { 63 | return p.pusher.Delete() 64 | } 65 | 66 | func instanceID() string { 67 | envKeysToTry := []string{"DYNO", "INSTANCE_ID", "HOSTNAME"} 68 | for _, key := range envKeysToTry { 69 | curr := os.Getenv(key) 70 | if curr != "" { 71 | return curr 72 | } 73 | } 74 | return "local" 75 | } 76 | -------------------------------------------------------------------------------- /metrics/register.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | 6 | "github.com/trustwallet/go-libs/logging" 7 | ) 8 | 9 | func Register(labels prometheus.Labels, reg prometheus.Registerer, collectors ...prometheus.Collector) { 10 | registerer := prometheus.WrapRegistererWith(labels, reg) 11 | for _, c := range collectors { 12 | err := registerer.Register(c) 13 | if err != nil { 14 | if _, ok := err.(*prometheus.AlreadyRegisteredError); !ok { 15 | logging.GetLogger().WithError(err). 16 | Error("failed to register job duration metrics with prometheus") 17 | } 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /middleware/cache.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "crypto/sha1" 6 | "encoding/base64" 7 | "encoding/json" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net/http" 12 | "sync" 13 | "time" 14 | 15 | "github.com/gin-gonic/gin" 16 | "github.com/patrickmn/go-cache" 17 | log "github.com/sirupsen/logrus" 18 | ) 19 | 20 | var memoryCache *memCache 21 | 22 | func init() { 23 | memoryCache = &memCache{cache: cache.New(5*time.Minute, 5*time.Minute)} 24 | } 25 | 26 | type memCache struct { 27 | sync.RWMutex 28 | cache *cache.Cache 29 | } 30 | 31 | type cacheResponse struct { 32 | Status int 33 | Header http.Header 34 | Data []byte 35 | } 36 | 37 | type cachedWriter struct { 38 | gin.ResponseWriter 39 | status int 40 | written bool 41 | expire time.Duration 42 | key string 43 | } 44 | 45 | func newCachedWriter(expire time.Duration, writer gin.ResponseWriter, key string) *cachedWriter { 46 | return &cachedWriter{writer, 0, false, expire, key} 47 | } 48 | 49 | func (w *cachedWriter) WriteHeader(code int) { 50 | w.status = code 51 | w.written = true 52 | w.ResponseWriter.WriteHeader(code) 53 | } 54 | 55 | func (w *cachedWriter) Status() int { 56 | return w.ResponseWriter.Status() 57 | } 58 | 59 | func (w *cachedWriter) Written() bool { 60 | return w.ResponseWriter.Written() 61 | } 62 | 63 | func (w *cachedWriter) Write(data []byte) (int, error) { 64 | ret, err := w.ResponseWriter.Write(data) 65 | if err != nil { 66 | return 0, err 67 | } 68 | if w.Status() != 200 { 69 | return 0, nil 70 | } 71 | val := cacheResponse{ 72 | w.Status(), 73 | w.Header(), 74 | data, 75 | } 76 | b, err := json.Marshal(val) 77 | if err != nil { 78 | return 0, errors.New("validator cache: failed to marshal cache object") 79 | } 80 | memoryCache.cache.Set(w.key, b, w.expire) 81 | return ret, nil 82 | } 83 | 84 | func (w *cachedWriter) WriteString(data string) (n int, err error) { 85 | ret, err := w.ResponseWriter.WriteString(data) 86 | if err != nil { 87 | return 0, errors.New(err.Error() + " fail to cache write string") 88 | } 89 | if w.Status() != 200 { 90 | return 0, errors.New("WriteString: invalid cache status") 91 | } 92 | val := cacheResponse{ 93 | w.Status(), 94 | w.Header(), 95 | []byte(data), 96 | } 97 | b, err := json.Marshal(val) 98 | if err != nil { 99 | return 0, errors.New("validator cache: failed to marshal cache object") 100 | } 101 | memoryCache.setCache(w.key, b, w.expire) 102 | return ret, err 103 | } 104 | 105 | func (mc *memCache) deleteCache(key string) { 106 | mc.RLock() 107 | defer mc.RUnlock() 108 | memoryCache.cache.Delete(key) 109 | } 110 | 111 | func (mc *memCache) setCache(k string, x interface{}, d time.Duration) { 112 | b, err := json.Marshal(x) 113 | if err != nil { 114 | log.Error(errors.New(err.Error() + " client cache cannot marshal cache object")) 115 | return 116 | } 117 | mc.RLock() 118 | defer mc.RUnlock() 119 | memoryCache.cache.Set(k, b, d) 120 | } 121 | 122 | func (mc *memCache) getCache(key string) (cacheResponse, error) { 123 | var result cacheResponse 124 | c, ok := mc.cache.Get(key) 125 | if !ok { 126 | return result, fmt.Errorf("gin-cache: invalid cache key %s", key) 127 | } 128 | r, ok := c.([]byte) 129 | if !ok { 130 | return result, errors.New("validator cache: failed to cast cache to bytes") 131 | } 132 | err := json.Unmarshal(r, &result) 133 | if err != nil { 134 | return result, errors.New(err.Error() + "not found") 135 | } 136 | return result, nil 137 | } 138 | 139 | func generateKey(c *gin.Context) string { 140 | url := c.Request.URL.String() 141 | var b []byte 142 | if c.Request.Body != nil { 143 | b, _ = io.ReadAll(c.Request.Body) 144 | // Restore the io.ReadCloser to its original state 145 | c.Request.Body = io.NopCloser(bytes.NewBuffer(b)) 146 | } 147 | hash := sha1.Sum(append([]byte(url), b...)) 148 | return base64.URLEncoding.EncodeToString(hash[:]) 149 | } 150 | 151 | // CacheMiddleware encapsulates a gin handler function and caches the model with an expiration time. 152 | func CacheMiddleware(expiration time.Duration, handle gin.HandlerFunc) gin.HandlerFunc { 153 | return func(c *gin.Context) { 154 | defer c.Next() 155 | key := generateKey(c) 156 | cacheControlValue := uint(expiration.Seconds()) 157 | mc, err := memoryCache.getCache(key) 158 | if err != nil || mc.Data == nil { 159 | writer := newCachedWriter(expiration, c.Writer, key) 160 | 161 | writer.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", cacheControlValue)) 162 | 163 | c.Writer = writer 164 | handle(c) 165 | if c.IsAborted() { 166 | memoryCache.deleteCache(key) 167 | } 168 | return 169 | } 170 | 171 | c.Writer.WriteHeader(mc.Status) 172 | for k, vals := range mc.Header { 173 | for _, v := range vals { 174 | c.Writer.Header().Set(k, v) 175 | } 176 | } 177 | 178 | c.Writer.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", cacheControlValue)) 179 | 180 | _, err = c.Writer.Write(mc.Data) 181 | if err != nil { 182 | memoryCache.deleteCache(key) 183 | log.Error(err, "cannot write data", mc) 184 | } 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /middleware/cache_control.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | func CacheControl(duration time.Duration, handle gin.HandlerFunc) gin.HandlerFunc { 11 | return func(c *gin.Context) { 12 | defer c.Next() 13 | cacheControlValue := uint(duration.Seconds()) 14 | c.Writer.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", cacheControlValue)) 15 | handle(c) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /middleware/cache_control_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "testing" 7 | "time" 8 | 9 | "github.com/gin-gonic/gin" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func init() { 14 | gin.SetMode(gin.TestMode) 15 | } 16 | 17 | func TestCacheControl(t *testing.T) { 18 | router := gin.New() 19 | router.GET("/cache_ping_control", CacheControl(time.Second*30, func(c *gin.Context) { 20 | c.JSON(http.StatusOK, "pong "+fmt.Sprint(time.Now().UnixNano())) 21 | })) 22 | 23 | w1 := performRequest("GET", "/cache_ping_control", router) 24 | w1CacheControl := w1.Header().Get("Cache-Control") 25 | assert.NotEqual(t, "no-cache", w1CacheControl) 26 | time.Sleep(time.Second * 1) 27 | w2 := performRequest("GET", "/cache_ping_control", router) 28 | w2CacheControl := w2.Header().Get("Cache-Control") 29 | 30 | assert.Equal(t, w1CacheControl, w2CacheControl) 31 | assert.Equal(t, "max-age=30", w2CacheControl) 32 | 33 | assert.Equal(t, http.StatusOK, w1.Code) 34 | assert.Equal(t, http.StatusOK, w2.Code) 35 | 36 | } 37 | -------------------------------------------------------------------------------- /middleware/cache_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | "time" 9 | 10 | "github.com/gin-gonic/gin" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func init() { 15 | gin.SetMode(gin.TestMode) 16 | } 17 | 18 | func performRequest(method, target string, router *gin.Engine) *httptest.ResponseRecorder { 19 | r := httptest.NewRequest(method, target, nil) 20 | w := httptest.NewRecorder() 21 | router.ServeHTTP(w, r) 22 | return w 23 | } 24 | 25 | func TestWrite(t *testing.T) { 26 | w := httptest.NewRecorder() 27 | c, _ := gin.CreateTestContext(w) 28 | 29 | writer := newCachedWriter(time.Second*3, c.Writer, "mykey") 30 | c.Writer = writer 31 | 32 | c.Writer.WriteHeader(http.StatusNoContent) 33 | c.Writer.WriteHeaderNow() 34 | _, _ = c.Writer.Write([]byte("foo")) // nolint 35 | assert.Equal(t, http.StatusNoContent, c.Writer.Status()) 36 | assert.Equal(t, "foo", w.Body.String()) 37 | assert.True(t, c.Writer.Written()) 38 | } 39 | 40 | func TestCachePage(t *testing.T) { 41 | router := gin.New() 42 | router.GET("/cache_ping", CacheMiddleware(time.Second*3, func(c *gin.Context) { 43 | c.JSON(http.StatusOK, "pong "+fmt.Sprint(time.Now().UnixNano())) 44 | })) 45 | 46 | w1 := performRequest("GET", "/cache_ping", router) 47 | w2 := performRequest("GET", "/cache_ping", router) 48 | 49 | assert.Equal(t, http.StatusOK, w1.Code) 50 | assert.Equal(t, http.StatusOK, w2.Code) 51 | assert.Equal(t, w1.Body.String(), w2.Body.String()) 52 | } 53 | 54 | func TestCachePageExpire(t *testing.T) { 55 | router := gin.New() 56 | router.GET("/cache_ping", CacheMiddleware(time.Second, func(c *gin.Context) { 57 | c.JSON(http.StatusOK, "pong "+fmt.Sprint(time.Now().UnixNano())) 58 | })) 59 | 60 | w1 := performRequest("GET", "/cache_ping", router) 61 | time.Sleep(time.Second * 3) 62 | w2 := performRequest("GET", "/cache_ping", router) 63 | 64 | assert.Equal(t, http.StatusOK, w1.Code) 65 | assert.Equal(t, http.StatusOK, w2.Code) 66 | assert.NotEqual(t, w1.Body.String(), w2.Body.String()) 67 | } 68 | 69 | func TestCacheControlMemory(t *testing.T) { 70 | router := gin.New() 71 | router.GET("/cache_ping_control", CacheMiddleware(time.Second*30, func(c *gin.Context) { 72 | c.JSON(http.StatusOK, "pong "+fmt.Sprint(time.Now().UnixNano())) 73 | })) 74 | 75 | w1 := performRequest("GET", "/cache_ping_control", router) 76 | w1CacheControl := w1.Header().Get("Cache-Control") 77 | assert.NotEqual(t, "no-cache", w1CacheControl) 78 | time.Sleep(time.Second * 1) 79 | w2 := performRequest("GET", "/cache_ping_control", router) 80 | w2CacheControl := w2.Header().Get("Cache-Control") 81 | 82 | assert.Equal(t, w1CacheControl, w2CacheControl) 83 | assert.Equal(t, w1.Body.String(), w2.Body.String()) 84 | 85 | assert.Equal(t, http.StatusOK, w1.Code) 86 | assert.Equal(t, http.StatusOK, w2.Code) 87 | 88 | } 89 | -------------------------------------------------------------------------------- /middleware/logger.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gin-gonic/gin" 6 | ) 7 | 8 | func Logger(skipPaths ...string) gin.HandlerFunc { 9 | return gin.LoggerWithConfig(gin.LoggerConfig{ 10 | Formatter: LoggerFormatter(), 11 | SkipPaths: skipPaths, 12 | }) 13 | } 14 | 15 | func LoggerFormatter() gin.LogFormatter { 16 | return func(param gin.LogFormatterParams) string { 17 | return fmt.Sprintf("%s - \"%s %s %s %d %s \"%s\" %s\"\n", 18 | param.ClientIP, 19 | param.Method, 20 | param.Path, 21 | param.Request.Proto, 22 | param.StatusCode, 23 | param.Latency, 24 | param.Request.UserAgent(), 25 | param.ErrorMessage, 26 | ) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /middleware/metrics.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "strconv" 5 | 6 | "github.com/gin-gonic/gin" 7 | "github.com/prometheus/client_golang/prometheus" 8 | 9 | "github.com/trustwallet/go-libs/metrics" 10 | ) 11 | 12 | const labelPath = "path" 13 | const labelMethod = "method" 14 | const labelStatus = "status" 15 | 16 | const ( 17 | _ = iota 18 | _ 19 | labelStatusIndex 20 | ) 21 | 22 | func MetricsMiddleware(namespace string, labels prometheus.Labels, reg prometheus.Registerer) gin.HandlerFunc { 23 | perfMetric := metrics.NewHttpServerMetric(namespace, []string{labelPath, labelMethod, labelStatus}, labels, reg) 24 | 25 | return func(c *gin.Context) { 26 | path := c.FullPath() 27 | method := c.Request.Method 28 | 29 | // route not found, call next and immediately return 30 | if path == "" { 31 | c.Next() 32 | return 33 | } 34 | 35 | labelValues := []string{path, method, "none"} 36 | 37 | startTime := perfMetric.Start(labelValues...) 38 | 39 | c.Next() 40 | 41 | var ( 42 | statusCode = c.Writer.Status() 43 | statusCodeStr = strconv.FormatInt(int64(statusCode), 10) 44 | ) 45 | labelValues[labelStatusIndex] = statusCodeStr 46 | 47 | // record duration with status code 48 | perfMetric.Duration(startTime, labelValues...) 49 | 50 | switch { 51 | case 200 <= statusCode && statusCode <= 299: 52 | perfMetric.Success(labelValues...) 53 | case 500 <= statusCode && statusCode <= 599: 54 | perfMetric.ServerError(labelValues...) 55 | case 400 <= statusCode && statusCode <= 499: 56 | perfMetric.ClientError(labelValues...) 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /middleware/metrics_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "testing" 7 | 8 | "github.com/gin-gonic/gin" 9 | "github.com/prometheus/client_golang/prometheus" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestMetricsMiddleware(t *testing.T) { 14 | r := prometheus.NewRegistry() 15 | router := gin.New() 16 | router.Use(MetricsMiddleware("", nil, r)) 17 | 18 | successGroup := router.Group("/success") 19 | successGroup.GET("/:test", func(c *gin.Context) { 20 | c.JSON(http.StatusOK, struct{}{}) 21 | }) 22 | 23 | successGroup.GET("", func(c *gin.Context) { 24 | c.JSON(http.StatusOK, struct{}{}) 25 | }) 26 | 27 | router.GET("/error", func(c *gin.Context) { 28 | _ = c.AbortWithError(http.StatusInternalServerError, errors.New("oops error")) 29 | }) 30 | router.GET("/404", func(c *gin.Context) { 31 | _ = c.AbortWithError(http.StatusNotFound, errors.New("404")) 32 | }) 33 | 34 | // 2 successes, 1 errors 35 | _ = performRequest("GET", "/success?haha=1&hoho=2", router) 36 | _ = performRequest("GET", "/error?hehe=1&huhu=3", router) 37 | _ = performRequest("GET", "/success/hihi", router) 38 | _ = performRequest("GET", "/404", router) 39 | 40 | metricFamilies, err := r.Gather() 41 | require.NoError(t, err) 42 | const ( 43 | requestSucceededTotalKey = "request_succeeded_total" 44 | requestClientErrTotalKey = "request_client_error_total" 45 | requestServerErrTotalKey = "request_server_error_total" 46 | ) 47 | // metricFamily.Name --> label --> counter value 48 | expected := map[string]map[string]int{ 49 | requestSucceededTotalKey: { 50 | "/success": 1, 51 | "/success/:test": 1, 52 | "/error": 0, 53 | "/404": 0, 54 | }, 55 | requestServerErrTotalKey: { 56 | "/success": 0, 57 | "/success/:test": 0, 58 | "/error": 1, 59 | "/404": 0, 60 | }, 61 | requestClientErrTotalKey: { 62 | "/success": 0, 63 | "/success/:test": 0, 64 | "/error": 0, 65 | "/404": 1, 66 | }, 67 | } 68 | for _, metricFamily := range metricFamilies { 69 | expectedLabelCounterMap, ok := expected[*metricFamily.Name] 70 | if !ok { 71 | continue 72 | } 73 | require.Len(t, metricFamily.Metric, len(expectedLabelCounterMap)) 74 | for _, metric := range metricFamily.Metric { 75 | require.Len(t, metric.Label, 3) 76 | labelIndexes := map[string]int{ 77 | labelMethod: -1, 78 | labelPath: -1, 79 | labelStatus: -1, 80 | } 81 | for idx, label := range metric.Label { 82 | labelIndexes[*label.Name] = idx 83 | } 84 | require.Equal(t, len(labelIndexes), 3) 85 | for _, labelIdx := range labelIndexes { 86 | require.NotEqual(t, -1, labelIdx) 87 | } 88 | pathIdx := labelIndexes[labelPath] 89 | path := *metric.Label[pathIdx].Value 90 | expectedPathMetric := float64(expectedLabelCounterMap[path]) 91 | require.Equal(t, expectedPathMetric, *metric.Counter.Value) 92 | } 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /middleware/sentry.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "strconv" 8 | 9 | "github.com/evalphobia/logrus_sentry" 10 | "github.com/getsentry/raven-go" 11 | log "github.com/sirupsen/logrus" 12 | ) 13 | 14 | type SentryOption func(hook *logrus_sentry.SentryHook) error 15 | type SentryCondition func(res *http.Response, url string) bool 16 | 17 | func SetupSentry(dsn string, opts ...SentryOption) error { 18 | hook, err := logrus_sentry.NewSentryHook(dsn, []log.Level{ 19 | log.PanicLevel, 20 | log.FatalLevel, 21 | log.ErrorLevel, 22 | }) 23 | if err != nil { 24 | return err 25 | } 26 | hook.Timeout = 0 27 | hook.StacktraceConfiguration.Enable = true 28 | hook.StacktraceConfiguration.IncludeErrorBreadcrumb = true 29 | hook.StacktraceConfiguration.Context = 10 30 | hook.StacktraceConfiguration.SendExceptionType = true 31 | hook.StacktraceConfiguration.SwitchExceptionTypeAndMessage = true 32 | 33 | for _, o := range opts { 34 | err = o(hook) 35 | if err != nil { 36 | return err 37 | } 38 | } 39 | 40 | log.AddHook(hook) 41 | return nil 42 | } 43 | 44 | func WithDefaultLoggerName(name string) SentryOption { 45 | return func(hook *logrus_sentry.SentryHook) error { 46 | hook.SetDefaultLoggerName(name) 47 | return nil 48 | } 49 | } 50 | 51 | func WithEnvironment(env string) SentryOption { 52 | return func(hook *logrus_sentry.SentryHook) error { 53 | hook.SetEnvironment(env) 54 | return nil 55 | } 56 | } 57 | 58 | func WithHttpContext(h *raven.Http) SentryOption { 59 | return func(hook *logrus_sentry.SentryHook) error { 60 | hook.SetHttpContext(h) 61 | return nil 62 | } 63 | } 64 | 65 | func WithIgnoreErrors(errs ...string) SentryOption { 66 | return func(hook *logrus_sentry.SentryHook) error { 67 | return hook.SetIgnoreErrors(errs...) 68 | } 69 | } 70 | 71 | func WithIncludePaths(p []string) SentryOption { 72 | return func(hook *logrus_sentry.SentryHook) error { 73 | hook.SetIncludePaths(p) 74 | return nil 75 | } 76 | } 77 | 78 | func WithRelease(release string) SentryOption { 79 | return func(hook *logrus_sentry.SentryHook) error { 80 | hook.SetRelease(release) 81 | return nil 82 | } 83 | } 84 | 85 | func WithSampleRate(rate float32) SentryOption { 86 | return func(hook *logrus_sentry.SentryHook) error { 87 | return hook.SetSampleRate(rate) 88 | } 89 | } 90 | 91 | func WithTagsContext(t map[string]string) SentryOption { 92 | return func(hook *logrus_sentry.SentryHook) error { 93 | hook.SetTagsContext(t) 94 | return nil 95 | } 96 | } 97 | 98 | func WithUserContext(u *raven.User) SentryOption { 99 | return func(hook *logrus_sentry.SentryHook) error { 100 | hook.SetUserContext(u) 101 | return nil 102 | } 103 | } 104 | 105 | func WithServerName(serverName string) SentryOption { 106 | return func(hook *logrus_sentry.SentryHook) error { 107 | hook.SetServerName(serverName) 108 | return nil 109 | } 110 | } 111 | 112 | var SentryErrorHandler = func(res *http.Response, url string) error { 113 | statusCode := res.StatusCode 114 | // Improve ways to identify if worth logging the error 115 | if statusCode != http.StatusOK && statusCode != http.StatusNotFound { 116 | log.WithFields(log.Fields{ 117 | "tags": raven.Tags{ 118 | {Key: "status_code", Value: strconv.Itoa(res.StatusCode)}, 119 | {Key: "host", Value: res.Request.URL.Host}, 120 | {Key: "path", Value: res.Request.URL.Path}, 121 | {Key: "body", Value: getBody(res)}, 122 | }, 123 | "url": url, 124 | "fingerprint": []string{"client_errors"}, 125 | }).Error("Client Errors") 126 | } 127 | 128 | return nil 129 | } 130 | 131 | // GetSentryErrorHandler initializes sentry logger for http response errors 132 | // Responses to be logged are defined via passed conditions 133 | func GetSentryErrorHandler(conditions ...SentryCondition) func(res *http.Response, url string) error { 134 | return func(res *http.Response, url string) error { 135 | for _, condition := range conditions { 136 | if condition(res, url) { 137 | log.WithFields(log.Fields{ 138 | "tags": raven.Tags{ 139 | {Key: "status_code", Value: strconv.Itoa(res.StatusCode)}, 140 | {Key: "host", Value: res.Request.URL.Host}, 141 | {Key: "path", Value: res.Request.URL.Path}, 142 | {Key: "body", Value: getBody(res)}, 143 | }, 144 | "url": url, 145 | "fingerprint": []string{"client_errors"}, 146 | }).Error("Client Errors") 147 | 148 | break 149 | } 150 | } 151 | 152 | return nil 153 | } 154 | } 155 | 156 | func getBody(res *http.Response) string { 157 | bodyBytes, _ := io.ReadAll(res.Body) 158 | _ = res.Body.Close() // must close 159 | res.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) 160 | 161 | return string(bodyBytes) 162 | } 163 | 164 | var ( 165 | // SentryConditionAnd returns true only when all conditions are satisfied 166 | SentryConditionAnd = func(conditions ...SentryCondition) SentryCondition { 167 | return func(res *http.Response, url string) bool { 168 | result := true 169 | for _, condition := range conditions { 170 | if !condition(res, url) { 171 | result = false 172 | break 173 | } 174 | } 175 | 176 | return result 177 | } 178 | } 179 | 180 | // SentryConditionOr return true when any of conditions is satisfied 181 | SentryConditionOr = func(conditions ...SentryCondition) SentryCondition { 182 | return func(res *http.Response, url string) bool { 183 | for _, condition := range conditions { 184 | if condition(res, url) { 185 | return true 186 | } 187 | } 188 | 189 | return false 190 | } 191 | } 192 | 193 | SentryConditionNotStatusOk = func(res *http.Response, _ string) bool { 194 | return res.StatusCode < 200 || res.StatusCode > 299 195 | } 196 | 197 | SentryConditionNotStatusBadRequest = func(res *http.Response, _ string) bool { 198 | return res.StatusCode != http.StatusBadRequest 199 | } 200 | 201 | SentryConditionNotStatusNotFound = func(res *http.Response, _ string) bool { 202 | return res.StatusCode != http.StatusNotFound 203 | } 204 | ) 205 | -------------------------------------------------------------------------------- /middleware/sentry_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestSentryConditionAnd(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | conditions []SentryCondition 14 | expected bool 15 | }{ 16 | { 17 | name: "all conditions satisfied", 18 | conditions: []SentryCondition{ 19 | func(res *http.Response, url string) bool { 20 | return true 21 | }, 22 | func(res *http.Response, url string) bool { 23 | return true 24 | }, 25 | }, 26 | expected: true, 27 | }, 28 | { 29 | name: "all conditions unsatisfied", 30 | conditions: []SentryCondition{ 31 | func(res *http.Response, url string) bool { 32 | return false 33 | }, 34 | func(res *http.Response, url string) bool { 35 | return false 36 | }, 37 | }, 38 | expected: false, 39 | }, 40 | { 41 | name: "first of two conditions is satisfied", 42 | conditions: []SentryCondition{ 43 | func(res *http.Response, url string) bool { 44 | return true 45 | }, 46 | func(res *http.Response, url string) bool { 47 | return false 48 | }, 49 | }, 50 | expected: false, 51 | }, 52 | { 53 | name: "second of two conditions is satisfied", 54 | conditions: []SentryCondition{ 55 | func(res *http.Response, url string) bool { 56 | return false 57 | }, 58 | func(res *http.Response, url string) bool { 59 | return true 60 | }, 61 | }, 62 | expected: false, 63 | }, 64 | } 65 | 66 | for _, tc := range tests { 67 | t.Run(tc.name, func(t *testing.T) { 68 | condition := SentryConditionAnd(tc.conditions...) 69 | actual := condition(nil, "") 70 | 71 | assert.Equal(t, tc.expected, actual) 72 | }) 73 | } 74 | } 75 | 76 | func TestSentryConditionOr(t *testing.T) { 77 | tests := []struct { 78 | name string 79 | conditions []SentryCondition 80 | expected bool 81 | }{ 82 | { 83 | name: "two out of two are satisfied", 84 | conditions: []SentryCondition{ 85 | func(res *http.Response, url string) bool { 86 | return true 87 | }, 88 | func(res *http.Response, url string) bool { 89 | return true 90 | }, 91 | }, 92 | expected: true, 93 | }, 94 | { 95 | name: "first out of two is satisfied", 96 | conditions: []SentryCondition{ 97 | func(res *http.Response, url string) bool { 98 | return true 99 | }, 100 | func(res *http.Response, url string) bool { 101 | return false 102 | }, 103 | }, 104 | expected: true, 105 | }, 106 | { 107 | name: "second out of two is satisfied", 108 | conditions: []SentryCondition{ 109 | func(res *http.Response, url string) bool { 110 | return false 111 | }, 112 | func(res *http.Response, url string) bool { 113 | return true 114 | }, 115 | }, 116 | expected: true, 117 | }, 118 | { 119 | name: "none of conditions is satisfied", 120 | conditions: []SentryCondition{ 121 | func(res *http.Response, url string) bool { 122 | return false 123 | }, 124 | func(res *http.Response, url string) bool { 125 | return false 126 | }, 127 | }, 128 | expected: false, 129 | }, 130 | } 131 | 132 | for _, tc := range tests { 133 | t.Run(tc.name, func(t *testing.T) { 134 | condition := SentryConditionOr(tc.conditions...) 135 | actual := condition(nil, "") 136 | 137 | assert.Equal(t, tc.expected, actual) 138 | }) 139 | } 140 | } 141 | 142 | func TestSentryConditionNotStatusOk(t *testing.T) { 143 | tests := []struct { 144 | name string 145 | resp *http.Response 146 | expected bool 147 | }{ 148 | { 149 | name: "response code is below 200", 150 | resp: &http.Response{StatusCode: 100}, 151 | expected: true, 152 | }, 153 | { 154 | name: "response code is 200", 155 | resp: &http.Response{StatusCode: 200}, 156 | expected: false, 157 | }, 158 | { 159 | name: "response code is between 200 and 300", 160 | resp: &http.Response{StatusCode: 201}, 161 | expected: false, 162 | }, 163 | { 164 | name: "response code is between 300", 165 | resp: &http.Response{StatusCode: 300}, 166 | expected: true, 167 | }, 168 | { 169 | name: "response code is above 300", 170 | resp: &http.Response{StatusCode: 303}, 171 | expected: true, 172 | }, 173 | } 174 | 175 | for _, tc := range tests { 176 | t.Run(tc.name, func(t *testing.T) { 177 | actual := SentryConditionNotStatusOk(tc.resp, "") 178 | assert.Equal(t, tc.expected, actual) 179 | }) 180 | } 181 | } 182 | 183 | func TestSentryConditionNotStatusBadRequest(t *testing.T) { 184 | tests := []struct { 185 | name string 186 | resp *http.Response 187 | expected bool 188 | }{ 189 | { 190 | name: "response code is bad request", 191 | resp: &http.Response{StatusCode: http.StatusBadRequest}, 192 | expected: false, 193 | }, 194 | { 195 | name: "response code is not bad request", 196 | resp: &http.Response{StatusCode: http.StatusOK}, 197 | expected: true, 198 | }, 199 | } 200 | 201 | for _, tc := range tests { 202 | t.Run(tc.name, func(t *testing.T) { 203 | actual := SentryConditionNotStatusBadRequest(tc.resp, "") 204 | assert.Equal(t, tc.expected, actual) 205 | }) 206 | } 207 | } 208 | 209 | func TestSentryConditionNotStatusNotFound(t *testing.T) { 210 | tests := []struct { 211 | name string 212 | resp *http.Response 213 | expected bool 214 | }{ 215 | { 216 | name: "response code is not found", 217 | resp: &http.Response{StatusCode: http.StatusNotFound}, 218 | expected: false, 219 | }, 220 | { 221 | name: "response code is not \"not found\"", 222 | resp: &http.Response{StatusCode: http.StatusBadRequest}, 223 | expected: true, 224 | }, 225 | } 226 | 227 | for _, tc := range tests { 228 | t.Run(tc.name, func(t *testing.T) { 229 | actual := SentryConditionNotStatusNotFound(tc.resp, "") 230 | assert.Equal(t, tc.expected, actual) 231 | }) 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /middleware/shutdown.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "os" 5 | "os/signal" 6 | "syscall" 7 | "time" 8 | 9 | log "github.com/sirupsen/logrus" 10 | ) 11 | 12 | func SetupGracefulShutdown(timeout time.Duration) { 13 | quit := make(chan os.Signal, 1) 14 | signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) 15 | <-quit 16 | log.Info("Shutdown timeout: ...", timeout) 17 | time.Sleep(timeout) 18 | log.Info("Exiting gracefully") 19 | } 20 | -------------------------------------------------------------------------------- /mock/mock.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "net/http" 7 | "os" 8 | ) 9 | 10 | func JsonModelFromFilePath(file string, intoStruct interface{}) error { 11 | jsonFile, err := os.Open(file) 12 | if err != nil { 13 | return err 14 | } 15 | defer jsonFile.Close() 16 | 17 | byteValue, err := io.ReadAll(jsonFile) 18 | if err != nil { 19 | return err 20 | } 21 | err = json.Unmarshal(byteValue, &intoStruct) 22 | if err != nil { 23 | return err 24 | } 25 | return nil 26 | } 27 | 28 | func JsonStringFromFilePath(file string) (string, error) { 29 | jsonFile, err := os.Open(file) 30 | if err != nil { 31 | return "", err 32 | } 33 | defer jsonFile.Close() 34 | 35 | byteValue, err := io.ReadAll(jsonFile) 36 | if err != nil { 37 | return "", err 38 | } 39 | 40 | return string(byteValue), nil 41 | } 42 | 43 | func CreateMockedAPI(funcsMap map[string]func(http.ResponseWriter, *http.Request)) http.Handler { 44 | r := http.NewServeMux() 45 | for pattern, f := range funcsMap { 46 | r.HandleFunc(pattern, f) 47 | } 48 | return r 49 | } 50 | -------------------------------------------------------------------------------- /mock/mock_test.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | 12 | "github.com/trustwallet/go-libs/client" 13 | ) 14 | 15 | type response struct { 16 | Status bool 17 | } 18 | 19 | func TestCreateMockedAPI(t *testing.T) { 20 | 21 | data := make(map[string]func(http.ResponseWriter, *http.Request)) 22 | data["/1"] = func(w http.ResponseWriter, r *http.Request) { 23 | w.WriteHeader(http.StatusOK) 24 | if _, err := fmt.Fprint(w, `{"status": true}`); err != nil { 25 | panic(err) 26 | } 27 | } 28 | 29 | server := httptest.NewServer(CreateMockedAPI(data)) 30 | defer server.Close() 31 | cli := client.InitClient(server.URL, nil) 32 | 33 | var resp response 34 | _, err := cli.Execute(context.TODO(), client.NewReqBuilder(). 35 | Method(http.MethodGet). 36 | PathStatic("1"). 37 | WriteTo(&resp). 38 | Build()) 39 | 40 | assert.Nil(t, err) 41 | assert.True(t, resp.Status) 42 | } 43 | 44 | func TestParseJsonFromFilePath(t *testing.T) { 45 | var s response 46 | err := JsonModelFromFilePath("test.json", &s) 47 | 48 | assert.Nil(t, err) 49 | assert.True(t, s.Status) 50 | } 51 | 52 | func TestJsonStringFromFilePath(t *testing.T) { 53 | data, err := JsonStringFromFilePath("test.json") 54 | assert.Nil(t, err) 55 | assert.Equal(t, `{ 56 | "status": true 57 | }`, data) 58 | } 59 | -------------------------------------------------------------------------------- /mock/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "status": true 3 | } -------------------------------------------------------------------------------- /mq/consumer.go: -------------------------------------------------------------------------------- 1 | package mq 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/trustwallet/go-libs/metrics" 9 | "github.com/trustwallet/go-libs/pkg/nullable" 10 | 11 | log "github.com/sirupsen/logrus" 12 | "github.com/streadway/amqp" 13 | ) 14 | 15 | const headerRemainingRetries = "x-remaining-retries" 16 | 17 | type consumer struct { 18 | client *Client 19 | 20 | queue Queue 21 | messageProcessor MessageProcessor 22 | options *ConsumerOptions 23 | 24 | messages <-chan amqp.Delivery 25 | stopChan chan struct{} 26 | } 27 | 28 | type Consumer interface { 29 | Start(ctx context.Context) error 30 | Reconnect(ctx context.Context) error 31 | HealthCheck() error 32 | } 33 | 34 | type MessageProcessor interface { 35 | Process(Message) error 36 | } 37 | 38 | // MessageProcessorFunc is an adapter to allow to use 39 | // an ordinary functions as mq MessageProcessor. 40 | type MessageProcessorFunc func(message Message) error 41 | 42 | func (f MessageProcessorFunc) Process(m Message) error { 43 | return f(m) 44 | } 45 | 46 | func (c *consumer) Start(ctx context.Context) error { 47 | c.stopChan = make(chan struct{}) 48 | 49 | var err error 50 | c.messages, err = c.messageChannel() 51 | if err != nil { 52 | return fmt.Errorf("get message channel: %v", err) 53 | } 54 | for w := 1; w <= c.options.Workers; w++ { 55 | go c.consume(ctx) 56 | } 57 | 58 | log.Infof("Started %d MQ consumer workers for queue %s", c.options.Workers, c.queue.Name()) 59 | 60 | return nil 61 | } 62 | 63 | func (c *consumer) Reconnect(ctx context.Context) error { 64 | c.messages = nil 65 | if c.stopChan != nil { 66 | close(c.stopChan) 67 | } 68 | 69 | err := c.Start(ctx) 70 | if err != nil { 71 | return err 72 | } 73 | 74 | return nil 75 | } 76 | 77 | func (c *consumer) consume(ctx context.Context) { 78 | queueName := string(c.queue.Name()) 79 | 80 | for { 81 | select { 82 | case <-ctx.Done(): 83 | log.Infof("Finished consuming queue %s", queueName) 84 | return 85 | case <-c.stopChan: 86 | log.Infof("Force stopped consuming queue %s", queueName) 87 | return 88 | case msg := <-c.messages: 89 | if msg.Body == nil { 90 | continue 91 | } 92 | 93 | err := c.process(queueName, msg.Body) 94 | if err != nil { 95 | log.Error(err) 96 | } 97 | 98 | if err != nil && c.options.RetryOnError { 99 | time.Sleep(c.options.RetryDelay) 100 | remainingRetries := c.getRemainingRetries(msg) 101 | 102 | switch { 103 | case remainingRetries > 0: 104 | if err := c.queue.PublishWithConfig(msg.Body, PublishConfig{ 105 | MaxRetries: nullable.Int(int(remainingRetries - 1)), 106 | }); err != nil { 107 | log.Error(err) 108 | } 109 | case remainingRetries == 0: 110 | break 111 | default: 112 | if err := msg.Reject(true); err != nil { 113 | log.Error(err) 114 | } 115 | continue 116 | } 117 | } 118 | 119 | if err := msg.Ack(false); err != nil { 120 | log.Error(err) 121 | } 122 | } 123 | } 124 | } 125 | 126 | func (c *consumer) process(queueName string, body []byte) error { 127 | metric := c.options.PerformanceMetric 128 | if metric == nil { 129 | metric = &metrics.NullablePerformanceMetric{} 130 | } 131 | 132 | defer metric.Duration(metric.Start()) 133 | err := c.messageProcessor.Process(body) 134 | 135 | if err != nil { 136 | metric.Failure() 137 | } else { 138 | metric.Success() 139 | } 140 | 141 | return err 142 | } 143 | 144 | // messageChannel will create a new dedicated channel for this consumer to use 145 | func (c *consumer) messageChannel() (<-chan amqp.Delivery, error) { 146 | mqChan, err := c.client.conn.Channel() 147 | if err != nil { 148 | return nil, fmt.Errorf("MQ issue. queue: %s, err: %w", string(c.queue.Name()), err) 149 | } 150 | 151 | err = mqChan.Qos(c.getSanitizedPrefetchCount(), 0, true) 152 | if err != nil { 153 | return nil, fmt.Errorf("MQ issue. queue: %s, err: %w", string(c.queue.Name()), err) 154 | } 155 | 156 | messageChannel, err := mqChan.Consume( 157 | string(c.queue.Name()), 158 | "", 159 | false, 160 | false, 161 | false, 162 | false, 163 | nil, 164 | ) 165 | if err != nil { 166 | return nil, fmt.Errorf("MQ issue" + err.Error() + " for queue: " + string(c.queue.Name())) 167 | } 168 | 169 | return messageChannel, nil 170 | } 171 | 172 | func (c *consumer) getSanitizedPrefetchCount() int { 173 | if c.options.Prefetch < c.options.Workers { 174 | return c.options.Workers 175 | } 176 | 177 | return c.options.Prefetch 178 | } 179 | 180 | func (c *consumer) getRemainingRetries(delivery amqp.Delivery) int32 { 181 | remainingRetriesRaw, exists := delivery.Headers[headerRemainingRetries] 182 | if !exists { 183 | return int32(c.options.MaxRetries) 184 | } 185 | 186 | remainingRetries, ok := remainingRetriesRaw.(int32) 187 | if !ok { 188 | return int32(c.options.MaxRetries) 189 | } 190 | 191 | return remainingRetries 192 | } 193 | 194 | func (c *consumer) HealthCheck() error { 195 | if err := c.client.HealthCheck(); err != nil { 196 | return fmt.Errorf("client health check: %v", err) 197 | } 198 | 199 | return nil 200 | } 201 | -------------------------------------------------------------------------------- /mq/exchange.go: -------------------------------------------------------------------------------- 1 | package mq 2 | 3 | import "fmt" 4 | 5 | type exchange struct { 6 | name ExchangeName 7 | client *Client 8 | } 9 | 10 | type Exchange interface { 11 | Declare(kind string) error 12 | Bind(queues []Queue) error 13 | BindWithKey(queues []Queue, key ExchangeKey) error 14 | Publish(body []byte) error 15 | PublishWithKey(body []byte, key ExchangeKey) error 16 | } 17 | 18 | func (e *exchange) Declare(kind string) error { 19 | return e.client.amqpChan.ExchangeDeclare(string(e.name), kind, true, false, false, false, nil) 20 | } 21 | 22 | func (e *exchange) Bind(queues []Queue) error { 23 | for _, q := range queues { 24 | err := e.client.amqpChan.QueueBind(string(q.Name()), "", string(e.name), false, nil) 25 | if err != nil { 26 | return err 27 | } 28 | } 29 | 30 | return nil 31 | } 32 | 33 | func (e *exchange) BindWithKey(queues []Queue, key ExchangeKey) error { 34 | for _, q := range queues { 35 | err := e.client.amqpChan.QueueBind(string(q.Name()), string(key), string(e.name), false, nil) 36 | if err != nil { 37 | return err 38 | } 39 | } 40 | 41 | return nil 42 | } 43 | 44 | func (e *exchange) Publish(body []byte) error { 45 | return publish(e.client.amqpChan, e.name, "", body) 46 | } 47 | 48 | func (e *exchange) PublishWithKey(body []byte, key ExchangeKey) error { 49 | return publish(e.client.amqpChan, e.name, key, body) 50 | } 51 | 52 | func (e *exchange) HealthCheck() error { 53 | if err := e.client.HealthCheck(); err != nil { 54 | return fmt.Errorf("client health check: %v", err) 55 | } 56 | 57 | return nil 58 | } 59 | -------------------------------------------------------------------------------- /mq/options.go: -------------------------------------------------------------------------------- 1 | package mq 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/trustwallet/go-libs/metrics" 7 | ) 8 | 9 | type ConsumerOptions struct { 10 | Workers int 11 | Prefetch int 12 | RetryOnError bool 13 | RetryDelay time.Duration 14 | PerformanceMetric metrics.PerformanceMetric 15 | 16 | // MaxRetries specifies the default number of retries for consuming a message. 17 | // A negative value is equal to infinite retries. 18 | MaxRetries int 19 | } 20 | 21 | func DefaultConsumerOptions(workers int) *ConsumerOptions { 22 | return &ConsumerOptions{ 23 | Workers: workers, 24 | Prefetch: 2 * workers, 25 | RetryOnError: true, 26 | RetryDelay: time.Second, 27 | MaxRetries: -1, 28 | PerformanceMetric: &metrics.NullablePerformanceMetric{}, 29 | } 30 | } 31 | 32 | // Deprecated: We should not put prefetch limit at channel level. We need to set limit at consumer level 33 | // This option no longer works to limit QoS globally. 34 | // 35 | // From rabbitMQ doc https://www.rabbitmq.com/consumer-prefetch.html 36 | // Unfortunately the channel is not the ideal scope for this - since a single channel may consume from multiple queues, 37 | // the channel and the queue(s) need to coordinate with each other for every message sent to ensure they don't go over 38 | // the limit. This is slow on a single machine, and very slow when consuming across a cluster. 39 | func OptionPrefetchLimit(limit int) Option { 40 | return func(m *Client) error { 41 | err := m.amqpChan.Qos( 42 | limit, 43 | 0, 44 | true, 45 | ) 46 | if err != nil { 47 | return err 48 | } 49 | 50 | return nil 51 | } 52 | } 53 | 54 | func OptionConnCheckTimeout(timeout time.Duration) Option { 55 | return func(m *Client) error { 56 | m.connCheckTimeout = timeout 57 | return nil 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /mq/queue.go: -------------------------------------------------------------------------------- 1 | package mq 2 | 3 | import "fmt" 4 | 5 | type queue struct { 6 | name QueueName 7 | client *Client 8 | } 9 | 10 | type Queue interface { 11 | Declare() error 12 | DeclareWithConfig(cfg DeclareConfig) error 13 | Publish(body []byte) error 14 | PublishWithConfig(body []byte, cfg PublishConfig) error 15 | Name() QueueName 16 | } 17 | 18 | func (q *queue) Name() QueueName { 19 | return q.name 20 | } 21 | 22 | func (q *queue) Declare() error { 23 | return q.DeclareWithConfig(DeclareConfig{Durable: true}) 24 | } 25 | 26 | func (q *queue) DeclareWithConfig(cfg DeclareConfig) error { 27 | _, err := q.client.amqpChan.QueueDeclare( 28 | string(q.name), 29 | cfg.Durable, 30 | cfg.AutoDelete, 31 | cfg.Exclusive, 32 | cfg.NoWait, 33 | cfg.Args, 34 | ) 35 | return err 36 | } 37 | 38 | func (q *queue) Publish(body []byte) error { 39 | return publish(q.client.amqpChan, "", ExchangeKey(q.name), body) 40 | } 41 | 42 | func (q *queue) PublishWithConfig(body []byte, cfg PublishConfig) error { 43 | return publishWithConfig(q.client.amqpChan, "", ExchangeKey(q.name), body, cfg) 44 | } 45 | 46 | func (q *queue) HealthCheck() error { 47 | if err := q.client.HealthCheck(); err != nil { 48 | return fmt.Errorf("client health check: %v", err) 49 | } 50 | 51 | return nil 52 | } 53 | 54 | type DeclareConfig struct { 55 | Durable bool 56 | AutoDelete bool 57 | Exclusive bool 58 | NoWait bool 59 | Args map[string]interface{} 60 | } 61 | 62 | type DeliveryMode uint8 63 | 64 | const ( 65 | DeliveryModeTransient DeliveryMode = 1 66 | DeliveryModePersistent DeliveryMode = 2 67 | ) 68 | 69 | type PublishConfig struct { 70 | // MaxRetries defines the maximum number of retries after processing failures. 71 | // Overrides the value of consumer's config. 72 | MaxRetries *int 73 | DeliveryMode DeliveryMode 74 | } 75 | -------------------------------------------------------------------------------- /pkg/nullable/primitives.go: -------------------------------------------------------------------------------- 1 | package nullable 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func String(s string) *string { 8 | return &s 9 | } 10 | 11 | func Stringf(s string, args ...interface{}) *string { 12 | s = fmt.Sprintf(s, args...) 13 | return &s 14 | } 15 | 16 | func Int(i int) *int { 17 | return &i 18 | } 19 | 20 | func Int8(i int8) *int8 { 21 | return &i 22 | } 23 | 24 | func Int16(i int16) *int16 { 25 | return &i 26 | } 27 | 28 | func Int32(i int32) *int32 { 29 | return &i 30 | } 31 | 32 | func Int64(i int64) *int64 { 33 | return &i 34 | } 35 | 36 | func Uint(i uint) *uint { 37 | return &i 38 | } 39 | 40 | func Uint8(i uint8) *uint8 { 41 | return &i 42 | } 43 | 44 | func Uint16(i uint16) *uint16 { 45 | return &i 46 | } 47 | 48 | func Uint32(i uint32) *uint32 { 49 | return &i 50 | } 51 | 52 | func Uint64(i uint64) *uint64 { 53 | return &i 54 | } 55 | 56 | func Float32(f float32) *float32 { 57 | return &f 58 | } 59 | 60 | func Float64(f float64) *float64 { 61 | return &f 62 | } 63 | 64 | func Bool(b bool) *bool { 65 | return &b 66 | } 67 | -------------------------------------------------------------------------------- /pkg/nullable/time.go: -------------------------------------------------------------------------------- 1 | package nullable 2 | 3 | import "time" 4 | 5 | func Time(t time.Time) *time.Time { 6 | return &t 7 | } 8 | -------------------------------------------------------------------------------- /set/ordered.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | type OrderedSet[T comparable] struct { 4 | valuesSet map[T]struct{} 5 | values []T 6 | } 7 | 8 | func NewOrderedSet[T comparable]() *OrderedSet[T] { 9 | return &OrderedSet[T]{ 10 | valuesSet: make(map[T]struct{}), 11 | values: make([]T, 0), 12 | } 13 | } 14 | 15 | func (u *OrderedSet[T]) Add(val T) { 16 | if _, exists := u.valuesSet[val]; !exists { 17 | u.valuesSet[val] = struct{}{} 18 | u.values = append(u.values, val) 19 | } 20 | } 21 | 22 | func (u *OrderedSet[T]) Contains(val T) bool { 23 | _, contains := u.valuesSet[val] 24 | return contains 25 | } 26 | 27 | func (u *OrderedSet[T]) Values() []T { 28 | return u.values 29 | } 30 | 31 | func (u *OrderedSet[T]) Size() int { 32 | return len(u.valuesSet) 33 | } 34 | 35 | // ValueAt assumes the provided idx is inside bounds 36 | func (u *OrderedSet[T]) ValueAt(idx int) T { 37 | return u.values[idx] 38 | } 39 | -------------------------------------------------------------------------------- /set/ordered_test.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestContains(t *testing.T) { 10 | t.Run("int", func(t *testing.T) { 11 | u := NewOrderedSet[int]() 12 | u.Add(1) 13 | u.Add(5) 14 | u.Add(5) 15 | u.Add(8) 16 | assert.Equal(t, []int{1, 5, 8}, u.Values()) 17 | 18 | assert.Equal(t, 3, u.Size()) 19 | 20 | assert.Equal(t, 1, u.ValueAt(0)) 21 | assert.Equal(t, 5, u.ValueAt(1)) 22 | assert.Equal(t, 8, u.ValueAt(2)) 23 | }) 24 | 25 | t.Run("str", func(t *testing.T) { 26 | u := NewOrderedSet[string]() 27 | u.Add("foo") 28 | u.Add("bar") 29 | u.Add("baz") 30 | u.Add("foo") 31 | assert.Equal(t, []string{"foo", "bar", "baz"}, u.Values()) 32 | 33 | assert.Equal(t, 3, u.Size()) 34 | 35 | assert.Equal(t, "foo", u.ValueAt(0)) 36 | assert.Equal(t, "bar", u.ValueAt(1)) 37 | assert.Equal(t, "baz", u.ValueAt(2)) 38 | }) 39 | } 40 | -------------------------------------------------------------------------------- /set/set.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | import ( 4 | "encoding/json" 5 | ) 6 | 7 | type Set[T comparable] struct { 8 | values map[T]struct{} 9 | } 10 | 11 | func New[T comparable]() *Set[T] { 12 | return &Set[T]{ 13 | values: make(map[T]struct{}), 14 | } 15 | } 16 | 17 | func NewFromValues[T comparable](values ...T) *Set[T] { 18 | s := New[T]() 19 | s.Add(values...) 20 | return s 21 | } 22 | 23 | func (s *Set[T]) Add(values ...T) { 24 | for _, val := range values { 25 | s.values[val] = struct{}{} 26 | } 27 | } 28 | 29 | func (s *Set[T]) Clear() { 30 | s.values = make(map[T]struct{}) 31 | } 32 | 33 | func (s *Set[T]) Remove(val T) { 34 | delete(s.values, val) 35 | } 36 | 37 | func (s *Set[T]) Contains(val T) bool { 38 | _, contains := s.values[val] 39 | return contains 40 | } 41 | 42 | func (s *Set[T]) ContainsAny(values ...T) bool { 43 | for _, val := range values { 44 | if s.Contains(val) { 45 | return true 46 | } 47 | } 48 | 49 | return false 50 | } 51 | 52 | func (s *Set[T]) ContainsAll(values ...T) bool { 53 | for _, val := range values { 54 | if !s.Contains(val) { 55 | return false 56 | } 57 | } 58 | 59 | return true 60 | } 61 | 62 | func (s *Set[T]) Extend(s2 *Set[T]) { 63 | for v := range s2.Values() { 64 | s.Add(v) 65 | } 66 | } 67 | 68 | func (s *Set[T]) Size() int { 69 | return len(s.values) 70 | } 71 | 72 | func (s *Set[T]) Values() map[T]struct{} { 73 | return s.values 74 | } 75 | 76 | func (s *Set[T]) ToSlice() []T { 77 | sl := make([]T, 0, len(s.values)) 78 | 79 | for v := range s.values { 80 | sl = append(sl, v) 81 | } 82 | return sl 83 | } 84 | 85 | func (s *Set[T]) MarshalJSON() ([]byte, error) { 86 | return json.Marshal(s.ToSlice()) 87 | } 88 | 89 | func (s *Set[T]) UnmarshalJSON(data []byte) error { 90 | values := make([]T, 0) 91 | 92 | err := json.Unmarshal(data, &values) 93 | if err != nil { 94 | return err 95 | } 96 | 97 | s.Clear() 98 | for _, v := range values { 99 | s.Add(v) 100 | } 101 | 102 | return nil 103 | } 104 | -------------------------------------------------------------------------------- /set/set_test.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestSetFromValues(t *testing.T) { 11 | s := NewFromValues("foo", "bar") 12 | assert.Equal(t, 2, s.Size()) 13 | assert.True(t, s.Contains("foo")) 14 | assert.True(t, s.Contains("bar")) 15 | assert.False(t, s.Contains("baz")) 16 | } 17 | 18 | func TestSet(t *testing.T) { 19 | s := New[string]() 20 | assert.Empty(t, s.ToSlice()) 21 | 22 | s.Add("foo") 23 | s.Add("bar") 24 | assert.Equal(t, 2, s.Size()) 25 | 26 | s.Add("bar", "baz") 27 | assert.True(t, s.Contains("foo")) 28 | assert.True(t, s.Contains("bar")) 29 | assert.True(t, s.Contains("baz")) 30 | assert.Equal(t, 3, s.Size()) 31 | 32 | s.Remove("baz") 33 | assert.Equal(t, 2, s.Size()) 34 | 35 | s.Remove("foo") 36 | assert.Equal(t, 1, s.Size()) 37 | 38 | assert.False(t, s.Contains("foo")) 39 | assert.True(t, s.Contains("bar")) 40 | assert.Equal(t, []string{"bar"}, s.ToSlice()) 41 | 42 | s.Clear() 43 | assert.False(t, s.Contains("bar")) 44 | assert.Empty(t, s.ToSlice()) 45 | assert.Equal(t, 0, s.Size()) 46 | 47 | t.Run("unmarshal", func(t *testing.T) { 48 | mySet := New[string]() 49 | mySet.Add("whatever") 50 | 51 | values := []string{"a", "b", "foo", "c", "b"} 52 | jsb, _ := json.Marshal(values) 53 | assert.NoError(t, json.Unmarshal(jsb, &mySet)) 54 | assert.Equal(t, 4, mySet.Size()) 55 | for _, v := range values { 56 | assert.True(t, mySet.Contains(v)) 57 | } 58 | }) 59 | 60 | } 61 | -------------------------------------------------------------------------------- /slice/filter.go: -------------------------------------------------------------------------------- 1 | package slice 2 | 3 | // Filter returns a sub-slice with all the elements that satisfy the fn condition. 4 | func Filter[T any](s []T, fn func(T) bool) []T { 5 | filtered := make([]T, 0, len(s)) 6 | 7 | for _, el := range s { 8 | if fn(el) { 9 | filtered = append(filtered, el) 10 | } 11 | } 12 | 13 | return filtered 14 | } 15 | -------------------------------------------------------------------------------- /slice/filter_test.go: -------------------------------------------------------------------------------- 1 | package slice 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestFilter(t *testing.T) { 10 | t.Run("logic", func(t *testing.T) { 11 | assert.Equal(t, 12 | []string{"a", "c"}, 13 | Filter( 14 | []string{"a", "b", "c"}, 15 | func(s string) bool { return s != "b" }, 16 | )) 17 | assert.Equal(t, 18 | []int{1, 1}, 19 | Filter( 20 | []int{1, 10, 100, 1000, 100, 10, 1}, 21 | func(i int) bool { return i < 10 }, 22 | )) 23 | 24 | type item struct { 25 | price int 26 | condition string 27 | onDiscount bool 28 | } 29 | 30 | assert.Len(t, 31 | Filter( 32 | []item{ 33 | { 34 | price: 115, 35 | condition: "new", 36 | onDiscount: true, 37 | }, 38 | { 39 | price: 225, 40 | condition: "used", 41 | onDiscount: false, 42 | }, 43 | { 44 | price: 335, 45 | condition: "mint", 46 | onDiscount: true, 47 | }, 48 | }, 49 | func(i item) bool { return i.onDiscount && i.condition != "used" && i.price < 300 }, 50 | ), 1, 51 | ) 52 | }) 53 | } 54 | -------------------------------------------------------------------------------- /slice/partition.go: -------------------------------------------------------------------------------- 1 | package slice 2 | 3 | // Partition creates partitions of a standard maximum size. 4 | func Partition[T any](s []T, partitionSize int) [][]T { 5 | if len(s) == 0 || partitionSize <= 0 { 6 | return [][]T{} 7 | } 8 | 9 | partitions := make([][]T, 0, (len(s)+partitionSize-1)/partitionSize) 10 | 11 | for { 12 | left := len(partitions) * partitionSize 13 | if left >= len(s) { 14 | break 15 | } 16 | 17 | right := Min(left+partitionSize, len(s)) 18 | 19 | part := s[left:right] 20 | partition := make([]T, len(part)) 21 | copy(partition, part) 22 | 23 | partitions = append(partitions, partition) 24 | } 25 | 26 | return partitions 27 | } 28 | -------------------------------------------------------------------------------- /slice/partition_test.go: -------------------------------------------------------------------------------- 1 | package slice 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestPartition(t *testing.T) { 10 | t.Run("logic", func(t *testing.T) { 11 | testCasesStr := []struct { 12 | s []string 13 | exp [][]string 14 | pSize int 15 | }{ 16 | {s: []string{"a", "b", "c", "d"}, exp: [][]string{{"a", "b"}, {"c", "d"}}, pSize: 2}, 17 | {s: []string{"a", "b", "c", "d"}, exp: [][]string{{"a", "b", "c", "d"}}, pSize: 8}, 18 | {s: []string{"a", "b", "c", "d"}, exp: [][]string{{"a"}, {"b"}, {"c"}, {"d"}}, pSize: 1}, 19 | {s: []string{"a"}, exp: [][]string{{"a"}}, pSize: 1}, 20 | {s: []string{"a"}, exp: [][]string{{"a"}}, pSize: 100}, 21 | {s: []string{"a"}, exp: [][]string{}, pSize: 0}, 22 | {s: []string{"a", "b", "c", "d", "d", "da"}, exp: [][]string{{"a", "b", "c", "d", "d"}, {"da"}}, pSize: 5}, 23 | } 24 | 25 | for _, tc := range testCasesStr { 26 | act := Partition(tc.s, tc.pSize) 27 | assert.Equal(t, tc.exp, act) 28 | } 29 | }) 30 | 31 | t.Run("generics int", func(t *testing.T) { 32 | act := Partition([]int{10, 20, 30, 40, 50}, 3) 33 | assert.Equal(t, [][]int{{10, 20, 30}, {40, 50}}, act) 34 | }) 35 | } 36 | -------------------------------------------------------------------------------- /slice/search.go: -------------------------------------------------------------------------------- 1 | package slice 2 | 3 | import "golang.org/x/exp/constraints" 4 | 5 | // Contains returns true if the provided slice contains the target value. 6 | func Contains[T comparable](sl []T, val T) bool { 7 | for _, v := range sl { 8 | if v == val { 9 | return true 10 | } 11 | } 12 | 13 | return false 14 | } 15 | 16 | // ValueAt returns if exists values[idx] else the default value. 17 | func ValueAt[T any](idx int, values []T, defaultValue T) T { 18 | if (idx < 0) || (idx >= len(values)) { 19 | return defaultValue 20 | } 21 | 22 | return values[idx] 23 | } 24 | 25 | // Min returns the minimum of the provided values. 26 | func Min[T constraints.Ordered](values ...T) T { 27 | if len(values) == 0 { 28 | return *new(T) 29 | } 30 | 31 | res := values[0] 32 | for _, v := range values[1:] { 33 | if v < res { 34 | res = v 35 | } 36 | } 37 | 38 | return res 39 | } 40 | -------------------------------------------------------------------------------- /slice/search_test.go: -------------------------------------------------------------------------------- 1 | package slice 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestSliceContains(t *testing.T) { 10 | assert.True(t, Contains([]int{1, 5, 100, 1000}, 100)) 11 | assert.True(t, Contains([]string{"abc", "z", "fe", "ll"}, "fe")) 12 | assert.False(t, Contains([]string{"abc", "z", "fe", "ll"}, "")) 13 | assert.False(t, Contains([]int{1, 5, 100, 1000}, -1)) 14 | assert.False(t, Contains([]bool{false, false, false}, true)) 15 | } 16 | 17 | func TestValueAt(t *testing.T) { 18 | t.Run("logic", func(t *testing.T) { 19 | type args struct { 20 | idx int 21 | values []string 22 | defaultValue string 23 | } 24 | 25 | tests := []struct { 26 | name string 27 | args args 28 | want string 29 | }{ 30 | { 31 | name: "empty slice negative index", 32 | args: args{idx: -1, values: nil, defaultValue: ""}, 33 | want: "", 34 | }, 35 | { 36 | name: "empty slice zero index", 37 | args: args{idx: 0, values: nil, defaultValue: ""}, 38 | want: "", 39 | }, 40 | { 41 | name: "index above bounds", 42 | args: args{idx: 4, values: []string{"a", "b", "c", "d"}, defaultValue: "not found"}, 43 | want: "not found", 44 | }, 45 | { 46 | name: "negative index", 47 | args: args{idx: -1, values: []string{"a", "b", "c", "d"}, defaultValue: "not found"}, 48 | want: "not found", 49 | }, 50 | { 51 | name: "zero index", 52 | args: args{idx: 0, values: []string{"a", "b", "c", "d"}, defaultValue: "not found"}, 53 | want: "a", 54 | }, 55 | { 56 | name: "fourth element", 57 | args: args{idx: 3, values: []string{"a", "b", "c", "d"}, defaultValue: "not found"}, 58 | want: "d", 59 | }, 60 | } 61 | 62 | for _, tt := range tests { 63 | t.Run(tt.name, func(t *testing.T) { 64 | if got := ValueAt(tt.args.idx, tt.args.values, tt.args.defaultValue); got != tt.want { 65 | t.Errorf("StrSliceValueAt() = %v, want %v", got, tt.want) 66 | } 67 | }) 68 | } 69 | }) 70 | 71 | t.Run("generics", func(t *testing.T) { 72 | assert.Equal(t, "foo", ValueAt(5, []string{"a", "B"}, "foo")) 73 | assert.Equal(t, "B", ValueAt(1, []string{"a", "B", "c"}, "foo")) 74 | assert.Equal(t, -1, ValueAt(1111, []int{10, 20, 30}, -1)) 75 | assert.Equal(t, 10, ValueAt(0, []int{10, 20, 30}, -1)) 76 | }) 77 | } 78 | 79 | func TestMin(t *testing.T) { 80 | assert.Equal(t, 5, Min(10, 30, 5, 123, 99)) 81 | assert.Equal(t, "b", Min("z", "b", "d")) 82 | } 83 | -------------------------------------------------------------------------------- /testy/integration_test_suite.go: -------------------------------------------------------------------------------- 1 | package testy 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | 8 | "gorm.io/driver/postgres" 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/logger" 11 | 12 | "github.com/trustwallet/go-libs/cache/redis" 13 | ) 14 | 15 | const ( 16 | testDbDsnEnvKey = "TEST_DB_DSN" 17 | testRedisUrlEnvKey = "TEST_REDIS_URL" 18 | ) 19 | 20 | // IntegrationTestSuite is an integration testing suite with methods 21 | // for retrieving the real database and redis connection. 22 | // Just absorb the built-in IntegrationTestSuite by defining your own suite, 23 | // you can also use it along with `testify`'s suite. 24 | // Example: 25 | // 26 | // type SomeTestSuite struct { 27 | // suite.Suite 28 | // IntegrationTestSuite 29 | // } 30 | type IntegrationTestSuite struct { 31 | db *gorm.DB 32 | redis *redis.Redis 33 | } 34 | 35 | // GetDb retrieves the current *gorm.DB connection, and it's lazy loaded. 36 | func (s *IntegrationTestSuite) GetDb() *gorm.DB { 37 | if s.db == nil { 38 | db, err := NewIntegrationTestDb() 39 | if err != nil { 40 | log.Fatalln("can not connect integration test db", err) 41 | } 42 | s.db = db 43 | } 44 | return s.db 45 | } 46 | 47 | // GetRedis retrieves the current *redis.Redis connection, and it's lazy loaded. 48 | func (s *IntegrationTestSuite) GetRedis() *redis.Redis { 49 | if s.redis == nil { 50 | r, err := NewIntegrationTestRedis() 51 | if err != nil { 52 | log.Fatalln("can not connect integration redis db", err) 53 | } 54 | s.redis = r 55 | } 56 | return s.redis 57 | } 58 | 59 | // NewIntegrationTestDb creates a *gorm.DB connection to a real database which is only for integration test. 60 | // The DSN for test database connection should be set by defining the TEST_DB_DSN env. 61 | func NewIntegrationTestDb() (*gorm.DB, error) { 62 | return gorm.Open( 63 | postgres.Open(MustGetTestDbDSN()), 64 | &gorm.Config{ 65 | Logger: logger.Default.LogMode(logger.Info), 66 | SkipDefaultTransaction: true, 67 | }, 68 | ) 69 | } 70 | 71 | // NewIntegrationTestRedis creates a *redis.Redis connection to a real redis pool which is only for integration test. 72 | // The url for test redis connection should be set by defining the TEST_REDIS_URL env. 73 | func NewIntegrationTestRedis() (*redis.Redis, error) { 74 | url, ok := os.LookupEnv(testRedisUrlEnvKey) 75 | if !ok { 76 | log.Fatalln(testRedisUrlEnvKey, "env not found") 77 | } 78 | return redis.Init(context.Background(), url) 79 | } 80 | 81 | func MustGetTestDbDSN() string { 82 | dsn, ok := os.LookupEnv(testDbDsnEnvKey) 83 | if !ok { 84 | log.Fatal(testDbDsnEnvKey, "env not found") 85 | } 86 | return dsn 87 | } 88 | -------------------------------------------------------------------------------- /testy/tagged.go: -------------------------------------------------------------------------------- 1 | package testy 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | const ( 10 | TagUnit = "unit" 11 | TagIntegration = "integration" 12 | TagPostgres = "postgres" 13 | TagRabbit = "rabbit" 14 | ) 15 | 16 | // TaggedTestsEnvVar defines the name of the environment variable for the tagged tests. 17 | // Example: 18 | // TaggedTestsEnvVar="TEST_TAGS" 19 | // env TEST_TAGS="unit" go test ./... 20 | var TaggedTestsEnvVar = "TEST_TAGS" 21 | 22 | // RequireTestTag runs the test if the provided tag matches at least one runtime tag. 23 | // Example: 24 | // func TestSomething(t *testing.T) { 25 | // RequireTestTag(t, "unit") 26 | // ... 27 | // } 28 | // Run with: 29 | // env TEST_TAGS="unit,integration" go test ./... 30 | func RequireTestTag(t *testing.T, testTag string) { 31 | if !getRuntimeTags().contains(testTag) { 32 | t.Skipf("skipping test '%s', requires '%s' tag", t.Name(), testTag) 33 | } 34 | } 35 | 36 | // RequireOneOfTestTags runs the test if any of the provided test tags matches one of the runtime tags. 37 | func RequireOneOfTestTags(t *testing.T, testTags ...string) { 38 | if !getRuntimeTags().containsAny(testTags...) { 39 | t.Skipf("skipping test '%s', requires at least one of the following tags: '%s'", 40 | t.Name(), strings.Join(testTags, ", ")) 41 | } 42 | } 43 | 44 | // RequireAllTestTags runs the test if all the provided test tags appear in runtime tags. 45 | func RequireAllTestTags(t *testing.T, testTags ...string) { 46 | if !getRuntimeTags().containsAll(testTags...) { 47 | t.Skipf("skipping test '%s', requires all of the following tags: '%s'", 48 | t.Name(), strings.Join(testTags, ", ")) 49 | } 50 | } 51 | 52 | type runtimeTags []string 53 | 54 | func getRuntimeTags() runtimeTags { 55 | return parseTags(os.Getenv(TaggedTestsEnvVar)) 56 | } 57 | 58 | func parseTags(rawTags string) runtimeTags { 59 | rawTags = strings.ReplaceAll(rawTags, " ", "") 60 | if rawTags == "" { 61 | return nil 62 | } 63 | return strings.Split(rawTags, ",") 64 | } 65 | 66 | func (rt runtimeTags) contains(targetTag string) bool { 67 | for _, tag := range rt { 68 | if tag == targetTag { 69 | return true 70 | } 71 | } 72 | return false 73 | } 74 | 75 | func (rt runtimeTags) containsAny(targetTags ...string) bool { 76 | for _, targetTag := range targetTags { 77 | if rt.contains(targetTag) { 78 | return true 79 | } 80 | } 81 | return false 82 | } 83 | 84 | func (rt runtimeTags) containsAll(targetTags ...string) bool { 85 | for _, targetTag := range targetTags { 86 | if !rt.contains(targetTag) { 87 | return false 88 | } 89 | } 90 | return true 91 | } 92 | -------------------------------------------------------------------------------- /testy/tagged_test.go: -------------------------------------------------------------------------------- 1 | package testy 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestContainsMethods(t *testing.T) { 10 | rt := parseTags("unit,integration") 11 | 12 | assert.True(t, rt.contains("unit")) 13 | assert.True(t, rt.contains("integration")) 14 | assert.False(t, rt.contains("")) 15 | assert.False(t, rt.contains("UNIT")) 16 | 17 | assert.True(t, rt.containsAll("unit")) 18 | assert.True(t, rt.containsAll("unit", "integration")) 19 | assert.False(t, rt.containsAll("unit", "integration", "something-else")) 20 | assert.False(t, rt.containsAll("unit", "integration", "")) 21 | 22 | assert.True(t, rt.containsAny("unit", "something-else")) 23 | assert.True(t, rt.containsAny("whatever", "unit", "something-else")) 24 | assert.True(t, rt.containsAny("whatever", "unit", "something-else", "integration")) 25 | assert.False(t, rt.containsAny("whatever", "", "something-else")) 26 | } 27 | -------------------------------------------------------------------------------- /worker/metrics/metricspusherworker.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "github.com/trustwallet/go-libs/metrics" 5 | "github.com/trustwallet/go-libs/worker" 6 | ) 7 | 8 | func NewMetricsPusherWorker(options *worker.WorkerOptions, pusher metrics.Pusher) worker.Worker { 9 | return worker.NewWorkerBuilder("metrics_pusher", pusher.Push). 10 | WithOptions(options). 11 | WithStop(pusher.Close). 12 | Build() 13 | } 14 | -------------------------------------------------------------------------------- /worker/options.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/trustwallet/go-libs/metrics" 7 | ) 8 | 9 | type WorkerOptions struct { 10 | Interval time.Duration 11 | RunImmediately bool 12 | RunConsequently bool 13 | PerformanceMetric metrics.PerformanceMetric 14 | } 15 | 16 | func DefaultWorkerOptions(interval time.Duration) *WorkerOptions { 17 | return &WorkerOptions{ 18 | Interval: interval, 19 | RunImmediately: true, 20 | RunConsequently: false, 21 | PerformanceMetric: &metrics.NullablePerformanceMetric{}, 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /worker/worker.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "time" 7 | 8 | log "github.com/sirupsen/logrus" 9 | 10 | "github.com/trustwallet/go-libs/metrics" 11 | ) 12 | 13 | type Builder interface { 14 | WithOptions(options *WorkerOptions) Builder 15 | WithStop(func() error) Builder 16 | Build() Worker 17 | } 18 | 19 | type builder struct { 20 | worker *worker 21 | } 22 | 23 | func NewWorkerBuilder(name string, workerFn func() error) Builder { 24 | return &builder{ 25 | worker: &worker{ 26 | name: name, 27 | workerFn: workerFn, 28 | options: DefaultWorkerOptions(1 * time.Minute), 29 | }, 30 | } 31 | } 32 | 33 | func (b *builder) WithOptions(options *WorkerOptions) Builder { 34 | b.worker.options = options 35 | return b 36 | } 37 | 38 | func (b *builder) WithStop(stopFn func() error) Builder { 39 | b.worker.stopFn = stopFn 40 | return b 41 | } 42 | 43 | func (b *builder) Build() Worker { 44 | return b.worker 45 | } 46 | 47 | // Worker interface can be constructed using worker.NewBuilder("worker_name", workerFn).Build() 48 | // or allows custom implementation (e.g. one-off jobs) 49 | type Worker interface { 50 | Name() string 51 | Start(ctx context.Context, wg *sync.WaitGroup) 52 | } 53 | 54 | type worker struct { 55 | name string 56 | workerFn func() error 57 | stopFn func() error 58 | options *WorkerOptions 59 | } 60 | 61 | func (w *worker) Name() string { 62 | return w.name 63 | } 64 | 65 | func (w *worker) Start(ctx context.Context, wg *sync.WaitGroup) { 66 | if w.options.Interval == -1 { 67 | w.hold(ctx, wg) 68 | return 69 | } 70 | w.start(ctx, wg) 71 | } 72 | 73 | func (w *worker) start(ctx context.Context, wg *sync.WaitGroup) { 74 | wg.Add(1) 75 | go func() { 76 | defer wg.Done() 77 | 78 | ticker := time.NewTicker(w.options.Interval) 79 | defer ticker.Stop() 80 | 81 | if w.options.RunImmediately { 82 | log.WithField("worker", w.name).Info("run immediately") 83 | w.invoke() 84 | } 85 | 86 | for { 87 | select { 88 | case <-ctx.Done(): 89 | if w.stopFn != nil { 90 | log.WithField("worker", w.name).Info("stopping...") 91 | if err := w.stopFn(); err != nil { 92 | log.WithField("worker", w.name).WithError(err).Warn("error occurred while stopping the worker") 93 | } 94 | } 95 | log.WithField("worker", w.name).Info("stopped") 96 | return 97 | case <-ticker.C: 98 | if w.options.RunConsequently { 99 | ticker.Stop() 100 | } 101 | 102 | log.WithField("worker", w.name).Info("processing") 103 | w.invoke() 104 | 105 | if w.options.RunConsequently { 106 | ticker = time.NewTicker(w.options.Interval) 107 | } 108 | } 109 | } 110 | }() 111 | } 112 | 113 | func (w *worker) hold(ctx context.Context, wg *sync.WaitGroup) { 114 | wg.Add(1) 115 | 116 | logger := log.WithField("worker", w.name) 117 | logger.Info("worker started, but won't be executed") 118 | 119 | go func() { 120 | defer wg.Done() 121 | 122 | <-ctx.Done() 123 | 124 | if w.stopFn != nil { 125 | logger.Info("stopping...") 126 | if err := w.stopFn(); err != nil { 127 | logger.WithError(err).Warn("error occurred while stopping the worker") 128 | } 129 | } 130 | logger.Info("stopped") 131 | }() 132 | } 133 | 134 | func (w *worker) invoke() { 135 | metric := w.options.PerformanceMetric 136 | if metric == nil { 137 | metric = &metrics.NullablePerformanceMetric{} 138 | } 139 | 140 | defer metric.Duration(metric.Start()) 141 | err := w.workerFn() 142 | 143 | if err != nil { 144 | metric.Failure() 145 | log.WithField("worker", w.name).Error(err) 146 | } else { 147 | metric.Success() 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /worker/worker_test.go: -------------------------------------------------------------------------------- 1 | package worker_test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | "gotest.tools/assert" 10 | 11 | "github.com/trustwallet/go-libs/worker" 12 | ) 13 | 14 | func TestWorkerWithDefaultOptions(t *testing.T) { 15 | counter := 0 16 | worker := worker.NewWorkerBuilder("test", func() error { 17 | counter++ 18 | return nil 19 | }).WithOptions(worker.DefaultWorkerOptions(100 * time.Millisecond)).Build() 20 | 21 | wg := &sync.WaitGroup{} 22 | ctx, cancel := context.WithTimeout(context.Background(), 350*time.Millisecond) 23 | defer cancel() 24 | 25 | worker.Start(ctx, wg) 26 | 27 | wg.Wait() 28 | 29 | assert.Equal(t, 4, counter, "Should execute 4 times - 1st immediately, and 3 after") 30 | } 31 | 32 | func TestWorkerStartsConsequently(t *testing.T) { 33 | counter := 0 34 | options := worker.DefaultWorkerOptions(100 * time.Millisecond) 35 | options.RunConsequently = true 36 | 37 | worker := worker.NewWorkerBuilder("test", func() error { 38 | time.Sleep(100 * time.Millisecond) 39 | counter++ 40 | return nil 41 | }).WithOptions(options).Build() 42 | 43 | wg := &sync.WaitGroup{} 44 | ctx, cancel := context.WithTimeout(context.Background(), 350*time.Millisecond) 45 | defer cancel() 46 | 47 | worker.Start(ctx, wg) 48 | 49 | wg.Wait() 50 | 51 | assert.Equal(t, 3, counter, "Should execute 3 times - 1st immediately, and 2 after with delay between runs") 52 | } 53 | 54 | func TestWorkerStartsWithoutExecution(t *testing.T) { 55 | counter := 0 56 | options := worker.DefaultWorkerOptions(100 * time.Millisecond) 57 | options.Interval = -1 58 | 59 | worker := worker.NewWorkerBuilder("test", func() error { 60 | counter++ 61 | return nil 62 | }).WithOptions(options).Build() 63 | 64 | wg := &sync.WaitGroup{} 65 | ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 66 | defer cancel() 67 | 68 | worker.Start(ctx, wg) 69 | 70 | wg.Wait() 71 | 72 | assert.Equal(t, 0, counter, "Should never be executed") 73 | } 74 | --------------------------------------------------------------------------------