├── .github └── workflows │ └── coverage.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── Makefile ├── README.md ├── assets └── logo.png ├── consistenthash ├── hash_calculate.go ├── hashring.go └── hashring_test.go ├── docker-compose.yml ├── examples └── main.go ├── go.mod ├── go.sum ├── logger └── zap.go ├── memcached ├── constants.go ├── constants_test.go ├── errors.go ├── memcached.go ├── memcached_test.go ├── metrics.go ├── metrics_test.go ├── node_provider.go ├── node_provider_test.go ├── options.go ├── options_test.go ├── requests.go ├── requests_test.go ├── responses.go ├── responses_test.go └── transport.go ├── pool ├── pool.go └── pool_test.go └── utils ├── addr.go ├── addr_test.go ├── math.go ├── math_test.go ├── stringer.go └── stringer_test.go /.github/workflows/coverage.yml: -------------------------------------------------------------------------------- 1 | name: Coverage 2 | 3 | on: 4 | pull_request: 5 | branches: [ main ] 6 | 7 | jobs: 8 | build: 9 | runs-on: ${{ matrix.os }} 10 | permissions: 11 | contents: write 12 | services: 13 | memcached: 14 | image: memcached:latest 15 | ports: 16 | - 11211:11211 17 | strategy: 18 | matrix: 19 | os: [ ubuntu-latest ] 20 | go: [ 1.21.x ] 21 | steps: 22 | - uses: actions/checkout@v4 23 | 24 | - name: Set up Go 25 | uses: actions/setup-go@v2 26 | with: 27 | go-version: ${{ matrix.go }} 28 | 29 | - name: Test 30 | run: | 31 | go test -v -cover ./... -coverprofile coverage.out 32 | go tool cover -func coverage.out -o coverage.out 33 | 34 | - name: Go Coverage Badge 35 | uses: tj-actions/coverage-badge-go@v1 36 | if: ${{ runner.os == 'linux' }} 37 | with: 38 | green: 80 39 | filename: coverage.out 40 | 41 | - uses: stefanzweifel/git-auto-commit-action@v5 42 | id: auto-commit-action 43 | with: 44 | commit_message: 'chore: apply code coverage badge' 45 | file_pattern: ./README.md 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Go template 2 | # If you prefer the allow list template instead of the deny list, see community template: 3 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 4 | # 5 | # Binaries for programs and plugins 6 | *.exe 7 | *.exe~ 8 | *.dll 9 | *.so 10 | *.dylib 11 | /bin 12 | 13 | # IDEs 14 | .idea 15 | .vscode 16 | 17 | # Test binary, built with `go test -c` 18 | *.test 19 | 20 | # Output of the go coverage tool, specifically when used with LiteIDE 21 | *.out 22 | 23 | # macos 24 | .DS_STORE 25 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | # More info on config here: https://github.com/golangci/golangci-lint#config-file 2 | run: 3 | deadline: 60s 4 | issues-exit-code: 1 5 | tests: true 6 | 7 | output: 8 | format: colored-line-number 9 | print-issued-lines: true 10 | print-linter-name: true 11 | 12 | linters-settings: 13 | govet: 14 | check-shadowing: true 15 | golint: 16 | min-confidence: 0 17 | dupl: 18 | threshold: 100 19 | goconst: 20 | min-len: 2 21 | min-occurrences: 2 22 | 23 | linters: 24 | disable-all: true 25 | enable: 26 | - govet 27 | - gosimple 28 | - errcheck 29 | - ineffassign 30 | - typecheck 31 | - goconst 32 | - gosec 33 | - nilnil 34 | - goimports 35 | - megacheck 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AliExpress Russia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | LOCAL_BIN=$(CURDIR)/bin 2 | 3 | GOENV=PATH=$(LOCAL_BIN):$(PATH) 4 | 5 | GOLANGCI_BIN=$(LOCAL_BIN)/golangci-lint 6 | $(GOLANGCI_BIN): 7 | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(LOCAL_BIN) v1.55.2 8 | 9 | .PHONY: lint 10 | lint: $(GOLANGCI_BIN) 11 | $(GOENV) $(GOLANGCI_BIN) run --fix -v ./... 12 | 13 | .PHONY: test 14 | test: 15 | $(GOENV) go test -race ./... 16 | 17 | .PHONY: test-cover 18 | test-cover: 19 | $(GOENV) go test ./... -cover 20 | 21 | .PHONY: test-cover-html 22 | test-cover-html: 23 | $(GOENV) go test ./... -coverprofile=prof.out 24 | $(GOENV) go tool cover -html=prof.out 25 | 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gomemcached 2 | 3 | --- 4 |
5 | 6 | 7 | [![License](https://img.shields.io/github/license/gogf/gf.svg?style=flat)](https://github.com/aliexpressru/gomemcached) 8 | [![Tag](https://img.shields.io/github/v/tag/aliexpressru/gomemcached?color=%23ff8936&logo=fitbit)](https://github.com/aliexpressru/gomemcached/tags) 9 | [![Godoc](https://godoc.org/github.com/aliexpressru/gomemcached?status.svg)](https://pkg.go.dev/github.com/aliexpressru/gomemcached) 10 | 11 | [![Gomemcached](https://goreportcard.com/badge/github.com/aliexpressru/gomemcached)](https://goreportcard.com/report/github.com/aliexpressru/gomemcached) 12 | ![Coverage](https://img.shields.io/badge/Coverage-91.0%25-brightgreen) 13 | 14 | [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go?tab=readme-ov-file#nosql-database-drivers) 15 |
16 | 17 | ___ 18 | `Gomemcached` is a Golang Memcached client designed to interact with multiple instances as shards. Implements sharding using a Consistent Hash. 19 | ___ 20 | 21 | ### Configuration 22 | 23 | ```yaml 24 | - name: MEMCACHED_HEADLESS_SERVICE_ADDRESS 25 | value: "my-memchached-service-headless.namespace.svc.cluster.local" 26 | ``` 27 | 28 | `MEMCACHED_HEADLESS_SERVICE_ADDRESS` groups all memcached instances by ip addresses using dns lookup. 29 | 30 | Default Memcached port is `11211`, but you can also specify it in config. 31 | 32 | ```yaml 33 | - name: MEMCACHED_PORT 34 | value: "12345" 35 | ``` 36 | 37 | For local run or if you have a static amount and setup of pods you can specify Servers (list separated by commas along with the port) manually instead of setting the 38 | HeadlessServiceAddress: 39 | 40 | ```yaml 41 | - name: MEMCACHED_SERVERS 42 | value: "127.0.0.1:11211,192.168.0.1:1234" 43 | ``` 44 | 45 | ___ 46 | 47 | ### Usage 48 | 49 | Initialization client and connected to memcached servers. 50 | 51 | ```go 52 | mcl, err := memcached.InitFromEnv() 53 | mustInit(err) 54 | a.AddCloser(func () error { 55 | mcl.CloseAllConns() 56 | return nil 57 | }) 58 | ``` 59 | [More examples](examples/main.go) 60 | 61 | To use SASL specify option for InitFromEnv: 62 | 63 | ```go 64 | memcached.InitFromEnv(memcached.WithAuthentication("", "")) 65 | ``` 66 | 67 | Can use Options with InitFromEnv to customize the client to suit your needs. However, for basic use, it is recommended 68 | to use the default client implementation. 69 | 70 | --- 71 | 72 | ### Recommended Versions 73 | 74 | This project is developed and tested with the following recommended versions: 75 | 76 | - Go: 1.21 or higher 77 | - [Download Go](https://golang.org/dl/) 78 | 79 | - Memcached: 1.6.9 or higher 80 | - [Memcached Releases](https://memcached.org/downloads) 81 | 82 | --- 83 | 84 | ### Dependencies 85 | 86 | This project utilizes the following third-party libraries, each governed by the MIT License: 87 | 88 | 1. [gomemcache](https://github.com/bradfitz/gomemcache) 89 | - Description: A Go client library for the memcached server. 90 | - Used for: Primary client methods for interacting with the library. 91 | - License: Apache License 2.0 92 | 93 | 2. [go-zero](https://github.com/zeromicro/go-zero) 94 | - Description: A cloud-native Go microservices framework with cli tool for productivity. 95 | - Used for: Implementation of Consistent Hash. 96 | - License: MIT License 97 | 98 | 3. [gomemcached](https://github.com/dustin/gomemcached) 99 | - Description: A memcached binary protocol toolkit for go. 100 | - Used for: Implementation of a binary client for Memcached. 101 | - License: MIT License 102 | 103 | Please review the respective license files in the linked repositories for more details. 104 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliexpressru/gomemcached/b7e131624a1b957e510aa03bf6d7f466ebcabe0d/assets/logo.png -------------------------------------------------------------------------------- /consistenthash/hash_calculate.go: -------------------------------------------------------------------------------- 1 | package consistenthash 2 | 3 | import "github.com/cespare/xxhash" 4 | 5 | // Hash returns the hash value of data. 6 | func Hash(data []byte) uint64 { 7 | return xxhash.Sum64(data) 8 | } 9 | -------------------------------------------------------------------------------- /consistenthash/hashring.go: -------------------------------------------------------------------------------- 1 | package consistenthash 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "strconv" 7 | "sync" 8 | 9 | "github.com/aliexpressru/gomemcached/utils" 10 | ) 11 | 12 | const ( 13 | // TopWeight is the top weight that one entry might set. 14 | TopWeight = 100 15 | 16 | minReplicas = 256 17 | prime = 18245165 18 | ) 19 | 20 | var _ ConsistentHash = (*HashRing)(nil) 21 | 22 | type ( 23 | ConsistentHash interface { 24 | Add(node any) 25 | AddWithReplicas(node any, replicas int) 26 | AddWithWeight(node any, weight int) 27 | Get(v any) (any, bool) 28 | GetAllNodes() []any 29 | Remove(node any) 30 | GetNodesCount() int 31 | } 32 | 33 | // Func defines the hash method. 34 | Func func(data []byte) uint64 35 | 36 | // A HashRing is implementation of consistent hash. 37 | HashRing struct { 38 | hashFunc Func 39 | replicas int 40 | keys []uint64 41 | ring map[uint64][]any 42 | nodes map[string]struct{} 43 | lock sync.RWMutex 44 | } 45 | ) 46 | 47 | // NewHashRing returns a HashRing. 48 | func NewHashRing() *HashRing { 49 | return NewCustomHashRing(minReplicas, Hash) 50 | } 51 | 52 | // NewCustomHashRing returns a HashRing with given replicas and hash func. 53 | func NewCustomHashRing(replicas int, fn Func) *HashRing { 54 | if replicas < minReplicas { 55 | replicas = minReplicas 56 | } 57 | 58 | if fn == nil { 59 | fn = Hash 60 | } 61 | 62 | return &HashRing{ 63 | hashFunc: fn, 64 | replicas: replicas, 65 | ring: make(map[uint64][]any), 66 | nodes: make(map[string]struct{}), 67 | } 68 | } 69 | 70 | // Add adds the node with the number of h.replicas, 71 | // the later call will overwrite the replicas of the former calls. 72 | func (h *HashRing) Add(node any) { 73 | h.AddWithReplicas(node, h.replicas) 74 | } 75 | 76 | // AddWithReplicas adds the node with the number of replicas, 77 | // replicas will be truncated to h.replicas if it's larger than h.replicas, 78 | // the later call will overwrite the replicas of the former calls. 79 | func (h *HashRing) AddWithReplicas(node any, replicas int) { 80 | h.Remove(node) 81 | 82 | if replicas > h.replicas { 83 | replicas = h.replicas 84 | } 85 | 86 | nodeRepr := repr(node) 87 | h.lock.Lock() 88 | defer h.lock.Unlock() 89 | h.addNode(nodeRepr) 90 | 91 | for i := 0; i < replicas; i++ { 92 | hash := h.hashFunc([]byte(replicaRepr(nodeRepr, i))) 93 | h.keys = append(h.keys, hash) 94 | h.ring[hash] = append(h.ring[hash], node) 95 | } 96 | 97 | sort.Slice(h.keys, func(i, j int) bool { 98 | return h.keys[i] < h.keys[j] 99 | }) 100 | } 101 | 102 | // AddWithWeight adds the node with weight, the weight can be 1 to 100, indicates the percent, 103 | // the later call will overwrite the replicas of the former calls. 104 | func (h *HashRing) AddWithWeight(node any, weight int) { 105 | // don't need to make sure weight not larger than TopWeight, 106 | // because AddWithReplicas makes sure replicas cannot be larger than h.replicas 107 | replicas := h.replicas * weight / TopWeight 108 | h.AddWithReplicas(node, replicas) 109 | } 110 | 111 | // Get returns the corresponding node from h base on the given v. 112 | func (h *HashRing) Get(v any) (any, bool) { 113 | h.lock.RLock() 114 | defer h.lock.RUnlock() 115 | 116 | if len(h.ring) == 0 { 117 | return nil, false 118 | } 119 | 120 | hash := h.hashFunc([]byte(repr(v))) 121 | index := sort.Search(len(h.keys), func(i int) bool { 122 | return h.keys[i] >= hash 123 | }) % len(h.keys) 124 | 125 | nodes := h.ring[h.keys[index]] 126 | switch len(nodes) { 127 | case 0: 128 | return nil, false 129 | case 1: 130 | return nodes[0], true 131 | default: 132 | innerIndex := h.hashFunc([]byte(innerRepr(v))) 133 | pos := int(innerIndex % uint64(len(nodes))) 134 | return nodes[pos], true 135 | } 136 | } 137 | 138 | // GetAllNodes returns all nodes used in hash ring 139 | // 140 | // return a slice with a string representation of the nodes 141 | func (h *HashRing) GetAllNodes() []any { 142 | h.lock.RLock() 143 | defer h.lock.RUnlock() 144 | 145 | if len(h.ring) == 0 { 146 | return nil 147 | } 148 | 149 | var ( 150 | allNodes = make([]any, 0, len(h.nodes)) 151 | uqNodes = make(map[any]struct{}, len(h.nodes)) 152 | ) 153 | 154 | for _, nodes := range h.ring { 155 | for _, node := range nodes { 156 | if _, ok := uqNodes[node]; !ok { 157 | allNodes = append(allNodes, node) 158 | uqNodes[node] = struct{}{} 159 | } 160 | } 161 | } 162 | 163 | return allNodes 164 | } 165 | 166 | // Remove removes the given node from h. 167 | func (h *HashRing) Remove(node any) { 168 | nodeRepr := repr(node) 169 | 170 | h.lock.Lock() 171 | defer h.lock.Unlock() 172 | 173 | if !h.containsNode(nodeRepr) { 174 | return 175 | } 176 | 177 | for i := 0; i < h.replicas; i++ { 178 | hash := h.hashFunc([]byte(replicaRepr(nodeRepr, i))) 179 | index := sort.Search(len(h.keys), func(i int) bool { 180 | return h.keys[i] >= hash 181 | }) 182 | if index < len(h.keys) && h.keys[index] == hash { 183 | h.keys = append(h.keys[:index], h.keys[index+1:]...) 184 | } 185 | h.removeRingNode(hash, nodeRepr) 186 | } 187 | 188 | h.removeNode(nodeRepr) 189 | } 190 | 191 | // GetNodesCount returns the current number of nodes 192 | func (h *HashRing) GetNodesCount() int { 193 | h.lock.RLock() 194 | defer h.lock.RUnlock() 195 | return len(h.nodes) 196 | } 197 | 198 | func (h *HashRing) removeRingNode(hash uint64, nodeRepr string) { 199 | if nodes, ok := h.ring[hash]; ok { 200 | newNodes := nodes[:0] 201 | for _, x := range nodes { 202 | if repr(x) != nodeRepr { 203 | newNodes = append(newNodes, x) 204 | } 205 | } 206 | if len(newNodes) > 0 { 207 | h.ring[hash] = newNodes 208 | } else { 209 | delete(h.ring, hash) 210 | } 211 | } 212 | } 213 | 214 | func (h *HashRing) addNode(nodeRepr string) { 215 | h.nodes[nodeRepr] = struct{}{} 216 | } 217 | 218 | func (h *HashRing) containsNode(nodeRepr string) bool { 219 | _, ok := h.nodes[nodeRepr] 220 | return ok 221 | } 222 | 223 | func (h *HashRing) removeNode(nodeRepr string) { 224 | delete(h.nodes, nodeRepr) 225 | } 226 | 227 | func innerRepr(node any) string { 228 | return fmt.Sprintf("%d:%v", prime, node) 229 | } 230 | 231 | func repr(node any) string { 232 | return utils.Repr(node) 233 | } 234 | 235 | func replicaRepr(nodeRepr string, replicaNumber int) string { 236 | return fmt.Sprintf("%s_virtual%s", nodeRepr, strconv.Itoa(replicaNumber)) 237 | } 238 | -------------------------------------------------------------------------------- /consistenthash/hashring_test.go: -------------------------------------------------------------------------------- 1 | //nolint:goconst 2 | package consistenthash 3 | 4 | import ( 5 | "fmt" 6 | "strconv" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | 11 | "github.com/aliexpressru/gomemcached/utils" 12 | ) 13 | 14 | const ( 15 | keySize = 20 16 | requestSize = 1000 17 | ) 18 | 19 | func BenchmarkHashRingGet(b *testing.B) { 20 | ch := NewHashRing() 21 | for i := 0; i < keySize; i++ { 22 | ch.Add("localhost:" + strconv.Itoa(i)) 23 | } 24 | 25 | for i := 0; i < b.N; i++ { 26 | ch.Get(i) 27 | } 28 | } 29 | 30 | func TestHashRing_GetAllNodes(t *testing.T) { 31 | ch := NewHashRing() 32 | 33 | allNodes := ch.GetAllNodes() 34 | assert.Nil(t, allNodes, "GetAllNodes: without added nodes") 35 | 36 | for i := 0; i < keySize; i++ { 37 | ch.Add("localhost:" + strconv.Itoa(i)) 38 | } 39 | count := ch.GetNodesCount() 40 | assert.Equalf(t, keySize, count, "GetNodesCount: have - %d; want - %d", count, keySize) 41 | 42 | allNodes = ch.GetAllNodes() 43 | assert.Equal(t, keySize, len(allNodes)) 44 | 45 | for i := 0; i < keySize; i++ { 46 | node := "localhost:" + strconv.Itoa(i) 47 | found := false 48 | for _, n := range allNodes { 49 | if n == node { 50 | found = true 51 | break 52 | } 53 | } 54 | assert.True(t, found, "Node not found in GetAllNodes: "+node) 55 | } 56 | } 57 | 58 | func TestHashRingWithEntropy(t *testing.T) { 59 | ch := NewCustomHashRing(0, nil) 60 | val, ok := ch.Get("any") 61 | assert.False(t, ok) 62 | assert.Nil(t, val) 63 | 64 | for i := 0; i < keySize; i++ { 65 | ch.AddWithReplicas("localhost:"+strconv.Itoa(i), minReplicas<<1) 66 | } 67 | 68 | keys := make(map[string]int) 69 | for i := 0; i < requestSize; i++ { 70 | key, ok := ch.Get(requestSize + i) 71 | assert.True(t, ok) 72 | keys[key.(string)]++ 73 | } 74 | 75 | mi := make(map[any]int, len(keys)) 76 | for k, v := range keys { 77 | mi[k] = v 78 | } 79 | entropy := utils.CalcEntropy(mi) 80 | assert.True(t, entropy > .95) 81 | } 82 | 83 | func TestHashRingIncrementalTransfer(t *testing.T) { 84 | prefix := "anything" 85 | create := func() *HashRing { 86 | ch := NewHashRing() 87 | for i := 0; i < keySize; i++ { 88 | ch.Add(prefix + strconv.Itoa(i)) 89 | } 90 | return ch 91 | } 92 | 93 | originCh := create() 94 | keys := make(map[int]string, requestSize) 95 | for i := 0; i < requestSize; i++ { 96 | key, ok := originCh.Get(requestSize + i) 97 | assert.True(t, ok) 98 | assert.NotNil(t, key) 99 | keys[i] = key.(string) 100 | } 101 | 102 | node := fmt.Sprintf("%s%d", prefix, keySize) 103 | for i := 0; i < 10; i++ { 104 | laterCh := create() 105 | laterCh.AddWithWeight(node, 10*(i+1)) 106 | 107 | for j := 0; j < requestSize; j++ { 108 | key, ok := laterCh.Get(requestSize + j) 109 | assert.True(t, ok) 110 | assert.NotNil(t, key) 111 | value := key.(string) 112 | assert.True(t, value == keys[j] || value == node) 113 | } 114 | } 115 | } 116 | 117 | func TestHashRingTransferOnFailure(t *testing.T) { 118 | index := 41 119 | keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index) 120 | var transferred int 121 | for k, v := range newKeys { 122 | if v != keys[k] { 123 | transferred++ 124 | } 125 | } 126 | 127 | ratio := float32(transferred) / float32(requestSize) 128 | assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio)) 129 | } 130 | 131 | func TestHashRingLeastTransferOnFailure(t *testing.T) { 132 | prefix := "localhost:" 133 | index := 41 134 | keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index) 135 | for k, v := range keys { 136 | newV := newKeys[k] 137 | if v != prefix+strconv.Itoa(index) { 138 | assert.Equal(t, v, newV) 139 | } 140 | } 141 | } 142 | 143 | func TestHashRing_Remove(t *testing.T) { 144 | ch := NewHashRing() 145 | ch.Add("first") 146 | ch.Add("second") 147 | ch.Remove("first") 148 | for i := 0; i < 100; i++ { 149 | val, ok := ch.Get(i) 150 | assert.True(t, ok) 151 | assert.Equal(t, "second", val) 152 | } 153 | } 154 | 155 | func TestHashRing_RemoveInterface(t *testing.T) { 156 | const key = "any" 157 | ch := NewHashRing() 158 | node1 := newMockNode(key, 1) 159 | node2 := newMockNode(key, 2) 160 | ch.AddWithWeight(node1, 80) 161 | ch.AddWithWeight(node2, 50) 162 | assert.Equal(t, 1, len(ch.nodes)) 163 | node, ok := ch.Get(1) 164 | assert.True(t, ok) 165 | assert.Equal(t, key, node.(*mockNode).addr) 166 | assert.Equal(t, 2, node.(*mockNode).id) 167 | } 168 | 169 | func Test_innerRepr(t *testing.T) { 170 | type args struct { 171 | node any 172 | } 173 | tests := []struct { 174 | name string 175 | args args 176 | want string 177 | }{ 178 | { 179 | name: "localhost:11211", 180 | args: args{node: "localhost:11211"}, 181 | want: fmt.Sprintf("%d:%v", prime, "localhost:11211"), 182 | }, 183 | { 184 | name: "127.0.0.1:11211", 185 | args: args{node: "127.0.0.1:11211"}, 186 | want: fmt.Sprintf("%d:%v", prime, "127.0.0.1:11211"), 187 | }, 188 | } 189 | for _, tt := range tests { 190 | t.Run(tt.name, func(t *testing.T) { 191 | out := innerRepr(tt.args.node) 192 | assert.Equal(t, tt.want, out, "innerRepr: returned wrong format") 193 | }) 194 | } 195 | } 196 | 197 | func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[int]string, map[int]string) { 198 | ch := NewHashRing() 199 | for i := 0; i < keySize; i++ { 200 | ch.Add(prefix + strconv.Itoa(i)) 201 | } 202 | 203 | keys := make(map[int]string, requestSize) 204 | for i := 0; i < requestSize; i++ { 205 | key, ok := ch.Get(requestSize + i) 206 | assert.True(t, ok) 207 | assert.NotNil(t, key) 208 | keys[i] = key.(string) 209 | } 210 | 211 | remove := fmt.Sprintf("%s%d", prefix, index) 212 | ch.Remove(remove) 213 | newKeys := make(map[int]string, requestSize) 214 | for i := 0; i < requestSize; i++ { 215 | key, ok := ch.Get(requestSize + i) 216 | assert.True(t, ok) 217 | assert.NotNil(t, key) 218 | assert.NotEqual(t, remove, key) 219 | newKeys[i] = key.(string) 220 | } 221 | 222 | return keys, newKeys 223 | } 224 | 225 | type mockNode struct { 226 | addr string 227 | id int 228 | } 229 | 230 | func newMockNode(addr string, id int) *mockNode { 231 | return &mockNode{ 232 | addr: addr, 233 | id: id, 234 | } 235 | } 236 | 237 | func (n *mockNode) String() string { 238 | return n.addr 239 | } 240 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | memcached: 4 | image: memcached:latest 5 | ports: 6 | - "11211:11211" 7 | -------------------------------------------------------------------------------- /examples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | 6 | "golang.org/x/exp/maps" 7 | 8 | "github.com/aliexpressru/gomemcached/memcached" 9 | ) 10 | 11 | func main() { 12 | _ = os.Setenv("MEMCACHED_SERVERS", "localhost:11211") 13 | 14 | mcl, err := memcached.InitFromEnv( 15 | memcached.WithMaxIdleConns(10), 16 | memcached.WithAuthentication("admin", "mysecretpassword"), 17 | memcached.WithDisableLogger(), 18 | memcached.WithDisableMemcachedDiagnostic(), 19 | ) 20 | mustInit(err) 21 | defer mcl.CloseAllConns() 22 | 23 | _, err = mcl.Store(memcached.Set, "foo", 10, []byte("bar")) 24 | mustInit(err) 25 | 26 | _, err = mcl.Get("foo") 27 | mustInit(err) 28 | 29 | _, err = mcl.Delete("foo") 30 | mustInit(err) 31 | 32 | _, err = mcl.Delta(memcached.Increment, "incappend", 1, 1, 0) 33 | mustInit(err) 34 | 35 | _, err = mcl.Append(memcached.Append, "incappend", []byte("add")) 36 | mustInit(err) 37 | 38 | items := map[string][]byte{ 39 | "foo": []byte("bar"), 40 | "gopher": []byte("golang"), 41 | "answer": []byte("42"), 42 | } 43 | 44 | err = mcl.MultiStore(memcached.Add, items, 0) 45 | mustInit(err) 46 | 47 | _, err = mcl.MultiGet(maps.Keys(items)) 48 | mustInit(err) 49 | 50 | err = mcl.MultiDelete(maps.Keys(items)) 51 | mustInit(err) 52 | 53 | err = mcl.FlushAll(0) 54 | mustInit(err) 55 | } 56 | 57 | func mustInit(e error) { 58 | if e != nil { 59 | panic(e) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/aliexpressru/gomemcached 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/cespare/xxhash v1.1.0 7 | github.com/kelseyhightower/envconfig v1.4.0 8 | github.com/prometheus/client_golang v1.18.0 9 | github.com/stretchr/testify v1.8.4 10 | go.uber.org/zap v1.26.0 11 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a 12 | golang.org/x/sync v0.6.0 13 | ) 14 | 15 | require ( 16 | github.com/beorn7/perks v1.0.1 // indirect 17 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 18 | github.com/davecgh/go-spew v1.1.1 // indirect 19 | github.com/pmezard/go-difflib v1.0.0 // indirect 20 | github.com/prometheus/client_model v0.5.0 // indirect 21 | github.com/prometheus/common v0.46.0 // indirect 22 | github.com/prometheus/procfs v0.12.0 // indirect 23 | github.com/stretchr/objx v0.5.0 // indirect 24 | go.uber.org/multierr v1.11.0 // indirect 25 | golang.org/x/sys v0.16.0 // indirect 26 | google.golang.org/protobuf v1.32.0 // indirect 27 | gopkg.in/yaml.v3 v3.0.1 // indirect 28 | ) 29 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= 2 | github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= 3 | github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= 4 | github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= 5 | github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= 6 | github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= 7 | github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= 8 | github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 9 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 10 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 11 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 12 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 13 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 14 | github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= 15 | github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= 16 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 17 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 18 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 19 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 22 | github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= 23 | github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= 24 | github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= 25 | github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= 26 | github.com/prometheus/common v0.46.0 h1:doXzt5ybi1HBKpsZOL0sSkaNHJJqkyfEWZGGqqScV0Y= 27 | github.com/prometheus/common v0.46.0/go.mod h1:Tp0qkxpb9Jsg54QMe+EAmqXkSV7Evdy1BTn+g2pa/hQ= 28 | github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= 29 | github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= 30 | github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= 31 | github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= 32 | github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= 33 | github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= 34 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 35 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 36 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= 37 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 38 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 39 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 40 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 41 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 42 | go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= 43 | go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo= 44 | go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= 45 | go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= 46 | go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= 47 | go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= 48 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= 49 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= 50 | golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= 51 | golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 52 | golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= 53 | golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 54 | google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= 55 | google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= 56 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 57 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 58 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 59 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 60 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 61 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 62 | -------------------------------------------------------------------------------- /logger/zap.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "os" 5 | "sync/atomic" 6 | 7 | "go.uber.org/zap" 8 | "go.uber.org/zap/zapcore" 9 | ) 10 | 11 | var ( 12 | // global logger instance. 13 | global *zap.SugaredLogger 14 | disableLogger atomic.Bool 15 | defaultLevel = zap.NewAtomicLevelAt(zap.DebugLevel) 16 | generationArgs = []any{"@gen", "1"} 17 | ) 18 | 19 | func init() { 20 | SetLogger(newSugaredLogger(defaultLevel)) 21 | } 22 | 23 | // SetLogger sets to global logger a new *zap.SugaredLogger. 24 | func SetLogger(l *zap.SugaredLogger) { 25 | global = l 26 | } 27 | 28 | // GetLogger returns current global logger. 29 | func GetLogger() *zap.SugaredLogger { 30 | return global 31 | } 32 | 33 | // DisableLogger turn off all logs, globally. 34 | func DisableLogger() { 35 | disableLogger.Store(true) 36 | } 37 | 38 | // LoggerIsDisable checks the status of the logger (true - disabled, false - enabled) 39 | func LoggerIsDisable() bool { 40 | return disableLogger.Load() 41 | } 42 | 43 | func newSugaredLogger(level zapcore.LevelEnabler, options ...zap.Option) *zap.SugaredLogger { 44 | if level == nil { 45 | level = defaultLevel 46 | } 47 | return zap.New( 48 | zapcore.NewCore( 49 | zapcore.NewJSONEncoder(zapcore.EncoderConfig{ 50 | TimeKey: "ts", 51 | LevelKey: "level", 52 | NameKey: "logger", 53 | CallerKey: "caller", 54 | MessageKey: "message", 55 | StacktraceKey: "stacktrace", 56 | LineEnding: zapcore.DefaultLineEnding, 57 | EncodeLevel: capitalLevelEncoder, 58 | EncodeTime: zapcore.ISO8601TimeEncoder, 59 | EncodeDuration: zapcore.SecondsDurationEncoder, 60 | EncodeCaller: zapcore.ShortCallerEncoder, 61 | }), 62 | zapcore.AddSync(os.Stdout), 63 | level, 64 | ), 65 | options..., 66 | ).Sugar().With(generationArgs...) 67 | } 68 | 69 | func capitalLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) { 70 | level := "" 71 | switch l { 72 | case zapcore.ErrorLevel: 73 | level = "ERR" 74 | case zapcore.WarnLevel: 75 | level = "WARNING" 76 | default: 77 | level = l.CapitalString() 78 | } 79 | enc.AppendString(level) 80 | } 81 | 82 | // Debug ... 83 | func Debug(args ...any) { 84 | if log := GetLogger(); !LoggerIsDisable() { 85 | log.Debug(args...) 86 | } 87 | } 88 | 89 | // Debugf ... 90 | func Debugf(format string, args ...any) { 91 | if log := GetLogger(); !LoggerIsDisable() { 92 | log.Debugf(format, args...) 93 | } 94 | } 95 | 96 | // Info ... 97 | func Info(args ...any) { 98 | if log := GetLogger(); !LoggerIsDisable() { 99 | log.Info(args...) 100 | } 101 | } 102 | 103 | // Infof ... 104 | func Infof(format string, args ...any) { 105 | if log := GetLogger(); !LoggerIsDisable() { 106 | log.Infof(format, args...) 107 | } 108 | } 109 | 110 | // Warn ... 111 | func Warn(args ...any) { 112 | if log := GetLogger(); !LoggerIsDisable() { 113 | log.Warn(args...) 114 | } 115 | } 116 | 117 | // Warnf ... 118 | func Warnf(format string, args ...any) { 119 | if log := GetLogger(); !LoggerIsDisable() { 120 | log.Warnf(format, args...) 121 | } 122 | } 123 | 124 | // Error ... 125 | func Error(args ...any) { 126 | if log := GetLogger(); !LoggerIsDisable() { 127 | log.Error(args...) 128 | } 129 | } 130 | 131 | // Errorf ... 132 | func Errorf(format string, args ...any) { 133 | if log := GetLogger(); !LoggerIsDisable() { 134 | log.Errorf(format, args...) 135 | } 136 | } 137 | 138 | // Fatal ... 139 | func Fatal(args ...any) { 140 | if log := GetLogger(); !LoggerIsDisable() { 141 | log.Fatal(args...) 142 | } 143 | } 144 | 145 | // Fatalf ... 146 | func Fatalf(format string, args ...any) { 147 | if log := GetLogger(); !LoggerIsDisable() { 148 | log.Fatalf(format, args...) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /memcached/constants.go: -------------------------------------------------------------------------------- 1 | // Package memcached binary protocol packet formats and constants. 2 | package memcached 3 | 4 | import ( 5 | "fmt" 6 | ) 7 | 8 | const ( 9 | REQ_MAGIC = 0x80 10 | RES_MAGIC = 0x81 11 | ) 12 | 13 | const ( 14 | SaslMechanism = "PLAIN" 15 | ) 16 | 17 | type OpCode uint8 18 | 19 | const ( 20 | GET = OpCode(0x00) 21 | SET = OpCode(0x01) 22 | ADD = OpCode(0x02) 23 | REPLACE = OpCode(0x03) 24 | DELETE = OpCode(0x04) 25 | INCREMENT = OpCode(0x05) 26 | DECREMENT = OpCode(0x06) 27 | QUIT = OpCode(0x07) 28 | FLUSH = OpCode(0x08) 29 | GETQ = OpCode(0x09) 30 | NOOP = OpCode(0x0a) 31 | VERSION = OpCode(0x0b) 32 | GETK = OpCode(0x0c) 33 | GETKQ = OpCode(0x0d) 34 | APPEND = OpCode(0x0e) 35 | PREPEND = OpCode(0x0f) 36 | STAT = OpCode(0x10) 37 | SETQ = OpCode(0x11) 38 | ADDQ = OpCode(0x12) 39 | REPLACEQ = OpCode(0x13) 40 | DELETEQ = OpCode(0x14) 41 | INCREMENTQ = OpCode(0x15) 42 | DECREMENTQ = OpCode(0x16) 43 | QUITQ = OpCode(0x17) 44 | FLUSHQ = OpCode(0x18) 45 | APPENDQ = OpCode(0x19) 46 | PREPENDQ = OpCode(0x1a) 47 | 48 | SASL_LIST_MECHS = OpCode(0x20) 49 | SASL_AUTH = OpCode(0x21) 50 | SASL_STEP = OpCode(0x22) 51 | ) 52 | 53 | type Status uint16 54 | 55 | const ( 56 | // SUCCESS - Successful operation. 57 | SUCCESS = Status(0x00) 58 | // KEY_ENOENT - Key not found. 59 | KEY_ENOENT = Status(0x01) 60 | // KEY_EEXISTS -Key already exists. 61 | KEY_EEXISTS = Status(0x02) 62 | // E2BIG - Data size exceeds limit. 63 | E2BIG = Status(0x03) 64 | // EINVAL - Invalid arguments or operation parameters. 65 | EINVAL = Status(0x04) 66 | // NOT_STORED - Operation was not performed because the data was not stored. 67 | NOT_STORED = Status(0x05) 68 | // DELTA_BADVAL - Invalid value specified for increment/decrement. 69 | DELTA_BADVAL = Status(0x06) 70 | // AUTHFAIL - Authentication required / Not Successful. 71 | AUTHFAIL = Status(0x20) 72 | // FURTHER_AUTH - Further authentication steps required. 73 | FURTHER_AUTH = Status(0x21) 74 | // UNKNOWN_COMMAND - Unknown command. 75 | UNKNOWN_COMMAND = Status(0x81) 76 | // ENOMEM - Insufficient memory for the operation. 77 | ENOMEM = Status(0x82) 78 | // TMPFAIL - Temporary failure, the operation cannot be performed at the moment. 79 | TMPFAIL = Status(0x86) 80 | 81 | // UNKNOWN_STATUS is not a Memcached status 82 | UNKNOWN_STATUS = Status(0xffff) 83 | ) 84 | 85 | const ( 86 | // HDR_LEN is a number of bytes in a binary protocol header. 87 | HDR_LEN = 24 88 | BODY_LEN = 128 89 | ) 90 | 91 | // Mapping of OpCode -> name of command (not exhaustive) 92 | var CommandNames map[OpCode]string 93 | 94 | var StatusNames map[Status]string 95 | 96 | // nolint:goconst 97 | func init() { 98 | CommandNames = make(map[OpCode]string) 99 | CommandNames[GET] = "GET" 100 | CommandNames[SET] = "SET" 101 | CommandNames[ADD] = "ADD" 102 | CommandNames[REPLACE] = "REPLACE" 103 | CommandNames[DELETE] = "DELETE" 104 | CommandNames[INCREMENT] = "INCREMENT" 105 | CommandNames[DECREMENT] = "DECREMENT" 106 | CommandNames[QUIT] = "QUIT" 107 | CommandNames[FLUSH] = "FLUSH" 108 | CommandNames[GETQ] = "GETQ" 109 | CommandNames[NOOP] = "NOOP" 110 | CommandNames[VERSION] = "VERSION" 111 | CommandNames[GETK] = "GETK" 112 | CommandNames[GETKQ] = "GETKQ" 113 | CommandNames[APPEND] = "APPEND" 114 | CommandNames[PREPEND] = "PREPEND" 115 | CommandNames[STAT] = "STAT" 116 | CommandNames[SETQ] = "SETQ" 117 | CommandNames[ADDQ] = "ADDQ" 118 | CommandNames[REPLACEQ] = "REPLACEQ" 119 | CommandNames[DELETEQ] = "DELETEQ" 120 | CommandNames[INCREMENTQ] = "INCREMENTQ" 121 | CommandNames[DECREMENTQ] = "DECREMENTQ" 122 | CommandNames[QUITQ] = "QUITQ" 123 | CommandNames[FLUSHQ] = "FLUSHQ" 124 | CommandNames[APPENDQ] = "APPENDQ" 125 | CommandNames[PREPENDQ] = "PREPENDQ" 126 | 127 | CommandNames[SASL_LIST_MECHS] = "SASL_LIST_MECHS" 128 | CommandNames[SASL_AUTH] = "SASL_AUTH" 129 | CommandNames[SASL_STEP] = "SASL_STEP" 130 | 131 | StatusNames = make(map[Status]string) 132 | StatusNames[SUCCESS] = "SUCCESS" 133 | StatusNames[KEY_ENOENT] = "KEY_ENOENT" 134 | StatusNames[KEY_EEXISTS] = "KEY_EEXISTS" 135 | StatusNames[E2BIG] = "E2BIG" 136 | StatusNames[EINVAL] = "EINVAL" 137 | StatusNames[NOT_STORED] = "NOT_STORED" 138 | StatusNames[DELTA_BADVAL] = "DELTA_BADVAL" 139 | StatusNames[AUTHFAIL] = "AUTHFAIL" 140 | StatusNames[FURTHER_AUTH] = "FURTHER_AUTH" 141 | StatusNames[UNKNOWN_COMMAND] = "UNKNOWN_COMMAND" 142 | StatusNames[ENOMEM] = "ENOMEM" 143 | StatusNames[TMPFAIL] = "TMPFAIL" 144 | } 145 | 146 | // String an op code. 147 | func (o OpCode) String() (rv string) { 148 | rv = CommandNames[o] 149 | if rv == "" { 150 | rv = fmt.Sprintf("0x%02x", int(o)) 151 | } 152 | return rv 153 | } 154 | 155 | // String an op code. 156 | func (s Status) String() (rv string) { 157 | rv = StatusNames[s] 158 | if rv == "" { 159 | rv = fmt.Sprintf("0x%02x", int(s)) 160 | } 161 | return rv 162 | } 163 | 164 | // IsQuiet return true if a command is a "quiet" command. 165 | func (o OpCode) IsQuiet() bool { 166 | switch o { 167 | case GETQ, 168 | GETKQ, 169 | SETQ, 170 | ADDQ, 171 | REPLACEQ, 172 | DELETEQ, 173 | INCREMENTQ, 174 | DECREMENTQ, 175 | QUITQ, 176 | FLUSHQ, 177 | APPENDQ, 178 | PREPENDQ: 179 | return true 180 | } 181 | return false 182 | } 183 | 184 | func (o OpCode) changeOnQuiet(def OpCode) OpCode { 185 | if o.IsQuiet() { 186 | return o 187 | } 188 | switch o { 189 | case GET: 190 | return GETQ 191 | case SET: 192 | return SETQ 193 | case ADD: 194 | return ADDQ 195 | case REPLACE: 196 | return REPLACEQ 197 | case DELETE: 198 | return DELETEQ 199 | case INCREMENT: 200 | return INCREMENTQ 201 | case DECREMENT: 202 | return DECREMENTQ 203 | case FLUSH: 204 | return FLUSHQ 205 | case APPEND: 206 | return APPENDQ 207 | case PREPEND: 208 | return PREPENDQ 209 | default: 210 | return def 211 | } 212 | } 213 | 214 | func prepareAuthData(user, pass string) []byte { 215 | return []byte(fmt.Sprintf("\x00%s\x00%s", user, pass)) 216 | } 217 | -------------------------------------------------------------------------------- /memcached/constants_test.go: -------------------------------------------------------------------------------- 1 | // nolint 2 | package memcached 3 | 4 | import ( 5 | "strings" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestCommandCodeString(t *testing.T) { 12 | if GET.String() != "GET" { 13 | t.Fatalf("Expected \"GET\" for GET, got \"%v\"", GET.String()) 14 | } 15 | 16 | cc := OpCode(0x80) 17 | if cc.String() != "0x80" { 18 | t.Fatalf("Expected \"0x80\" for 0x80, got \"%v\"", cc.String()) 19 | } 20 | } 21 | 22 | func TestStatusNameString(t *testing.T) { 23 | if SUCCESS.String() != "SUCCESS" { 24 | t.Fatalf("Expected \"SUCCESS\" for SUCCESS, got \"%v\"", 25 | SUCCESS.String()) 26 | } 27 | 28 | s := Status(0x80) 29 | if s.String() != "0x80" { 30 | t.Fatalf("Expected \"0x80\" for 0x80, got \"%v\"", s.String()) 31 | } 32 | } 33 | 34 | func TestIsQuiet(t *testing.T) { 35 | for v, k := range CommandNames { 36 | isq := strings.HasSuffix(k, "Q") 37 | if v.IsQuiet() != isq { 38 | t.Errorf("Expected quiet=%v for %v, got %v", 39 | isq, v, v.IsQuiet()) 40 | } 41 | } 42 | } 43 | 44 | func Test_prepareAuthData(t *testing.T) { 45 | type args struct { 46 | user string 47 | pass string 48 | } 49 | tests := []struct { 50 | name string 51 | args args 52 | want []byte 53 | }{ 54 | { 55 | name: "1", args: args{ 56 | user: "testuser", 57 | pass: "testpass", 58 | }, 59 | want: []byte("\x00testuser\x00testpass"), 60 | }, 61 | { 62 | name: "2", args: args{ 63 | user: "anotheruser", 64 | pass: "anotherpass", 65 | }, 66 | want: []byte("\x00anotheruser\x00anotherpass"), 67 | }, 68 | } 69 | for _, tt := range tests { 70 | t.Run(tt.name, func(t *testing.T) { 71 | assert.Equalf(t, tt.want, prepareAuthData(tt.args.user, tt.args.pass), "prepareAuthData(%v, %v)", tt.args.user, tt.args.pass) 72 | }) 73 | } 74 | } 75 | 76 | func TestOpCode_changeOnQuiet(t *testing.T) { 77 | type args struct { 78 | def OpCode 79 | } 80 | tests := []struct { 81 | name string 82 | o OpCode 83 | args args 84 | want OpCode 85 | }{ 86 | { 87 | name: GET.String(), 88 | o: GET, 89 | args: args{def: GETQ}, 90 | want: GETQ, 91 | }, 92 | { 93 | name: GETQ.String(), 94 | o: GETQ, 95 | args: args{def: GETQ}, 96 | want: GETQ, 97 | }, 98 | { 99 | name: "unknown opcode", 100 | o: OpCode(0x1b), 101 | args: args{def: GETQ}, 102 | want: GETQ, 103 | }, 104 | { 105 | name: SET.String(), 106 | o: SET, 107 | args: args{def: GETQ}, 108 | want: SETQ, 109 | }, 110 | { 111 | name: ADD.String(), 112 | o: ADD, 113 | args: args{def: GETQ}, 114 | want: ADDQ, 115 | }, 116 | { 117 | name: REPLACE.String(), 118 | o: REPLACE, 119 | args: args{def: GETQ}, 120 | want: REPLACEQ, 121 | }, 122 | { 123 | name: DELETE.String(), 124 | o: DELETE, 125 | args: args{def: GETQ}, 126 | want: DELETEQ, 127 | }, 128 | { 129 | name: INCREMENT.String(), 130 | o: INCREMENT, 131 | args: args{def: GETQ}, 132 | want: INCREMENTQ, 133 | }, 134 | { 135 | name: DECREMENT.String(), 136 | o: DECREMENT, 137 | args: args{def: GETQ}, 138 | want: DECREMENTQ, 139 | }, 140 | { 141 | name: FLUSH.String(), 142 | o: FLUSH, 143 | args: args{def: GETQ}, 144 | want: FLUSHQ, 145 | }, 146 | { 147 | name: APPEND.String(), 148 | o: APPEND, 149 | args: args{def: GETQ}, 150 | want: APPENDQ, 151 | }, 152 | { 153 | name: PREPEND.String(), 154 | o: PREPEND, 155 | args: args{def: GETQ}, 156 | want: PREPENDQ, 157 | }, 158 | } 159 | for _, tt := range tests { 160 | t.Run(tt.name, func(t *testing.T) { 161 | assert.Equalf(t, tt.want, tt.o.changeOnQuiet(tt.args.def), "changeOnQuiet(%v)", tt.args.def) 162 | }) 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /memcached/errors.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | const libPrefix = "gomemcached" 9 | 10 | var ( 11 | // ErrCacheMiss means that a Get failed because the item wasn't present. 12 | ErrCacheMiss = errors.New("gomemcached: cache miss") 13 | 14 | // ErrCASConflict means that a CompareAndSwap call failed due to the 15 | // cached value being modified between the Get and the CompareAndSwap. 16 | // If the cached value was simply evicted rather than replaced, 17 | // ErrNotStored will be returned instead. 18 | ErrCASConflict = errors.New("gomemcached: compare-and-swap conflict") 19 | 20 | // ErrNotStored means that a conditional write operation (i.e. Add or 21 | // CompareAndSwap) failed because the condition was not satisfied. 22 | ErrNotStored = errors.New("gomemcached: item not stored") 23 | 24 | // ErrServerError means that a server error occurred. 25 | ErrServerError = errors.New("gomemcached: server error") 26 | 27 | // ErrNoStats means that no statistics were available. 28 | ErrNoStats = errors.New("gomemcached: no statistics available") 29 | 30 | // ErrMalformedKey is returned when an invalid key is used. 31 | // Keys must be at maximum 250 bytes long and not 32 | // contain whitespace or control characters. 33 | ErrMalformedKey = errors.New("gomemcached: key is too long or contains invalid characters") 34 | 35 | // ErrNoServers is returned when no servers are configured or available. 36 | ErrNoServers = errors.New("gomemcached: no servers configured or available") 37 | 38 | // ErrInvalidAddr means that an incorrect address was passed and could not be cast to net.Addr 39 | ErrInvalidAddr = errors.New("gomemcached: invalid address for server") 40 | 41 | // ErrServerNotAvailable means that one of the nodes is currently unavailable 42 | ErrServerNotAvailable = errors.New("gomemcached: server(s) is not available") 43 | 44 | // ErrNotConfigured means that some required parameter is not set in the configuration 45 | ErrNotConfigured = errors.New("gomemcached: not complete configuration") 46 | 47 | // ErrUnknownCommand means that in request consumer use unknown command for memcached. 48 | ErrUnknownCommand = errors.New("gomemcached: Unknown command") 49 | 50 | // ErrDataSizeExceedsLimit means that memcached cannot process the request data due to its size. 51 | ErrDataSizeExceedsLimit = errors.New("gomemcached: Data size exceeds limit") 52 | 53 | // ErrInvalidArguments indicates invalid arguments or operation parameters (non-user request error). 54 | ErrInvalidArguments = errors.New("gomemcached: Invalid arguments or operation parameters") 55 | 56 | // ErrAuthFail indicates that an authorization attempt was made, but it did not work 57 | ErrAuthFail = errors.New("gomemcached: authentication enabled but operation failed") 58 | ) 59 | 60 | // resumableError returns true if err is only a protocol-level cache error. 61 | // This is used to determine whether a server connection should 62 | // be re-used or not. If an error occurs, by default we don't reuse the 63 | // connection, unless it was just a cache error. 64 | func resumableError(err error) bool { 65 | switch { 66 | case errors.Is(err, ErrCacheMiss), errors.Is(err, ErrCASConflict), 67 | errors.Is(err, ErrNotStored), errors.Is(err, ErrMalformedKey): 68 | return true 69 | } 70 | return false 71 | } 72 | 73 | func wrapMemcachedResp(resp *Response) error { 74 | switch resp.Status { 75 | case SUCCESS: 76 | return nil 77 | case KEY_ENOENT: 78 | return fmt.Errorf("%w. %w", ErrCacheMiss, resp) 79 | case NOT_STORED, KEY_EEXISTS: 80 | return fmt.Errorf("%w. %w", ErrNotStored, resp) 81 | case EINVAL, DELTA_BADVAL: 82 | return fmt.Errorf("%w. %w", ErrInvalidArguments, resp) 83 | case ENOMEM: 84 | return fmt.Errorf("%w. %w", ErrServerError, resp) 85 | case TMPFAIL: 86 | return fmt.Errorf("%w. %w", ErrServerNotAvailable, resp) 87 | case UNKNOWN_COMMAND: 88 | return fmt.Errorf("%w. %w", ErrUnknownCommand, resp) 89 | case E2BIG: 90 | return fmt.Errorf("%w. %w", ErrDataSizeExceedsLimit, resp) 91 | default: 92 | return fmt.Errorf("%w. %w", ErrServerError, resp) 93 | } 94 | } 95 | 96 | func errStatus(e error) Status { 97 | status := UNKNOWN_STATUS 98 | var res *Response 99 | if errors.As(e, &res) { 100 | status = res.Status 101 | } 102 | return status 103 | } 104 | -------------------------------------------------------------------------------- /memcached/memcached.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/binary" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "math" 11 | "net" 12 | "sync" 13 | "sync/atomic" 14 | "time" 15 | 16 | "github.com/kelseyhightower/envconfig" 17 | "golang.org/x/exp/maps" 18 | 19 | "github.com/aliexpressru/gomemcached/consistenthash" 20 | "github.com/aliexpressru/gomemcached/logger" 21 | "github.com/aliexpressru/gomemcached/pool" 22 | "github.com/aliexpressru/gomemcached/utils" 23 | ) 24 | 25 | const ( 26 | // DefaultTimeout is the default socket read/write timeout. 27 | DefaultTimeout = 500 * time.Millisecond 28 | 29 | // DefaultMaxIdleConns is the default maximum number of idle connections 30 | // kept for any single address. 31 | DefaultMaxIdleConns = 100 32 | 33 | // DefaultNodeHealthCheckPeriod is the default time period for start check available nods 34 | DefaultNodeHealthCheckPeriod = 15 * time.Second 35 | // DefaultRebuildingNodePeriod is the default time period for rebuilds the nodes in hash ring using freshly discovered 36 | DefaultRebuildingNodePeriod = 15 * time.Second 37 | 38 | // DefaultRetryCountForConn is a default number of connection retries before return i/o timeout error 39 | DefaultRetryCountForConn = uint8(3) 40 | 41 | // DefaultOfNumberConnsToDestroyPerRBPeriod is number of connections in pool whose needed close in every rebuild node cycle 42 | DefaultOfNumberConnsToDestroyPerRBPeriod = 1 43 | 44 | // DefaultSocketPoolingTimeout Amount of time to acquire socket from pool 45 | DefaultSocketPoolingTimeout = 50 * time.Millisecond 46 | ) 47 | 48 | var _ Memcached = (*Client)(nil) 49 | 50 | type ( 51 | Memcached interface { 52 | Store(storeMode StoreMode, key string, exp uint32, body []byte) (*Response, error) 53 | Get(key string) (*Response, error) 54 | Delete(key string) (*Response, error) 55 | Delta(deltaMode DeltaMode, key string, delta, initial uint64, exp uint32) (newValue uint64, err error) 56 | Append(appendMode AppendMode, key string, data []byte) (*Response, error) 57 | FlushAll(exp uint32) error 58 | MultiDelete(keys []string) error 59 | MultiStore(storeMode StoreMode, items map[string][]byte, exp uint32) error 60 | MultiGet(keys []string) (map[string][]byte, error) 61 | 62 | CloseAllConns() 63 | CloseAvailableConnsInAllShardPools(numOfClose int) int 64 | } 65 | 66 | // Client is a memcached client. 67 | // It is safe for unlocked use by multiple concurrent goroutines. 68 | Client struct { 69 | ctx context.Context 70 | nw *network 71 | cfg *config 72 | 73 | // opaque - a unique identifier for the request, used to associate the request with its corresponding response. 74 | opaque *uint32 75 | 76 | // timeout specifies the socket read/write timeout. 77 | // If zero, DefaultTimeout is used. 78 | timeout time.Duration 79 | 80 | // maxIdleConns specifies the maximum number of idle connections that will 81 | // be maintained per address. If less than one, DefaultMaxIdleConns will be 82 | // used. 83 | // 84 | // Consider your expected traffic rates and latency carefully. This should 85 | // be set to a number higher than your peak parallel requests. 86 | maxIdleConns int 87 | 88 | // hr - hash ring implementation (can be a custom consistenthash.NewCustomHashRing) 89 | hr consistenthash.ConsistentHash 90 | 91 | // disableMemcachedDiagnostic - is flag for turn off write metrics from lib. 92 | disableMemcachedDiagnostic bool 93 | // disableNodeProvider - is flag for turn off rebuild and health check nodes. 94 | disableNodeProvider bool 95 | // disableRefreshConns - is flag for turn off to refresh conns in the pool. 96 | disableRefreshConns bool 97 | // nodeHCPeriod - period for execute nodes health checker 98 | // if zero, DefaultNodeHealthCheckPeriod is used. 99 | nodeHCPeriod time.Duration 100 | // nodeRBPeriod - period for execute rebuilding nodes 101 | // if zero, DefaultNodeHealthCheckPeriod is used. 102 | nodeRBPeriod time.Duration 103 | 104 | // fmu - mutex for freeConns 105 | fmu sync.RWMutex 106 | // freeConns hashmap with nodes and their open dial connections 107 | freeConns map[string]*pool.Pool 108 | // dmu - mutex for deadNodes 109 | dmu sync.RWMutex 110 | // deadNodes hashmap with nodes that did not respond to health check 111 | deadNodes map[string]struct{} 112 | 113 | authEnable bool 114 | // authData ready body for authentication request 115 | authData []byte 116 | } 117 | 118 | network struct { 119 | dial func(network string, address string) (net.Conn, error) 120 | dialTimeout func(network string, address string, timeout time.Duration) (net.Conn, error) 121 | lookupHost func(host string) (addrs []string, err error) 122 | } 123 | 124 | config struct { 125 | // HeadlessServiceAddress Headless service to lookup all the memcached ip addresses. 126 | HeadlessServiceAddress string `envconfig:"MEMCACHED_HEADLESS_SERVICE_ADDRESS"` 127 | // Servers List of servers with hosted memcached 128 | Servers []string `envconfig:"MEMCACHED_SERVERS"` 129 | // MemcachedPort The optional port override for cases when memcached IP addresses are obtained from headless service. 130 | MemcachedPort int `envconfig:"MEMCACHED_PORT" default:"11211"` 131 | } 132 | conn struct { 133 | rc io.ReadCloser 134 | addr net.Addr 135 | c *Client 136 | hdrBuf []byte 137 | healthy bool 138 | wrtBuf *bufio.Writer 139 | authed bool 140 | } 141 | ) 142 | 143 | // InitFromEnv returns a memcached client using the config.HeadlessServiceAddress or config.Servers 144 | // with equal weight. If a server is listed multiple times, 145 | // it gets a proportional amount of weight. 146 | func InitFromEnv(opts ...Option) (*Client, error) { 147 | var ( 148 | op = new(options) 149 | cfg = new(config) 150 | ) 151 | if err := envconfig.Process("", cfg); err != nil { 152 | return nil, fmt.Errorf("%s: client init err: %s", libPrefix, err.Error()) 153 | } 154 | 155 | op.cfg = cfg 156 | 157 | for _, opt := range opts { 158 | opt(op) 159 | } 160 | 161 | if op.Client.nw == nil { 162 | op.Client.nw = &network{ 163 | dial: net.Dial, 164 | dialTimeout: net.DialTimeout, 165 | lookupHost: net.LookupHost, 166 | } 167 | } 168 | if op.Client.hr == nil { 169 | op.Client.hr = consistenthash.NewHashRing() 170 | } 171 | if op.Client.ctx == nil { 172 | op.Client.ctx = context.Background() 173 | } 174 | if op.Client.opaque == nil { 175 | op.Client.opaque = new(uint32) 176 | } 177 | if op.disableLogger { 178 | logger.DisableLogger() 179 | } 180 | 181 | return newFromConfig(op) 182 | } 183 | 184 | func newForTests(servers ...string) (*Client, error) { 185 | hr := consistenthash.NewHashRing() 186 | for _, s := range servers { 187 | addr, err := utils.AddrRepr(s) 188 | if err != nil { 189 | return nil, fmt.Errorf("%w: %s", ErrInvalidAddr, err.Error()) 190 | } 191 | hr.Add(addr) 192 | } 193 | cm := &Client{ 194 | ctx: context.Background(), 195 | opaque: new(uint32), 196 | hr: hr, 197 | disableMemcachedDiagnostic: true, 198 | nw: &network{ 199 | dial: net.Dial, 200 | dialTimeout: net.DialTimeout, 201 | lookupHost: net.LookupHost, 202 | }, 203 | } 204 | 205 | return cm, nil 206 | } 207 | 208 | func newFromConfig(op *options) (*Client, error) { 209 | if op.cfg != nil && !(op.cfg.HeadlessServiceAddress != "" || len(op.cfg.Servers) != 0) { 210 | return nil, fmt.Errorf("%w, you must fill in either MEMCACHED_HEADLESS_SERVICE_ADDRESS or MEMCACHED_SERVERS", ErrNotConfigured) 211 | } 212 | nodes, err := getNodes(op.nw.lookupHost, op.cfg) 213 | if err != nil { 214 | return nil, fmt.Errorf("%w, %s", ErrInvalidAddr, err.Error()) 215 | } 216 | 217 | mc := &op.Client 218 | 219 | for _, n := range nodes { 220 | addr, err := utils.AddrRepr(n) 221 | if err != nil { 222 | return nil, fmt.Errorf("%w: %s", ErrInvalidAddr, err.Error()) 223 | } 224 | mc.hr.Add(addr) 225 | } 226 | 227 | if !mc.disableNodeProvider { 228 | mc.initNodesProvider() 229 | } 230 | return mc, nil 231 | } 232 | 233 | // release returns this connection back to the client's free pool 234 | func (cn *conn) release() { 235 | cn.c.putFreeConn(cn) 236 | } 237 | 238 | func (cn *conn) close() { 239 | if p, ok := cn.c.safeGetFreeConn(cn.addr); ok { 240 | p.Close(cn) 241 | } else { 242 | _ = cn.rc.Close() 243 | } 244 | } 245 | 246 | // condRelease releases this connection if the error pointed to by err 247 | // is nil (not an error) or is only a protocol level error (e.g. a 248 | // cache miss). The purpose is to not recycle TCP connections that 249 | // are bad. 250 | func (cn *conn) condRelease(err *error) { 251 | if (*err == nil || resumableError(*err)) && cn.healthy { 252 | cn.release() 253 | } else { 254 | cn.close() 255 | } 256 | } 257 | 258 | func (c *Client) getOpaque() uint32 { 259 | atomic.CompareAndSwapUint32(c.opaque, math.MaxUint32, uint32(0)) 260 | return atomic.AddUint32(c.opaque, uint32(1)) 261 | } 262 | 263 | func (c *Client) safeGetFreeConn(addr net.Addr) (*pool.Pool, bool) { 264 | c.fmu.RLock() 265 | defer c.fmu.RUnlock() 266 | connPool, ok := c.freeConns[addr.String()] 267 | return connPool, ok 268 | } 269 | 270 | func (c *Client) safeGetOrInitFreeConn(addr net.Addr) *pool.Pool { 271 | c.fmu.Lock() 272 | defer c.fmu.Unlock() 273 | 274 | connPool, ok := c.freeConns[addr.String()] 275 | if ok { 276 | return connPool 277 | } 278 | 279 | dialConn := func() (any, error) { 280 | nc, err := c.dial(addr) 281 | if err != nil { 282 | return nil, err 283 | } 284 | return &conn{ 285 | rc: nc, 286 | addr: addr, 287 | c: c, 288 | hdrBuf: make([]byte, HDR_LEN), 289 | wrtBuf: bufio.NewWriter(nc), 290 | healthy: true, 291 | }, nil 292 | } 293 | 294 | closeConn := func(cn any) { 295 | _ = cn.(*conn).rc.Close() 296 | } 297 | 298 | newPool := pool.New(c.ctx, int32(c.getMaxIdleConns()), DefaultSocketPoolingTimeout, dialConn, closeConn) 299 | 300 | if c.freeConns == nil { 301 | c.freeConns = make(map[string]*pool.Pool) 302 | } 303 | c.freeConns[addr.String()] = newPool 304 | 305 | return newPool 306 | } 307 | 308 | func (c *Client) freeConnsIsNil() bool { 309 | c.fmu.RLock() 310 | defer c.fmu.RUnlock() 311 | return c.freeConns == nil 312 | } 313 | 314 | func (c *Client) putFreeConn(cn *conn) { 315 | connPool, ok := c.safeGetFreeConn(cn.addr) 316 | if ok { 317 | connPool.Put(cn) 318 | } else { 319 | _ = cn.rc.Close() 320 | } 321 | } 322 | 323 | func (c *Client) getFreeConn(addr net.Addr) (*conn, error) { 324 | connPool := c.safeGetOrInitFreeConn(addr) 325 | 326 | connRaw, err := connPool.Get() 327 | if err != nil { 328 | return nil, fmt.Errorf("%s: Get from pool error - %w", libPrefix, err) 329 | } 330 | 331 | cn := connRaw.(*conn) 332 | 333 | if c.authEnable && !cn.authed { 334 | if c.authenticate(cn) { 335 | cn.authed = true 336 | return cn, nil 337 | } else { 338 | return nil, ErrAuthFail 339 | } 340 | } 341 | 342 | return connRaw.(*conn), nil 343 | } 344 | 345 | func (c *Client) removeFromFreeConns(addr net.Addr) { 346 | if c.freeConnsIsNil() { 347 | return 348 | } 349 | connPool, ok := c.safeGetFreeConn(addr) 350 | 351 | c.fmu.Lock() 352 | defer c.fmu.Unlock() 353 | if ok { 354 | connPool.Destroy() 355 | } 356 | delete(c.freeConns, addr.String()) 357 | } 358 | 359 | func (c *Client) netTimeout() time.Duration { 360 | if c.timeout != 0 { 361 | return c.timeout 362 | } 363 | return DefaultTimeout 364 | } 365 | 366 | func (c *Client) getMaxIdleConns() int { 367 | if c.maxIdleConns > 0 { 368 | return c.maxIdleConns 369 | } 370 | return DefaultMaxIdleConns 371 | } 372 | 373 | func (c *Client) getHCPeriod() time.Duration { 374 | if c.nodeHCPeriod > 0 { 375 | return c.nodeHCPeriod 376 | } 377 | return DefaultNodeHealthCheckPeriod 378 | } 379 | 380 | func (c *Client) getRBPeriod() time.Duration { 381 | if c.nodeRBPeriod > 0 { 382 | return c.nodeRBPeriod 383 | } 384 | return DefaultRebuildingNodePeriod 385 | } 386 | 387 | // ConnectTimeoutError is the error type used when it takes 388 | // too long to connect to the desired host. This level of 389 | // detail can generally be ignored. 390 | type ConnectTimeoutError struct { 391 | Addr net.Addr 392 | } 393 | 394 | func (cte *ConnectTimeoutError) Error() string { 395 | return "connect timeout to " + cte.Addr.String() 396 | } 397 | 398 | func (c *Client) dial(addr net.Addr) (net.Conn, error) { 399 | if c.netTimeout() > 0 { 400 | nc, err := c.nw.dialTimeout(addr.Network(), addr.String(), c.netTimeout()) 401 | if err != nil { 402 | var ne net.Error 403 | if errors.As(err, &ne) && ne.Timeout() { 404 | return nil, &ConnectTimeoutError{addr} 405 | } 406 | return nil, err 407 | } 408 | return nc, nil 409 | } 410 | return c.nw.dial(addr.Network(), addr.String()) 411 | } 412 | 413 | func (c *Client) getConnForNode(node any) (*conn, error) { 414 | addr, ok := node.(net.Addr) 415 | if !ok { 416 | return nil, ErrInvalidAddr 417 | } 418 | cn, err := c.getFreeConn(addr) 419 | if err != nil { 420 | return nil, err 421 | } 422 | 423 | return cn, nil 424 | } 425 | 426 | // Store is a wrote the provided item with expiration. 427 | func (c *Client) Store(storeMode StoreMode, key string, exp uint32, body []byte) (_ *Response, err error) { 428 | timer := time.Now() 429 | defer c.writeMethodDiagnostics("Store", timer, &err) 430 | 431 | if !legalKey(key) { 432 | return nil, ErrMalformedKey 433 | } 434 | 435 | node, find := c.hr.Get(key) 436 | if !find { 437 | return nil, ErrNoServers 438 | } 439 | 440 | cn, err := c.getConnForNode(node) 441 | if err != nil { 442 | return nil, err 443 | } 444 | return c.store(cn, storeMode.Resolve(), key, exp, c.getOpaque(), body) 445 | } 446 | 447 | func (c *Client) store(cn *conn, opcode OpCode, key string, exp, opaque uint32, body []byte) (*Response, error) { 448 | req := &Request{ 449 | Opcode: opcode, 450 | Key: []byte(key), 451 | Opaque: opaque, 452 | Body: body, 453 | } 454 | req.prepareExtras(exp, 0, 0) 455 | return c.send(cn, req) 456 | } 457 | 458 | func (c *Client) send(cn *conn, req *Request) (resp *Response, err error) { 459 | defer cn.condRelease(&err) 460 | _, err = transmitRequest(cn.wrtBuf, req) 461 | if err != nil { 462 | cn.healthy = false 463 | return 464 | } 465 | 466 | if err = cn.wrtBuf.Flush(); err != nil { 467 | return nil, err 468 | } 469 | 470 | resp, _, err = getResponse(cn.rc, cn.hdrBuf) 471 | cn.healthy = !isFatal(err) 472 | return resp, err 473 | } 474 | 475 | // Get is return an item for provided key. 476 | func (c *Client) Get(key string) (_ *Response, err error) { 477 | timer := time.Now() 478 | defer c.writeMethodDiagnostics("Get", timer, &err) 479 | 480 | if !legalKey(key) { 481 | return nil, ErrMalformedKey 482 | } 483 | 484 | node, find := c.hr.Get(key) 485 | if !find { 486 | return nil, ErrNoServers 487 | } 488 | 489 | cn, err := c.getConnForNode(node) 490 | if err != nil { 491 | return nil, err 492 | } 493 | 494 | req := &Request{ 495 | Opcode: GET, 496 | Opaque: c.getOpaque(), 497 | Key: []byte(key), 498 | } 499 | req.prepareExtras(0, 0, 0) 500 | 501 | return c.send(cn, req) 502 | } 503 | 504 | // Delete is a deletes the element with the provided key. 505 | // If the element does not exist, an ErrCacheMiss error is returned. 506 | func (c *Client) Delete(key string) (_ *Response, err error) { 507 | timer := time.Now() 508 | defer c.writeMethodDiagnostics("Delete", timer, &err) 509 | 510 | if !legalKey(key) { 511 | return nil, ErrMalformedKey 512 | } 513 | 514 | node, find := c.hr.Get(key) 515 | if !find { 516 | return nil, ErrNoServers 517 | } 518 | 519 | cn, err := c.getConnForNode(node) 520 | if err != nil { 521 | return nil, err 522 | } 523 | 524 | req := &Request{ 525 | Opcode: DELETE, 526 | Opaque: c.getOpaque(), 527 | Key: []byte(key), 528 | } 529 | req.prepareExtras(0, 0, 0) 530 | 531 | return c.send(cn, req) 532 | } 533 | 534 | // Delta is an atomically increments/decrements value by delta. The return value is 535 | // the new value after being incremented/decrements or an error. 536 | func (c *Client) Delta(deltaMode DeltaMode, key string, delta, initial uint64, exp uint32) (newValue uint64, err error) { 537 | timer := time.Now() 538 | defer c.writeMethodDiagnostics("Delta", timer, &err) 539 | 540 | if !legalKey(key) { 541 | return 0, ErrMalformedKey 542 | } 543 | 544 | node, find := c.hr.Get(key) 545 | if !find { 546 | return 0, ErrNoServers 547 | } 548 | 549 | cn, err := c.getConnForNode(node) 550 | if err != nil { 551 | return 0, err 552 | } 553 | 554 | req := &Request{ 555 | Opcode: deltaMode.Resolve(), 556 | Key: []byte(key), 557 | } 558 | req.prepareExtras(exp, delta, initial) 559 | 560 | resp, err := c.send(cn, req) 561 | if err != nil { 562 | return 0, err 563 | } 564 | 565 | return binary.BigEndian.Uint64(resp.Body), nil 566 | } 567 | 568 | // Append is an appends/prepends the given item to the existing item, if a value already 569 | // exists for its key. ErrNotStored is returned if that condition is not met. 570 | func (c *Client) Append(appendMode AppendMode, key string, data []byte) (_ *Response, err error) { 571 | timer := time.Now() 572 | defer c.writeMethodDiagnostics("Append", timer, &err) 573 | 574 | if !legalKey(key) { 575 | return nil, ErrMalformedKey 576 | } 577 | 578 | node, find := c.hr.Get(key) 579 | if !find { 580 | return nil, ErrNoServers 581 | } 582 | 583 | cn, err := c.getConnForNode(node) 584 | if err != nil { 585 | return nil, err 586 | } 587 | 588 | req := &Request{ 589 | Opcode: appendMode.Resolve(), 590 | Opaque: c.getOpaque(), 591 | Key: []byte(key), 592 | Body: data, 593 | } 594 | req.prepareExtras(0, 0, 0) 595 | 596 | return c.send(cn, req) 597 | } 598 | 599 | // FlushAll is a deletes all items in the cache. 600 | func (c *Client) FlushAll(exp uint32) (err error) { 601 | timerMethod := time.Now() 602 | defer c.writeMethodDiagnostics("FlushAll", timerMethod, &err) 603 | 604 | var ( 605 | wg sync.WaitGroup 606 | mu sync.Mutex 607 | multiErr error 608 | 609 | nodes = c.hr.GetAllNodes() 610 | ) 611 | 612 | addToMultiErr := func(e error) { 613 | mu.Lock() 614 | defer mu.Unlock() 615 | multiErr = errors.Join(multiErr, e) 616 | } 617 | 618 | for _, node := range nodes { 619 | wg.Add(1) 620 | go func(node any) { 621 | defer wg.Done() 622 | 623 | var cn *conn 624 | cn, err = c.getConnForNode(node) 625 | if err != nil { 626 | addToMultiErr(err) 627 | return 628 | } 629 | defer cn.condRelease(&err) 630 | 631 | req := &Request{ 632 | Opcode: FLUSH, 633 | } 634 | req.prepareExtras(exp, 0, 0) 635 | 636 | _, err = transmitRequest(cn.wrtBuf, req) 637 | if err != nil { 638 | cn.healthy = false 639 | return 640 | } 641 | 642 | if err = cn.wrtBuf.Flush(); err != nil { 643 | logger.Errorf("%s. %s", ErrServerError.Error(), err.Error()) 644 | return 645 | } 646 | 647 | _, _, err = getResponse(cn.rc, cn.hdrBuf) 648 | if err != nil { 649 | if isFatal(err) { 650 | cn.healthy = false 651 | return 652 | } 653 | addToMultiErr(err) 654 | } 655 | }(node) 656 | } 657 | 658 | wg.Wait() 659 | 660 | return multiErr 661 | } 662 | 663 | // MultiGet is a batch version of Get. The returned map from keys to 664 | // items may have fewer elements than the input slice, due to memcached 665 | // cache misses. Each key must be at most 250 bytes in length. 666 | // If no error is returned, the returned map will also be non-nil. 667 | func (c *Client) MultiGet(keys []string) (_ map[string][]byte, err error) { 668 | var ( 669 | wg sync.WaitGroup 670 | mu sync.Mutex 671 | 672 | ret = make(map[string][]byte, len(keys)) 673 | ) 674 | if len(keys) == 0 { 675 | return ret, nil 676 | } 677 | 678 | timerMethod := time.Now() 679 | defer c.writeMethodDiagnostics("MultiGet", timerMethod, &err) 680 | 681 | if len(keys) == 1 { 682 | var res *Response 683 | res, err = c.Get(keys[0]) 684 | if res != nil { 685 | if res.Status == SUCCESS { 686 | ret[keys[0]] = res.Body 687 | } else if res.Status == KEY_ENOENT { 688 | // MultiGet never returns a ENOENT 689 | err = nil 690 | } 691 | } 692 | return ret, err 693 | } 694 | 695 | var ( 696 | once sync.Once 697 | singleError error 698 | ) 699 | 700 | addToRet := func(key string, body []byte) { 701 | mu.Lock() 702 | defer mu.Unlock() 703 | ret[key] = body 704 | } 705 | 706 | nodes, err := getNodesForKeys(c.hr, keys) 707 | if err != nil { 708 | return ret, err 709 | } 710 | 711 | for node, ks := range nodes { 712 | wg.Add(1) 713 | go func(node any, keys []string) { 714 | defer wg.Done() 715 | 716 | var cnErr error 717 | 718 | cn, nErr := c.getConnForNode(node) 719 | if nErr != nil { 720 | once.Do(func() { 721 | singleError = nErr 722 | }) 723 | return 724 | } 725 | defer cn.condRelease(&cnErr) 726 | 727 | idToKey := make(map[uint32]string, len(keys)) 728 | 729 | for _, key := range keys { 730 | opaqueGet := c.getOpaque() 731 | req := &Request{ 732 | Opcode: GETQ, 733 | Opaque: opaqueGet, 734 | Key: []byte(key), 735 | } 736 | req.prepareExtras(0, 0, 0) 737 | 738 | _, cnErr = transmitRequest(cn.wrtBuf, req) 739 | if cnErr != nil { 740 | cn.healthy = false 741 | return 742 | } 743 | 744 | idToKey[opaqueGet] = key 745 | } 746 | 747 | opaqueNOOP := c.getOpaque() 748 | req := &Request{ 749 | Opcode: NOOP, 750 | Opaque: opaqueNOOP, 751 | } 752 | req.prepareExtras(0, 0, 0) 753 | 754 | _, cnErr = transmitRequest(cn.wrtBuf, req) 755 | if cnErr != nil { 756 | cn.healthy = false 757 | return 758 | } 759 | 760 | if cnErr = cn.wrtBuf.Flush(); err != nil { 761 | logger.Errorf("%s. %s", ErrServerError.Error(), cnErr.Error()) 762 | return 763 | } 764 | 765 | for { 766 | var resp *Response 767 | resp, _, cnErr = getResponse(cn.rc, cn.hdrBuf) 768 | if isFatal(cnErr) { 769 | cn.healthy = false 770 | return 771 | } 772 | 773 | if resp.Opcode == NOOP && resp.Opaque == opaqueNOOP { 774 | break 775 | } 776 | 777 | if key, ok := idToKey[resp.Opaque]; ok && cnErr == nil { 778 | addToRet(key, resp.Body) 779 | } 780 | } 781 | }(node, ks) 782 | } 783 | 784 | wg.Wait() 785 | 786 | return ret, singleError 787 | } 788 | 789 | // MultiStore is a batch version of Store. 790 | // Writes the provided items with expiration. 791 | func (c *Client) MultiStore(storeMode StoreMode, items map[string][]byte, exp uint32) (err error) { 792 | if len(items) == 0 { 793 | return nil 794 | } 795 | 796 | timerMethod := time.Now() 797 | defer c.writeMethodDiagnostics("MultiStore", timerMethod, &err) 798 | 799 | var ( 800 | wg sync.WaitGroup 801 | muMErr sync.Mutex 802 | multiErr error 803 | ) 804 | 805 | addToMultiErr := func(e error) { 806 | muMErr.Lock() 807 | defer muMErr.Unlock() 808 | multiErr = errors.Join(multiErr, e) 809 | } 810 | 811 | var muItems sync.RWMutex 812 | safeGetItems := func(key string) []byte { 813 | muItems.RLock() 814 | defer muItems.RUnlock() 815 | return items[key] 816 | } 817 | 818 | quietCode := storeMode.Resolve().changeOnQuiet(SETQ) 819 | 820 | keys := maps.Keys(items) 821 | nodes, err := getNodesForKeys(c.hr, keys) 822 | if err != nil { 823 | return err 824 | } 825 | 826 | for node, ks := range nodes { 827 | wg.Add(1) 828 | go func(node any, keys []string, exp uint32) { 829 | defer wg.Done() 830 | 831 | var cnErr error 832 | 833 | cn, nErr := c.getConnForNode(node) 834 | if nErr != nil { 835 | addToMultiErr(nErr) 836 | return 837 | } 838 | defer cn.condRelease(&cnErr) 839 | 840 | idToKey := make(map[uint32]string, len(keys)) 841 | 842 | for _, key := range keys { 843 | opaqueStore := c.getOpaque() 844 | req := &Request{ 845 | Opcode: quietCode, 846 | Opaque: opaqueStore, 847 | Key: []byte(key), 848 | Body: safeGetItems(key), 849 | } 850 | req.prepareExtras(exp, 0, 0) 851 | 852 | _, cnErr = transmitRequest(cn.wrtBuf, req) 853 | if cnErr != nil { 854 | cn.healthy = false 855 | return 856 | } 857 | 858 | idToKey[opaqueStore] = key 859 | } 860 | 861 | opaqueNOOP := c.getOpaque() 862 | req := &Request{ 863 | Opcode: NOOP, 864 | Opaque: opaqueNOOP, 865 | } 866 | req.prepareExtras(0, 0, 0) 867 | 868 | _, cnErr = transmitRequest(cn.wrtBuf, req) 869 | if cnErr != nil { 870 | cn.healthy = false 871 | return 872 | } 873 | 874 | if cnErr = cn.wrtBuf.Flush(); err != nil { 875 | logger.Errorf("%s. %s", ErrServerError.Error(), cnErr.Error()) 876 | return 877 | } 878 | 879 | for { 880 | var resp *Response 881 | resp, _, cnErr = getResponse(cn.rc, cn.hdrBuf) 882 | if isFatal(cnErr) { 883 | cn.healthy = false 884 | return 885 | } 886 | 887 | if resp.Opcode == NOOP && resp.Opaque == opaqueNOOP { 888 | break 889 | } 890 | 891 | if key, ok := idToKey[resp.Opaque]; ok { 892 | if resp.Status != SUCCESS { 893 | addToMultiErr(fmt.Errorf("%w. Error for key - %s", cnErr, key)) 894 | } 895 | } 896 | } 897 | }(node, ks, exp) 898 | } 899 | 900 | wg.Wait() 901 | 902 | return multiErr 903 | } 904 | 905 | // MultiDelete is a batch version of Delete. 906 | // Deletes the items with the provided keys. 907 | // If there is a key in the provided keys that is missing in the cache, 908 | // the ErrCacheMiss error is ignored. 909 | func (c *Client) MultiDelete(keys []string) (err error) { 910 | if len(keys) == 0 { 911 | return nil 912 | } 913 | 914 | timerMethod := time.Now() 915 | defer c.writeMethodDiagnostics("MultiDelete", timerMethod, &err) 916 | 917 | var ( 918 | wg sync.WaitGroup 919 | mu sync.Mutex 920 | multiErr error 921 | ) 922 | 923 | addToMultiErr := func(e error) { 924 | mu.Lock() 925 | defer mu.Unlock() 926 | multiErr = errors.Join(multiErr, e) 927 | } 928 | 929 | nodes, err := getNodesForKeys(c.hr, keys) 930 | if err != nil { 931 | return err 932 | } 933 | 934 | for node, ks := range nodes { 935 | wg.Add(1) 936 | go func(node any, keys []string) { 937 | defer wg.Done() 938 | 939 | var cnErr error 940 | 941 | cn, nErr := c.getConnForNode(node) 942 | if nErr != nil { 943 | addToMultiErr(nErr) 944 | return 945 | } 946 | defer cn.condRelease(&cnErr) 947 | 948 | idToKey := make(map[uint32]string, len(keys)) 949 | 950 | for _, key := range keys { 951 | opaqueDel := c.getOpaque() 952 | req := &Request{ 953 | Opcode: DELETEQ, 954 | Opaque: opaqueDel, 955 | Key: []byte(key), 956 | } 957 | req.prepareExtras(0, 0, 0) 958 | 959 | _, cnErr = transmitRequest(cn.wrtBuf, req) 960 | if cnErr != nil { 961 | cn.healthy = false 962 | return 963 | } 964 | 965 | idToKey[opaqueDel] = key 966 | } 967 | 968 | opaqueNOOP := c.getOpaque() 969 | req := &Request{ 970 | Opcode: NOOP, 971 | Opaque: opaqueNOOP, 972 | } 973 | req.prepareExtras(0, 0, 0) 974 | 975 | _, cnErr = transmitRequest(cn.wrtBuf, req) 976 | if cnErr != nil { 977 | cn.healthy = false 978 | return 979 | } 980 | 981 | if cnErr = cn.wrtBuf.Flush(); err != nil { 982 | logger.Errorf("%s. %s", ErrServerError.Error(), cnErr.Error()) 983 | return 984 | } 985 | 986 | for { 987 | var resp *Response 988 | resp, _, cnErr = getResponse(cn.rc, cn.hdrBuf) 989 | if isFatal(cnErr) { 990 | cn.healthy = false 991 | return 992 | } 993 | 994 | if resp.Opcode == NOOP && resp.Opaque == opaqueNOOP { 995 | break 996 | } 997 | 998 | if key, ok := idToKey[resp.Opaque]; ok { 999 | if resp.Status != SUCCESS && resp.Status != KEY_ENOENT { 1000 | addToMultiErr(fmt.Errorf("%w. Error for key - %s", cnErr, key)) 1001 | } 1002 | } 1003 | } 1004 | }(node, ks) 1005 | } 1006 | 1007 | wg.Wait() 1008 | 1009 | return multiErr 1010 | } 1011 | 1012 | // CloseAllConns is close all opened connection per shards. 1013 | // Once closed, resources should be released. 1014 | func (c *Client) CloseAllConns() { 1015 | c.fmu.Lock() 1016 | defer c.fmu.Unlock() 1017 | 1018 | for addr, connPool := range c.freeConns { 1019 | connPool.Destroy() 1020 | delete(c.freeConns, addr) 1021 | } 1022 | } 1023 | 1024 | // CloseAvailableConnsInAllShardPools - removes the specified number of connections from the pools of all shards. 1025 | func (c *Client) CloseAvailableConnsInAllShardPools(numOfClose int) int { 1026 | var closed int 1027 | 1028 | c.fmu.Lock() 1029 | defer c.fmu.Unlock() 1030 | 1031 | for _, p := range c.freeConns { 1032 | for i := 0; i < numOfClose; i++ { 1033 | if connRaw, ok := p.Pop(); ok { 1034 | p.Close(connRaw) 1035 | closed++ 1036 | } 1037 | } 1038 | } 1039 | 1040 | return closed 1041 | } 1042 | 1043 | func (c *Client) writeMethodDiagnostics(methodName string, timer time.Time, err *error) { 1044 | if methodName == "" || c.disableMemcachedDiagnostic { 1045 | return 1046 | } 1047 | 1048 | observeMethodDurationSeconds(methodName, time.Since(timer).Seconds(), *err == nil) 1049 | } 1050 | 1051 | func (c *Client) authenticate(cn *conn) (ok bool) { 1052 | req := &Request{ 1053 | Key: []byte(SaslMechanism), 1054 | Body: c.authData, 1055 | } 1056 | 1057 | req.Opcode = SASL_AUTH 1058 | _, err := transmitRequest(cn.wrtBuf, req) 1059 | if err != nil { 1060 | return 1061 | } 1062 | 1063 | if err = cn.wrtBuf.Flush(); err != nil { 1064 | return 1065 | } 1066 | 1067 | resp, _, err := getResponse(cn.rc, cn.hdrBuf) 1068 | if err == nil { 1069 | return true 1070 | } 1071 | if err != nil && resp.Status != FURTHER_AUTH { 1072 | logger.Errorf("%s: Error from sasl auth - %v", libPrefix, resp) 1073 | return 1074 | } 1075 | 1076 | req.Opcode = SASL_STEP 1077 | _, err = transmitRequest(cn.wrtBuf, req) 1078 | if err != nil { 1079 | return 1080 | } 1081 | 1082 | resp, _, err = getResponse(cn.rc, cn.hdrBuf) 1083 | if err != nil { 1084 | logger.Errorf("%s: Error from sasl step - %v", libPrefix, resp) 1085 | return 1086 | } 1087 | 1088 | if err = cn.wrtBuf.Flush(); err != nil { 1089 | return 1090 | } 1091 | 1092 | return true 1093 | } 1094 | 1095 | func legalKey(key string) bool { 1096 | if len(key) > 250 { 1097 | return false 1098 | } 1099 | for i := 0; i < len(key); i++ { 1100 | if key[i] <= ' ' || key[i] == 0x7f { 1101 | return false 1102 | } 1103 | } 1104 | return true 1105 | } 1106 | 1107 | // getNodesForKeys return a map where key is a node and value is a suitable keys 1108 | func getNodesForKeys(hr consistenthash.ConsistentHash, keys []string) (map[any][]string, error) { 1109 | resp := make(map[any][]string, hr.GetNodesCount()) 1110 | 1111 | for _, key := range keys { 1112 | if !legalKey(key) { 1113 | return nil, fmt.Errorf("%w. Invalid key - %v", ErrMalformedKey, key) 1114 | } 1115 | if node, found := hr.Get(key); found { 1116 | resp[node] = append(resp[node], key) 1117 | } 1118 | } 1119 | 1120 | return resp, nil 1121 | } 1122 | -------------------------------------------------------------------------------- /memcached/memcached_test.go: -------------------------------------------------------------------------------- 1 | // nolint 2 | package memcached 3 | 4 | import ( 5 | "bufio" 6 | "bytes" 7 | "errors" 8 | "fmt" 9 | "io/ioutil" 10 | "net" 11 | "reflect" 12 | "strconv" 13 | "sync" 14 | "testing" 15 | "time" 16 | 17 | "github.com/stretchr/testify/assert" 18 | "github.com/stretchr/testify/require" 19 | "golang.org/x/exp/maps" 20 | 21 | "github.com/aliexpressru/gomemcached/consistenthash" 22 | "github.com/aliexpressru/gomemcached/utils" 23 | ) 24 | 25 | func TestTransmitReq(t *testing.T) { 26 | b := bytes.NewBuffer([]byte{}) 27 | buf := bufio.NewWriter(b) 28 | 29 | req := Request{ 30 | Opcode: SET, 31 | Cas: 938424885, 32 | Opaque: 7242, 33 | Extras: []byte{}, 34 | Key: []byte("somekey"), 35 | Body: []byte("somevalue"), 36 | } 37 | 38 | // Verify nil transmit is OK 39 | _, err := transmitRequest(nil, &req) 40 | if !errors.Is(err, ErrNoServers) { 41 | t.Errorf("Expected errNoConn with no conn, got %v", err) 42 | } 43 | 44 | _, err = transmitRequest(buf, &req) 45 | if err != nil { 46 | t.Fatalf("Error transmitting request: %v", err) 47 | } 48 | 49 | buf.Flush() 50 | 51 | expected := []byte{ 52 | REQ_MAGIC, byte(SET), 53 | 0x0, 0x7, // length of key 54 | 0x0, // extra length 55 | 0x0, // reserved 56 | 0x0, 0x0, // reserved 57 | 0x0, 0x0, 0x0, 0x10, // Length of value 58 | 0x0, 0x0, 0x1c, 0x4a, // opaque 59 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS 60 | 's', 'o', 'm', 'e', 'k', 'e', 'y', 61 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e', 62 | } 63 | 64 | if len(b.Bytes()) != req.Size() { 65 | t.Fatalf("Expected %v bytes, got %v", req.Size(), 66 | len(b.Bytes())) 67 | } 68 | 69 | if !reflect.DeepEqual(b.Bytes(), expected) { 70 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v", 71 | expected, b.Bytes()) 72 | } 73 | } 74 | 75 | func BenchmarkTransmitReq(b *testing.B) { 76 | bout := bytes.NewBuffer([]byte{}) 77 | 78 | req := Request{ 79 | Opcode: SET, 80 | Cas: 938424885, 81 | Opaque: 7242, 82 | Extras: []byte{}, 83 | Key: []byte("somekey"), 84 | Body: []byte("somevalue"), 85 | } 86 | 87 | b.SetBytes(int64(req.Size())) 88 | 89 | for i := 0; i < b.N; i++ { 90 | bout.Reset() 91 | buf := bufio.NewWriterSize(bout, req.Size()*2) 92 | _, err := transmitRequest(buf, &req) 93 | if err != nil { 94 | b.Fatalf("Error transmitting request: %v", err) 95 | } 96 | } 97 | } 98 | 99 | func BenchmarkTransmitReqLarge(b *testing.B) { 100 | bout := bytes.NewBuffer([]byte{}) 101 | 102 | req := Request{ 103 | Opcode: SET, 104 | Cas: 938424885, 105 | Opaque: 7242, 106 | Extras: []byte{}, 107 | Key: []byte("somekey"), 108 | Body: make([]byte, 24*1024), 109 | } 110 | 111 | b.SetBytes(int64(req.Size())) 112 | 113 | for i := 0; i < b.N; i++ { 114 | bout.Reset() 115 | buf := bufio.NewWriterSize(bout, req.Size()*2) 116 | _, err := transmitRequest(buf, &req) 117 | if err != nil { 118 | b.Fatalf("Error transmitting request: %v", err) 119 | } 120 | } 121 | } 122 | 123 | func BenchmarkTransmitReqNull(b *testing.B) { 124 | req := Request{ 125 | Opcode: SET, 126 | Cas: 938424885, 127 | Opaque: 7242, 128 | Extras: []byte{}, 129 | Key: []byte("somekey"), 130 | Body: []byte("somevalue"), 131 | } 132 | 133 | b.SetBytes(int64(req.Size())) 134 | 135 | for i := 0; i < b.N; i++ { 136 | _, err := transmitRequest(ioutil.Discard, &req) 137 | if err != nil { 138 | b.Fatalf("Error transmitting request: %v", err) 139 | } 140 | } 141 | } 142 | 143 | /* 144 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7| 145 | +---------------+---------------+---------------+---------------+ 146 | 0| 0x81 | 0x00 | 0x00 | 0x00 | 147 | +---------------+---------------+---------------+---------------+ 148 | 4| 0x04 | 0x00 | 0x00 | 0x00 | 149 | +---------------+---------------+---------------+---------------+ 150 | 8| 0x00 | 0x00 | 0x00 | 0x09 | 151 | +---------------+---------------+---------------+---------------+ 152 | 12| 0x00 | 0x00 | 0x00 | 0x00 | 153 | +---------------+---------------+---------------+---------------+ 154 | 16| 0x00 | 0x00 | 0x00 | 0x00 | 155 | +---------------+---------------+---------------+---------------+ 156 | 20| 0x00 | 0x00 | 0x00 | 0x01 | 157 | +---------------+---------------+---------------+---------------+ 158 | 24| 0xde | 0xad | 0xbe | 0xef | 159 | +---------------+---------------+---------------+---------------+ 160 | 28| 0x57 ('W') | 0x6f ('o') | 0x72 ('r') | 0x6c ('l') | 161 | +---------------+---------------+---------------+---------------+ 162 | 32| 0x64 ('d') | 163 | +---------------+ 164 | 165 | Field (offset) (value) 166 | Magic (0) : 0x81 167 | Opcode (1) : 0x00 168 | Key length (2,3) : 0x0000 169 | Extra length (4) : 0x04 170 | Data type (5) : 0x00 171 | Status (6,7) : 0x0000 172 | Total body (8-11) : 0x00000009 173 | Opaque (12-15): 0x00000000 174 | CAS (16-23): 0x0000000000000001 175 | Extras : 176 | Flags (24-27): 0xdeadbeef 177 | Key : None 178 | Value (28-32): The textual string "World" 179 | 180 | */ 181 | 182 | func TestDecodeSpecSample(t *testing.T) { 183 | data := []byte{ 184 | 0x81, 0x00, 0x00, 0x00, // 0 185 | 0x04, 0x00, 0x00, 0x00, // 4 186 | 0x00, 0x00, 0x00, 0x09, // 8 187 | 0x00, 0x00, 0x00, 0x00, // 12 188 | 0x00, 0x00, 0x00, 0x00, // 16 189 | 0x00, 0x00, 0x00, 0x01, // 20 190 | 0xde, 0xad, 0xbe, 0xef, // 24 191 | 0x57, 0x6f, 0x72, 0x6c, // 28 192 | 0x64, // 32 193 | } 194 | 195 | buf := make([]byte, HDR_LEN) 196 | res, _, err := getResponse(bytes.NewReader(data), buf) 197 | if err != nil { 198 | t.Fatalf("Error parsing response: %v", err) 199 | } 200 | 201 | expected := &Response{ 202 | Opcode: GET, 203 | Status: 0, 204 | Opaque: 0, 205 | Cas: 1, 206 | Extras: []byte{0xde, 0xad, 0xbe, 0xef}, 207 | Body: []byte("World"), 208 | } 209 | 210 | if !reflect.DeepEqual(res, expected) { 211 | t.Fatalf("Expected\n%#v -- got --\n%#v", expected, res) 212 | } 213 | assert.Nil(t, UnwrapMemcachedError(err), "UnwrapMemcachedError: should be return nil for success getResponse") 214 | } 215 | 216 | func TestNilReader(t *testing.T) { 217 | res, _, err := getResponse(nil, nil) 218 | if !errors.Is(err, ErrNoServers) { 219 | t.Fatalf("Expected error reading from nil, got %#v", res) 220 | } 221 | } 222 | 223 | func TestNilConfig(t *testing.T) { 224 | mcl, err := InitFromEnv() 225 | assert.Nil(t, mcl, "InitFromEnv without config should be return nil client") 226 | assert.ErrorIs(t, err, ErrNotConfigured, "InitFromEnv without config should be return error == ErrNotConfigured") 227 | } 228 | 229 | func TestErrWrap(t *testing.T) { 230 | type args struct { 231 | resp *Response 232 | } 233 | tests := []struct { 234 | name string 235 | args args 236 | wantErr error 237 | }{ 238 | { 239 | name: ENOMEM.String(), 240 | args: args{resp: &Response{ 241 | Status: ENOMEM, 242 | }}, 243 | wantErr: ErrServerError, 244 | }, 245 | { 246 | name: TMPFAIL.String(), 247 | args: args{resp: &Response{ 248 | Status: TMPFAIL, 249 | }}, 250 | wantErr: ErrServerNotAvailable, 251 | }, 252 | { 253 | name: UNKNOWN_COMMAND.String(), 254 | args: args{resp: &Response{ 255 | Status: UNKNOWN_COMMAND, 256 | }}, 257 | wantErr: ErrUnknownCommand, 258 | }, 259 | } 260 | for _, tt := range tests { 261 | t.Run(tt.name, func(t *testing.T) { 262 | wrapErr := wrapMemcachedResp(tt.args.resp) 263 | require.ErrorIs(t, wrapErr, tt.wantErr, "wrapMemcachedResp wrap error not equal expected") 264 | }) 265 | } 266 | } 267 | 268 | func TestDecode(t *testing.T) { 269 | data := []byte{ 270 | RES_MAGIC, byte(SET), 271 | 0x0, 0x7, // length of key 272 | 0x0, // extra length 273 | 0x0, // reserved 274 | 0x6, 0x2e, // status 275 | 0x0, 0x0, 0x0, 0x10, // Length of value 276 | 0x0, 0x0, 0x1c, 0x4a, // opaque 277 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS 278 | 's', 'o', 'm', 'e', 'k', 'e', 'y', 279 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e', 280 | } 281 | 282 | buf := make([]byte, HDR_LEN) 283 | res, _, _ := getResponse(bytes.NewReader(data), buf) 284 | 285 | expected := &Response{ 286 | Opcode: SET, 287 | Status: 1582, 288 | Opaque: 7242, 289 | Cas: 938424885, 290 | Extras: nil, 291 | Key: []byte("somekey"), 292 | Body: []byte("somevalue"), 293 | } 294 | 295 | if !reflect.DeepEqual(res, expected) { 296 | t.Fatalf("Expected\n%#v -- got --\n%#v", expected, res) 297 | } 298 | } 299 | 300 | func BenchmarkDecodeResponse(b *testing.B) { 301 | data := []byte{ 302 | RES_MAGIC, byte(SET), 303 | 0x0, 0x7, // length of key 304 | 0x0, // extra length 305 | 0x0, // reserved 306 | 0x6, 0x2e, // status 307 | 0x0, 0x0, 0x0, 0x10, // Length of value 308 | 0x0, 0x0, 0x1c, 0x4a, // opaque 309 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS 310 | 's', 'o', 'm', 'e', 'k', 'e', 'y', 311 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e', 312 | } 313 | buf := make([]byte, HDR_LEN) 314 | b.SetBytes(int64(len(buf))) 315 | 316 | for i := 0; i < b.N; i++ { 317 | getResponse(bytes.NewReader(data), buf) 318 | } 319 | } 320 | 321 | const localhostTCPAddr = "localhost:11211" 322 | 323 | func TestLocalhost(t *testing.T) { 324 | t.Parallel() 325 | c, err := net.Dial("tcp", localhostTCPAddr) 326 | if err != nil { 327 | t.Skipf("skipping test; no server running at %s", localhostTCPAddr) 328 | } 329 | req := Request{ 330 | Opcode: VERSION, 331 | } 332 | 333 | _, err = transmitRequest(c, &req) 334 | if err != nil { 335 | t.Errorf("Expected errNoConn with no conn, got %v", err) 336 | } 337 | 338 | buf := make([]byte, HDR_LEN) 339 | resp, _, err := getResponse(c, buf) 340 | if err != nil { 341 | t.Fatalf("Error transmitting request: %v", err) 342 | } 343 | 344 | if resp.Status != SUCCESS { 345 | t.Errorf("Expected SUCCESS, got %v", resp.Status) 346 | } 347 | if err = c.Close(); err != nil { 348 | t.Fatalf("Error with close connection: %v", err) 349 | } 350 | 351 | _, err = newForTests("invalidServerAddr") 352 | require.ErrorIs(t, err, ErrInvalidAddr) 353 | 354 | mc, err := newForTests(localhostTCPAddr) 355 | if err != nil { 356 | t.Fatalf("failed to create new client: %v", err) 357 | } 358 | t.Cleanup(mc.CloseAllConns) 359 | testWithClient(t, mc) 360 | } 361 | 362 | func testWithClient(t *testing.T, c *Client) { 363 | resp, err := c.Store(Set, "bigdata", 0, make([]byte, MaxBodyLen+1)) 364 | assert.ErrorIsf(t, err, ErrDataSizeExceedsLimit, "Store: body > MaxBodyLen, want error ErrDataSizeExceedsLimit") 365 | unwrapResp := UnwrapMemcachedError(err) 366 | if !reflect.DeepEqual(resp, unwrapResp) { 367 | t.Fatalf("Expected\n%#v -- got --\n%#v", resp, unwrapResp) 368 | } 369 | 370 | // multi 371 | err = c.MultiStore(Set, map[string][]byte{}, 0) 372 | assert.Nil(t, err, "MultiStore with 0 items should have no errors") 373 | items, err := c.MultiGet([]string{}) 374 | assert.Nil(t, err, "MultiGet with 0 keys should have no errors") 375 | assert.Empty(t, items, "MultiGet with 0 keys should return empty map") 376 | err = c.MultiDelete([]string{}) 377 | assert.Nil(t, err, "MultiDelete with 0 keys should have no errors") 378 | 379 | // Set 380 | _, err = c.Store(Set, "foo", 0, []byte("fooval-fromset1")) 381 | assert.Nilf(t, err, "first set(foo): %v", err) 382 | _, err = c.Store(Set, "foo", 0, []byte("fooval-fromset2")) 383 | assert.Nilf(t, err, "second set(foo): %v", err) 384 | // Add 385 | _, err = c.Store(Add, "foo", 0, []byte("fooval-fromset3")) 386 | assert.ErrorIsf(t, err, ErrNotStored, "Add with exist key - %s, want error - ErrNotStored, have - %v", "foo", err) 387 | 388 | // Get 389 | resp, err = c.Get("foo") 390 | assert.Nilf(t, err, "get(foo): %v", err) 391 | // assert.Equalf(t, []byte("foo"), resp.Key, "get(foo) Key = %s, want foo", string(resp.Key)) only for GETK 392 | assert.Equalf(t, []byte("fooval-fromset2"), resp.Body, "get(foo) Body = %s, want fooval-fromset2", string(resp.Body)) 393 | err = wrapMemcachedResp(resp) 394 | assert.Nil(t, err, "Get: wrapped success resp should be nil") 395 | 396 | // Get and set a unicode key 397 | quxKey := "Hello_世界" 398 | _, err = c.Store(Set, quxKey, 0, []byte("hello world")) 399 | assert.Nilf(t, err, "first set(Hello_世界): %v", err) 400 | resp, err = c.Get(quxKey) 401 | assert.Nilf(t, err, "get(Hello_世界): %v", err) 402 | // assert.Equalf(t, quxKey, string(resp.Key), "get(Hello_世界) Key = %q, want Hello_世界", quxKey) only for GETK 403 | assert.Equalf(t, "hello world", string(resp.Body), "get(Hello_世界) Value = %q, want hello world", string(resp.Body)) 404 | 405 | // Set malformed keys 406 | _, err = c.Store(Set, "foo bar", 0, []byte("foobarval")) 407 | assert.ErrorIsf(t, err, ErrMalformedKey, "set(foo bar) should return ErrMalformedKey instead of %v", err) 408 | _, err = c.Store(Set, "foo"+string(rune(0x7f)), 0, []byte("foobarval")) 409 | assert.ErrorIsf(t, err, ErrMalformedKey, "set(foo<0x7f>) should return ErrMalformedKey instead of %v", err) 410 | 411 | // Append 412 | _, err = c.Append(Append, "append", []byte("appendval")) 413 | assert.ErrorIsf(t, err, ErrNotStored, "first append(append) want ErrNotStored, got %v", err) 414 | 415 | _, err = c.Store(Set, "append", 0, []byte("appendval")) 416 | assert.Nilf(t, err, "Set for append have error - %v", err) 417 | _, err = c.Append(Append, "append", []byte("1")) 418 | assert.Nilf(t, err, "second append(append): %v", err) 419 | appended, err := c.Get("append") 420 | assert.Nilf(t, err, "after append(append): %v", err) 421 | assert.Equalf(t, fmt.Sprintf("%s%s", "appendval", "1"), string(appended.Body), 422 | "Append: want=append1, got=%s", string(appended.Body)) 423 | 424 | // Prepend 425 | _, err = c.Append(Prepend, "prepend", []byte("prependval")) 426 | assert.ErrorIsf(t, err, ErrNotStored, "first prepend(prepend) want ErrNotStored, got %v", err) 427 | 428 | _, err = c.Store(Set, "prepend", 0, []byte("prependval")) 429 | assert.Nilf(t, err, "Set for prepend have error - %v", err) 430 | _, err = c.Append(Prepend, "prepend", []byte("1")) 431 | assert.Nilf(t, err, "second prepend(prepend): %v", err) 432 | prepend, err := c.Get("prepend") 433 | assert.Nilf(t, err, "after prepend(prepend): %v", err) 434 | assert.Equalf(t, fmt.Sprintf("%s%s", "1", "prependval"), string(prepend.Body), 435 | "Prepend: want=1prependval, got=%s", string(prepend.Body)) 436 | 437 | // Replace 438 | _, err = c.Store(Replace, "baz", 0, []byte("bazvalue")) 439 | assert.ErrorIsf(t, err, ErrCacheMiss, "expected replace(baz) to return ErrCacheMiss, got %v", err) 440 | _, err = c.Store(Set, "baz", 0, []byte("bazvalue")) 441 | assert.Nilf(t, err, "Set for Replace have error - %v", err) 442 | resp, err = c.Store(Replace, "baz", 0, []byte("42")) 443 | assert.Nilf(t, err, "Replace have error - %v", err) 444 | resp, err = c.Get("baz") 445 | assert.Nilf(t, err, "Get for Replace have error - %v", err) 446 | assert.Equalf(t, "42", string(resp.Body), "Resp after replaces want - 42, have - %s", string(resp.Body)) 447 | 448 | // Incr/Decr 449 | _, err = c.Store(Set, "num", 0, []byte("42")) 450 | assert.Nilf(t, err, "Set for Increment have error - %v", err) 451 | n, err := c.Delta(Increment, "num", 8, 0, 0) 452 | assert.Nilf(t, err, "Increment num + 8: %v", err) 453 | assert.Equalf(t, 50, int(n), "Increment num + 8: want=50, got=%d", n) 454 | n, err = c.Delta(Decrement, "num", 49, 0, 0) 455 | assert.Nilf(t, err, "Decrement: %v", err) 456 | assert.Equalf(t, 1, int(n), "Decrement 49: want=1, got=%d", n) 457 | _, err = c.Delete("num") 458 | assert.Nilf(t, err, "Delete for Increment/Decrement have error - %v", err) 459 | n, err = c.Delta(Increment, "num", 1, 10, 0) 460 | assert.Nilf(t, err, "Increment with initial value have error - %v", err) 461 | assert.Equalf(t, 10, int(n), "Increment with initial value 10: want=10, got=%d", n) 462 | n, err = c.Delta(Decrement, "num", 2, 0, 0) 463 | assert.Nilf(t, err, "Increment with initial value have error - %v", err) 464 | assert.Equalf(t, 8, int(n), "Increment with initial value 1: want=8, got=%d", n) 465 | const fakeDeltaMode = DeltaMode(42) 466 | n, err = c.Delta(fakeDeltaMode, "num", 2, 0, 0) 467 | assert.Nilf(t, err, "Increment with fakeDeltaMode have error - %v", err) 468 | 469 | _, err = c.Store(Set, "num", 0, []byte("not-numeric")) 470 | assert.Nilf(t, err, "Set for Increment non-numeric value have error - %v", err) 471 | _, err = c.Delta(Increment, "num", 1, 0, 0) 472 | assert.ErrorIs(t, err, ErrInvalidArguments, "Increment not-numeric value") 473 | 474 | // Delete 475 | _, err = c.Delete("foo") 476 | assert.Nilf(t, err, "Delete: %v", err) 477 | _, err = c.Get("foo") 478 | assert.ErrorIsf(t, err, ErrCacheMiss, "post-Delete want ErrCacheMiss, got %v", err) 479 | 480 | testExpireWithClient(t, c) 481 | 482 | // MutliGet 483 | // Create some test items. 484 | keys := []string{"foo", "bar", "gopher", "42"} 485 | input := make(map[string][]byte, len(keys)) 486 | 487 | addKeys := func() { 488 | for i, key := range keys { 489 | body := []byte(key + strconv.Itoa(i)) 490 | _, err = c.Store(Set, key, 0, body) 491 | assert.Nilf(t, err, "Store for MutliGet have error - %v", err) 492 | input[key] = body 493 | } 494 | } 495 | 496 | checkKeyOnExist := func(method string, input map[string][]byte, output map[string][]byte) { 497 | for key, reqBody := range input { 498 | if respBody, ok := output[key]; ok { 499 | assert.Equalf(t, reqBody, respBody, "%s. Request and response body not equal, have - %v, want - %v", method, respBody, reqBody) 500 | } else { 501 | t.Errorf("%s. Don't found requset key %v in response", method, key) 502 | } 503 | } 504 | } 505 | 506 | _, err = c.MultiGet(append(keys, invalidKey)) 507 | assert.ErrorIsf(t, err, ErrMalformedKey, "MultiGet: invalid key, want error ErrMalformedKey") 508 | 509 | addKeys() 510 | output, err := c.MultiGet(keys) 511 | assert.Nilf(t, err, "MultiGet have error: %v", err) 512 | if len(input) != len(output) { 513 | t.Errorf("want %d items after MultiGet, have %d", len(input), len(output)) 514 | } else { 515 | checkKeyOnExist("MultiGet", input, output) 516 | } 517 | 518 | // remove one key from cache 519 | _, err = c.Delete(keys[0]) 520 | assert.Nilf(t, err, "Delete for MultiGet have error: %v", err) 521 | output, err = c.MultiGet(keys) 522 | assert.Nilf(t, err, "MultiGet after delete one elem have error: %v", err) 523 | if len(input)-1 != len(output) { 524 | t.Errorf("want %d items after MultiStore, have %d", len(input)-1, len(output)) 525 | } 526 | 527 | // MutliStore 528 | inputMStore := map[string][]byte{ 529 | "foo42": []byte("bar"), 530 | "hello": []byte("world"), 531 | "go": []byte("gopher"), 532 | } 533 | inputMStoreExp := map[string][]byte{ 534 | "exp": []byte("needDelete"), 535 | } 536 | inputExp := uint32(1) 537 | 538 | err = c.MultiStore(Set, inputMStore, 0) 539 | assert.Nilf(t, err, "MultiStore have error: %v", err) 540 | err = c.MultiStore(Set, inputMStoreExp, inputExp) 541 | assert.Nilf(t, err, "MultiStore with exp have error: %v", err) 542 | 543 | time.Sleep(time.Second) 544 | keyWithExp := maps.Keys(inputMStoreExp)[0] 545 | _, err = c.Get(keyWithExp) 546 | assert.ErrorIsf(t, err, ErrCacheMiss, "Get for item with 1 sec experetion setted in MultiStore. want - %v, have - %v", ErrCacheMiss, err) 547 | 548 | keysInputMStore := maps.Keys(inputMStore) 549 | outputMStoreOne, err := c.Get(keysInputMStore[0]) 550 | assert.Nilf(t, err, "Get for MultiStore have error: %v", err) 551 | assert.NotNil(t, outputMStoreOne.Body, "Get after MultiStore gets item without body") 552 | outputMStore, err := c.MultiGet(keysInputMStore) 553 | assert.Nilf(t, err, "MultiGet for MultiStore have error: %v", err) 554 | checkKeyOnExist("MultiStore", inputMStore, outputMStore) 555 | 556 | singleMStore, err := c.MultiGet([]string{keysInputMStore[0]}) 557 | assert.Nilf(t, err, "MultiGet with 1 item have error: %v", err) 558 | for key, body := range singleMStore { 559 | assert.Equal(t, keysInputMStore[0], key, "MultiGet with 1 item not equals keys") 560 | assert.Equal(t, inputMStore[key], body, "MultiGet with 1 item not equals body") 561 | } 562 | 563 | // Test Flush All 564 | err = c.FlushAll(0) 565 | assert.Nilf(t, err, "FlushAll: %v", err) 566 | _, err = c.Get("bar") 567 | assert.ErrorIsf(t, err, ErrCacheMiss, "post-FlushAll want ErrCacheMiss, got %v", err) 568 | } 569 | 570 | func testExpireWithClient(t *testing.T, c *Client) { 571 | if testing.Short() { 572 | t.Log("Skipping testing memcached Touch with testing in Short mode") 573 | return 574 | } 575 | 576 | const secondsToExpiry = uint32(1) 577 | 578 | _, err := c.Store(Set, "foo", secondsToExpiry, []byte("fooval")) 579 | assert.Nilf(t, err, "Store(Set) with expire have error - %v", err) 580 | _, err = c.Store(Add, "bar", secondsToExpiry, []byte("barval")) 581 | assert.Nilf(t, err, "Store(Add) with expire have error - %v", err) 582 | 583 | time.Sleep(time.Second) 584 | 585 | _, err = c.Get("foo") 586 | assert.ErrorIsf(t, err, ErrCacheMiss, "Get for expire item - %v", err) 587 | 588 | _, err = c.Get("bar") 589 | assert.ErrorIsf(t, err, ErrCacheMiss, "Get for expire item - %v", err) 590 | } 591 | 592 | func TestLocalhost_FlushAll_MultiDelete(t *testing.T) { 593 | c, err := net.Dial("tcp", localhostTCPAddr) 594 | if err != nil { 595 | t.Skipf("skipping test; no server running at %s", localhostTCPAddr) 596 | } 597 | req := Request{ 598 | Opcode: VERSION, 599 | } 600 | 601 | _, err = transmitRequest(c, &req) 602 | if err != nil { 603 | t.Errorf("Expected errNoConn with no conn, got %v", err) 604 | } 605 | 606 | buf := make([]byte, HDR_LEN) 607 | resp, _, err := getResponse(c, buf) 608 | if err != nil { 609 | t.Fatalf("Error transmitting request: %v", err) 610 | } 611 | 612 | if resp.Status != SUCCESS { 613 | t.Errorf("Expected SUCCESS, got %v", resp.Status) 614 | } 615 | if err = c.Close(); err != nil { 616 | t.Fatalf("Error with close connection: %v", err) 617 | } 618 | 619 | mc, err := newForTests(localhostTCPAddr) 620 | if err != nil { 621 | t.Fatalf("failed to create new client: %v", err) 622 | } 623 | t.Cleanup(mc.CloseAllConns) 624 | 625 | keys := []string{"foo", "bar", "gopher", "42"} 626 | 627 | addKeys := func() { 628 | for i, key := range keys { 629 | _, err = mc.Store(Set, key, 0, []byte(key+strconv.Itoa(i))) 630 | assert.Nil(t, err, fmt.Sprintf("Fail to Store item with key - %s", key)) 631 | } 632 | } 633 | 634 | checkKeyOnExist := func(meth string) { 635 | for _, key := range keys { 636 | _, err = mc.Get(key) 637 | assert.ErrorIsf(t, err, ErrCacheMiss, "Get item after %s. want - %v, have - %v", meth, ErrCacheMiss, err) 638 | } 639 | } 640 | 641 | addKeys() 642 | err = mc.MultiDelete(append(keys, "fake")) 643 | assert.Nil(t, err, "MultiDelete") 644 | checkKeyOnExist("MultiDelete") 645 | 646 | addKeys() 647 | err = mc.FlushAll(0) 648 | assert.Nil(t, err, "FlushAll") 649 | checkKeyOnExist("FlushAll") 650 | } 651 | 652 | func TestClient_CloseAvailableConnsInAllShardPools(t *testing.T) { 653 | _, err := net.Dial("tcp", localhostTCPAddr) 654 | if err != nil { 655 | t.Skipf("skipping test; no server running at %s", localhostTCPAddr) 656 | } 657 | mc, err := newForTests(localhostTCPAddr) 658 | assert.Nilf(t, err, "failed to create new client: %v", err) 659 | t.Cleanup(mc.CloseAllConns) 660 | 661 | // for create conns in pool 662 | var wg sync.WaitGroup 663 | wg.Add(1) 664 | go func() { 665 | defer wg.Done() 666 | _, err := mc.Store(Set, "foo1", 0, []byte("bar")) 667 | assert.Nilf(t, err, "Set foo1: %v", err) 668 | }() 669 | wg.Add(1) 670 | go func() { 671 | defer wg.Done() 672 | _, err := mc.Store(Set, "foo2", 0, []byte("bar")) 673 | assert.Nilf(t, err, "Set foo2: %v", err) 674 | }() 675 | wg.Add(1) 676 | go func() { 677 | defer wg.Done() 678 | _, err := mc.Store(Set, "foo3", 0, []byte("bar")) 679 | assert.Nilf(t, err, "Set foo3: %v", err) 680 | }() 681 | wg.Add(1) 682 | go func() { 683 | defer wg.Done() 684 | _, err := mc.Store(Set, "foo4", 0, []byte("bar")) 685 | assert.Nilf(t, err, "Set foo4: %v", err) 686 | }() 687 | 688 | wg.Wait() 689 | 690 | addr, err := utils.AddrRepr(localhostTCPAddr) 691 | assert.Nilf(t, err, "AddrRepr: %v", err) 692 | 693 | pool, ok := mc.safeGetFreeConn(addr) 694 | assert.Truef(t, ok, "Get from freeConns not found pool for %s", addr.String()) 695 | 696 | l := pool.Len() 697 | 698 | numOfClose := 1 699 | c := mc.CloseAvailableConnsInAllShardPools(numOfClose) 700 | assert.Equal(t, numOfClose, c, "Request for closed not equal actual") 701 | 702 | assert.Equalf(t, l-numOfClose, pool.Len(), "Resulting pool len not equal expected number") 703 | } 704 | 705 | func TestConn(t *testing.T) { 706 | c, err := net.DialTimeout("tcp", localhostTCPAddr, time.Second) 707 | if err != nil { 708 | t.Skipf("skipping test; no server running at %s", localhostTCPAddr) 709 | } 710 | req := Request{ 711 | Opcode: VERSION, 712 | } 713 | 714 | n, err := transmitRequest(c, &req) 715 | if err != nil { 716 | t.Errorf("Expected errNoConn with no conn, got %v", err) 717 | } 718 | 719 | buf := make([]byte, HDR_LEN) 720 | resp, _, err := getResponse(c, buf) 721 | if err != nil { 722 | t.Fatalf("Error transmitting request: %v", err) 723 | } 724 | 725 | if n != len(buf) { 726 | t.Errorf("write bytes - %d != read bytes - %d\n", n, len(buf)) 727 | } 728 | if resp.Status != SUCCESS { 729 | t.Errorf("Expected SUCCESS, got %v", resp.Status) 730 | } 731 | if err = c.Close(); err != nil { 732 | t.Fatalf("Error with close connection: %v", err) 733 | } 734 | } 735 | 736 | func TestClient_Getters(t *testing.T) { 737 | type fields struct { 738 | timeout time.Duration 739 | maxIdleConns int 740 | nodeHCPeriod time.Duration 741 | nodeRBPeriod time.Duration 742 | } 743 | tests := []struct { 744 | name string 745 | fields fields 746 | wantTimeout time.Duration 747 | wantMaxIdleConns int 748 | wantNodeHCPeriod time.Duration 749 | wantNodeRBPeriod time.Duration 750 | }{ 751 | { 752 | name: "Default", 753 | fields: fields{}, 754 | wantTimeout: DefaultTimeout, 755 | wantMaxIdleConns: DefaultMaxIdleConns, 756 | wantNodeHCPeriod: DefaultNodeHealthCheckPeriod, 757 | wantNodeRBPeriod: DefaultRebuildingNodePeriod, 758 | }, 759 | { 760 | name: "Custom", 761 | fields: fields{ 762 | timeout: 5 * time.Second, 763 | maxIdleConns: 50, 764 | nodeHCPeriod: time.Second, 765 | nodeRBPeriod: time.Second, 766 | }, 767 | wantTimeout: 5 * time.Second, 768 | wantMaxIdleConns: 50, 769 | wantNodeHCPeriod: time.Second, 770 | wantNodeRBPeriod: time.Second, 771 | }, 772 | } 773 | for _, tt := range tests { 774 | t.Run(tt.name, func(t *testing.T) { 775 | c := &Client{ 776 | timeout: tt.fields.timeout, 777 | maxIdleConns: tt.fields.maxIdleConns, 778 | nodeHCPeriod: tt.fields.nodeHCPeriod, 779 | nodeRBPeriod: tt.fields.nodeRBPeriod, 780 | } 781 | assert.Equalf(t, tt.wantTimeout, c.netTimeout(), "netTimeout()") 782 | assert.Equalf(t, tt.wantMaxIdleConns, c.getMaxIdleConns(), "getMaxIdleConns()") 783 | assert.Equalf(t, tt.wantNodeHCPeriod, c.getHCPeriod(), "getHCPeriod()") 784 | assert.Equalf(t, tt.wantNodeRBPeriod, c.getRBPeriod(), "getRBPeriod()") 785 | }) 786 | } 787 | } 788 | 789 | func TestMethodsErrors(t *testing.T) { 790 | c := &Client{ 791 | hr: consistenthash.NewHashRing(), 792 | disableMemcachedDiagnostic: true, 793 | } 794 | 795 | // invalid key 796 | _, err := c.Store(Set, invalidKey, 0, []byte("foo")) 797 | assert.ErrorIsf(t, err, ErrMalformedKey, "Store: invalid key, want error ErrMalformedKey") 798 | _, err = c.Get(invalidKey) 799 | assert.ErrorIsf(t, err, ErrMalformedKey, "Get: invalid key, want error ErrMalformedKey") 800 | _, err = c.Delete(invalidKey) 801 | assert.ErrorIsf(t, err, ErrMalformedKey, "Delete: invalid key, want error ErrMalformedKey") 802 | _, err = c.Delta(Increment, invalidKey, 1, 0, 0) 803 | assert.ErrorIsf(t, err, ErrMalformedKey, "Delta: invalid key, want error ErrMalformedKey") 804 | _, err = c.Append(Append, invalidKey, []byte("foo")) 805 | assert.ErrorIsf(t, err, ErrMalformedKey, "Append: invalid key, want error ErrMalformedKey") 806 | _, err = c.MultiGet([]string{invalidKey, "foo", "bar"}) 807 | assert.ErrorIsf(t, err, ErrMalformedKey, "MultiGet: invalid key, want error ErrMalformedKey") 808 | err = c.MultiDelete([]string{invalidKey, "foo", "bar"}) 809 | assert.ErrorIsf(t, err, ErrMalformedKey, "MultiDelete: invalid key, want error ErrMalformedKey") 810 | err = c.MultiStore(Set, map[string][]byte{"foo": []byte("bar"), invalidKey: []byte("data")}, 0) 811 | assert.ErrorIsf(t, err, ErrMalformedKey, "MultiDelete: invalid key, want error ErrMalformedKey") 812 | 813 | // empty hash ring 814 | _, err = c.Store(Set, "store", 0, []byte("foo")) 815 | assert.ErrorIsf(t, err, ErrNoServers, "Store: with empty hash ring, want error ErrNoServers") 816 | _, err = c.Get("get") 817 | assert.ErrorIsf(t, err, ErrNoServers, "Get: with empty hash ring, want error ErrNoServers") 818 | _, err = c.Delete("delete") 819 | assert.ErrorIsf(t, err, ErrNoServers, "Delete: with empty hash ring, want error ErrNoServers") 820 | _, err = c.Delta(Increment, "deltaInc", 1, 0, 0) 821 | assert.ErrorIsf(t, err, ErrNoServers, "Delta: with empty hash ring, want error ErrNoServers") 822 | _, err = c.Append(Append, "append", []byte("foo")) 823 | assert.ErrorIsf(t, err, ErrNoServers, "Append: with empty hash ring, want error ErrNoServers") 824 | 825 | // add invalid node 826 | c.hr.Add("node1") 827 | 828 | // invalid node 829 | _, err = c.Store(Set, "store", 0, []byte("foo")) 830 | assert.ErrorIsf(t, err, ErrInvalidAddr, "Store: invalid node, want error ErrInvalidAddr") 831 | _, err = c.Get("get") 832 | assert.ErrorIsf(t, err, ErrInvalidAddr, "Get: invalid node, want error ErrInvalidAddr") 833 | _, err = c.Delete("delete") 834 | assert.ErrorIsf(t, err, ErrInvalidAddr, "Delete: invalid node, want error ErrInvalidAddr") 835 | _, err = c.Delta(Increment, "deltaInc", 1, 0, 0) 836 | assert.ErrorIsf(t, err, ErrInvalidAddr, "Delta: invalid node, want error ErrInvalidAddr") 837 | _, err = c.Append(Append, "append", []byte("foo")) 838 | assert.ErrorIsf(t, err, ErrInvalidAddr, "Append: invalid node, want error ErrInvalidAddr") 839 | _, err = c.MultiGet([]string{"gopher", "foo", "bar"}) 840 | assert.ErrorIsf(t, err, ErrInvalidAddr, "MutliGet: invalid node, want error ErrInvalidAddr") 841 | err = c.MultiDelete([]string{"gopher", "foo", "bar"}) 842 | assert.ErrorIsf(t, err, ErrInvalidAddr, "MutliDelete: invalid node, want error ErrInvalidAddr") 843 | err = c.MultiStore(Set, map[string][]byte{"foo": []byte("bar"), "data": []byte("data")}, 0) 844 | assert.ErrorIsf(t, err, ErrInvalidAddr, "MutliStore: invalid node, want error ErrInvalidAddr") 845 | 846 | var ( 847 | mockNetworkHeadlessErr = new(MockNetworkOperations) 848 | 849 | expectedErr = errors.New("mocked dial error") 850 | 851 | headlessServiceAddress = "example.com" 852 | ) 853 | mockNetworkHeadlessErr.On("LookupHost", headlessServiceAddress).Return(nil, expectedErr) 854 | 855 | op := &options{ 856 | Client: Client{ 857 | nw: &network{lookupHost: mockNetworkHeadlessErr.LookupHost}, 858 | cfg: &config{HeadlessServiceAddress: headlessServiceAddress}, 859 | }, 860 | } 861 | 862 | _, err = newFromConfig(op) 863 | assert.ErrorIs(t, err, ErrInvalidAddr) 864 | 865 | mockNetworkNodeErr := new(MockNetworkOperations) 866 | mockNetworkNodeErr.On("LookupHost", headlessServiceAddress).Return([]string{"wrongNode"}, nil) 867 | 868 | op = &options{ 869 | Client: Client{ 870 | nw: &network{lookupHost: mockNetworkNodeErr.LookupHost}, 871 | cfg: &config{HeadlessServiceAddress: headlessServiceAddress}, 872 | }, 873 | } 874 | 875 | _, err = newFromConfig(op) 876 | assert.ErrorIs(t, err, ErrInvalidAddr) 877 | } 878 | 879 | const invalidKey = `Loremipsumdolorsitamet,consecteturadipiscingelit.Velelitvoluptateeleifendquisproidentnonfeugaitiriureliberminimveniamillumcupiditataliquid,nihiltefeugiatlobortiseleifendnibhproidenttationatoptionesseconsectetuerdeserunt.Gubergrenveroidsolutaquis.Dignissimlobortisloremveroenimrebumconsetetur.` 880 | -------------------------------------------------------------------------------- /memcached/metrics.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | ) 6 | 7 | const ( 8 | methodNameLabel = "method_name" 9 | isSuccessfulLabel = "is_successful" 10 | ) 11 | 12 | var ( 13 | methodDurationSeconds = func() *prometheus.HistogramVec { 14 | return prometheus.NewHistogramVec(prometheus.HistogramOpts{ 15 | Namespace: "", 16 | Name: "gomemcached_method_duration_seconds", 17 | Help: "counts the execution time of successful and failed gomemcached methods", 18 | Buckets: []float64{ 19 | 0.0005, 0.001, 0.005, 0.007, 0.015, 0.05, 0.1, 0.2, 0.5, 1, 20 | }, 21 | }, []string{ 22 | methodNameLabel, 23 | isSuccessfulLabel, 24 | }) 25 | }() 26 | ) 27 | 28 | // observeMultiMethodDurationSeconds is observing the duration of a method. 29 | func observeMethodDurationSeconds(methodName string, duration float64, isSuccessful bool) { 30 | flag := "0" 31 | if isSuccessful { 32 | flag = "1" 33 | } 34 | 35 | methodDurationSeconds. 36 | WithLabelValues(methodName, flag). 37 | Observe(duration) 38 | } 39 | -------------------------------------------------------------------------------- /memcached/metrics_test.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_observeMethodDurationSeconds(t *testing.T) { 11 | type args struct { 12 | methodName string 13 | duration float64 14 | isSuccessful bool 15 | } 16 | tests := []struct { 17 | name string 18 | args args 19 | }{ 20 | { 21 | name: "60 true", 22 | args: args{ 23 | methodName: "TestMeth", 24 | duration: 60 * time.Millisecond.Seconds(), 25 | isSuccessful: true, 26 | }, 27 | }, 28 | { 29 | name: "15 true", 30 | args: args{ 31 | methodName: "TestMeth", 32 | duration: 15 * time.Millisecond.Seconds(), 33 | isSuccessful: true, 34 | }, 35 | }, 36 | { 37 | name: "39 true", 38 | args: args{ 39 | methodName: "TestMeth", 40 | duration: 39 * time.Millisecond.Seconds(), 41 | isSuccessful: true, 42 | }, 43 | }, 44 | { 45 | name: "100 false", 46 | args: args{ 47 | methodName: "TestMeth", 48 | duration: 100 * time.Millisecond.Seconds(), 49 | isSuccessful: false, 50 | }, 51 | }, 52 | { 53 | name: "66 true", 54 | args: args{ 55 | methodName: "TestMeth", 56 | duration: 66 * time.Millisecond.Seconds(), 57 | isSuccessful: true, 58 | }, 59 | }, 60 | { 61 | name: "11 false", 62 | args: args{ 63 | methodName: "TestMeth", 64 | duration: 11 * time.Millisecond.Seconds(), 65 | isSuccessful: false, 66 | }, 67 | }, 68 | } 69 | for _, tt := range tests { 70 | t.Run(tt.name, func(t *testing.T) { 71 | observeMethodDurationSeconds(tt.args.methodName, tt.args.duration, tt.args.isSuccessful) 72 | 73 | var success = "0" 74 | if tt.args.isSuccessful { 75 | success = "1" 76 | } 77 | 78 | _, err := methodDurationSeconds.GetMetricWith(map[string]string{methodNameLabel: tt.args.methodName, isSuccessfulLabel: success}) 79 | assert.Nil(t, err, "GetMetricWith: returned error is not nil - %v", err) 80 | }) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /memcached/node_provider.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "slices" 7 | "strconv" 8 | "sync" 9 | "time" 10 | 11 | "golang.org/x/exp/maps" 12 | 13 | "github.com/aliexpressru/gomemcached/logger" 14 | "github.com/aliexpressru/gomemcached/utils" 15 | ) 16 | 17 | func (c *Client) initNodesProvider() { 18 | var ( 19 | periodHC = c.getHCPeriod() 20 | tHC = time.NewTimer(periodHC) 21 | 22 | periodRB = c.getRBPeriod() 23 | tRB = time.NewTimer(periodRB) 24 | ) 25 | 26 | if c.deadNodes == nil { 27 | c.deadNodes = make(map[string]struct{}) 28 | } 29 | 30 | go func() { 31 | for { 32 | select { 33 | case <-tHC.C: 34 | c.checkNodesHealth() 35 | tHC.Reset(periodHC) 36 | case <-c.ctx.Done(): 37 | tHC.Stop() 38 | return 39 | } 40 | } 41 | }() 42 | go func() { 43 | for { 44 | select { 45 | case <-tRB.C: 46 | c.rebuildNodes() 47 | tRB.Reset(periodRB) 48 | case <-c.ctx.Done(): 49 | tRB.Stop() 50 | return 51 | } 52 | } 53 | }() 54 | } 55 | 56 | func (c *Client) checkNodesHealth() { 57 | currentNodes, err := getNodes(c.nw.lookupHost, c.cfg) 58 | if err != nil { 59 | logger.Warnf("%s: Error occurred while checking nodes health, getNodes error - %s", libPrefix, err.Error()) 60 | return 61 | } 62 | 63 | recheckDeadNodes := func(node any) { 64 | sNode := utils.Repr(node) 65 | if !slices.Contains(currentNodes, sNode) { 66 | c.safeRemoveFromDeadNodes(sNode) 67 | return 68 | } 69 | 70 | if c.nodeIsDead(node) { 71 | c.safeAddToDeadNodes(sNode) 72 | } else { 73 | c.safeRemoveFromDeadNodes(sNode) 74 | } 75 | } 76 | 77 | wg := sync.WaitGroup{} 78 | for node := range c.safeGetDeadNodes() { 79 | wg.Add(1) 80 | go func(n string) { 81 | defer wg.Done() 82 | recheckDeadNodes(n) 83 | }(node) 84 | } 85 | wg.Wait() 86 | 87 | ringNodes := c.hr.GetAllNodes() 88 | for node := range c.safeGetDeadNodes() { 89 | ringNodes = slices.DeleteFunc(ringNodes, func(a any) bool { return utils.Repr(a) == node }) 90 | } 91 | 92 | for _, node := range ringNodes { 93 | wg.Add(1) 94 | go func(n any) { 95 | defer wg.Done() 96 | if c.nodeIsDead(n) { 97 | sNode := utils.Repr(n) 98 | c.safeAddToDeadNodes(sNode) 99 | } 100 | }(node) 101 | } 102 | 103 | wg.Wait() 104 | 105 | deadNodes := c.safeGetDeadNodes() 106 | if len(deadNodes) != 0 { 107 | nodes := maps.Keys(deadNodes) 108 | 109 | logger.Warnf("%s: Dead nodes - %s", libPrefix, nodes) 110 | 111 | for _, node := range nodes { 112 | addr, cErr := utils.AddrRepr(node) 113 | if cErr != nil { 114 | continue 115 | } 116 | c.hr.Remove(addr) 117 | c.removeFromFreeConns(addr) 118 | } 119 | } 120 | } 121 | 122 | func (c *Client) rebuildNodes() { 123 | currentNodes, err := getNodes(c.nw.lookupHost, c.cfg) 124 | if err != nil { 125 | logger.Warnf("%s: Error occurred while rebuild nodes health, getNodes error - %s", libPrefix, err.Error()) 126 | return 127 | } 128 | slices.Sort(currentNodes) 129 | 130 | for node := range c.safeGetDeadNodes() { 131 | currentNodes = slices.DeleteFunc(currentNodes, func(a string) bool { return a == node }) 132 | } 133 | 134 | var nodesInRing []string 135 | for _, node := range c.hr.GetAllNodes() { 136 | nodesInRing = append(nodesInRing, utils.Repr(node)) 137 | } 138 | slices.Sort(nodesInRing) 139 | 140 | var nodesToAdd []string 141 | for _, node := range currentNodes { 142 | if _, ok := slices.BinarySearch(nodesInRing, node); !ok { 143 | nodesToAdd = append(nodesToAdd, node) 144 | } 145 | } 146 | 147 | var nodesToRemove []string 148 | for _, node := range nodesInRing { 149 | if _, ok := slices.BinarySearch(currentNodes, node); !ok { 150 | nodesToRemove = append(nodesToRemove, node) 151 | } 152 | } 153 | 154 | if len(nodesToAdd) != 0 { 155 | for _, node := range nodesToAdd { 156 | addr, cErr := utils.AddrRepr(node) 157 | if cErr != nil { 158 | continue 159 | } 160 | c.hr.Add(addr) 161 | } 162 | } 163 | 164 | if len(nodesToRemove) != 0 { 165 | for _, node := range nodesToRemove { 166 | addr, cErr := utils.AddrRepr(node) 167 | if cErr != nil { 168 | continue 169 | } 170 | c.hr.Remove(addr) 171 | } 172 | } 173 | 174 | if !c.disableRefreshConns { 175 | _ = c.CloseAvailableConnsInAllShardPools(DefaultOfNumberConnsToDestroyPerRBPeriod) 176 | } 177 | } 178 | 179 | func (c *Client) nodeIsDead(node any) bool { 180 | addr, err := utils.AddrRepr(utils.Repr(node)) 181 | if err != nil { 182 | return true 183 | } 184 | 185 | var ( 186 | countRetry uint8 187 | cn net.Conn 188 | ) 189 | 190 | for { 191 | cn, err = c.dial(addr) 192 | if err != nil { 193 | var tErr *ConnectTimeoutError 194 | if errors.As(err, &tErr) { 195 | if countRetry < DefaultRetryCountForConn { 196 | countRetry++ 197 | continue 198 | } 199 | logger.Errorf("%s. Node health check failed. error - %s, with timeout - %d", 200 | ErrServerError.Error(), err.Error(), c.netTimeout(), 201 | ) 202 | return true 203 | } else { 204 | logger.Errorf("%s. %s", ErrServerError.Error(), err.Error()) 205 | return true 206 | } 207 | } 208 | _ = cn.Close() 209 | break 210 | } 211 | 212 | return false 213 | } 214 | 215 | func (c *Client) safeGetDeadNodes() map[string]struct{} { 216 | c.dmu.RLock() 217 | defer c.dmu.RUnlock() 218 | return maps.Clone(c.deadNodes) 219 | } 220 | 221 | func (c *Client) safeAddToDeadNodes(node string) { 222 | c.dmu.Lock() 223 | defer c.dmu.Unlock() 224 | c.deadNodes[node] = struct{}{} 225 | } 226 | 227 | func (c *Client) safeRemoveFromDeadNodes(node string) { 228 | c.dmu.Lock() 229 | defer c.dmu.Unlock() 230 | delete(c.deadNodes, node) 231 | } 232 | 233 | func getNodes(lookup func(host string) (addrs []string, err error), cfg *config) ([]string, error) { 234 | if cfg != nil { 235 | if cfg.HeadlessServiceAddress != "" { 236 | nodes, err := lookup(cfg.HeadlessServiceAddress) 237 | if err != nil { 238 | return nil, err 239 | } 240 | 241 | nodesWithHost := make([]string, len(nodes)) 242 | for i := range nodes { 243 | nodesWithHost[i] = net.JoinHostPort(nodes[i], strconv.Itoa(cfg.MemcachedPort)) 244 | } 245 | 246 | return nodesWithHost, nil 247 | } else if len(cfg.Servers) != 0 { 248 | for _, s := range cfg.Servers { 249 | _, _, err := net.SplitHostPort(s) 250 | if err != nil { 251 | return nil, err 252 | } 253 | } 254 | return cfg.Servers, nil 255 | } 256 | } 257 | 258 | return []string{}, nil 259 | } 260 | -------------------------------------------------------------------------------- /memcached/node_provider_test.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net" 7 | "slices" 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/mock" 14 | "github.com/stretchr/testify/require" 15 | 16 | "github.com/aliexpressru/gomemcached/consistenthash" 17 | "github.com/aliexpressru/gomemcached/logger" 18 | "github.com/aliexpressru/gomemcached/utils" 19 | ) 20 | 21 | func Test_getNodes(t *testing.T) { 22 | type args struct { 23 | cfg *config 24 | mock *network 25 | } 26 | tests := []struct { 27 | name string 28 | args args 29 | want []string 30 | wantErr assert.ErrorAssertionFunc 31 | }{ 32 | { 33 | name: "Servers", 34 | args: args{ 35 | mock: &network{lookupHost: func(host string) (addrs []string, err error) { 36 | return []string{"server1:11211", "server2:11211"}, nil 37 | }}, 38 | cfg: &config{ 39 | Servers: []string{"server1:11211", "server2:11211"}, 40 | }}, 41 | want: []string{"server1:11211", "server2:11211"}, 42 | wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { 43 | if err != nil { 44 | t.Errorf("getNodes have error - %v", err) 45 | return false 46 | } 47 | return true 48 | }, 49 | }, 50 | { 51 | name: "Headless", 52 | args: args{ 53 | mock: &network{lookupHost: func(host string) (addrs []string, err error) { 54 | return []string{"93.184.216.34", "123.323.32.11"}, nil 55 | }}, 56 | cfg: &config{ 57 | HeadlessServiceAddress: "example.com", 58 | MemcachedPort: 11211, 59 | }}, 60 | want: []string{"93.184.216.34:11211", "123.323.32.11:11211"}, 61 | wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { 62 | if err != nil { 63 | t.Errorf("getNodes have error - %v", err) 64 | return false 65 | } 66 | return true 67 | }, 68 | }, 69 | { 70 | name: "config nil", 71 | args: args{ 72 | mock: &network{lookupHost: func(_ string) (_ []string, _ error) { 73 | return 74 | }}, 75 | cfg: nil}, 76 | want: []string{}, 77 | wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { 78 | if err != nil { 79 | t.Errorf("getNodes have error - %v", err) 80 | return false 81 | } 82 | return true 83 | }, 84 | }, 85 | { 86 | name: "error headless", 87 | args: args{ 88 | mock: &network{lookupHost: func(host string) (addrs []string, err error) { 89 | return nil, &net.DNSError{ 90 | Err: "no such host", 91 | Name: "fakeaddress.r", 92 | } 93 | }}, 94 | cfg: &config{HeadlessServiceAddress: "fakeaddress.r"}}, 95 | want: nil, 96 | wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { 97 | if err != nil { 98 | dnsError := new(net.DNSError) 99 | assert.ErrorAs(t, err, &dnsError, "Error should be as net.DNSError") 100 | return true 101 | } 102 | return false 103 | }, 104 | }, 105 | { 106 | name: "error servers", 107 | args: args{ 108 | mock: new(network), 109 | cfg: &config{Servers: []string{"localhost:1234", "fakeaddress.r", "localhost"}}}, 110 | want: nil, 111 | wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { 112 | if err != nil { 113 | return true 114 | } 115 | t.Errorf("getNodes dot't have error") 116 | return false 117 | }, 118 | }, 119 | } 120 | for _, tt := range tests { 121 | t.Run(tt.name, func(t *testing.T) { 122 | got, err := getNodes(tt.args.mock.lookupHost, tt.args.cfg) 123 | if !tt.wantErr(t, err) { 124 | return 125 | } 126 | assert.Equalf(t, tt.want, got, "getNodes(%v)", tt.args.cfg) 127 | }) 128 | } 129 | } 130 | 131 | func Test_safeGetDeadNodes(t *testing.T) { 132 | client := &Client{ 133 | deadNodes: map[string]struct{}{ 134 | "node1": {}, 135 | "node2": {}, 136 | }, 137 | } 138 | 139 | wg := sync.WaitGroup{} 140 | wg.Add(2) 141 | go func() { 142 | defer wg.Done() 143 | deadNodes := client.safeGetDeadNodes() 144 | 145 | expectedDeadNodes := map[string]struct{}{ 146 | "node1": {}, 147 | "node2": {}, 148 | } 149 | assert.Equal(t, expectedDeadNodes, deadNodes) 150 | }() 151 | go func() { 152 | defer wg.Done() 153 | deadNodes := client.safeGetDeadNodes() 154 | 155 | expectedDeadNodes := map[string]struct{}{ 156 | "node1": {}, 157 | "node2": {}, 158 | } 159 | assert.Equal(t, expectedDeadNodes, deadNodes) 160 | }() 161 | 162 | wg.Wait() 163 | } 164 | 165 | func Test_safeAddToDeadNodes(t *testing.T) { 166 | client := &Client{ 167 | deadNodes: map[string]struct{}{}, 168 | } 169 | 170 | wg := sync.WaitGroup{} 171 | wg.Add(2) 172 | go func() { 173 | defer wg.Done() 174 | client.safeAddToDeadNodes("node1") 175 | }() 176 | go func() { 177 | defer wg.Done() 178 | client.safeAddToDeadNodes("node2") 179 | }() 180 | 181 | wg.Wait() 182 | 183 | assert.Contains(t, client.deadNodes, "node1") 184 | assert.Contains(t, client.deadNodes, "node2") 185 | } 186 | 187 | func Test_safeRemoveFromDeadNodes(t *testing.T) { 188 | client := &Client{ 189 | deadNodes: map[string]struct{}{ 190 | "node1": {}, 191 | "node2": {}, 192 | }, 193 | } 194 | 195 | wg := sync.WaitGroup{} 196 | wg.Add(2) 197 | go func() { 198 | defer wg.Done() 199 | client.safeRemoveFromDeadNodes("node1") 200 | }() 201 | go func() { 202 | defer wg.Done() 203 | client.safeRemoveFromDeadNodes("node2") 204 | }() 205 | 206 | wg.Wait() 207 | 208 | assert.NotContains(t, client.deadNodes, "node1") 209 | assert.NotContains(t, client.deadNodes, "node2") 210 | } 211 | 212 | func Test_nodeIsDead(t *testing.T) { 213 | logger.DisableLogger() 214 | addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345} 215 | 216 | mockNetworkError := new(MockNetworkOperations) 217 | client := &Client{nw: &network{ 218 | dialTimeout: mockNetworkError.DialTimeout, 219 | }} 220 | 221 | assert.True(t, client.nodeIsDead("wrongarrd.r"), "nodeIsDead: wrong addr should be return true") 222 | 223 | expectedErr := errors.New("mocked dial error") 224 | 225 | mockNetworkError.On("DialTimeout", addr.Network(), addr.String(), client.netTimeout()).Return(nil, expectedErr) 226 | 227 | result := client.nodeIsDead(addr) 228 | 229 | assert.True(t, result) 230 | 231 | mockNetworkError.AssertCalled(t, "DialTimeout", addr.Network(), addr.String(), client.netTimeout()) 232 | 233 | mockNetworkRetry := new(MockNetworkOperations) 234 | client = &Client{nw: &network{ 235 | dialTimeout: mockNetworkRetry.DialTimeout, 236 | }} 237 | 238 | expectedErr = &ConnectTimeoutError{addr} 239 | 240 | mockNetworkRetry.On("DialTimeout", addr.Network(), addr.String(), client.netTimeout()).Return(nil, expectedErr) 241 | result = client.nodeIsDead(addr) 242 | 243 | assert.True(t, result) 244 | 245 | // int(DefaultRetryCountForConn)+1 - the default number of retries plus the first execution. 246 | mockNetworkRetry.AssertNumberOfCalls(t, "DialTimeout", int(DefaultRetryCountForConn)+1) 247 | 248 | mockNetworkSuccess := new(MockNetworkOperations) 249 | client = &Client{nw: &network{ 250 | dialTimeout: mockNetworkSuccess.DialTimeout, 251 | }} 252 | 253 | mockNetworkSuccess.On("DialTimeout", addr.Network(), addr.String(), client.netTimeout()).Return(&FakeConn{}, nil) 254 | 255 | result = client.nodeIsDead(addr) 256 | 257 | assert.False(t, result) 258 | 259 | mockNetworkSuccess.AssertCalled(t, "DialTimeout", addr.Network(), addr.String(), client.netTimeout()) 260 | } 261 | 262 | func Test_initNodesProvider(t *testing.T) { 263 | var ( 264 | mockNetworkErr = new(MockNetworkOperations) 265 | 266 | period = 10 * time.Millisecond 267 | 268 | ctx, cancel = context.WithCancel(context.TODO()) 269 | expectedErr = errors.New("mocked dial error") 270 | ) 271 | cl := &Client{ 272 | ctx: ctx, 273 | nw: &network{ 274 | dial: mockNetworkErr.Dial, 275 | lookupHost: mockNetworkErr.LookupHost, 276 | }, 277 | cfg: &config{ 278 | HeadlessServiceAddress: "example.com", 279 | }, 280 | nodeHCPeriod: period, 281 | nodeRBPeriod: period, 282 | } 283 | 284 | mockNetworkErr.On("LookupHost", cl.cfg.HeadlessServiceAddress).Return(nil, expectedErr) 285 | mockNetworkErr.On("Dial", mock.Anything, mock.Anything).Return(&FakeConn{}, nil) 286 | 287 | cl.initNodesProvider() 288 | 289 | mockNetworkErr.AssertNotCalled(t, "Dial") 290 | 291 | <-time.After(2 * period) 292 | cancel() 293 | 294 | mockNetworkErr.AssertCalled(t, "LookupHost", cl.cfg.HeadlessServiceAddress) 295 | } 296 | 297 | func Test_checkNodesHealth(t *testing.T) { 298 | var ( 299 | mockNetworkErr = new(MockNetworkOperations) 300 | 301 | expectedErr = errors.New("mocked dial error") 302 | ) 303 | cl := &Client{ 304 | nw: &network{ 305 | dial: mockNetworkErr.Dial, 306 | lookupHost: mockNetworkErr.LookupHost, 307 | }, 308 | cfg: &config{ 309 | HeadlessServiceAddress: "example.com", 310 | }, 311 | } 312 | 313 | mockNetworkErr.On("LookupHost", cl.cfg.HeadlessServiceAddress).Return(nil, expectedErr) 314 | mockNetworkErr.On("Dial", mock.Anything, mock.Anything).Return(&FakeConn{}, nil) 315 | 316 | cl.checkNodesHealth() 317 | 318 | mockNetworkErr.AssertNotCalled(t, "Dial") 319 | mockNetworkErr.AssertNumberOfCalls(t, "LookupHost", 1) 320 | 321 | var ( 322 | currentNodes = []string{"127.0.0.1:12345", "127.0.0.2:12345", "127.0.0.3:12345", "127.0.0.4:12345", "127.0.0.5:12345"} 323 | alreadyDeadNodes = []string{"127.0.0.4:12345", "127.0.0.5:12345"} 324 | disableNodes = []string{"127.0.0.6:12345"} 325 | 326 | mockNetwork = new(MockNetworkOperations) 327 | ) 328 | 329 | cl = &Client{ 330 | hr: consistenthash.NewHashRing(), 331 | timeout: -1, 332 | nw: &network{ 333 | dial: mockNetwork.Dial, 334 | lookupHost: mockNetwork.LookupHost, 335 | }, 336 | cfg: &config{ 337 | Servers: currentNodes, 338 | }, 339 | } 340 | 341 | mockNetwork.On("Dial", "tcp", "127.0.0.2:12345").Return(nil, expectedErr).Once() 342 | mockNetwork.On("Dial", "tcp", "127.0.0.4:12345").Return(nil, expectedErr).Once() 343 | mockNetwork.On("Dial", mock.Anything, mock.Anything).Return(&FakeConn{}, nil) 344 | 345 | for _, node := range currentNodes { 346 | addr, _ := utils.AddrRepr(node) 347 | cl.hr.Add(addr) 348 | } 349 | cl.deadNodes = make(map[string]struct{}) 350 | for _, node := range alreadyDeadNodes { 351 | cl.deadNodes[node] = struct{}{} 352 | } 353 | for _, node := range disableNodes { 354 | cl.deadNodes[node] = struct{}{} 355 | } 356 | 357 | cl.checkNodesHealth() 358 | 359 | assert.Equal(t, 3, len(cl.hr.GetAllNodes())) 360 | assert.Equal(t, 2, len(cl.deadNodes)) 361 | } 362 | 363 | func Test_rebuildNodes(t *testing.T) { 364 | var ( 365 | mockNetworkErr = new(MockNetworkOperations) 366 | 367 | expectedErr = errors.New("mocked dial error") 368 | ) 369 | cl := &Client{ 370 | nw: &network{ 371 | dial: mockNetworkErr.Dial, 372 | lookupHost: mockNetworkErr.LookupHost, 373 | }, 374 | cfg: &config{ 375 | HeadlessServiceAddress: "example.com", 376 | }, 377 | } 378 | 379 | mockNetworkErr.On("LookupHost", cl.cfg.HeadlessServiceAddress).Return(nil, expectedErr) 380 | mockNetworkErr.On("Dial", mock.Anything, mock.Anything).Return(&FakeConn{}, nil) 381 | 382 | cl.rebuildNodes() 383 | 384 | mockNetworkErr.AssertNotCalled(t, "Dial") 385 | mockNetworkErr.AssertNumberOfCalls(t, "LookupHost", 1) 386 | 387 | var ( 388 | currentNodes = []string{"127.0.0.1:12345", "127.0.0.2:12345", "127.0.0.3:12345", "127.0.0.4:12345", "127.0.0.5:12345"} 389 | alreadyDeadNodes = []string{"127.0.0.4:12345", "127.0.0.2:12345"} 390 | expectedNodesInRing = []string{"127.0.0.1:12345", "127.0.0.3:12345", "127.0.0.5:12345"} 391 | 392 | mockNetwork = new(MockNetworkOperations) 393 | ) 394 | cl = &Client{ 395 | ctx: context.TODO(), 396 | nw: &network{ 397 | dial: mockNetwork.Dial, 398 | lookupHost: mockNetwork.LookupHost, 399 | }, 400 | cfg: &config{ 401 | Servers: currentNodes, 402 | }, 403 | timeout: -1, 404 | maxIdleConns: 1, 405 | hr: consistenthash.NewHashRing(), 406 | } 407 | 408 | mockNetwork.On("LookupHost", cl.cfg.Servers).Return(currentNodes, nil) 409 | mockNetwork.On("Dial", mock.Anything, mock.Anything).Return(&FakeConn{}, nil) 410 | 411 | cl.deadNodes = make(map[string]struct{}) 412 | for _, node := range alreadyDeadNodes { 413 | cl.deadNodes[node] = struct{}{} 414 | } 415 | // len(currentNodes)-1 simulates the absence of one node in the hash ring. 416 | for i := 0; i < len(currentNodes)-1; i++ { 417 | addr, _ := utils.AddrRepr(currentNodes[i]) 418 | cl.hr.Add(addr) 419 | } 420 | 421 | // simulating the borrowing of a connection from the connection pool by taking a connection and putting it back. 422 | // for test CloseAvailableConnsInAllShardPools 423 | for i := 0; i < len(currentNodes)-1; i++ { 424 | node, ok := cl.hr.Get(currentNodes[i]) 425 | require.Truef(t, ok, "Not found node (%s) in hash ring", currentNodes[i]) 426 | cn, err := cl.getConnForNode(node) 427 | require.Nil(t, err, "getConnForNode try get conn") 428 | cn.condRelease(new(error)) 429 | } 430 | 431 | cl.rebuildNodes() 432 | 433 | assert.Equal(t, 3, cl.hr.GetNodesCount()) 434 | 435 | var actualNodesInRing []string 436 | for _, node := range cl.hr.GetAllNodes() { 437 | actualNodesInRing = append(actualNodesInRing, node.(net.Addr).String()) 438 | } 439 | 440 | slices.Sort(actualNodesInRing) 441 | slices.Sort(expectedNodesInRing) 442 | assert.Equal(t, expectedNodesInRing, actualNodesInRing) 443 | 444 | for _, pool := range cl.freeConns { 445 | assert.Equal(t, 0, pool.Len()) 446 | } 447 | } 448 | 449 | type MockNetworkOperations struct { 450 | mock.Mock 451 | } 452 | 453 | func (m *MockNetworkOperations) Dial(network, address string) (net.Conn, error) { 454 | args := m.Called(network, address) 455 | if args.Get(0) == nil { 456 | return nil, args.Error(1) 457 | } 458 | return args.Get(0).(net.Conn), args.Error(1) 459 | } 460 | 461 | func (m *MockNetworkOperations) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { 462 | args := m.Called(network, address, timeout) 463 | if args.Get(0) == nil { 464 | return nil, args.Error(1) 465 | } 466 | return args.Get(0).(net.Conn), args.Error(1) 467 | } 468 | 469 | func (m *MockNetworkOperations) LookupHost(host string) ([]string, error) { 470 | args := m.Called(host) 471 | if args.Get(0) == nil { 472 | return nil, args.Error(1) 473 | } 474 | return args.Get(0).([]string), args.Error(1) 475 | } 476 | 477 | type FakeConn struct { 478 | net.TCPConn 479 | } 480 | 481 | func (f *FakeConn) Read(_ []byte) (n int, err error) { 482 | return 0, nil 483 | } 484 | 485 | func (f *FakeConn) Write(_ []byte) (n int, err error) { 486 | return 0, nil 487 | } 488 | 489 | func (f *FakeConn) Close() error { 490 | return nil 491 | } 492 | -------------------------------------------------------------------------------- /memcached/options.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/aliexpressru/gomemcached/consistenthash" 7 | ) 8 | 9 | type options struct { 10 | Client 11 | disableLogger bool 12 | } 13 | 14 | type Option func(*options) 15 | 16 | // WithMaxIdleConns is sets a custom value of open connections per address. 17 | // By default, DefaultMaxIdleConns will be used. 18 | func WithMaxIdleConns(num int) Option { 19 | return func(o *options) { 20 | o.Client.maxIdleConns = num 21 | } 22 | } 23 | 24 | // WithTimeout is sets custom timeout for connections. 25 | // By default, DefaultTimeout will be used. 26 | func WithTimeout(tm time.Duration) Option { 27 | return func(o *options) { 28 | o.Client.timeout = tm 29 | } 30 | } 31 | 32 | // WithCustomHashRing for setup use consistenthash.NewCustomHashRing 33 | func WithCustomHashRing(hr *consistenthash.HashRing) Option { 34 | return func(o *options) { 35 | o.Client.hr = hr 36 | } 37 | } 38 | 39 | // WithPeriodForNodeHealthCheck is sets a custom frequency for health checker of physical nodes. 40 | // By default, DefaultNodeHealthCheckPeriod will be used. 41 | func WithPeriodForNodeHealthCheck(t time.Duration) Option { 42 | return func(o *options) { 43 | o.Client.nodeHCPeriod = t 44 | } 45 | } 46 | 47 | // WithPeriodForRebuildingNodes is sets a custom frequency for resharding and checking for dead nodes. 48 | // By default, DefaultRebuildingNodePeriod will be used. 49 | func WithPeriodForRebuildingNodes(t time.Duration) Option { 50 | return func(o *options) { 51 | o.Client.nodeRBPeriod = t 52 | } 53 | } 54 | 55 | // WithDisableNodeProvider is disabled node health cheek and rebuild nodes for hash ring 56 | func WithDisableNodeProvider() Option { 57 | return func(o *options) { 58 | o.Client.disableNodeProvider = true 59 | } 60 | } 61 | 62 | // WithDisableRefreshConnsInPool is disabled auto close some connections in pool in NodeProvider. 63 | // This is done to refresh connections in the pool. 64 | func WithDisableRefreshConnsInPool() Option { 65 | return func(o *options) { 66 | o.Client.disableRefreshConns = true 67 | } 68 | } 69 | 70 | // WithDisableMemcachedDiagnostic is disabled write library metrics. 71 | // 72 | // gomemcached_method_duration_seconds 73 | func WithDisableMemcachedDiagnostic() Option { 74 | return func(o *options) { 75 | o.Client.disableMemcachedDiagnostic = true 76 | } 77 | } 78 | 79 | // WithDisableLogger is disabled internal library logs. 80 | func WithDisableLogger() Option { 81 | return func(o *options) { 82 | o.disableLogger = true 83 | } 84 | } 85 | 86 | // WithAuthentication is turn on authenticate for memcached 87 | func WithAuthentication(user, pass string) Option { 88 | return func(o *options) { 89 | o.Client.authEnable = true 90 | o.Client.authData = prepareAuthData(user, pass) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /memcached/options_test.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | 10 | "github.com/aliexpressru/gomemcached/consistenthash" 11 | "github.com/aliexpressru/gomemcached/logger" 12 | ) 13 | 14 | func TestWithOptions(t *testing.T) { 15 | os.Setenv("MEMCACHED_SERVERS", "localhost:11211") 16 | 17 | hMcl, _ := InitFromEnv() 18 | assert.NotNil(t, hMcl.hr, "InitFromEnv: hash ring is nil") 19 | 20 | const ( 21 | maxIdleConns = 10 22 | disable = true 23 | enable 24 | authUser = "admin" 25 | authPass = "password" 26 | timeout = 5 * time.Second 27 | period = time.Second 28 | ) 29 | 30 | hr := consistenthash.NewCustomHashRing(1, nil) 31 | mcl, _ := InitFromEnv( 32 | WithMaxIdleConns(maxIdleConns), 33 | WithTimeout(timeout), 34 | WithCustomHashRing(hr), 35 | WithPeriodForNodeHealthCheck(period), 36 | WithPeriodForRebuildingNodes(period), 37 | WithDisableNodeProvider(), 38 | WithDisableRefreshConnsInPool(), 39 | WithDisableMemcachedDiagnostic(), 40 | WithAuthentication(authUser, authPass), 41 | WithDisableLogger(), 42 | ) 43 | t.Cleanup(func() { 44 | mcl.CloseAllConns() 45 | }) 46 | 47 | assert.Equal(t, maxIdleConns, mcl.maxIdleConns, "WithMaxIdleConns should set maxIdleConns") 48 | assert.Equal(t, timeout, mcl.timeout, "WithTimeout should set timeout") 49 | assert.Equal(t, hr, mcl.hr, "WithCustomHashRing should set hr") 50 | assert.Equal(t, period, mcl.nodeHCPeriod, "WithPeriodForNodeHealthCheck should set period") 51 | assert.Equal(t, period, mcl.nodeRBPeriod, "WithPeriodForRebuildingNodes should set period") 52 | assert.Equal(t, disable, mcl.disableNodeProvider, "WithDisableNodeProvider should set disable") 53 | assert.Equal(t, disable, mcl.disableRefreshConns, "WithDisableRefreshConnsInPool should set disable") 54 | assert.Equal(t, disable, mcl.disableMemcachedDiagnostic, "WithDisableMemcachedDiagnostic should set disable") 55 | assert.Equal(t, enable, mcl.authEnable, "WithAuthentication should set enable") 56 | assert.Equal(t, disable, logger.LoggerIsDisable(), "WithDisableLogger should set disable") 57 | } 58 | -------------------------------------------------------------------------------- /memcached/requests.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | const ( 10 | // MaxBodyLen a maximum reasonable body length to expect. 11 | // Anything larger than this will result in an error. 12 | MaxBodyLen = int(22 * 1e6) // 22 MB 13 | 14 | BUF_LEN = 256 15 | 16 | // reserved always 0 17 | reserved8 = uint8(0) 18 | reserved16 = uint16(0) 19 | ) 20 | 21 | // Request a Memcached request 22 | type Request struct { 23 | // The command being issued 24 | Opcode OpCode 25 | // The CAS (if applicable, or 0) 26 | Cas uint64 27 | // An opaque value to be returned with this request 28 | Opaque uint32 29 | // Command extras, key, and body 30 | Extras, Key, Body []byte 31 | } 32 | 33 | // Size is a number of bytes this request requires. 34 | func (r *Request) Size() int { 35 | return HDR_LEN + len(r.Extras) + len(r.Key) + len(r.Body) 36 | } 37 | 38 | // String is a debugging string representation of this request 39 | func (r Request) String() string { 40 | return fmt.Sprintf("{Request opcode=%s, bodylen=%d, key='%s'}", 41 | r.Opcode, len(r.Body), r.Key) 42 | } 43 | 44 | func (r *Request) fillHeaderBytes(data []byte) int { 45 | /* 46 | Byte/ 0 | 1 | 2 | 3 | 47 | / | | | | 48 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7| 49 | +---------------+---------------+---------------+---------------+ 50 | 0| Magic | Opcode | Key length | 51 | +---------------+---------------+---------------+---------------+ 52 | 4| Extras length | Data type | vbucket id | 53 | +---------------+---------------+---------------+---------------+ 54 | 8| Total body length | 55 | +---------------+---------------+---------------+---------------+ 56 | 12| Opaque | 57 | +---------------+---------------+---------------+---------------+ 58 | 16| CAS | 59 | | | 60 | +---------------+---------------+---------------+---------------+ 61 | Total 24 bytes 62 | */ 63 | 64 | pos := 0 65 | data[pos] /*0x00*/ = REQ_MAGIC 66 | pos++ // 1 67 | data[pos] /*0x01*/ = byte(r.Opcode) 68 | pos++ // 2 69 | binary.BigEndian.PutUint16(data[pos:pos+2] /*0x02 - 0x03*/, uint16(len(r.Key))) 70 | 71 | pos += 2 // 4 72 | data[pos] /*0x04*/ = byte(len(r.Extras)) 73 | 74 | pos++ // 5 75 | data[pos] /*0x05*/ = reserved8 76 | 77 | pos++ // 6 78 | binary.BigEndian.PutUint16(data[pos:pos+2] /*0x06*/, reserved16) 79 | 80 | pos += 2 // 8 81 | binary.BigEndian.PutUint32(data[pos:pos+4] /*0x08 - 0x09 - 0x0a - 0x0b*/, uint32(len(r.Body)+len(r.Key)+len(r.Extras))) 82 | 83 | pos += 4 // 12 84 | binary.BigEndian.PutUint32(data[pos:pos+4] /*0x0c - 0x0d - 0x0e - 0x0f*/, r.Opaque) 85 | 86 | pos += 4 // 16 87 | if r.Cas != 0 { 88 | binary.BigEndian.PutUint64(data[pos:pos+8] /*0x10 - 0x11 - 0x12 - 0x13 - 0x14 - 0x16 - 0x17*/, r.Cas) 89 | } 90 | 91 | pos += 8 // 24 92 | if len(r.Extras) > 0 { 93 | copy(data[pos:pos+len(r.Extras)], r.Extras) 94 | pos += len(r.Extras) 95 | } 96 | 97 | if len(r.Key) > 0 { 98 | copy(data[pos:pos+len(r.Key)], r.Key) 99 | pos += len(r.Key) 100 | } 101 | return pos 102 | } 103 | 104 | // HeaderBytes is a wire representation of the header (with the extras and key) 105 | func (r *Request) HeaderBytes() []byte { 106 | data := make([]byte, HDR_LEN+len(r.Extras)+len(r.Key)) 107 | 108 | r.fillHeaderBytes(data) 109 | 110 | return data 111 | } 112 | 113 | // Bytes is a wire representation of this request. 114 | func (r *Request) Bytes() []byte { 115 | data := make([]byte, r.Size()) 116 | 117 | pos := r.fillHeaderBytes(data) 118 | 119 | if len(r.Body) > 0 { 120 | copy(data[pos:pos+len(r.Body)], r.Body) 121 | } 122 | 123 | return data 124 | } 125 | 126 | // Transmit is send this request message across a writer. 127 | func (r *Request) Transmit(w io.Writer) (n int, err error) { 128 | if len(r.Body) < BODY_LEN { 129 | n, err = w.Write(r.Bytes()) 130 | } else { 131 | n, err = w.Write(r.HeaderBytes()) 132 | if err == nil { 133 | m := 0 134 | m, err = w.Write(r.Body) 135 | n += m 136 | } 137 | } 138 | return 139 | } 140 | 141 | // Receive a fill this Request with the data from this reader. 142 | func (r *Request) Receive(rd io.Reader, hdrBytes []byte) (int, error) { 143 | /* 144 | Byte/ 0 | 1 | 2 | 3 | 145 | / | | | | 146 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7| 147 | +---------------+---------------+---------------+---------------+ 148 | 0| Magic | Opcode | Key length | 149 | +---------------+---------------+---------------+---------------+ 150 | 4| Extras length | Data type | vbucket id | 151 | +---------------+---------------+---------------+---------------+ 152 | 8| Total body length | 153 | +---------------+---------------+---------------+---------------+ 154 | 12| Opaque | 155 | +---------------+---------------+---------------+---------------+ 156 | 16| CAS | 157 | | | 158 | +---------------+---------------+---------------+---------------+ 159 | Total 24 bytes 160 | */ 161 | 162 | if len(hdrBytes) < HDR_LEN { 163 | hdrBytes = make([]byte, HDR_LEN) 164 | } 165 | 166 | n, err := io.ReadFull(rd, hdrBytes) 167 | if err != nil { 168 | return n, err 169 | } 170 | 171 | if hdrBytes[0] != RES_MAGIC && hdrBytes[0] != REQ_MAGIC { 172 | return n, fmt.Errorf("bad magic: 0x%02x", hdrBytes[0]) 173 | } 174 | r.Opcode = OpCode(hdrBytes[1]) 175 | 176 | klen := int(binary.BigEndian.Uint16(hdrBytes[2:])) 177 | elen := int(hdrBytes[4]) 178 | bodyLen := int(binary.BigEndian.Uint32(hdrBytes[8:]) - uint32(klen) - uint32(elen)) 179 | if bodyLen > MaxBodyLen { 180 | return n, fmt.Errorf("%d is too big (max %d)", 181 | bodyLen, MaxBodyLen) 182 | } 183 | r.Opaque = binary.BigEndian.Uint32(hdrBytes[12:]) 184 | r.Cas = binary.BigEndian.Uint64(hdrBytes[16:]) 185 | 186 | buf := make([]byte, klen+elen+bodyLen) 187 | m, err := io.ReadFull(rd, buf) 188 | n += m 189 | if err == nil { 190 | if elen > 0 { 191 | r.Extras = buf[0:elen] 192 | } 193 | if klen > 0 { 194 | r.Key = buf[elen : klen+elen] 195 | } 196 | if klen+elen > 0 { 197 | r.Body = buf[klen+elen:] 198 | } 199 | } 200 | 201 | return n, err 202 | } 203 | 204 | // prepareExtras fills Extras depending on OpCode for Request 205 | func (r *Request) prepareExtras(expiration uint32, delta uint64, initVal uint64) { 206 | switch r.Opcode { 207 | case DELETE, DELETEQ, QUIT, QUITQ, NOOP, VERSION, APPEND, APPENDQ, PREPEND, PREPENDQ, STAT, GET, GETQ, GETK, GETKQ: // MUST NOT have extras 208 | case SET, SETQ, ADD, ADDQ, REPLACE, REPLACEQ: 209 | /* 210 | Byte/ 0 | 1 | 2 | 3 | 211 | / | | | | 212 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7| 213 | +---------------+---------------+---------------+---------------+ 214 | 0| Flags | 215 | +---------------+---------------+---------------+---------------+ 216 | 4| Expiration | 217 | +---------------+---------------+---------------+---------------+ 218 | Total 8 bytes 219 | */ 220 | 221 | r.Extras = make([]byte, 8) 222 | // flags always is 0 223 | binary.BigEndian.PutUint32(r.Extras[:4], uint32(0)) 224 | binary.BigEndian.PutUint32(r.Extras[4:], expiration) 225 | case INCREMENT, INCREMENTQ, DECREMENT, DECREMENTQ: 226 | /* 227 | 228 | Byte/ 0 | 1 | 2 | 3 | 229 | / | | | | 230 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7| 231 | +---------------+---------------+---------------+---------------+ 232 | 0| Amount to add / subtract (delta) | 233 | | | 234 | +---------------+---------------+---------------+---------------+ 235 | 8| Initial value | 236 | | | 237 | +---------------+---------------+---------------+---------------+ 238 | 16| Expiration | 239 | +---------------+---------------+---------------+---------------+ 240 | Total 20 bytes 241 | */ 242 | r.Extras = make([]byte, 20) 243 | binary.BigEndian.PutUint64(r.Extras[:8], delta) 244 | binary.BigEndian.PutUint64(r.Extras[8:], initVal) 245 | binary.BigEndian.PutUint32(r.Extras[16:], expiration) 246 | case FLUSH, FLUSHQ: 247 | /* 248 | Byte/ 0 | 1 | 2 | 3 | 249 | / | | | | 250 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7| 251 | +---------------+---------------+---------------+---------------+ 252 | 0| Expiration | 253 | +---------------+---------------+---------------+---------------+ 254 | Total 4 bytes 255 | */ 256 | r.Extras = make([]byte, 4) 257 | binary.BigEndian.PutUint32(r.Extras, expiration) 258 | } 259 | } 260 | 261 | type StoreMode uint8 262 | 263 | const ( 264 | // Add - Store the data, but only if the server does not already hold data for a given key 265 | Add StoreMode = iota 266 | // Set - Store the data, overwrite if already exists 267 | Set 268 | // Replace - Store the data, but only if the server does already hold data for a given key 269 | Replace 270 | ) 271 | 272 | func (sm StoreMode) Resolve() OpCode { 273 | switch sm { 274 | case Set: 275 | return SET 276 | case Replace: 277 | return REPLACE 278 | default: 279 | return ADD 280 | } 281 | } 282 | 283 | type DeltaMode uint8 284 | 285 | const ( 286 | // Increment - increases the value by the specified amount 287 | Increment DeltaMode = iota 288 | // Decrement - decreases the value by the specified amount 289 | Decrement 290 | ) 291 | 292 | func (sm DeltaMode) Resolve() OpCode { 293 | switch sm { 294 | case Increment: 295 | return INCREMENT 296 | case Decrement: 297 | return DECREMENT 298 | default: 299 | return INCREMENT 300 | } 301 | } 302 | 303 | type AppendMode uint8 304 | 305 | const ( 306 | // Append - Appends data to the end of an existing key value. 307 | Append AppendMode = iota 308 | // Prepend -Appends data to the beginning of an existing key value. 309 | Prepend 310 | ) 311 | 312 | func (sm AppendMode) Resolve() OpCode { 313 | switch sm { 314 | case Append: 315 | return APPEND 316 | default: 317 | return PREPEND 318 | } 319 | } 320 | -------------------------------------------------------------------------------- /memcached/requests_test.go: -------------------------------------------------------------------------------- 1 | // nolint 2 | package memcached 3 | 4 | import ( 5 | "bytes" 6 | "fmt" 7 | "io/ioutil" 8 | "reflect" 9 | "testing" 10 | ) 11 | 12 | func TestEncodingRequest(t *testing.T) { 13 | req := Request{ 14 | Opcode: SET, 15 | Cas: 938424885, 16 | Opaque: 7242, 17 | Key: []byte("somekey"), 18 | Body: []byte("somevalue"), 19 | } 20 | 21 | got := req.Bytes() 22 | 23 | expected := []byte{ 24 | REQ_MAGIC, byte(SET), 25 | 0x0, 0x7, // length of key 26 | 0x0, // extra length 27 | 0x0, // reserved 28 | 0x0, 0x0, // vbucket 29 | 0x0, 0x0, 0x0, 0x10, // Length of value 30 | 0x0, 0x0, 0x1c, 0x4a, // opaque 31 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS 32 | 's', 'o', 'm', 'e', 'k', 'e', 'y', 33 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e', 34 | } 35 | 36 | if len(got) != req.Size() { 37 | t.Fatalf("Expected %v bytes, got %v", got, 38 | len(got)) 39 | } 40 | 41 | if !reflect.DeepEqual(got, expected) { 42 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v", 43 | expected, got) 44 | } 45 | 46 | exp := `{Request opcode=SET, bodylen=9, key='somekey'}` 47 | if req.String() != exp { 48 | t.Errorf("Expected string=%q, got %q", exp, req.String()) 49 | } 50 | } 51 | 52 | func TestEncodingRequestWithExtras(t *testing.T) { 53 | req := Request{ 54 | Opcode: SET, 55 | Cas: 938424885, 56 | Opaque: 7242, 57 | Extras: []byte{1, 2, 3, 4}, 58 | Key: []byte("somekey"), 59 | Body: []byte("somevalue"), 60 | } 61 | 62 | buf := &bytes.Buffer{} 63 | req.Transmit(buf) 64 | got := buf.Bytes() 65 | 66 | expected := []byte{ 67 | REQ_MAGIC, byte(SET), 68 | 0x0, 0x7, // length of key 69 | 0x4, // extra length 70 | 0x0, // reserved 71 | 0x0, 0x0, // vbucket 72 | 0x0, 0x0, 0x0, 0x14, // Length of remainder 73 | 0x0, 0x0, 0x1c, 0x4a, // opaque 74 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS 75 | 1, 2, 3, 4, // extras 76 | 's', 'o', 'm', 'e', 'k', 'e', 'y', 77 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e', 78 | } 79 | 80 | if len(got) != req.Size() { 81 | t.Fatalf("Expected %v bytes, got %v", got, 82 | len(got)) 83 | } 84 | 85 | if !reflect.DeepEqual(got, expected) { 86 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v", 87 | expected, got) 88 | } 89 | } 90 | 91 | func TestEncodingRequestWithLargeBody(t *testing.T) { 92 | req := Request{ 93 | Opcode: SET, 94 | Cas: 938424885, 95 | Opaque: 7242, 96 | Extras: []byte{1, 2, 3, 4}, 97 | Key: []byte("somekey"), 98 | Body: make([]byte, 256), 99 | } 100 | 101 | buf := &bytes.Buffer{} 102 | req.Transmit(buf) 103 | got := buf.Bytes() 104 | 105 | expected := append([]byte{ 106 | REQ_MAGIC, byte(SET), 107 | 0x0, 0x7, // length of key 108 | 0x4, // extra length 109 | 0x0, // reserved 110 | 0x0, 0x0, // vbucket 111 | 0x0, 0x0, 0x1, 0xb, // Length of remainder 112 | 0x0, 0x0, 0x1c, 0x4a, // opaque 113 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS 114 | 1, 2, 3, 4, // extras 115 | 's', 'o', 'm', 'e', 'k', 'e', 'y', 116 | }, make([]byte, 256)...) 117 | 118 | if len(got) != req.Size() { 119 | t.Fatalf("Expected %v bytes, got %v", got, 120 | len(got)) 121 | } 122 | 123 | if !reflect.DeepEqual(got, expected) { 124 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v", 125 | expected, got) 126 | } 127 | } 128 | 129 | func BenchmarkEncodingRequest(b *testing.B) { 130 | req := Request{ 131 | Opcode: SET, 132 | Cas: 938424885, 133 | Opaque: 7242, 134 | Key: []byte("somekey"), 135 | Body: []byte("somevalue"), 136 | } 137 | 138 | b.SetBytes(int64(req.Size())) 139 | 140 | for i := 0; i < b.N; i++ { 141 | req.Bytes() 142 | } 143 | } 144 | 145 | func BenchmarkEncodingRequest0CAS(b *testing.B) { 146 | req := Request{ 147 | Opcode: SET, 148 | Cas: 0, 149 | Opaque: 7242, 150 | Key: []byte("somekey"), 151 | Body: []byte("somevalue"), 152 | } 153 | 154 | b.SetBytes(int64(req.Size())) 155 | 156 | for i := 0; i < b.N; i++ { 157 | req.Bytes() 158 | } 159 | } 160 | 161 | func BenchmarkEncodingRequest1Extra(b *testing.B) { 162 | req := Request{ 163 | Opcode: SET, 164 | Cas: 0, 165 | Opaque: 7242, 166 | Extras: []byte{1}, 167 | Key: []byte("somekey"), 168 | Body: []byte("somevalue"), 169 | } 170 | 171 | b.SetBytes(int64(req.Size())) 172 | 173 | for i := 0; i < b.N; i++ { 174 | req.Bytes() 175 | } 176 | } 177 | 178 | func TestRequestTransmit(t *testing.T) { 179 | res := Request{Key: []byte("thekey")} 180 | _, err := res.Transmit(ioutil.Discard) 181 | if err != nil { 182 | t.Errorf("Error sending small request: %v", err) 183 | } 184 | 185 | res.Body = make([]byte, 256) 186 | _, err = res.Transmit(ioutil.Discard) 187 | if err != nil { 188 | t.Errorf("Error sending large request thing: %v", err) 189 | } 190 | } 191 | 192 | func TestReceiveRequest(t *testing.T) { 193 | req := Request{ 194 | Opcode: SET, 195 | Cas: 0, 196 | Opaque: 7242, 197 | Extras: []byte{1}, 198 | Key: []byte("somekey"), 199 | Body: []byte("somevalue"), 200 | } 201 | 202 | data := req.Bytes() 203 | 204 | req2 := Request{} 205 | n, err := req2.Receive(bytes.NewReader(data), nil) 206 | if err != nil { 207 | t.Fatalf("Error receiving: %v", err) 208 | } 209 | if len(data) != n { 210 | t.Errorf("Expected to read %v bytes, read %v", len(data), n) 211 | } 212 | 213 | if !reflect.DeepEqual(req, req2) { 214 | t.Fatalf("Expected %#v == %#v", req, req2) 215 | } 216 | } 217 | 218 | func TestReceiveRequestNoContent(t *testing.T) { 219 | req := Request{ 220 | Opcode: SET, 221 | Cas: 0, 222 | Opaque: 7242, 223 | } 224 | 225 | data := req.Bytes() 226 | 227 | req2 := Request{} 228 | n, err := req2.Receive(bytes.NewReader(data), nil) 229 | if err != nil { 230 | t.Fatalf("Error receiving: %v", err) 231 | } 232 | if len(data) != n { 233 | t.Errorf("Expected to read %v bytes, read %v", len(data), n) 234 | } 235 | 236 | if fmt.Sprintf("%#v", req) != fmt.Sprintf("%#v", req2) { 237 | t.Fatalf("Expected %#v == %#v", req, req2) 238 | } 239 | } 240 | 241 | func TestReceiveRequestShortHdr(t *testing.T) { 242 | req := Request{} 243 | n, err := req.Receive(bytes.NewReader([]byte{1, 2, 3}), nil) 244 | if err == nil { 245 | t.Errorf("Expected error, got %#v", req) 246 | } 247 | if n != 3 { 248 | t.Errorf("Expected to have read 3 bytes, read %v", n) 249 | } 250 | } 251 | 252 | func TestReceiveRequestShortBody(t *testing.T) { 253 | req := Request{ 254 | Opcode: SET, 255 | Cas: 0, 256 | Opaque: 7242, 257 | Extras: []byte{1}, 258 | Key: []byte("somekey"), 259 | Body: []byte("somevalue"), 260 | } 261 | 262 | data := req.Bytes() 263 | 264 | req2 := Request{} 265 | n, err := req2.Receive(bytes.NewReader(data[:len(data)-3]), nil) 266 | if err == nil { 267 | t.Errorf("Expected error, got %#v", req2) 268 | } 269 | if n != len(data)-3 { 270 | t.Errorf("Expected to have read %v bytes, read %v", len(data)-3, n) 271 | } 272 | } 273 | 274 | func TestReceiveRequestBadMagic(t *testing.T) { 275 | req := Request{ 276 | Opcode: SET, 277 | Cas: 0, 278 | Opaque: 7242, 279 | Extras: []byte{1}, 280 | Key: []byte("somekey"), 281 | Body: []byte("somevalue"), 282 | } 283 | 284 | data := req.Bytes() 285 | data[0] = 0x83 286 | 287 | req2 := Request{} 288 | _, err := req2.Receive(bytes.NewReader(data), nil) 289 | if err == nil { 290 | t.Fatalf("Expected error, got %#v", req2) 291 | } 292 | } 293 | 294 | func TestReceiveRequestLongBody(t *testing.T) { 295 | req := Request{ 296 | Opcode: SET, 297 | Cas: 0, 298 | Opaque: 7242, 299 | Extras: []byte{1}, 300 | Key: []byte("somekey"), 301 | Body: make([]byte, MaxBodyLen+5), 302 | } 303 | 304 | data := req.Bytes() 305 | 306 | req2 := Request{} 307 | _, err := req2.Receive(bytes.NewReader(data), nil) 308 | if err == nil { 309 | t.Fatalf("Expected error, got %#v", req2) 310 | } 311 | } 312 | 313 | func BenchmarkReceiveRequest(b *testing.B) { 314 | req := Request{ 315 | Opcode: SET, 316 | Cas: 0, 317 | Opaque: 7242, 318 | Extras: []byte{1}, 319 | Key: []byte("somekey"), 320 | Body: []byte("somevalue"), 321 | } 322 | 323 | data := req.Bytes() 324 | data[0] = REQ_MAGIC 325 | rdr := bytes.NewReader(data) 326 | 327 | b.SetBytes(int64(len(data))) 328 | 329 | b.ResetTimer() 330 | buf := make([]byte, HDR_LEN) 331 | for i := 0; i < b.N; i++ { 332 | req2 := Request{} 333 | rdr.Seek(0, 0) 334 | _, err := req2.Receive(rdr, buf) 335 | if err != nil { 336 | b.Fatalf("Error receiving: %v", err) 337 | } 338 | } 339 | } 340 | 341 | func BenchmarkReceiveRequestNoBuf(b *testing.B) { 342 | req := Request{ 343 | Opcode: SET, 344 | Cas: 0, 345 | Opaque: 7242, 346 | Extras: []byte{1}, 347 | Key: []byte("somekey"), 348 | Body: []byte("somevalue"), 349 | } 350 | 351 | data := req.Bytes() 352 | data[0] = REQ_MAGIC 353 | rdr := bytes.NewReader(data) 354 | 355 | b.SetBytes(int64(len(data))) 356 | 357 | b.ResetTimer() 358 | for i := 0; i < b.N; i++ { 359 | req2 := Request{} 360 | rdr.Seek(0, 0) 361 | _, err := req2.Receive(rdr, nil) 362 | if err != nil { 363 | b.Fatalf("Error receiving: %v", err) 364 | } 365 | } 366 | } 367 | 368 | func TestRequest_prepareExtras(t *testing.T) { 369 | type fields struct { 370 | Opcode OpCode 371 | } 372 | type args struct { 373 | expiration uint32 374 | delta uint64 375 | initVal uint64 376 | } 377 | tests := []struct { 378 | name string 379 | fields fields 380 | args args 381 | expect []byte 382 | }{ 383 | { 384 | name: "GET must not have extras", 385 | fields: fields{ 386 | Opcode: GET, 387 | }, 388 | args: args{ 389 | expiration: 256, 390 | delta: 1, 391 | initVal: 1, 392 | }, 393 | expect: nil, 394 | }, 395 | { 396 | name: "SET", 397 | fields: fields{ 398 | Opcode: SET, 399 | }, 400 | args: args{ 401 | expiration: 256, 402 | delta: 1, 403 | initVal: 1, 404 | }, 405 | expect: []byte{ 406 | 0x00, 0x00, 0x00, 0x00, 407 | 0x00, 0x00, 0x01, 0x00, 408 | }, 409 | }, 410 | { 411 | name: "INCREMENT", 412 | fields: fields{ 413 | Opcode: INCREMENT, 414 | }, 415 | args: args{ 416 | expiration: 256, 417 | delta: 1, 418 | initVal: 42, 419 | }, 420 | expect: []byte{ 421 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 422 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2a, 423 | 0x00, 0x00, 0x01, 0x00, 424 | }, 425 | }, 426 | { 427 | name: "FLUSH", 428 | fields: fields{ 429 | Opcode: FLUSH, 430 | }, 431 | args: args{ 432 | expiration: 256, 433 | delta: 0, 434 | initVal: 0, 435 | }, 436 | expect: []byte{ 437 | 0x00, 0x00, 0x01, 0x00, 438 | }, 439 | }, 440 | } 441 | for _, tt := range tests { 442 | t.Run(tt.name, func(t *testing.T) { 443 | r := &Request{ 444 | Opcode: tt.fields.Opcode, 445 | } 446 | r.prepareExtras(tt.args.expiration, tt.args.delta, tt.args.initVal) 447 | 448 | if !bytes.Equal(r.Extras, tt.expect) { 449 | t.Fatalf("Expected %#v == %#v", r.Extras, tt.expect) 450 | } 451 | }) 452 | } 453 | } 454 | -------------------------------------------------------------------------------- /memcached/responses.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "encoding/binary" 5 | "fmt" 6 | "io" 7 | ) 8 | 9 | // Response is a memcached response 10 | type Response struct { 11 | // The command opcode of the command that sent the request 12 | Opcode OpCode 13 | // The status of the response 14 | Status Status 15 | // The opaque sent in the request 16 | Opaque uint32 17 | // The CAS identifier (if applicable) 18 | Cas uint64 19 | // Extras, key, and body for this response 20 | Extras, Key, Body []byte 21 | } 22 | 23 | // String a debugging string representation of this response 24 | func (r Response) String() string { 25 | return fmt.Sprintf("{Response status=%v keylen=%d, extralen=%d, bodylen=%d}", 26 | r.Status, len(r.Key), len(r.Extras), len(r.Body)) 27 | } 28 | 29 | // Error - Response as an error. 30 | func (r *Response) Error() string { 31 | return fmt.Sprintf("Response status=%v, opcode=%v, opaque=%v, msg: %s", 32 | r.Status, r.Opcode, r.Opaque, string(r.Body)) 33 | } 34 | 35 | // isFatal return false if this error isn't believed to be fatal to a connection. 36 | func isFatal(e error) bool { 37 | if e == nil { 38 | return false 39 | } 40 | switch errStatus(e) { 41 | case KEY_ENOENT, KEY_EEXISTS, NOT_STORED, TMPFAIL, AUTHFAIL: 42 | return false 43 | } 44 | return true 45 | } 46 | 47 | // Size is a number of bytes this response consumes on the wire. 48 | func (r *Response) Size() int { 49 | return HDR_LEN + len(r.Extras) + len(r.Key) + len(r.Body) 50 | } 51 | 52 | func (r *Response) fillHeaderBytes(data []byte) int { 53 | /* 54 | Byte/ 0 | 1 | 2 | 3 | 55 | / | | | | 56 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7| 57 | +---------------+---------------+---------------+---------------+ 58 | 0| Magic | Opcode | Key Length | 59 | +---------------+---------------+---------------+---------------+ 60 | 4| Extras length | Data type | Status | 61 | +---------------+---------------+---------------+---------------+ 62 | 8| Total body length | 63 | +---------------+---------------+---------------+---------------+ 64 | 12| Opaque | 65 | +---------------+---------------+---------------+---------------+ 66 | 16| CAS | 67 | | | 68 | +---------------+---------------+---------------+---------------+ 69 | Total 24 bytes 70 | */ 71 | 72 | pos := 0 73 | data[pos] /*0x00*/ = RES_MAGIC 74 | pos++ // 1 75 | data[pos] /*0x01*/ = byte(r.Opcode) 76 | pos++ // 2 77 | binary.BigEndian.PutUint16(data[pos:pos+2] /*0x02 - 0x03*/, uint16(len(r.Key))) 78 | 79 | pos += 2 // 4 80 | data[pos] /*0x04*/ = byte(len(r.Extras)) 81 | 82 | pos++ // 5 83 | data[pos] /*0x05*/ = reserved8 84 | 85 | pos++ // 6 86 | binary.BigEndian.PutUint16(data[pos:pos+2] /*0x06*/, uint16(r.Status)) 87 | 88 | pos += 2 // 8 89 | binary.BigEndian.PutUint32(data[pos:pos+4] /*0x08 - 0x09 - 0x0a - 0x0b*/, uint32(len(r.Body)+len(r.Key)+len(r.Extras))) 90 | 91 | pos += 4 // 12 92 | binary.BigEndian.PutUint32(data[pos:pos+4] /*0x0c - 0x0d - 0x0e - 0x0f*/, r.Opaque) 93 | 94 | pos += 4 // 16 95 | if r.Cas != 0 { 96 | binary.BigEndian.PutUint64(data[pos:pos+8] /*0x10 - 0x11 - 0x12 - 0x13 - 0x14 - 0x16 - 0x17*/, r.Cas) 97 | } 98 | 99 | pos += 8 // 24 100 | if len(r.Extras) > 0 { 101 | copy(data[pos:pos+len(r.Extras)], r.Extras) 102 | pos += len(r.Extras) 103 | } 104 | 105 | if len(r.Key) > 0 { 106 | copy(data[pos:pos+len(r.Key)], r.Key) 107 | pos += len(r.Key) 108 | } 109 | 110 | return pos 111 | } 112 | 113 | // HeaderBytes get just the header bytes for this response. 114 | func (r *Response) HeaderBytes() []byte { 115 | data := make([]byte, HDR_LEN+len(r.Extras)+len(r.Key)) 116 | 117 | r.fillHeaderBytes(data) 118 | 119 | return data 120 | } 121 | 122 | // Bytes the actual bytes transmitted for this response. 123 | func (r *Response) Bytes() []byte { 124 | data := make([]byte, r.Size()) 125 | 126 | pos := r.fillHeaderBytes(data) 127 | 128 | copy(data[pos:pos+len(r.Body)], r.Body) 129 | 130 | return data 131 | } 132 | 133 | // Transmit send this response message across a writer. 134 | func (r *Response) Transmit(w io.Writer) (n int, err error) { 135 | if len(r.Body) < BODY_LEN { 136 | n, err = w.Write(r.Bytes()) 137 | } else { 138 | n, err = w.Write(r.HeaderBytes()) 139 | if err == nil { 140 | m := 0 141 | m, err = w.Write(r.Body) 142 | n += m 143 | } 144 | } 145 | return 146 | } 147 | 148 | // Receive - fill this Response with the data from this reader. 149 | func (r *Response) Receive(rd io.Reader, hdrBytes []byte) (int, error) { 150 | /* 151 | Byte/ 0 | 1 | 2 | 3 | 152 | / | | | | 153 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7| 154 | +---------------+---------------+---------------+---------------+ 155 | 0| Magic | Opcode | Key Length | 156 | +---------------+---------------+---------------+---------------+ 157 | 4| Extras length | Data type | Status | 158 | +---------------+---------------+---------------+---------------+ 159 | 8| Total body length | 160 | +---------------+---------------+---------------+---------------+ 161 | 12| Opaque | 162 | +---------------+---------------+---------------+---------------+ 163 | 16| CAS | 164 | | | 165 | +---------------+---------------+---------------+---------------+ 166 | Total 24 bytes 167 | */ 168 | 169 | if len(hdrBytes) < HDR_LEN { 170 | hdrBytes = make([]byte, HDR_LEN) 171 | } 172 | 173 | n, err := io.ReadFull(rd, hdrBytes) 174 | if err != nil { 175 | return n, err 176 | } 177 | 178 | if hdrBytes[0] != RES_MAGIC && hdrBytes[0] != REQ_MAGIC { 179 | return n, fmt.Errorf("Bad magic: 0x%02x", hdrBytes[0]) 180 | } 181 | 182 | klen := int(binary.BigEndian.Uint16(hdrBytes[2:4])) 183 | elen := int(hdrBytes[4]) 184 | 185 | r.Opcode = OpCode(hdrBytes[1]) 186 | r.Status = Status(binary.BigEndian.Uint16(hdrBytes[6:8])) 187 | r.Opaque = binary.BigEndian.Uint32(hdrBytes[12:16]) 188 | r.Cas = binary.BigEndian.Uint64(hdrBytes[16:24]) 189 | 190 | bodyLen := int(binary.BigEndian.Uint32(hdrBytes[8:12])) - (klen + elen) 191 | 192 | buf := make([]byte, klen+elen+bodyLen) 193 | m, err := io.ReadFull(rd, buf) 194 | if err == nil { 195 | if elen > 0 { 196 | r.Extras = buf[0:elen] 197 | } 198 | if klen > 0 { 199 | r.Key = buf[elen : klen+elen] 200 | } 201 | if bodyLen > 0 { 202 | r.Body = buf[klen+elen:] 203 | } 204 | } 205 | 206 | return n + m, err 207 | } 208 | -------------------------------------------------------------------------------- /memcached/responses_test.go: -------------------------------------------------------------------------------- 1 | // nolint 2 | package memcached 3 | 4 | import ( 5 | "bytes" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "reflect" 10 | "testing" 11 | ) 12 | 13 | func TestEncodingResponse(t *testing.T) { 14 | req := Response{ 15 | Opcode: SET, 16 | Status: 1582, 17 | Opaque: 7242, 18 | Cas: 938424885, 19 | Key: []byte("somekey"), 20 | Body: []byte("somevalue"), 21 | } 22 | 23 | got := req.Bytes() 24 | 25 | expected := []byte{ 26 | RES_MAGIC, byte(SET), 27 | 0x0, 0x7, // length of key 28 | 0x0, // extra length 29 | 0x0, // reserved16 30 | 0x6, 0x2e, // status 31 | 0x0, 0x0, 0x0, 0x10, // Length of value 32 | 0x0, 0x0, 0x1c, 0x4a, // opaque 33 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS 34 | 's', 'o', 'm', 'e', 'k', 'e', 'y', 35 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e', 36 | } 37 | 38 | if len(got) != req.Size() { 39 | t.Fatalf("Expected %v bytes, got %v", got, 40 | len(got)) 41 | } 42 | 43 | if !reflect.DeepEqual(got, expected) { 44 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v", 45 | expected, got) 46 | } 47 | 48 | exp := `{Response status=0x62e keylen=7, extralen=0, bodylen=9}` 49 | if req.String() != exp { 50 | t.Errorf("Expected string=%q, got %q", exp, req.String()) 51 | } 52 | 53 | exp = `Response status=0x62e, opcode=SET, opaque=7242, msg: somevalue` 54 | if req.Error() != exp { 55 | t.Errorf("Expected string=%q, got %q", exp, req.Error()) 56 | } 57 | } 58 | 59 | func TestEncodingResponseWithExtras(t *testing.T) { 60 | res := Response{ 61 | Opcode: SET, 62 | Status: 1582, 63 | Opaque: 7242, 64 | Cas: 938424885, 65 | Extras: []byte{1, 2, 3, 4}, 66 | Key: []byte("somekey"), 67 | Body: []byte("somevalue"), 68 | } 69 | 70 | buf := &bytes.Buffer{} 71 | res.Transmit(buf) 72 | got := buf.Bytes() 73 | 74 | expected := []byte{ 75 | RES_MAGIC, byte(SET), 76 | 0x0, 0x7, // length of key 77 | 0x4, // extra length 78 | 0x0, // reserved 79 | 0x6, 0x2e, // status 80 | 0x0, 0x0, 0x0, 0x14, // Length of remainder 81 | 0x0, 0x0, 0x1c, 0x4a, // opaque 82 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS 83 | 1, 2, 3, 4, // extras 84 | 's', 'o', 'm', 'e', 'k', 'e', 'y', 85 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e', 86 | } 87 | 88 | if len(got) != res.Size() { 89 | t.Fatalf("Expected %v bytes, got %v", got, 90 | len(got)) 91 | } 92 | 93 | if !reflect.DeepEqual(got, expected) { 94 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v", 95 | expected, got) 96 | } 97 | } 98 | 99 | func TestEncodingResponseWithLargeBody(t *testing.T) { 100 | res := Response{ 101 | Opcode: SET, 102 | Status: 1582, 103 | Opaque: 7242, 104 | Cas: 938424885, 105 | Extras: []byte{1, 2, 3, 4}, 106 | Key: []byte("somekey"), 107 | Body: make([]byte, 256), 108 | } 109 | 110 | buf := &bytes.Buffer{} 111 | res.Transmit(buf) 112 | got := buf.Bytes() 113 | 114 | expected := append([]byte{ 115 | RES_MAGIC, byte(SET), 116 | 0x0, 0x7, // length of key 117 | 0x4, // extra length 118 | 0x0, // reserved 119 | 0x6, 0x2e, // status 120 | 0x0, 0x0, 0x1, 0xb, // Length of remainder 121 | 0x0, 0x0, 0x1c, 0x4a, // opaque 122 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS 123 | 1, 2, 3, 4, // extras 124 | 's', 'o', 'm', 'e', 'k', 'e', 'y', 125 | }, make([]byte, 256)...) 126 | 127 | if len(got) != res.Size() { 128 | t.Fatalf("Expected %v bytes, got %v", got, 129 | len(got)) 130 | } 131 | 132 | if !reflect.DeepEqual(got, expected) { 133 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v", 134 | expected, got) 135 | } 136 | } 137 | 138 | func BenchmarkEncodingResponse(b *testing.B) { 139 | req := Response{ 140 | Opcode: SET, 141 | Status: 1582, 142 | Opaque: 7242, 143 | Cas: 938424885, 144 | Extras: []byte{}, 145 | Key: []byte("somekey"), 146 | Body: []byte("somevalue"), 147 | } 148 | 149 | b.SetBytes(int64(req.Size())) 150 | 151 | for i := 0; i < b.N; i++ { 152 | req.Bytes() 153 | } 154 | } 155 | 156 | func BenchmarkEncodingResponseLarge(b *testing.B) { 157 | req := Response{ 158 | Opcode: SET, 159 | Status: 1582, 160 | Opaque: 7242, 161 | Cas: 938424885, 162 | Extras: []byte{}, 163 | Key: []byte("somekey"), 164 | Body: make([]byte, 24*1024), 165 | } 166 | 167 | b.SetBytes(int64(req.Size())) 168 | 169 | for i := 0; i < b.N; i++ { 170 | req.Bytes() 171 | } 172 | } 173 | 174 | func TestIsNotFound(t *testing.T) { 175 | tests := []struct { 176 | e error 177 | is bool 178 | }{ 179 | {nil, false}, 180 | {errors.New("something"), false}, 181 | {&Response{}, false}, 182 | {&Response{Status: KEY_ENOENT}, true}, 183 | } 184 | 185 | for i, x := range tests { 186 | if isNotFound(x.e) != x.is { 187 | t.Errorf("Expected %v for %#v (%v)", x.is, x.e, i) 188 | } 189 | } 190 | } 191 | 192 | func TestIsFatal(t *testing.T) { 193 | tests := []struct { 194 | e error 195 | is bool 196 | }{ 197 | {nil, false}, 198 | {errors.New("something"), true}, 199 | {&Response{}, true}, 200 | {&Response{Status: KEY_ENOENT}, false}, 201 | {&Response{Status: EINVAL}, true}, 202 | {&Response{Status: TMPFAIL}, false}, 203 | } 204 | 205 | for i, x := range tests { 206 | if isFatal(x.e) != x.is { 207 | t.Errorf("Expected %v for %#v (%v)", x.is, x.e, i) 208 | } 209 | } 210 | } 211 | 212 | func TestResponseTransmit(t *testing.T) { 213 | res := Response{Key: []byte("thekey")} 214 | _, err := res.Transmit(ioutil.Discard) 215 | if err != nil { 216 | t.Errorf("Error sending small response: %v", err) 217 | } 218 | 219 | res.Body = make([]byte, 256) 220 | _, err = res.Transmit(ioutil.Discard) 221 | if err != nil { 222 | t.Errorf("Error sending large response thing: %v", err) 223 | } 224 | } 225 | 226 | func TestReceiveResponse(t *testing.T) { 227 | res := Response{ 228 | Opcode: SET, 229 | Status: 74, 230 | Opaque: 7242, 231 | Extras: []byte{1}, 232 | Key: []byte("somekey"), 233 | Body: []byte("somevalue"), 234 | } 235 | 236 | data := res.Bytes() 237 | 238 | res2 := Response{} 239 | _, err := res2.Receive(bytes.NewReader(data), nil) 240 | if err != nil { 241 | t.Fatalf("Error receiving: %v", err) 242 | } 243 | 244 | if !reflect.DeepEqual(res, res2) { 245 | t.Fatalf("Expected %#v == %#v", res, res2) 246 | } 247 | } 248 | 249 | func TestReceiveResponseBadMagic(t *testing.T) { 250 | res := Response{ 251 | Opcode: SET, 252 | Status: 74, 253 | Opaque: 7242, 254 | Extras: []byte{1}, 255 | Key: []byte("somekey"), 256 | Body: []byte("somevalue"), 257 | } 258 | 259 | data := res.Bytes() 260 | data[0] = 0x13 261 | 262 | res2 := Response{} 263 | _, err := res2.Receive(bytes.NewReader(data), nil) 264 | if err == nil { 265 | t.Fatalf("Expected error, got: %#v", res2) 266 | } 267 | } 268 | 269 | func TestReceiveResponseShortHeader(t *testing.T) { 270 | res := Response{ 271 | Opcode: SET, 272 | Status: 74, 273 | Opaque: 7242, 274 | Extras: []byte{1}, 275 | Key: []byte("somekey"), 276 | Body: []byte("somevalue"), 277 | } 278 | 279 | data := res.Bytes() 280 | data[0] = 0x13 281 | 282 | res2 := Response{} 283 | _, err := res2.Receive(bytes.NewReader(data[:13]), nil) 284 | if err == nil { 285 | t.Fatalf("Expected error, got: %#v", res2) 286 | } 287 | } 288 | 289 | func TestReceiveResponseShortBody(t *testing.T) { 290 | res := Response{ 291 | Opcode: SET, 292 | Status: 74, 293 | Opaque: 7242, 294 | Extras: []byte{1}, 295 | Key: []byte("somekey"), 296 | Body: []byte("somevalue"), 297 | } 298 | 299 | data := res.Bytes() 300 | data[0] = 0x13 301 | 302 | res2 := Response{} 303 | _, err := res2.Receive(bytes.NewReader(data[:len(data)-3]), nil) 304 | if err == nil { 305 | t.Fatalf("Expected error, got: %#v", res2) 306 | } 307 | } 308 | 309 | func TestReceiveResponseWithBuffer(t *testing.T) { 310 | res := Response{ 311 | Opcode: SET, 312 | Status: 74, 313 | Opaque: 7242, 314 | Extras: []byte{1}, 315 | Key: []byte("somekey"), 316 | Body: []byte("somevalue"), 317 | } 318 | 319 | data := res.Bytes() 320 | 321 | res2 := Response{} 322 | buf := make([]byte, HDR_LEN) 323 | _, err := res2.Receive(bytes.NewReader(data), buf) 324 | if err != nil { 325 | t.Fatalf("Error receiving: %v", err) 326 | } 327 | 328 | if !reflect.DeepEqual(res, res2) { 329 | t.Fatalf("Expected %#v == %#v", res, res2) 330 | } 331 | } 332 | 333 | func TestReceiveResponseNoContent(t *testing.T) { 334 | res := Response{ 335 | Opcode: SET, 336 | Status: 74, 337 | Opaque: 7242, 338 | } 339 | 340 | data := res.Bytes() 341 | 342 | res2 := Response{} 343 | _, err := res2.Receive(bytes.NewReader(data), nil) 344 | if err != nil { 345 | t.Fatalf("Error receiving: %v", err) 346 | } 347 | 348 | // Can't use reflect here because []byte{} != nil, though they 349 | // look the same. 350 | if fmt.Sprintf("%#v", res) != fmt.Sprintf("%#v", res2) { 351 | t.Fatalf("Expected %#v == %#v", res, res2) 352 | } 353 | } 354 | 355 | func BenchmarkReceiveResponse(b *testing.B) { 356 | req := Response{ 357 | Opcode: SET, 358 | Status: 183, 359 | Cas: 0, 360 | Opaque: 7242, 361 | Extras: []byte{1}, 362 | Key: []byte("somekey"), 363 | Body: []byte("somevalue"), 364 | } 365 | 366 | data := req.Bytes() 367 | rdr := bytes.NewReader(data) 368 | 369 | b.SetBytes(int64(len(data))) 370 | 371 | b.ResetTimer() 372 | buf := make([]byte, HDR_LEN) 373 | for i := 0; i < b.N; i++ { 374 | res2 := Response{} 375 | rdr.Seek(0, 0) 376 | res2.Receive(rdr, buf) 377 | } 378 | } 379 | 380 | func BenchmarkReceiveResponseNoBuf(b *testing.B) { 381 | req := Response{ 382 | Opcode: SET, 383 | Status: 183, 384 | Cas: 0, 385 | Opaque: 7242, 386 | Extras: []byte{1}, 387 | Key: []byte("somekey"), 388 | Body: []byte("somevalue"), 389 | } 390 | 391 | data := req.Bytes() 392 | rdr := bytes.NewReader(data) 393 | 394 | b.SetBytes(int64(len(data))) 395 | 396 | b.ResetTimer() 397 | for i := 0; i < b.N; i++ { 398 | res2 := Response{} 399 | rdr.Seek(0, 0) 400 | res2.Receive(rdr, nil) 401 | } 402 | } 403 | 404 | func isNotFound(e error) bool { 405 | return errStatus(e) == KEY_ENOENT 406 | } 407 | -------------------------------------------------------------------------------- /memcached/transport.go: -------------------------------------------------------------------------------- 1 | package memcached 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | ) 7 | 8 | // UnwrapMemcachedError converts memcached errors to normal responses. 9 | // 10 | // If the error is a memcached response, declare the error to be nil 11 | // so a client can handle the status without worrying about whether it 12 | // indicates success or failure. 13 | func UnwrapMemcachedError(err error) *Response { 14 | var res *Response 15 | if errors.As(err, &res) { 16 | return res 17 | } 18 | return nil 19 | } 20 | 21 | func getResponse(s io.Reader, hdrBytes []byte) (rv *Response, n int, err error) { 22 | if s == nil { 23 | return nil, 0, ErrNoServers 24 | } 25 | 26 | rv = &Response{} 27 | n, err = rv.Receive(s, hdrBytes) 28 | if err == nil && rv.Status != SUCCESS { 29 | err = wrapMemcachedResp(rv) 30 | } 31 | return rv, n, err 32 | } 33 | 34 | func transmitRequest(o io.Writer, req *Request) (int, error) { 35 | if o == nil { 36 | return 0, ErrNoServers 37 | } 38 | n, err := req.Transmit(o) 39 | return n, err 40 | } 41 | -------------------------------------------------------------------------------- /pool/pool.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "golang.org/x/sync/semaphore" 9 | ) 10 | 11 | const token int64 = 1 12 | 13 | var ( 14 | ErrClosedPool = fmt.Errorf("pool is closed") 15 | ErrNewFuncNil = fmt.Errorf("newFunc for pool is nil, can not create connection") 16 | ErrAcquireTimeout = fmt.Errorf("timeout for Acquire from the pool. Need to increase the maxCap for pool") 17 | ) 18 | 19 | var _ ConnPool = (*Pool)(nil) 20 | 21 | type ConnPool interface { 22 | Get() (any, error) 23 | Pop() (any, bool) 24 | Put(v any) 25 | Destroy() 26 | Len() int 27 | Close(v any) 28 | } 29 | 30 | // Pool common connection pool 31 | type Pool struct { 32 | ctx context.Context 33 | 34 | // newConn are functions for creating new connections if maxCap is not reached. 35 | newConn func() (any, error) 36 | // closeConn is a function for graceful closed connections. 37 | closeConn func(any) 38 | 39 | // sema is a semaphore implementation for control a max capacity of pool 40 | sema *semaphore.Weighted 41 | // aqSemaTimeout is an amount of time to acquire conn from pool 42 | aqSemaTimeout time.Duration 43 | 44 | // store is a chan with connections. 45 | store chan any 46 | // storeClose is a flag indicating that store is closed. 47 | storeClose chan struct{} 48 | // maxCap is maximum of total connections used 49 | maxCap int32 50 | } 51 | 52 | // New create a pool with capacity 53 | func New(ctx context.Context, maxCap int32, acquireSemaTimeout time.Duration, newFunc func() (any, error), closeFunc func(any)) *Pool { 54 | if maxCap <= 0 { 55 | panic("invalid memcached maxCap") 56 | } 57 | 58 | return &Pool{ 59 | ctx: ctx, 60 | newConn: newFunc, 61 | closeConn: closeFunc, 62 | sema: semaphore.NewWeighted(int64(maxCap)), 63 | aqSemaTimeout: acquireSemaTimeout, 64 | store: make(chan any, maxCap), 65 | storeClose: make(chan struct{}), 66 | maxCap: maxCap, 67 | } 68 | } 69 | 70 | // Len returns current connections in pool 71 | func (p *Pool) Len() int { 72 | return len(p.store) 73 | } 74 | 75 | // Get returns a conn from store or create one 76 | func (p *Pool) Get() (any, error) { 77 | var aqTimeout bool 78 | 79 | for { 80 | select { 81 | case v, ok := <-p.store: 82 | if ok { 83 | return v, nil 84 | } 85 | return nil, ErrClosedPool 86 | default: 87 | if aqTimeout { 88 | return nil, ErrAcquireTimeout 89 | } 90 | if cn, timeout, err := p.create(); timeout { 91 | // last try get conn after timeout 92 | aqTimeout = true 93 | continue 94 | } else { 95 | return cn, err 96 | } 97 | } 98 | } 99 | } 100 | 101 | // Pop return available conn without block 102 | func (p *Pool) Pop() (any, bool) { 103 | if p.isClosed() { 104 | return nil, false 105 | } 106 | 107 | select { 108 | case v, ok := <-p.store: 109 | return v, ok 110 | default: 111 | return nil, false 112 | } 113 | } 114 | 115 | // Put set back conn into store again 116 | func (p *Pool) Put(v any) { 117 | if p.isClosed() { 118 | return 119 | } 120 | select { 121 | case p.store <- v: 122 | default: 123 | } 124 | } 125 | 126 | // Destroy close all connections and deactivate the pool 127 | func (p *Pool) Destroy() { 128 | if p.isClosed() { 129 | // pool already destroyed 130 | return 131 | } 132 | 133 | close(p.storeClose) 134 | close(p.store) 135 | for v := range p.store { 136 | p.close(v) 137 | } 138 | } 139 | 140 | // Close is closed a connection 141 | func (p *Pool) Close(v any) { 142 | p.close(v) 143 | } 144 | 145 | func (p *Pool) create() (any, bool, error) { 146 | ctx, cancel := context.WithTimeout(p.ctx, p.aqSemaTimeout) 147 | defer cancel() 148 | 149 | if err := p.sema.Acquire(ctx, token); err != nil { 150 | return nil, true, nil 151 | } 152 | 153 | if p.isClosed() { 154 | p.sema.Release(token) 155 | return nil, false, ErrClosedPool 156 | } 157 | 158 | if p.newConn == nil { 159 | return nil, false, ErrNewFuncNil 160 | } 161 | cn, err := p.newConn() 162 | if err != nil { 163 | p.sema.Release(token) 164 | return nil, false, err 165 | } 166 | return cn, false, nil 167 | } 168 | 169 | func (p *Pool) close(v any) { 170 | p.sema.Release(token) 171 | if p.closeConn != nil { 172 | p.closeConn(v) 173 | } 174 | } 175 | 176 | func (p *Pool) isClosed() bool { 177 | select { 178 | case <-p.storeClose: 179 | return true 180 | default: 181 | return false 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /pool/pool_test.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "context" 5 | "math/rand" 6 | "net/http" 7 | "sync" 8 | "sync/atomic" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | const defaultSocketPoolingTimeout = 50 * time.Millisecond 16 | 17 | type testConnection struct{} 18 | 19 | func newTestConnection() (any, error) { 20 | return &testConnection{}, nil 21 | } 22 | 23 | func newTestConnectionWithErr() (any, error) { 24 | return nil, http.ErrHandlerTimeout 25 | } 26 | 27 | func closeTestConnection(any) { 28 | // Do nothing 29 | } 30 | 31 | func TestPool(t *testing.T) { 32 | assert.Panics(t, func() { 33 | _ = New(context.TODO(), 0, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection) 34 | }, "was expected panic") 35 | 36 | defer func() { 37 | if pErr := recover(); pErr != nil { 38 | t.Fatalf("pool have panic - %v", pErr) 39 | } 40 | }() 41 | 42 | p := New(context.TODO(), 2, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection) 43 | defer p.Destroy() 44 | 45 | _, ok := p.Pop() 46 | assert.False(t, ok, "Pop return ok != false for empty pool") 47 | 48 | assert.Equalf(t, 0, p.Len(), "Expected pool length to be 0, got %d", p.Len()) 49 | 50 | conn, err := p.Get() 51 | assert.Nilf(t, err, "Get from empty pool have error - %v", err) 52 | 53 | assert.Equalf(t, 0, p.Len(), "Expected pool length to be 0 after getting a connection, got %d", p.Len()) 54 | 55 | p.Put(conn) 56 | assert.Equalf(t, 1, p.Len(), "Expected pool length to be 1 after putting back a connection, got %d", p.Len()) 57 | 58 | _, ok = p.Pop() 59 | assert.True(t, ok, "Pop return ok != true for non-empty pool") 60 | 61 | conn, err = p.Get() 62 | assert.Nilf(t, err, "Get from pool have error - %v", err) 63 | 64 | assert.Equalf(t, 0, p.Len(), "Expected pool length to be 0 after getting a connection from the pool, got %d", p.Len()) 65 | 66 | p.Put(conn) 67 | p.Destroy() 68 | assert.Equalf(t, 0, p.Len(), "Expected pool length to be 0 after destroying the pool, got %d", p.Len()) 69 | 70 | _, err = p.Get() 71 | assert.ErrorIsf(t, err, ErrClosedPool, "Expected to get an error when getting from a destroyed pool, got %v", err) 72 | 73 | p.Put(conn) 74 | assert.ErrorIsf(t, err, ErrClosedPool, "Expected to put an error when putting a destroyed pool, got %v", err) 75 | } 76 | 77 | func TestPoolConcurrency(t *testing.T) { 78 | p := New(context.TODO(), 10, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection) 79 | defer p.Destroy() 80 | 81 | var wg sync.WaitGroup 82 | for i := 0; i < 10; i++ { 83 | wg.Add(1) 84 | go func() { 85 | defer wg.Done() 86 | conn, err := p.Get() 87 | assert.Nilf(t, err, "Get have error %v", err) 88 | <-time.After(5 * time.Millisecond) 89 | p.Put(conn) 90 | }() 91 | } 92 | wg.Wait() 93 | 94 | assert.Equalf(t, 10, p.Len(), "Expected pool length to be 10, got %d", p.Len()) 95 | } 96 | 97 | func TestCountConns(t *testing.T) { 98 | const count = 300 99 | p := New(context.TODO(), int32(count), defaultSocketPoolingTimeout, newTestConnection, closeTestConnection) 100 | 101 | conn := atomic.Int32{} 102 | wg1 := sync.WaitGroup{} 103 | 104 | wg1.Add(3) 105 | go func() { 106 | defer wg1.Done() 107 | for i := 0; i < count/3; i++ { 108 | _, pErr := p.Get() 109 | conn.Add(1) 110 | assert.Nilf(t, pErr, "Get have error - %v", pErr) 111 | } 112 | }() 113 | go func() { 114 | defer wg1.Done() 115 | for i := 0; i < count/3; i++ { 116 | _, pErr := p.Get() 117 | conn.Add(1) 118 | assert.Nilf(t, pErr, "Get have error - %v", pErr) 119 | } 120 | }() 121 | go func() { 122 | defer wg1.Done() 123 | for i := 0; i < count/3; i++ { 124 | _, pErr := p.Get() 125 | conn.Add(1) 126 | assert.Nilf(t, pErr, "Get have error - %v", pErr) 127 | } 128 | }() 129 | 130 | wg1.Wait() 131 | 132 | assert.Equalf(t, conn.Load(), int32(count), "Not equal init and received conns. have - %d, expacted - %d ", conn.Load(), int32(count)) 133 | 134 | for i := 0; i < int(conn.Load()); i++ { 135 | p.Put(testConnection{}) 136 | } 137 | 138 | // p.store is full, over-conn 139 | p.Put(testConnection{}) 140 | 141 | wg1.Add(2) 142 | 143 | go func() { 144 | defer wg1.Done() 145 | p.Destroy() 146 | }() 147 | go func() { 148 | defer wg1.Done() 149 | p.Destroy() 150 | }() 151 | wg1.Wait() 152 | 153 | cn, err := p.Get() 154 | assert.Nil(t, cn, "Get: after method Destroy, pool is closed and should return cn == nil") 155 | assert.ErrorIs(t, err, ErrClosedPool, "Get: after method Destroy, pool is closed, want error ErrClosedPool") 156 | 157 | p2 := New(context.TODO(), count, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection) 158 | 159 | var ( 160 | mu sync.RWMutex 161 | conns []any 162 | wg2 sync.WaitGroup 163 | ) 164 | 165 | addToSl := func(c any) { 166 | mu.Lock() 167 | defer mu.Unlock() 168 | conns = append(conns, c) 169 | } 170 | 171 | getFromSl := func() any { 172 | mu.Lock() 173 | defer mu.Unlock() 174 | return conns[len(conns)-1] 175 | } 176 | 177 | getSlLen := func() int { 178 | mu.RLock() 179 | defer mu.RUnlock() 180 | return len(conns) 181 | } 182 | 183 | wg2.Add(1) 184 | go func() { 185 | defer wg2.Done() 186 | for i := 0; i < count/2; i++ { 187 | c, gErr := p2.Get() 188 | assert.Nilf(t, gErr, "Get have error") 189 | //nolint:gosec 190 | if rand.Int()%2 == 0 { 191 | addToSl(c) 192 | } else { 193 | p2.Put(c) 194 | } 195 | } 196 | }() 197 | wg2.Add(1) 198 | go func() { 199 | defer wg2.Done() 200 | for i := 0; i < count/2; i++ { 201 | c, gErr := p2.Get() 202 | assert.Nilf(t, gErr, "Get have error") 203 | //nolint:gosec 204 | if rand.Int()%2 == 0 { 205 | addToSl(c) 206 | } else { 207 | p2.Put(c) 208 | } 209 | } 210 | }() 211 | 212 | wg2.Add(1) 213 | go func() { 214 | defer wg2.Done() 215 | <-time.After(200 * time.Millisecond) 216 | c, gErr := p2.Get() 217 | assert.Nilf(t, gErr, "Get with full cap have error") 218 | addToSl(c) 219 | }() 220 | 221 | wg2.Wait() 222 | 223 | for i := 0; i < getSlLen(); i++ { 224 | go func() { 225 | p2.Put(getFromSl()) 226 | }() 227 | } 228 | 229 | p3 := New(context.TODO(), 1, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection) 230 | 231 | // maxConns is full 232 | _, _ = p3.Get() 233 | 234 | cn, err = p3.Get() 235 | assert.Nil(t, cn, "Get: after a timeout, it should return cn == nil") 236 | assert.ErrorIsf(t, ErrAcquireTimeout, err, "Get: after a timeout, it should return ErrAcquireTimeout") 237 | 238 | _, ok := p3.Pop() 239 | assert.False(t, ok, "Pop: pool with empty pool it should return false for second arg") 240 | 241 | p3.Destroy() 242 | cn, ok = p3.Pop() 243 | assert.Nil(t, cn, "Pop: after method Destroy, pool is closed and should return cn == nil") 244 | assert.False(t, ok, "Pop: after method Destroy, pool is closed and should return false for second arg") 245 | 246 | p4 := New(context.TODO(), 1, defaultSocketPoolingTimeout, newTestConnectionWithErr, closeTestConnection) 247 | 248 | cn, err = p4.Get() 249 | assert.Nil(t, cn, "Get: create new conn returned an error, conn should be nil") 250 | assert.ErrorIs(t, err, http.ErrHandlerTimeout, "Get: error should be equal - http.ErrHandlerTimeout") 251 | 252 | p5 := New(context.TODO(), 1, time.Second, nil, nil) 253 | 254 | cn, err = p5.Get() 255 | assert.Nil(t, cn, "Get: newFunc equal nil, conn should be nil") 256 | assert.ErrorIs(t, err, ErrNewFuncNil, "Get: error should be equal ErrNewFuncNil") 257 | 258 | p6 := New(context.TODO(), 1, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection) 259 | bcn, err := p6.Get() 260 | assert.NotNil(t, bcn, "Get: conn cannot be nil") 261 | assert.Nil(t, err, "Get: error should be nil") 262 | wg := sync.WaitGroup{} 263 | wg.Add(1) 264 | go func() { 265 | defer wg.Done() 266 | cn, err = p6.Get() 267 | assert.Nil(t, cn, "Get: conn should be nil") 268 | assert.ErrorIs(t, err, ErrClosedPool, "Get: error should be equal ErrClosedPool") 269 | }() 270 | wg.Add(1) 271 | go func() { 272 | defer wg.Done() 273 | <-time.After(10 * time.Millisecond) 274 | p6.Close(bcn) 275 | p6.Destroy() 276 | }() 277 | 278 | wg.Wait() 279 | } 280 | -------------------------------------------------------------------------------- /utils/addr.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "net" 5 | "strings" 6 | ) 7 | 8 | // staticAddr caches the Network() and String() values from any net.Addr. 9 | type staticAddr struct { 10 | ntw, str string 11 | } 12 | 13 | func newStaticAddr(a net.Addr) net.Addr { 14 | return &staticAddr{ 15 | ntw: a.Network(), 16 | str: a.String(), 17 | } 18 | } 19 | 20 | func (s *staticAddr) Network() string { return s.ntw } 21 | func (s *staticAddr) String() string { return s.str } 22 | 23 | // AddrRepr a string representation of the server address implements net.Addr 24 | func AddrRepr(server string) (net.Addr, error) { 25 | var nAddr net.Addr 26 | if strings.Contains(server, "/") { 27 | addr, _ := net.ResolveUnixAddr("unix", server) 28 | nAddr = newStaticAddr(addr) 29 | } else { 30 | tcpAddr, err := net.ResolveTCPAddr("tcp", server) 31 | if err != nil { 32 | return nil, err 33 | } 34 | nAddr = newStaticAddr(tcpAddr) 35 | } 36 | 37 | return nAddr, nil 38 | } 39 | -------------------------------------------------------------------------------- /utils/addr_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestNewStaticAddr(t *testing.T) { 12 | tcpAddr := &net.TCPAddr{ 13 | IP: net.IPv4(127, 0, 0, 1), 14 | Port: 8080, 15 | } 16 | staticAddr := newStaticAddr(tcpAddr) 17 | if staticAddr.Network() != tcpAddr.Network() { 18 | t.Errorf("Expected Network() to be %s, got %s", tcpAddr.Network(), staticAddr.Network()) 19 | } 20 | if staticAddr.String() != tcpAddr.String() { 21 | t.Errorf("Expected String() to be %s, got %s", tcpAddr.String(), staticAddr.String()) 22 | } 23 | } 24 | 25 | func TestAddrRepr(t *testing.T) { 26 | type args struct { 27 | server string 28 | } 29 | tests := []struct { 30 | name string 31 | args args 32 | want net.Addr 33 | wantErr bool 34 | }{ 35 | { 36 | name: "invalid address", 37 | args: args{server: "invalid-address"}, 38 | want: nil, 39 | wantErr: true, 40 | }, 41 | { 42 | name: "unix", 43 | args: args{server: "/var/unix.sock"}, 44 | want: &staticAddr{ 45 | ntw: "unix", 46 | str: "/var/unix.sock", 47 | }, 48 | wantErr: false, 49 | }, 50 | { 51 | name: "tcp", 52 | args: args{server: "127.0.0.1:8080"}, 53 | want: &staticAddr{ 54 | ntw: "tcp", 55 | str: "127.0.0.1:8080", 56 | }, 57 | wantErr: false, 58 | }, 59 | } 60 | for _, tt := range tests { 61 | t.Run(tt.name, func(t *testing.T) { 62 | got, err := AddrRepr(tt.args.server) 63 | if tt.wantErr { 64 | assert.NotNilf(t, err, fmt.Sprintf("AddrRepr(%v) Expected an error, got nil", tt.args.server)) 65 | } 66 | assert.Equalf(t, tt.want, got, "AddrRepr(%v)", tt.args.server) 67 | }) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /utils/math.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "math" 4 | 5 | const epsilon = 1e-6 6 | 7 | // CalcEntropy calculates the entropy of m. 8 | func CalcEntropy(m map[any]int) float64 { 9 | if len(m) == 0 || len(m) == 1 { 10 | return 1 11 | } 12 | 13 | var entropy float64 14 | var total int 15 | for _, v := range m { 16 | total += v 17 | } 18 | 19 | for _, v := range m { 20 | proba := float64(v) / float64(total) 21 | if proba < epsilon { 22 | proba = epsilon 23 | } 24 | entropy -= proba * math.Log2(proba) 25 | } 26 | 27 | return entropy / math.Log2(float64(len(m))) 28 | } 29 | -------------------------------------------------------------------------------- /utils/math_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | ) 7 | 8 | func TestCalcEntropyWithEmptyMap(t *testing.T) { 9 | m := make(map[any]int) 10 | result := CalcEntropy(m) 11 | expected := 1.0 12 | 13 | if math.Abs(result-expected) > epsilon { 14 | t.Errorf("Expected: %f, Got: %f", expected, result) 15 | } 16 | } 17 | 18 | func TestCalcEntropyWithSingleValue(t *testing.T) { 19 | m := map[any]int{"A": 5} 20 | result := CalcEntropy(m) 21 | expected := 1.0 22 | 23 | if math.Abs(result-expected) > epsilon { 24 | t.Errorf("Expected: %f, Got: %f", expected, result) 25 | } 26 | } 27 | 28 | func TestCalcEntropyWithMultipleValues(t *testing.T) { 29 | m := map[any]int{"A": 5, "B": 5, "C": 5} 30 | result := CalcEntropy(m) 31 | expected := 1.0 32 | 33 | if math.Abs(result-expected) > epsilon { 34 | t.Errorf("Expected: %f, Got: %f", expected, result) 35 | } 36 | } 37 | 38 | func TestCalcEntropyWithDifferentProbabilities(t *testing.T) { 39 | m := map[any]int{"A": 1, "B": 3, "C": 6} 40 | result := CalcEntropy(m) 41 | 42 | total := 0 43 | probabilities := make(map[any]float64) 44 | for _, count := range m { 45 | total += count 46 | } 47 | for key, count := range m { 48 | probabilities[key] = float64(count) / float64(total) 49 | } 50 | 51 | entropy := 0.0 52 | for _, p := range probabilities { 53 | entropy -= p * math.Log2(p) 54 | } 55 | expected := entropy / math.Log2(float64(len(m))) 56 | 57 | if math.Abs(result-expected) > epsilon { 58 | t.Errorf("Expected: %f, Got: %f", expected, result) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /utils/stringer.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | ) 8 | 9 | // Repr returns the string representation of v. 10 | func Repr(v any) string { 11 | if v == nil { 12 | return "" 13 | } 14 | 15 | // if func (v *Type) String() string, we can't use Elem() 16 | switch vt := v.(type) { 17 | case fmt.Stringer: 18 | return vt.String() 19 | } 20 | 21 | val := reflect.ValueOf(v) 22 | for val.Kind() == reflect.Ptr && !val.IsNil() { 23 | val = val.Elem() 24 | } 25 | 26 | return reprOfValue(val) 27 | } 28 | 29 | func reprOfValue(val reflect.Value) string { 30 | switch vt := val.Interface().(type) { 31 | case bool: 32 | return strconv.FormatBool(vt) 33 | case error: 34 | return vt.Error() 35 | case float32: 36 | return strconv.FormatFloat(float64(vt), 'f', -1, 32) 37 | case float64: 38 | return strconv.FormatFloat(vt, 'f', -1, 64) 39 | case fmt.Stringer: 40 | return vt.String() 41 | case int: 42 | return strconv.Itoa(vt) 43 | case int8: 44 | return strconv.Itoa(int(vt)) 45 | case int16: 46 | return strconv.Itoa(int(vt)) 47 | case int32: 48 | return strconv.Itoa(int(vt)) 49 | case int64: 50 | return strconv.FormatInt(vt, 10) 51 | case string: 52 | return vt 53 | case uint: 54 | return strconv.FormatUint(uint64(vt), 10) 55 | case uint8: 56 | return strconv.FormatUint(uint64(vt), 10) 57 | case uint16: 58 | return strconv.FormatUint(uint64(vt), 10) 59 | case uint32: 60 | return strconv.FormatUint(uint64(vt), 10) 61 | case uint64: 62 | return strconv.FormatUint(vt, 10) 63 | case []byte: 64 | return string(vt) 65 | default: 66 | return fmt.Sprint(val.Interface()) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /utils/stringer_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestRepr(t *testing.T) { 13 | var ( 14 | f32 float32 = 1.1 15 | f64 = 2.2 16 | i8 int8 = 1 17 | i16 int16 = 2 18 | i32 int32 = 3 19 | i64 int64 = 4 20 | u8 uint8 = 5 21 | u16 uint16 = 6 22 | u32 uint32 = 7 23 | u64 uint64 = 8 24 | ) 25 | tests := []struct { 26 | v any 27 | expect string 28 | }{ 29 | { 30 | nil, 31 | "", 32 | }, 33 | { 34 | mockStringable{}, 35 | "mocked", 36 | }, 37 | { 38 | new(mockStringable), 39 | "mocked", 40 | }, 41 | { 42 | newMockPtr(), 43 | "mockptr", 44 | }, 45 | { 46 | &mockOpacity{ 47 | val: 1, 48 | }, 49 | "{1}", 50 | }, 51 | { 52 | true, 53 | "true", 54 | }, 55 | { 56 | false, 57 | "false", 58 | }, 59 | { 60 | f32, 61 | "1.1", 62 | }, 63 | { 64 | f64, 65 | "2.2", 66 | }, 67 | { 68 | i8, 69 | "1", 70 | }, 71 | { 72 | i16, 73 | "2", 74 | }, 75 | { 76 | i32, 77 | "3", 78 | }, 79 | { 80 | i64, 81 | "4", 82 | }, 83 | { 84 | u8, 85 | "5", 86 | }, 87 | { 88 | u16, 89 | "6", 90 | }, 91 | { 92 | u32, 93 | "7", 94 | }, 95 | { 96 | u64, 97 | "8", 98 | }, 99 | { 100 | []byte(`abcd`), 101 | "abcd", 102 | }, 103 | { 104 | mockOpacity{val: 1}, 105 | "{1}", 106 | }, 107 | } 108 | 109 | for _, test := range tests { 110 | t.Run(test.expect, func(t *testing.T) { 111 | assert.Equal(t, test.expect, Repr(test.v)) 112 | }) 113 | } 114 | } 115 | 116 | func TestReprOfValue(t *testing.T) { 117 | t.Run("error", func(t *testing.T) { 118 | assert.Equal(t, "error", reprOfValue(reflect.ValueOf(errors.New("error")))) 119 | }) 120 | 121 | t.Run("stringer", func(t *testing.T) { 122 | assert.Equal(t, "1.23", reprOfValue(reflect.ValueOf(json.Number("1.23")))) 123 | }) 124 | 125 | t.Run("int", func(t *testing.T) { 126 | assert.Equal(t, "1", reprOfValue(reflect.ValueOf(1))) 127 | }) 128 | 129 | t.Run("int", func(t *testing.T) { 130 | assert.Equal(t, "1", reprOfValue(reflect.ValueOf("1"))) 131 | }) 132 | 133 | t.Run("int", func(t *testing.T) { 134 | assert.Equal(t, "1", reprOfValue(reflect.ValueOf(uint(1)))) 135 | }) 136 | } 137 | 138 | type mockStringable struct{} 139 | 140 | func (m mockStringable) String() string { 141 | return "mocked" 142 | } 143 | 144 | type mockPtr struct{} 145 | 146 | func newMockPtr() *mockPtr { 147 | return new(mockPtr) 148 | } 149 | 150 | func (m *mockPtr) String() string { 151 | return "mockptr" 152 | } 153 | 154 | type mockOpacity struct { 155 | val int 156 | } 157 | --------------------------------------------------------------------------------