├── .github
└── workflows
│ └── coverage.yml
├── .gitignore
├── .golangci.yml
├── LICENSE
├── Makefile
├── README.md
├── assets
└── logo.png
├── consistenthash
├── hash_calculate.go
├── hashring.go
└── hashring_test.go
├── docker-compose.yml
├── examples
└── main.go
├── go.mod
├── go.sum
├── logger
└── zap.go
├── memcached
├── constants.go
├── constants_test.go
├── errors.go
├── memcached.go
├── memcached_test.go
├── metrics.go
├── metrics_test.go
├── node_provider.go
├── node_provider_test.go
├── options.go
├── options_test.go
├── requests.go
├── requests_test.go
├── responses.go
├── responses_test.go
└── transport.go
├── pool
├── pool.go
└── pool_test.go
└── utils
├── addr.go
├── addr_test.go
├── math.go
├── math_test.go
├── stringer.go
└── stringer_test.go
/.github/workflows/coverage.yml:
--------------------------------------------------------------------------------
1 | name: Coverage
2 |
3 | on:
4 | pull_request:
5 | branches: [ main ]
6 |
7 | jobs:
8 | build:
9 | runs-on: ${{ matrix.os }}
10 | permissions:
11 | contents: write
12 | services:
13 | memcached:
14 | image: memcached:latest
15 | ports:
16 | - 11211:11211
17 | strategy:
18 | matrix:
19 | os: [ ubuntu-latest ]
20 | go: [ 1.21.x ]
21 | steps:
22 | - uses: actions/checkout@v4
23 |
24 | - name: Set up Go
25 | uses: actions/setup-go@v2
26 | with:
27 | go-version: ${{ matrix.go }}
28 |
29 | - name: Test
30 | run: |
31 | go test -v -cover ./... -coverprofile coverage.out
32 | go tool cover -func coverage.out -o coverage.out
33 |
34 | - name: Go Coverage Badge
35 | uses: tj-actions/coverage-badge-go@v1
36 | if: ${{ runner.os == 'linux' }}
37 | with:
38 | green: 80
39 | filename: coverage.out
40 |
41 | - uses: stefanzweifel/git-auto-commit-action@v5
42 | id: auto-commit-action
43 | with:
44 | commit_message: 'chore: apply code coverage badge'
45 | file_pattern: ./README.md
46 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Go template
2 | # If you prefer the allow list template instead of the deny list, see community template:
3 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
4 | #
5 | # Binaries for programs and plugins
6 | *.exe
7 | *.exe~
8 | *.dll
9 | *.so
10 | *.dylib
11 | /bin
12 |
13 | # IDEs
14 | .idea
15 | .vscode
16 |
17 | # Test binary, built with `go test -c`
18 | *.test
19 |
20 | # Output of the go coverage tool, specifically when used with LiteIDE
21 | *.out
22 |
23 | # macos
24 | .DS_STORE
25 |
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | # More info on config here: https://github.com/golangci/golangci-lint#config-file
2 | run:
3 | deadline: 60s
4 | issues-exit-code: 1
5 | tests: true
6 |
7 | output:
8 | format: colored-line-number
9 | print-issued-lines: true
10 | print-linter-name: true
11 |
12 | linters-settings:
13 | govet:
14 | check-shadowing: true
15 | golint:
16 | min-confidence: 0
17 | dupl:
18 | threshold: 100
19 | goconst:
20 | min-len: 2
21 | min-occurrences: 2
22 |
23 | linters:
24 | disable-all: true
25 | enable:
26 | - govet
27 | - gosimple
28 | - errcheck
29 | - ineffassign
30 | - typecheck
31 | - goconst
32 | - gosec
33 | - nilnil
34 | - goimports
35 | - megacheck
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 AliExpress Russia
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | LOCAL_BIN=$(CURDIR)/bin
2 |
3 | GOENV=PATH=$(LOCAL_BIN):$(PATH)
4 |
5 | GOLANGCI_BIN=$(LOCAL_BIN)/golangci-lint
6 | $(GOLANGCI_BIN):
7 | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(LOCAL_BIN) v1.55.2
8 |
9 | .PHONY: lint
10 | lint: $(GOLANGCI_BIN)
11 | $(GOENV) $(GOLANGCI_BIN) run --fix -v ./...
12 |
13 | .PHONY: test
14 | test:
15 | $(GOENV) go test -race ./...
16 |
17 | .PHONY: test-cover
18 | test-cover:
19 | $(GOENV) go test ./... -cover
20 |
21 | .PHONY: test-cover-html
22 | test-cover-html:
23 | $(GOENV) go test ./... -coverprofile=prof.out
24 | $(GOENV) go tool cover -html=prof.out
25 |
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Gomemcached
2 |
3 | ---
4 |
5 |

6 |
7 | [](https://github.com/aliexpressru/gomemcached)
8 | [](https://github.com/aliexpressru/gomemcached/tags)
9 | [](https://pkg.go.dev/github.com/aliexpressru/gomemcached)
10 |
11 | [](https://goreportcard.com/report/github.com/aliexpressru/gomemcached)
12 | 
13 |
14 | [](https://github.com/avelino/awesome-go?tab=readme-ov-file#nosql-database-drivers)
15 |
16 |
17 | ___
18 | `Gomemcached` is a Golang Memcached client designed to interact with multiple instances as shards. Implements sharding using a Consistent Hash.
19 | ___
20 |
21 | ### Configuration
22 |
23 | ```yaml
24 | - name: MEMCACHED_HEADLESS_SERVICE_ADDRESS
25 | value: "my-memchached-service-headless.namespace.svc.cluster.local"
26 | ```
27 |
28 | `MEMCACHED_HEADLESS_SERVICE_ADDRESS` groups all memcached instances by ip addresses using dns lookup.
29 |
30 | Default Memcached port is `11211`, but you can also specify it in config.
31 |
32 | ```yaml
33 | - name: MEMCACHED_PORT
34 | value: "12345"
35 | ```
36 |
37 | For local run or if you have a static amount and setup of pods you can specify Servers (list separated by commas along with the port) manually instead of setting the
38 | HeadlessServiceAddress:
39 |
40 | ```yaml
41 | - name: MEMCACHED_SERVERS
42 | value: "127.0.0.1:11211,192.168.0.1:1234"
43 | ```
44 |
45 | ___
46 |
47 | ### Usage
48 |
49 | Initialization client and connected to memcached servers.
50 |
51 | ```go
52 | mcl, err := memcached.InitFromEnv()
53 | mustInit(err)
54 | a.AddCloser(func () error {
55 | mcl.CloseAllConns()
56 | return nil
57 | })
58 | ```
59 | [More examples](examples/main.go)
60 |
61 | To use SASL specify option for InitFromEnv:
62 |
63 | ```go
64 | memcached.InitFromEnv(memcached.WithAuthentication("", ""))
65 | ```
66 |
67 | Can use Options with InitFromEnv to customize the client to suit your needs. However, for basic use, it is recommended
68 | to use the default client implementation.
69 |
70 | ---
71 |
72 | ### Recommended Versions
73 |
74 | This project is developed and tested with the following recommended versions:
75 |
76 | - Go: 1.21 or higher
77 | - [Download Go](https://golang.org/dl/)
78 |
79 | - Memcached: 1.6.9 or higher
80 | - [Memcached Releases](https://memcached.org/downloads)
81 |
82 | ---
83 |
84 | ### Dependencies
85 |
86 | This project utilizes the following third-party libraries, each governed by the MIT License:
87 |
88 | 1. [gomemcache](https://github.com/bradfitz/gomemcache)
89 | - Description: A Go client library for the memcached server.
90 | - Used for: Primary client methods for interacting with the library.
91 | - License: Apache License 2.0
92 |
93 | 2. [go-zero](https://github.com/zeromicro/go-zero)
94 | - Description: A cloud-native Go microservices framework with cli tool for productivity.
95 | - Used for: Implementation of Consistent Hash.
96 | - License: MIT License
97 |
98 | 3. [gomemcached](https://github.com/dustin/gomemcached)
99 | - Description: A memcached binary protocol toolkit for go.
100 | - Used for: Implementation of a binary client for Memcached.
101 | - License: MIT License
102 |
103 | Please review the respective license files in the linked repositories for more details.
104 |
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliexpressru/gomemcached/b7e131624a1b957e510aa03bf6d7f466ebcabe0d/assets/logo.png
--------------------------------------------------------------------------------
/consistenthash/hash_calculate.go:
--------------------------------------------------------------------------------
1 | package consistenthash
2 |
3 | import "github.com/cespare/xxhash"
4 |
5 | // Hash returns the hash value of data.
6 | func Hash(data []byte) uint64 {
7 | return xxhash.Sum64(data)
8 | }
9 |
--------------------------------------------------------------------------------
/consistenthash/hashring.go:
--------------------------------------------------------------------------------
1 | package consistenthash
2 |
3 | import (
4 | "fmt"
5 | "sort"
6 | "strconv"
7 | "sync"
8 |
9 | "github.com/aliexpressru/gomemcached/utils"
10 | )
11 |
12 | const (
13 | // TopWeight is the top weight that one entry might set.
14 | TopWeight = 100
15 |
16 | minReplicas = 256
17 | prime = 18245165
18 | )
19 |
20 | var _ ConsistentHash = (*HashRing)(nil)
21 |
22 | type (
23 | ConsistentHash interface {
24 | Add(node any)
25 | AddWithReplicas(node any, replicas int)
26 | AddWithWeight(node any, weight int)
27 | Get(v any) (any, bool)
28 | GetAllNodes() []any
29 | Remove(node any)
30 | GetNodesCount() int
31 | }
32 |
33 | // Func defines the hash method.
34 | Func func(data []byte) uint64
35 |
36 | // A HashRing is implementation of consistent hash.
37 | HashRing struct {
38 | hashFunc Func
39 | replicas int
40 | keys []uint64
41 | ring map[uint64][]any
42 | nodes map[string]struct{}
43 | lock sync.RWMutex
44 | }
45 | )
46 |
47 | // NewHashRing returns a HashRing.
48 | func NewHashRing() *HashRing {
49 | return NewCustomHashRing(minReplicas, Hash)
50 | }
51 |
52 | // NewCustomHashRing returns a HashRing with given replicas and hash func.
53 | func NewCustomHashRing(replicas int, fn Func) *HashRing {
54 | if replicas < minReplicas {
55 | replicas = minReplicas
56 | }
57 |
58 | if fn == nil {
59 | fn = Hash
60 | }
61 |
62 | return &HashRing{
63 | hashFunc: fn,
64 | replicas: replicas,
65 | ring: make(map[uint64][]any),
66 | nodes: make(map[string]struct{}),
67 | }
68 | }
69 |
70 | // Add adds the node with the number of h.replicas,
71 | // the later call will overwrite the replicas of the former calls.
72 | func (h *HashRing) Add(node any) {
73 | h.AddWithReplicas(node, h.replicas)
74 | }
75 |
76 | // AddWithReplicas adds the node with the number of replicas,
77 | // replicas will be truncated to h.replicas if it's larger than h.replicas,
78 | // the later call will overwrite the replicas of the former calls.
79 | func (h *HashRing) AddWithReplicas(node any, replicas int) {
80 | h.Remove(node)
81 |
82 | if replicas > h.replicas {
83 | replicas = h.replicas
84 | }
85 |
86 | nodeRepr := repr(node)
87 | h.lock.Lock()
88 | defer h.lock.Unlock()
89 | h.addNode(nodeRepr)
90 |
91 | for i := 0; i < replicas; i++ {
92 | hash := h.hashFunc([]byte(replicaRepr(nodeRepr, i)))
93 | h.keys = append(h.keys, hash)
94 | h.ring[hash] = append(h.ring[hash], node)
95 | }
96 |
97 | sort.Slice(h.keys, func(i, j int) bool {
98 | return h.keys[i] < h.keys[j]
99 | })
100 | }
101 |
102 | // AddWithWeight adds the node with weight, the weight can be 1 to 100, indicates the percent,
103 | // the later call will overwrite the replicas of the former calls.
104 | func (h *HashRing) AddWithWeight(node any, weight int) {
105 | // don't need to make sure weight not larger than TopWeight,
106 | // because AddWithReplicas makes sure replicas cannot be larger than h.replicas
107 | replicas := h.replicas * weight / TopWeight
108 | h.AddWithReplicas(node, replicas)
109 | }
110 |
111 | // Get returns the corresponding node from h base on the given v.
112 | func (h *HashRing) Get(v any) (any, bool) {
113 | h.lock.RLock()
114 | defer h.lock.RUnlock()
115 |
116 | if len(h.ring) == 0 {
117 | return nil, false
118 | }
119 |
120 | hash := h.hashFunc([]byte(repr(v)))
121 | index := sort.Search(len(h.keys), func(i int) bool {
122 | return h.keys[i] >= hash
123 | }) % len(h.keys)
124 |
125 | nodes := h.ring[h.keys[index]]
126 | switch len(nodes) {
127 | case 0:
128 | return nil, false
129 | case 1:
130 | return nodes[0], true
131 | default:
132 | innerIndex := h.hashFunc([]byte(innerRepr(v)))
133 | pos := int(innerIndex % uint64(len(nodes)))
134 | return nodes[pos], true
135 | }
136 | }
137 |
138 | // GetAllNodes returns all nodes used in hash ring
139 | //
140 | // return a slice with a string representation of the nodes
141 | func (h *HashRing) GetAllNodes() []any {
142 | h.lock.RLock()
143 | defer h.lock.RUnlock()
144 |
145 | if len(h.ring) == 0 {
146 | return nil
147 | }
148 |
149 | var (
150 | allNodes = make([]any, 0, len(h.nodes))
151 | uqNodes = make(map[any]struct{}, len(h.nodes))
152 | )
153 |
154 | for _, nodes := range h.ring {
155 | for _, node := range nodes {
156 | if _, ok := uqNodes[node]; !ok {
157 | allNodes = append(allNodes, node)
158 | uqNodes[node] = struct{}{}
159 | }
160 | }
161 | }
162 |
163 | return allNodes
164 | }
165 |
166 | // Remove removes the given node from h.
167 | func (h *HashRing) Remove(node any) {
168 | nodeRepr := repr(node)
169 |
170 | h.lock.Lock()
171 | defer h.lock.Unlock()
172 |
173 | if !h.containsNode(nodeRepr) {
174 | return
175 | }
176 |
177 | for i := 0; i < h.replicas; i++ {
178 | hash := h.hashFunc([]byte(replicaRepr(nodeRepr, i)))
179 | index := sort.Search(len(h.keys), func(i int) bool {
180 | return h.keys[i] >= hash
181 | })
182 | if index < len(h.keys) && h.keys[index] == hash {
183 | h.keys = append(h.keys[:index], h.keys[index+1:]...)
184 | }
185 | h.removeRingNode(hash, nodeRepr)
186 | }
187 |
188 | h.removeNode(nodeRepr)
189 | }
190 |
191 | // GetNodesCount returns the current number of nodes
192 | func (h *HashRing) GetNodesCount() int {
193 | h.lock.RLock()
194 | defer h.lock.RUnlock()
195 | return len(h.nodes)
196 | }
197 |
198 | func (h *HashRing) removeRingNode(hash uint64, nodeRepr string) {
199 | if nodes, ok := h.ring[hash]; ok {
200 | newNodes := nodes[:0]
201 | for _, x := range nodes {
202 | if repr(x) != nodeRepr {
203 | newNodes = append(newNodes, x)
204 | }
205 | }
206 | if len(newNodes) > 0 {
207 | h.ring[hash] = newNodes
208 | } else {
209 | delete(h.ring, hash)
210 | }
211 | }
212 | }
213 |
214 | func (h *HashRing) addNode(nodeRepr string) {
215 | h.nodes[nodeRepr] = struct{}{}
216 | }
217 |
218 | func (h *HashRing) containsNode(nodeRepr string) bool {
219 | _, ok := h.nodes[nodeRepr]
220 | return ok
221 | }
222 |
223 | func (h *HashRing) removeNode(nodeRepr string) {
224 | delete(h.nodes, nodeRepr)
225 | }
226 |
227 | func innerRepr(node any) string {
228 | return fmt.Sprintf("%d:%v", prime, node)
229 | }
230 |
231 | func repr(node any) string {
232 | return utils.Repr(node)
233 | }
234 |
235 | func replicaRepr(nodeRepr string, replicaNumber int) string {
236 | return fmt.Sprintf("%s_virtual%s", nodeRepr, strconv.Itoa(replicaNumber))
237 | }
238 |
--------------------------------------------------------------------------------
/consistenthash/hashring_test.go:
--------------------------------------------------------------------------------
1 | //nolint:goconst
2 | package consistenthash
3 |
4 | import (
5 | "fmt"
6 | "strconv"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/assert"
10 |
11 | "github.com/aliexpressru/gomemcached/utils"
12 | )
13 |
14 | const (
15 | keySize = 20
16 | requestSize = 1000
17 | )
18 |
19 | func BenchmarkHashRingGet(b *testing.B) {
20 | ch := NewHashRing()
21 | for i := 0; i < keySize; i++ {
22 | ch.Add("localhost:" + strconv.Itoa(i))
23 | }
24 |
25 | for i := 0; i < b.N; i++ {
26 | ch.Get(i)
27 | }
28 | }
29 |
30 | func TestHashRing_GetAllNodes(t *testing.T) {
31 | ch := NewHashRing()
32 |
33 | allNodes := ch.GetAllNodes()
34 | assert.Nil(t, allNodes, "GetAllNodes: without added nodes")
35 |
36 | for i := 0; i < keySize; i++ {
37 | ch.Add("localhost:" + strconv.Itoa(i))
38 | }
39 | count := ch.GetNodesCount()
40 | assert.Equalf(t, keySize, count, "GetNodesCount: have - %d; want - %d", count, keySize)
41 |
42 | allNodes = ch.GetAllNodes()
43 | assert.Equal(t, keySize, len(allNodes))
44 |
45 | for i := 0; i < keySize; i++ {
46 | node := "localhost:" + strconv.Itoa(i)
47 | found := false
48 | for _, n := range allNodes {
49 | if n == node {
50 | found = true
51 | break
52 | }
53 | }
54 | assert.True(t, found, "Node not found in GetAllNodes: "+node)
55 | }
56 | }
57 |
58 | func TestHashRingWithEntropy(t *testing.T) {
59 | ch := NewCustomHashRing(0, nil)
60 | val, ok := ch.Get("any")
61 | assert.False(t, ok)
62 | assert.Nil(t, val)
63 |
64 | for i := 0; i < keySize; i++ {
65 | ch.AddWithReplicas("localhost:"+strconv.Itoa(i), minReplicas<<1)
66 | }
67 |
68 | keys := make(map[string]int)
69 | for i := 0; i < requestSize; i++ {
70 | key, ok := ch.Get(requestSize + i)
71 | assert.True(t, ok)
72 | keys[key.(string)]++
73 | }
74 |
75 | mi := make(map[any]int, len(keys))
76 | for k, v := range keys {
77 | mi[k] = v
78 | }
79 | entropy := utils.CalcEntropy(mi)
80 | assert.True(t, entropy > .95)
81 | }
82 |
83 | func TestHashRingIncrementalTransfer(t *testing.T) {
84 | prefix := "anything"
85 | create := func() *HashRing {
86 | ch := NewHashRing()
87 | for i := 0; i < keySize; i++ {
88 | ch.Add(prefix + strconv.Itoa(i))
89 | }
90 | return ch
91 | }
92 |
93 | originCh := create()
94 | keys := make(map[int]string, requestSize)
95 | for i := 0; i < requestSize; i++ {
96 | key, ok := originCh.Get(requestSize + i)
97 | assert.True(t, ok)
98 | assert.NotNil(t, key)
99 | keys[i] = key.(string)
100 | }
101 |
102 | node := fmt.Sprintf("%s%d", prefix, keySize)
103 | for i := 0; i < 10; i++ {
104 | laterCh := create()
105 | laterCh.AddWithWeight(node, 10*(i+1))
106 |
107 | for j := 0; j < requestSize; j++ {
108 | key, ok := laterCh.Get(requestSize + j)
109 | assert.True(t, ok)
110 | assert.NotNil(t, key)
111 | value := key.(string)
112 | assert.True(t, value == keys[j] || value == node)
113 | }
114 | }
115 | }
116 |
117 | func TestHashRingTransferOnFailure(t *testing.T) {
118 | index := 41
119 | keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
120 | var transferred int
121 | for k, v := range newKeys {
122 | if v != keys[k] {
123 | transferred++
124 | }
125 | }
126 |
127 | ratio := float32(transferred) / float32(requestSize)
128 | assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio))
129 | }
130 |
131 | func TestHashRingLeastTransferOnFailure(t *testing.T) {
132 | prefix := "localhost:"
133 | index := 41
134 | keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
135 | for k, v := range keys {
136 | newV := newKeys[k]
137 | if v != prefix+strconv.Itoa(index) {
138 | assert.Equal(t, v, newV)
139 | }
140 | }
141 | }
142 |
143 | func TestHashRing_Remove(t *testing.T) {
144 | ch := NewHashRing()
145 | ch.Add("first")
146 | ch.Add("second")
147 | ch.Remove("first")
148 | for i := 0; i < 100; i++ {
149 | val, ok := ch.Get(i)
150 | assert.True(t, ok)
151 | assert.Equal(t, "second", val)
152 | }
153 | }
154 |
155 | func TestHashRing_RemoveInterface(t *testing.T) {
156 | const key = "any"
157 | ch := NewHashRing()
158 | node1 := newMockNode(key, 1)
159 | node2 := newMockNode(key, 2)
160 | ch.AddWithWeight(node1, 80)
161 | ch.AddWithWeight(node2, 50)
162 | assert.Equal(t, 1, len(ch.nodes))
163 | node, ok := ch.Get(1)
164 | assert.True(t, ok)
165 | assert.Equal(t, key, node.(*mockNode).addr)
166 | assert.Equal(t, 2, node.(*mockNode).id)
167 | }
168 |
169 | func Test_innerRepr(t *testing.T) {
170 | type args struct {
171 | node any
172 | }
173 | tests := []struct {
174 | name string
175 | args args
176 | want string
177 | }{
178 | {
179 | name: "localhost:11211",
180 | args: args{node: "localhost:11211"},
181 | want: fmt.Sprintf("%d:%v", prime, "localhost:11211"),
182 | },
183 | {
184 | name: "127.0.0.1:11211",
185 | args: args{node: "127.0.0.1:11211"},
186 | want: fmt.Sprintf("%d:%v", prime, "127.0.0.1:11211"),
187 | },
188 | }
189 | for _, tt := range tests {
190 | t.Run(tt.name, func(t *testing.T) {
191 | out := innerRepr(tt.args.node)
192 | assert.Equal(t, tt.want, out, "innerRepr: returned wrong format")
193 | })
194 | }
195 | }
196 |
197 | func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[int]string, map[int]string) {
198 | ch := NewHashRing()
199 | for i := 0; i < keySize; i++ {
200 | ch.Add(prefix + strconv.Itoa(i))
201 | }
202 |
203 | keys := make(map[int]string, requestSize)
204 | for i := 0; i < requestSize; i++ {
205 | key, ok := ch.Get(requestSize + i)
206 | assert.True(t, ok)
207 | assert.NotNil(t, key)
208 | keys[i] = key.(string)
209 | }
210 |
211 | remove := fmt.Sprintf("%s%d", prefix, index)
212 | ch.Remove(remove)
213 | newKeys := make(map[int]string, requestSize)
214 | for i := 0; i < requestSize; i++ {
215 | key, ok := ch.Get(requestSize + i)
216 | assert.True(t, ok)
217 | assert.NotNil(t, key)
218 | assert.NotEqual(t, remove, key)
219 | newKeys[i] = key.(string)
220 | }
221 |
222 | return keys, newKeys
223 | }
224 |
225 | type mockNode struct {
226 | addr string
227 | id int
228 | }
229 |
230 | func newMockNode(addr string, id int) *mockNode {
231 | return &mockNode{
232 | addr: addr,
233 | id: id,
234 | }
235 | }
236 |
237 | func (n *mockNode) String() string {
238 | return n.addr
239 | }
240 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 | services:
3 | memcached:
4 | image: memcached:latest
5 | ports:
6 | - "11211:11211"
7 |
--------------------------------------------------------------------------------
/examples/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "os"
5 |
6 | "golang.org/x/exp/maps"
7 |
8 | "github.com/aliexpressru/gomemcached/memcached"
9 | )
10 |
11 | func main() {
12 | _ = os.Setenv("MEMCACHED_SERVERS", "localhost:11211")
13 |
14 | mcl, err := memcached.InitFromEnv(
15 | memcached.WithMaxIdleConns(10),
16 | memcached.WithAuthentication("admin", "mysecretpassword"),
17 | memcached.WithDisableLogger(),
18 | memcached.WithDisableMemcachedDiagnostic(),
19 | )
20 | mustInit(err)
21 | defer mcl.CloseAllConns()
22 |
23 | _, err = mcl.Store(memcached.Set, "foo", 10, []byte("bar"))
24 | mustInit(err)
25 |
26 | _, err = mcl.Get("foo")
27 | mustInit(err)
28 |
29 | _, err = mcl.Delete("foo")
30 | mustInit(err)
31 |
32 | _, err = mcl.Delta(memcached.Increment, "incappend", 1, 1, 0)
33 | mustInit(err)
34 |
35 | _, err = mcl.Append(memcached.Append, "incappend", []byte("add"))
36 | mustInit(err)
37 |
38 | items := map[string][]byte{
39 | "foo": []byte("bar"),
40 | "gopher": []byte("golang"),
41 | "answer": []byte("42"),
42 | }
43 |
44 | err = mcl.MultiStore(memcached.Add, items, 0)
45 | mustInit(err)
46 |
47 | _, err = mcl.MultiGet(maps.Keys(items))
48 | mustInit(err)
49 |
50 | err = mcl.MultiDelete(maps.Keys(items))
51 | mustInit(err)
52 |
53 | err = mcl.FlushAll(0)
54 | mustInit(err)
55 | }
56 |
57 | func mustInit(e error) {
58 | if e != nil {
59 | panic(e)
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/aliexpressru/gomemcached
2 |
3 | go 1.21
4 |
5 | require (
6 | github.com/cespare/xxhash v1.1.0
7 | github.com/kelseyhightower/envconfig v1.4.0
8 | github.com/prometheus/client_golang v1.18.0
9 | github.com/stretchr/testify v1.8.4
10 | go.uber.org/zap v1.26.0
11 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a
12 | golang.org/x/sync v0.6.0
13 | )
14 |
15 | require (
16 | github.com/beorn7/perks v1.0.1 // indirect
17 | github.com/cespare/xxhash/v2 v2.2.0 // indirect
18 | github.com/davecgh/go-spew v1.1.1 // indirect
19 | github.com/pmezard/go-difflib v1.0.0 // indirect
20 | github.com/prometheus/client_model v0.5.0 // indirect
21 | github.com/prometheus/common v0.46.0 // indirect
22 | github.com/prometheus/procfs v0.12.0 // indirect
23 | github.com/stretchr/objx v0.5.0 // indirect
24 | go.uber.org/multierr v1.11.0 // indirect
25 | golang.org/x/sys v0.16.0 // indirect
26 | google.golang.org/protobuf v1.32.0 // indirect
27 | gopkg.in/yaml.v3 v3.0.1 // indirect
28 | )
29 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE=
2 | github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
3 | github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
4 | github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
5 | github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
6 | github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
7 | github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
8 | github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
9 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
10 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
11 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
12 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
13 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
14 | github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
15 | github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
16 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
17 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
18 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
19 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
22 | github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk=
23 | github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA=
24 | github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
25 | github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
26 | github.com/prometheus/common v0.46.0 h1:doXzt5ybi1HBKpsZOL0sSkaNHJJqkyfEWZGGqqScV0Y=
27 | github.com/prometheus/common v0.46.0/go.mod h1:Tp0qkxpb9Jsg54QMe+EAmqXkSV7Evdy1BTn+g2pa/hQ=
28 | github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
29 | github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
30 | github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
31 | github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
32 | github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ=
33 | github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
34 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
35 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
36 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
37 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
38 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
39 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
40 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
41 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
42 | go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
43 | go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
44 | go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
45 | go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
46 | go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
47 | go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
48 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA=
49 | golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
50 | golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
51 | golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
52 | golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
53 | golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
54 | google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
55 | google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
56 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
57 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
58 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
59 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
60 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
61 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
62 |
--------------------------------------------------------------------------------
/logger/zap.go:
--------------------------------------------------------------------------------
1 | package logger
2 |
3 | import (
4 | "os"
5 | "sync/atomic"
6 |
7 | "go.uber.org/zap"
8 | "go.uber.org/zap/zapcore"
9 | )
10 |
11 | var (
12 | // global logger instance.
13 | global *zap.SugaredLogger
14 | disableLogger atomic.Bool
15 | defaultLevel = zap.NewAtomicLevelAt(zap.DebugLevel)
16 | generationArgs = []any{"@gen", "1"}
17 | )
18 |
19 | func init() {
20 | SetLogger(newSugaredLogger(defaultLevel))
21 | }
22 |
23 | // SetLogger sets to global logger a new *zap.SugaredLogger.
24 | func SetLogger(l *zap.SugaredLogger) {
25 | global = l
26 | }
27 |
28 | // GetLogger returns current global logger.
29 | func GetLogger() *zap.SugaredLogger {
30 | return global
31 | }
32 |
33 | // DisableLogger turn off all logs, globally.
34 | func DisableLogger() {
35 | disableLogger.Store(true)
36 | }
37 |
38 | // LoggerIsDisable checks the status of the logger (true - disabled, false - enabled)
39 | func LoggerIsDisable() bool {
40 | return disableLogger.Load()
41 | }
42 |
43 | func newSugaredLogger(level zapcore.LevelEnabler, options ...zap.Option) *zap.SugaredLogger {
44 | if level == nil {
45 | level = defaultLevel
46 | }
47 | return zap.New(
48 | zapcore.NewCore(
49 | zapcore.NewJSONEncoder(zapcore.EncoderConfig{
50 | TimeKey: "ts",
51 | LevelKey: "level",
52 | NameKey: "logger",
53 | CallerKey: "caller",
54 | MessageKey: "message",
55 | StacktraceKey: "stacktrace",
56 | LineEnding: zapcore.DefaultLineEnding,
57 | EncodeLevel: capitalLevelEncoder,
58 | EncodeTime: zapcore.ISO8601TimeEncoder,
59 | EncodeDuration: zapcore.SecondsDurationEncoder,
60 | EncodeCaller: zapcore.ShortCallerEncoder,
61 | }),
62 | zapcore.AddSync(os.Stdout),
63 | level,
64 | ),
65 | options...,
66 | ).Sugar().With(generationArgs...)
67 | }
68 |
69 | func capitalLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
70 | level := ""
71 | switch l {
72 | case zapcore.ErrorLevel:
73 | level = "ERR"
74 | case zapcore.WarnLevel:
75 | level = "WARNING"
76 | default:
77 | level = l.CapitalString()
78 | }
79 | enc.AppendString(level)
80 | }
81 |
82 | // Debug ...
83 | func Debug(args ...any) {
84 | if log := GetLogger(); !LoggerIsDisable() {
85 | log.Debug(args...)
86 | }
87 | }
88 |
89 | // Debugf ...
90 | func Debugf(format string, args ...any) {
91 | if log := GetLogger(); !LoggerIsDisable() {
92 | log.Debugf(format, args...)
93 | }
94 | }
95 |
96 | // Info ...
97 | func Info(args ...any) {
98 | if log := GetLogger(); !LoggerIsDisable() {
99 | log.Info(args...)
100 | }
101 | }
102 |
103 | // Infof ...
104 | func Infof(format string, args ...any) {
105 | if log := GetLogger(); !LoggerIsDisable() {
106 | log.Infof(format, args...)
107 | }
108 | }
109 |
110 | // Warn ...
111 | func Warn(args ...any) {
112 | if log := GetLogger(); !LoggerIsDisable() {
113 | log.Warn(args...)
114 | }
115 | }
116 |
117 | // Warnf ...
118 | func Warnf(format string, args ...any) {
119 | if log := GetLogger(); !LoggerIsDisable() {
120 | log.Warnf(format, args...)
121 | }
122 | }
123 |
124 | // Error ...
125 | func Error(args ...any) {
126 | if log := GetLogger(); !LoggerIsDisable() {
127 | log.Error(args...)
128 | }
129 | }
130 |
131 | // Errorf ...
132 | func Errorf(format string, args ...any) {
133 | if log := GetLogger(); !LoggerIsDisable() {
134 | log.Errorf(format, args...)
135 | }
136 | }
137 |
138 | // Fatal ...
139 | func Fatal(args ...any) {
140 | if log := GetLogger(); !LoggerIsDisable() {
141 | log.Fatal(args...)
142 | }
143 | }
144 |
145 | // Fatalf ...
146 | func Fatalf(format string, args ...any) {
147 | if log := GetLogger(); !LoggerIsDisable() {
148 | log.Fatalf(format, args...)
149 | }
150 | }
151 |
--------------------------------------------------------------------------------
/memcached/constants.go:
--------------------------------------------------------------------------------
1 | // Package memcached binary protocol packet formats and constants.
2 | package memcached
3 |
4 | import (
5 | "fmt"
6 | )
7 |
8 | const (
9 | REQ_MAGIC = 0x80
10 | RES_MAGIC = 0x81
11 | )
12 |
13 | const (
14 | SaslMechanism = "PLAIN"
15 | )
16 |
17 | type OpCode uint8
18 |
19 | const (
20 | GET = OpCode(0x00)
21 | SET = OpCode(0x01)
22 | ADD = OpCode(0x02)
23 | REPLACE = OpCode(0x03)
24 | DELETE = OpCode(0x04)
25 | INCREMENT = OpCode(0x05)
26 | DECREMENT = OpCode(0x06)
27 | QUIT = OpCode(0x07)
28 | FLUSH = OpCode(0x08)
29 | GETQ = OpCode(0x09)
30 | NOOP = OpCode(0x0a)
31 | VERSION = OpCode(0x0b)
32 | GETK = OpCode(0x0c)
33 | GETKQ = OpCode(0x0d)
34 | APPEND = OpCode(0x0e)
35 | PREPEND = OpCode(0x0f)
36 | STAT = OpCode(0x10)
37 | SETQ = OpCode(0x11)
38 | ADDQ = OpCode(0x12)
39 | REPLACEQ = OpCode(0x13)
40 | DELETEQ = OpCode(0x14)
41 | INCREMENTQ = OpCode(0x15)
42 | DECREMENTQ = OpCode(0x16)
43 | QUITQ = OpCode(0x17)
44 | FLUSHQ = OpCode(0x18)
45 | APPENDQ = OpCode(0x19)
46 | PREPENDQ = OpCode(0x1a)
47 |
48 | SASL_LIST_MECHS = OpCode(0x20)
49 | SASL_AUTH = OpCode(0x21)
50 | SASL_STEP = OpCode(0x22)
51 | )
52 |
53 | type Status uint16
54 |
55 | const (
56 | // SUCCESS - Successful operation.
57 | SUCCESS = Status(0x00)
58 | // KEY_ENOENT - Key not found.
59 | KEY_ENOENT = Status(0x01)
60 | // KEY_EEXISTS -Key already exists.
61 | KEY_EEXISTS = Status(0x02)
62 | // E2BIG - Data size exceeds limit.
63 | E2BIG = Status(0x03)
64 | // EINVAL - Invalid arguments or operation parameters.
65 | EINVAL = Status(0x04)
66 | // NOT_STORED - Operation was not performed because the data was not stored.
67 | NOT_STORED = Status(0x05)
68 | // DELTA_BADVAL - Invalid value specified for increment/decrement.
69 | DELTA_BADVAL = Status(0x06)
70 | // AUTHFAIL - Authentication required / Not Successful.
71 | AUTHFAIL = Status(0x20)
72 | // FURTHER_AUTH - Further authentication steps required.
73 | FURTHER_AUTH = Status(0x21)
74 | // UNKNOWN_COMMAND - Unknown command.
75 | UNKNOWN_COMMAND = Status(0x81)
76 | // ENOMEM - Insufficient memory for the operation.
77 | ENOMEM = Status(0x82)
78 | // TMPFAIL - Temporary failure, the operation cannot be performed at the moment.
79 | TMPFAIL = Status(0x86)
80 |
81 | // UNKNOWN_STATUS is not a Memcached status
82 | UNKNOWN_STATUS = Status(0xffff)
83 | )
84 |
85 | const (
86 | // HDR_LEN is a number of bytes in a binary protocol header.
87 | HDR_LEN = 24
88 | BODY_LEN = 128
89 | )
90 |
91 | // Mapping of OpCode -> name of command (not exhaustive)
92 | var CommandNames map[OpCode]string
93 |
94 | var StatusNames map[Status]string
95 |
96 | // nolint:goconst
97 | func init() {
98 | CommandNames = make(map[OpCode]string)
99 | CommandNames[GET] = "GET"
100 | CommandNames[SET] = "SET"
101 | CommandNames[ADD] = "ADD"
102 | CommandNames[REPLACE] = "REPLACE"
103 | CommandNames[DELETE] = "DELETE"
104 | CommandNames[INCREMENT] = "INCREMENT"
105 | CommandNames[DECREMENT] = "DECREMENT"
106 | CommandNames[QUIT] = "QUIT"
107 | CommandNames[FLUSH] = "FLUSH"
108 | CommandNames[GETQ] = "GETQ"
109 | CommandNames[NOOP] = "NOOP"
110 | CommandNames[VERSION] = "VERSION"
111 | CommandNames[GETK] = "GETK"
112 | CommandNames[GETKQ] = "GETKQ"
113 | CommandNames[APPEND] = "APPEND"
114 | CommandNames[PREPEND] = "PREPEND"
115 | CommandNames[STAT] = "STAT"
116 | CommandNames[SETQ] = "SETQ"
117 | CommandNames[ADDQ] = "ADDQ"
118 | CommandNames[REPLACEQ] = "REPLACEQ"
119 | CommandNames[DELETEQ] = "DELETEQ"
120 | CommandNames[INCREMENTQ] = "INCREMENTQ"
121 | CommandNames[DECREMENTQ] = "DECREMENTQ"
122 | CommandNames[QUITQ] = "QUITQ"
123 | CommandNames[FLUSHQ] = "FLUSHQ"
124 | CommandNames[APPENDQ] = "APPENDQ"
125 | CommandNames[PREPENDQ] = "PREPENDQ"
126 |
127 | CommandNames[SASL_LIST_MECHS] = "SASL_LIST_MECHS"
128 | CommandNames[SASL_AUTH] = "SASL_AUTH"
129 | CommandNames[SASL_STEP] = "SASL_STEP"
130 |
131 | StatusNames = make(map[Status]string)
132 | StatusNames[SUCCESS] = "SUCCESS"
133 | StatusNames[KEY_ENOENT] = "KEY_ENOENT"
134 | StatusNames[KEY_EEXISTS] = "KEY_EEXISTS"
135 | StatusNames[E2BIG] = "E2BIG"
136 | StatusNames[EINVAL] = "EINVAL"
137 | StatusNames[NOT_STORED] = "NOT_STORED"
138 | StatusNames[DELTA_BADVAL] = "DELTA_BADVAL"
139 | StatusNames[AUTHFAIL] = "AUTHFAIL"
140 | StatusNames[FURTHER_AUTH] = "FURTHER_AUTH"
141 | StatusNames[UNKNOWN_COMMAND] = "UNKNOWN_COMMAND"
142 | StatusNames[ENOMEM] = "ENOMEM"
143 | StatusNames[TMPFAIL] = "TMPFAIL"
144 | }
145 |
146 | // String an op code.
147 | func (o OpCode) String() (rv string) {
148 | rv = CommandNames[o]
149 | if rv == "" {
150 | rv = fmt.Sprintf("0x%02x", int(o))
151 | }
152 | return rv
153 | }
154 |
155 | // String an op code.
156 | func (s Status) String() (rv string) {
157 | rv = StatusNames[s]
158 | if rv == "" {
159 | rv = fmt.Sprintf("0x%02x", int(s))
160 | }
161 | return rv
162 | }
163 |
164 | // IsQuiet return true if a command is a "quiet" command.
165 | func (o OpCode) IsQuiet() bool {
166 | switch o {
167 | case GETQ,
168 | GETKQ,
169 | SETQ,
170 | ADDQ,
171 | REPLACEQ,
172 | DELETEQ,
173 | INCREMENTQ,
174 | DECREMENTQ,
175 | QUITQ,
176 | FLUSHQ,
177 | APPENDQ,
178 | PREPENDQ:
179 | return true
180 | }
181 | return false
182 | }
183 |
184 | func (o OpCode) changeOnQuiet(def OpCode) OpCode {
185 | if o.IsQuiet() {
186 | return o
187 | }
188 | switch o {
189 | case GET:
190 | return GETQ
191 | case SET:
192 | return SETQ
193 | case ADD:
194 | return ADDQ
195 | case REPLACE:
196 | return REPLACEQ
197 | case DELETE:
198 | return DELETEQ
199 | case INCREMENT:
200 | return INCREMENTQ
201 | case DECREMENT:
202 | return DECREMENTQ
203 | case FLUSH:
204 | return FLUSHQ
205 | case APPEND:
206 | return APPENDQ
207 | case PREPEND:
208 | return PREPENDQ
209 | default:
210 | return def
211 | }
212 | }
213 |
214 | func prepareAuthData(user, pass string) []byte {
215 | return []byte(fmt.Sprintf("\x00%s\x00%s", user, pass))
216 | }
217 |
--------------------------------------------------------------------------------
/memcached/constants_test.go:
--------------------------------------------------------------------------------
1 | // nolint
2 | package memcached
3 |
4 | import (
5 | "strings"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestCommandCodeString(t *testing.T) {
12 | if GET.String() != "GET" {
13 | t.Fatalf("Expected \"GET\" for GET, got \"%v\"", GET.String())
14 | }
15 |
16 | cc := OpCode(0x80)
17 | if cc.String() != "0x80" {
18 | t.Fatalf("Expected \"0x80\" for 0x80, got \"%v\"", cc.String())
19 | }
20 | }
21 |
22 | func TestStatusNameString(t *testing.T) {
23 | if SUCCESS.String() != "SUCCESS" {
24 | t.Fatalf("Expected \"SUCCESS\" for SUCCESS, got \"%v\"",
25 | SUCCESS.String())
26 | }
27 |
28 | s := Status(0x80)
29 | if s.String() != "0x80" {
30 | t.Fatalf("Expected \"0x80\" for 0x80, got \"%v\"", s.String())
31 | }
32 | }
33 |
34 | func TestIsQuiet(t *testing.T) {
35 | for v, k := range CommandNames {
36 | isq := strings.HasSuffix(k, "Q")
37 | if v.IsQuiet() != isq {
38 | t.Errorf("Expected quiet=%v for %v, got %v",
39 | isq, v, v.IsQuiet())
40 | }
41 | }
42 | }
43 |
44 | func Test_prepareAuthData(t *testing.T) {
45 | type args struct {
46 | user string
47 | pass string
48 | }
49 | tests := []struct {
50 | name string
51 | args args
52 | want []byte
53 | }{
54 | {
55 | name: "1", args: args{
56 | user: "testuser",
57 | pass: "testpass",
58 | },
59 | want: []byte("\x00testuser\x00testpass"),
60 | },
61 | {
62 | name: "2", args: args{
63 | user: "anotheruser",
64 | pass: "anotherpass",
65 | },
66 | want: []byte("\x00anotheruser\x00anotherpass"),
67 | },
68 | }
69 | for _, tt := range tests {
70 | t.Run(tt.name, func(t *testing.T) {
71 | assert.Equalf(t, tt.want, prepareAuthData(tt.args.user, tt.args.pass), "prepareAuthData(%v, %v)", tt.args.user, tt.args.pass)
72 | })
73 | }
74 | }
75 |
76 | func TestOpCode_changeOnQuiet(t *testing.T) {
77 | type args struct {
78 | def OpCode
79 | }
80 | tests := []struct {
81 | name string
82 | o OpCode
83 | args args
84 | want OpCode
85 | }{
86 | {
87 | name: GET.String(),
88 | o: GET,
89 | args: args{def: GETQ},
90 | want: GETQ,
91 | },
92 | {
93 | name: GETQ.String(),
94 | o: GETQ,
95 | args: args{def: GETQ},
96 | want: GETQ,
97 | },
98 | {
99 | name: "unknown opcode",
100 | o: OpCode(0x1b),
101 | args: args{def: GETQ},
102 | want: GETQ,
103 | },
104 | {
105 | name: SET.String(),
106 | o: SET,
107 | args: args{def: GETQ},
108 | want: SETQ,
109 | },
110 | {
111 | name: ADD.String(),
112 | o: ADD,
113 | args: args{def: GETQ},
114 | want: ADDQ,
115 | },
116 | {
117 | name: REPLACE.String(),
118 | o: REPLACE,
119 | args: args{def: GETQ},
120 | want: REPLACEQ,
121 | },
122 | {
123 | name: DELETE.String(),
124 | o: DELETE,
125 | args: args{def: GETQ},
126 | want: DELETEQ,
127 | },
128 | {
129 | name: INCREMENT.String(),
130 | o: INCREMENT,
131 | args: args{def: GETQ},
132 | want: INCREMENTQ,
133 | },
134 | {
135 | name: DECREMENT.String(),
136 | o: DECREMENT,
137 | args: args{def: GETQ},
138 | want: DECREMENTQ,
139 | },
140 | {
141 | name: FLUSH.String(),
142 | o: FLUSH,
143 | args: args{def: GETQ},
144 | want: FLUSHQ,
145 | },
146 | {
147 | name: APPEND.String(),
148 | o: APPEND,
149 | args: args{def: GETQ},
150 | want: APPENDQ,
151 | },
152 | {
153 | name: PREPEND.String(),
154 | o: PREPEND,
155 | args: args{def: GETQ},
156 | want: PREPENDQ,
157 | },
158 | }
159 | for _, tt := range tests {
160 | t.Run(tt.name, func(t *testing.T) {
161 | assert.Equalf(t, tt.want, tt.o.changeOnQuiet(tt.args.def), "changeOnQuiet(%v)", tt.args.def)
162 | })
163 | }
164 | }
165 |
--------------------------------------------------------------------------------
/memcached/errors.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | )
7 |
8 | const libPrefix = "gomemcached"
9 |
10 | var (
11 | // ErrCacheMiss means that a Get failed because the item wasn't present.
12 | ErrCacheMiss = errors.New("gomemcached: cache miss")
13 |
14 | // ErrCASConflict means that a CompareAndSwap call failed due to the
15 | // cached value being modified between the Get and the CompareAndSwap.
16 | // If the cached value was simply evicted rather than replaced,
17 | // ErrNotStored will be returned instead.
18 | ErrCASConflict = errors.New("gomemcached: compare-and-swap conflict")
19 |
20 | // ErrNotStored means that a conditional write operation (i.e. Add or
21 | // CompareAndSwap) failed because the condition was not satisfied.
22 | ErrNotStored = errors.New("gomemcached: item not stored")
23 |
24 | // ErrServerError means that a server error occurred.
25 | ErrServerError = errors.New("gomemcached: server error")
26 |
27 | // ErrNoStats means that no statistics were available.
28 | ErrNoStats = errors.New("gomemcached: no statistics available")
29 |
30 | // ErrMalformedKey is returned when an invalid key is used.
31 | // Keys must be at maximum 250 bytes long and not
32 | // contain whitespace or control characters.
33 | ErrMalformedKey = errors.New("gomemcached: key is too long or contains invalid characters")
34 |
35 | // ErrNoServers is returned when no servers are configured or available.
36 | ErrNoServers = errors.New("gomemcached: no servers configured or available")
37 |
38 | // ErrInvalidAddr means that an incorrect address was passed and could not be cast to net.Addr
39 | ErrInvalidAddr = errors.New("gomemcached: invalid address for server")
40 |
41 | // ErrServerNotAvailable means that one of the nodes is currently unavailable
42 | ErrServerNotAvailable = errors.New("gomemcached: server(s) is not available")
43 |
44 | // ErrNotConfigured means that some required parameter is not set in the configuration
45 | ErrNotConfigured = errors.New("gomemcached: not complete configuration")
46 |
47 | // ErrUnknownCommand means that in request consumer use unknown command for memcached.
48 | ErrUnknownCommand = errors.New("gomemcached: Unknown command")
49 |
50 | // ErrDataSizeExceedsLimit means that memcached cannot process the request data due to its size.
51 | ErrDataSizeExceedsLimit = errors.New("gomemcached: Data size exceeds limit")
52 |
53 | // ErrInvalidArguments indicates invalid arguments or operation parameters (non-user request error).
54 | ErrInvalidArguments = errors.New("gomemcached: Invalid arguments or operation parameters")
55 |
56 | // ErrAuthFail indicates that an authorization attempt was made, but it did not work
57 | ErrAuthFail = errors.New("gomemcached: authentication enabled but operation failed")
58 | )
59 |
60 | // resumableError returns true if err is only a protocol-level cache error.
61 | // This is used to determine whether a server connection should
62 | // be re-used or not. If an error occurs, by default we don't reuse the
63 | // connection, unless it was just a cache error.
64 | func resumableError(err error) bool {
65 | switch {
66 | case errors.Is(err, ErrCacheMiss), errors.Is(err, ErrCASConflict),
67 | errors.Is(err, ErrNotStored), errors.Is(err, ErrMalformedKey):
68 | return true
69 | }
70 | return false
71 | }
72 |
73 | func wrapMemcachedResp(resp *Response) error {
74 | switch resp.Status {
75 | case SUCCESS:
76 | return nil
77 | case KEY_ENOENT:
78 | return fmt.Errorf("%w. %w", ErrCacheMiss, resp)
79 | case NOT_STORED, KEY_EEXISTS:
80 | return fmt.Errorf("%w. %w", ErrNotStored, resp)
81 | case EINVAL, DELTA_BADVAL:
82 | return fmt.Errorf("%w. %w", ErrInvalidArguments, resp)
83 | case ENOMEM:
84 | return fmt.Errorf("%w. %w", ErrServerError, resp)
85 | case TMPFAIL:
86 | return fmt.Errorf("%w. %w", ErrServerNotAvailable, resp)
87 | case UNKNOWN_COMMAND:
88 | return fmt.Errorf("%w. %w", ErrUnknownCommand, resp)
89 | case E2BIG:
90 | return fmt.Errorf("%w. %w", ErrDataSizeExceedsLimit, resp)
91 | default:
92 | return fmt.Errorf("%w. %w", ErrServerError, resp)
93 | }
94 | }
95 |
96 | func errStatus(e error) Status {
97 | status := UNKNOWN_STATUS
98 | var res *Response
99 | if errors.As(e, &res) {
100 | status = res.Status
101 | }
102 | return status
103 | }
104 |
--------------------------------------------------------------------------------
/memcached/memcached.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "bufio"
5 | "context"
6 | "encoding/binary"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "math"
11 | "net"
12 | "sync"
13 | "sync/atomic"
14 | "time"
15 |
16 | "github.com/kelseyhightower/envconfig"
17 | "golang.org/x/exp/maps"
18 |
19 | "github.com/aliexpressru/gomemcached/consistenthash"
20 | "github.com/aliexpressru/gomemcached/logger"
21 | "github.com/aliexpressru/gomemcached/pool"
22 | "github.com/aliexpressru/gomemcached/utils"
23 | )
24 |
25 | const (
26 | // DefaultTimeout is the default socket read/write timeout.
27 | DefaultTimeout = 500 * time.Millisecond
28 |
29 | // DefaultMaxIdleConns is the default maximum number of idle connections
30 | // kept for any single address.
31 | DefaultMaxIdleConns = 100
32 |
33 | // DefaultNodeHealthCheckPeriod is the default time period for start check available nods
34 | DefaultNodeHealthCheckPeriod = 15 * time.Second
35 | // DefaultRebuildingNodePeriod is the default time period for rebuilds the nodes in hash ring using freshly discovered
36 | DefaultRebuildingNodePeriod = 15 * time.Second
37 |
38 | // DefaultRetryCountForConn is a default number of connection retries before return i/o timeout error
39 | DefaultRetryCountForConn = uint8(3)
40 |
41 | // DefaultOfNumberConnsToDestroyPerRBPeriod is number of connections in pool whose needed close in every rebuild node cycle
42 | DefaultOfNumberConnsToDestroyPerRBPeriod = 1
43 |
44 | // DefaultSocketPoolingTimeout Amount of time to acquire socket from pool
45 | DefaultSocketPoolingTimeout = 50 * time.Millisecond
46 | )
47 |
48 | var _ Memcached = (*Client)(nil)
49 |
50 | type (
51 | Memcached interface {
52 | Store(storeMode StoreMode, key string, exp uint32, body []byte) (*Response, error)
53 | Get(key string) (*Response, error)
54 | Delete(key string) (*Response, error)
55 | Delta(deltaMode DeltaMode, key string, delta, initial uint64, exp uint32) (newValue uint64, err error)
56 | Append(appendMode AppendMode, key string, data []byte) (*Response, error)
57 | FlushAll(exp uint32) error
58 | MultiDelete(keys []string) error
59 | MultiStore(storeMode StoreMode, items map[string][]byte, exp uint32) error
60 | MultiGet(keys []string) (map[string][]byte, error)
61 |
62 | CloseAllConns()
63 | CloseAvailableConnsInAllShardPools(numOfClose int) int
64 | }
65 |
66 | // Client is a memcached client.
67 | // It is safe for unlocked use by multiple concurrent goroutines.
68 | Client struct {
69 | ctx context.Context
70 | nw *network
71 | cfg *config
72 |
73 | // opaque - a unique identifier for the request, used to associate the request with its corresponding response.
74 | opaque *uint32
75 |
76 | // timeout specifies the socket read/write timeout.
77 | // If zero, DefaultTimeout is used.
78 | timeout time.Duration
79 |
80 | // maxIdleConns specifies the maximum number of idle connections that will
81 | // be maintained per address. If less than one, DefaultMaxIdleConns will be
82 | // used.
83 | //
84 | // Consider your expected traffic rates and latency carefully. This should
85 | // be set to a number higher than your peak parallel requests.
86 | maxIdleConns int
87 |
88 | // hr - hash ring implementation (can be a custom consistenthash.NewCustomHashRing)
89 | hr consistenthash.ConsistentHash
90 |
91 | // disableMemcachedDiagnostic - is flag for turn off write metrics from lib.
92 | disableMemcachedDiagnostic bool
93 | // disableNodeProvider - is flag for turn off rebuild and health check nodes.
94 | disableNodeProvider bool
95 | // disableRefreshConns - is flag for turn off to refresh conns in the pool.
96 | disableRefreshConns bool
97 | // nodeHCPeriod - period for execute nodes health checker
98 | // if zero, DefaultNodeHealthCheckPeriod is used.
99 | nodeHCPeriod time.Duration
100 | // nodeRBPeriod - period for execute rebuilding nodes
101 | // if zero, DefaultNodeHealthCheckPeriod is used.
102 | nodeRBPeriod time.Duration
103 |
104 | // fmu - mutex for freeConns
105 | fmu sync.RWMutex
106 | // freeConns hashmap with nodes and their open dial connections
107 | freeConns map[string]*pool.Pool
108 | // dmu - mutex for deadNodes
109 | dmu sync.RWMutex
110 | // deadNodes hashmap with nodes that did not respond to health check
111 | deadNodes map[string]struct{}
112 |
113 | authEnable bool
114 | // authData ready body for authentication request
115 | authData []byte
116 | }
117 |
118 | network struct {
119 | dial func(network string, address string) (net.Conn, error)
120 | dialTimeout func(network string, address string, timeout time.Duration) (net.Conn, error)
121 | lookupHost func(host string) (addrs []string, err error)
122 | }
123 |
124 | config struct {
125 | // HeadlessServiceAddress Headless service to lookup all the memcached ip addresses.
126 | HeadlessServiceAddress string `envconfig:"MEMCACHED_HEADLESS_SERVICE_ADDRESS"`
127 | // Servers List of servers with hosted memcached
128 | Servers []string `envconfig:"MEMCACHED_SERVERS"`
129 | // MemcachedPort The optional port override for cases when memcached IP addresses are obtained from headless service.
130 | MemcachedPort int `envconfig:"MEMCACHED_PORT" default:"11211"`
131 | }
132 | conn struct {
133 | rc io.ReadCloser
134 | addr net.Addr
135 | c *Client
136 | hdrBuf []byte
137 | healthy bool
138 | wrtBuf *bufio.Writer
139 | authed bool
140 | }
141 | )
142 |
143 | // InitFromEnv returns a memcached client using the config.HeadlessServiceAddress or config.Servers
144 | // with equal weight. If a server is listed multiple times,
145 | // it gets a proportional amount of weight.
146 | func InitFromEnv(opts ...Option) (*Client, error) {
147 | var (
148 | op = new(options)
149 | cfg = new(config)
150 | )
151 | if err := envconfig.Process("", cfg); err != nil {
152 | return nil, fmt.Errorf("%s: client init err: %s", libPrefix, err.Error())
153 | }
154 |
155 | op.cfg = cfg
156 |
157 | for _, opt := range opts {
158 | opt(op)
159 | }
160 |
161 | if op.Client.nw == nil {
162 | op.Client.nw = &network{
163 | dial: net.Dial,
164 | dialTimeout: net.DialTimeout,
165 | lookupHost: net.LookupHost,
166 | }
167 | }
168 | if op.Client.hr == nil {
169 | op.Client.hr = consistenthash.NewHashRing()
170 | }
171 | if op.Client.ctx == nil {
172 | op.Client.ctx = context.Background()
173 | }
174 | if op.Client.opaque == nil {
175 | op.Client.opaque = new(uint32)
176 | }
177 | if op.disableLogger {
178 | logger.DisableLogger()
179 | }
180 |
181 | return newFromConfig(op)
182 | }
183 |
184 | func newForTests(servers ...string) (*Client, error) {
185 | hr := consistenthash.NewHashRing()
186 | for _, s := range servers {
187 | addr, err := utils.AddrRepr(s)
188 | if err != nil {
189 | return nil, fmt.Errorf("%w: %s", ErrInvalidAddr, err.Error())
190 | }
191 | hr.Add(addr)
192 | }
193 | cm := &Client{
194 | ctx: context.Background(),
195 | opaque: new(uint32),
196 | hr: hr,
197 | disableMemcachedDiagnostic: true,
198 | nw: &network{
199 | dial: net.Dial,
200 | dialTimeout: net.DialTimeout,
201 | lookupHost: net.LookupHost,
202 | },
203 | }
204 |
205 | return cm, nil
206 | }
207 |
208 | func newFromConfig(op *options) (*Client, error) {
209 | if op.cfg != nil && !(op.cfg.HeadlessServiceAddress != "" || len(op.cfg.Servers) != 0) {
210 | return nil, fmt.Errorf("%w, you must fill in either MEMCACHED_HEADLESS_SERVICE_ADDRESS or MEMCACHED_SERVERS", ErrNotConfigured)
211 | }
212 | nodes, err := getNodes(op.nw.lookupHost, op.cfg)
213 | if err != nil {
214 | return nil, fmt.Errorf("%w, %s", ErrInvalidAddr, err.Error())
215 | }
216 |
217 | mc := &op.Client
218 |
219 | for _, n := range nodes {
220 | addr, err := utils.AddrRepr(n)
221 | if err != nil {
222 | return nil, fmt.Errorf("%w: %s", ErrInvalidAddr, err.Error())
223 | }
224 | mc.hr.Add(addr)
225 | }
226 |
227 | if !mc.disableNodeProvider {
228 | mc.initNodesProvider()
229 | }
230 | return mc, nil
231 | }
232 |
233 | // release returns this connection back to the client's free pool
234 | func (cn *conn) release() {
235 | cn.c.putFreeConn(cn)
236 | }
237 |
238 | func (cn *conn) close() {
239 | if p, ok := cn.c.safeGetFreeConn(cn.addr); ok {
240 | p.Close(cn)
241 | } else {
242 | _ = cn.rc.Close()
243 | }
244 | }
245 |
246 | // condRelease releases this connection if the error pointed to by err
247 | // is nil (not an error) or is only a protocol level error (e.g. a
248 | // cache miss). The purpose is to not recycle TCP connections that
249 | // are bad.
250 | func (cn *conn) condRelease(err *error) {
251 | if (*err == nil || resumableError(*err)) && cn.healthy {
252 | cn.release()
253 | } else {
254 | cn.close()
255 | }
256 | }
257 |
258 | func (c *Client) getOpaque() uint32 {
259 | atomic.CompareAndSwapUint32(c.opaque, math.MaxUint32, uint32(0))
260 | return atomic.AddUint32(c.opaque, uint32(1))
261 | }
262 |
263 | func (c *Client) safeGetFreeConn(addr net.Addr) (*pool.Pool, bool) {
264 | c.fmu.RLock()
265 | defer c.fmu.RUnlock()
266 | connPool, ok := c.freeConns[addr.String()]
267 | return connPool, ok
268 | }
269 |
270 | func (c *Client) safeGetOrInitFreeConn(addr net.Addr) *pool.Pool {
271 | c.fmu.Lock()
272 | defer c.fmu.Unlock()
273 |
274 | connPool, ok := c.freeConns[addr.String()]
275 | if ok {
276 | return connPool
277 | }
278 |
279 | dialConn := func() (any, error) {
280 | nc, err := c.dial(addr)
281 | if err != nil {
282 | return nil, err
283 | }
284 | return &conn{
285 | rc: nc,
286 | addr: addr,
287 | c: c,
288 | hdrBuf: make([]byte, HDR_LEN),
289 | wrtBuf: bufio.NewWriter(nc),
290 | healthy: true,
291 | }, nil
292 | }
293 |
294 | closeConn := func(cn any) {
295 | _ = cn.(*conn).rc.Close()
296 | }
297 |
298 | newPool := pool.New(c.ctx, int32(c.getMaxIdleConns()), DefaultSocketPoolingTimeout, dialConn, closeConn)
299 |
300 | if c.freeConns == nil {
301 | c.freeConns = make(map[string]*pool.Pool)
302 | }
303 | c.freeConns[addr.String()] = newPool
304 |
305 | return newPool
306 | }
307 |
308 | func (c *Client) freeConnsIsNil() bool {
309 | c.fmu.RLock()
310 | defer c.fmu.RUnlock()
311 | return c.freeConns == nil
312 | }
313 |
314 | func (c *Client) putFreeConn(cn *conn) {
315 | connPool, ok := c.safeGetFreeConn(cn.addr)
316 | if ok {
317 | connPool.Put(cn)
318 | } else {
319 | _ = cn.rc.Close()
320 | }
321 | }
322 |
323 | func (c *Client) getFreeConn(addr net.Addr) (*conn, error) {
324 | connPool := c.safeGetOrInitFreeConn(addr)
325 |
326 | connRaw, err := connPool.Get()
327 | if err != nil {
328 | return nil, fmt.Errorf("%s: Get from pool error - %w", libPrefix, err)
329 | }
330 |
331 | cn := connRaw.(*conn)
332 |
333 | if c.authEnable && !cn.authed {
334 | if c.authenticate(cn) {
335 | cn.authed = true
336 | return cn, nil
337 | } else {
338 | return nil, ErrAuthFail
339 | }
340 | }
341 |
342 | return connRaw.(*conn), nil
343 | }
344 |
345 | func (c *Client) removeFromFreeConns(addr net.Addr) {
346 | if c.freeConnsIsNil() {
347 | return
348 | }
349 | connPool, ok := c.safeGetFreeConn(addr)
350 |
351 | c.fmu.Lock()
352 | defer c.fmu.Unlock()
353 | if ok {
354 | connPool.Destroy()
355 | }
356 | delete(c.freeConns, addr.String())
357 | }
358 |
359 | func (c *Client) netTimeout() time.Duration {
360 | if c.timeout != 0 {
361 | return c.timeout
362 | }
363 | return DefaultTimeout
364 | }
365 |
366 | func (c *Client) getMaxIdleConns() int {
367 | if c.maxIdleConns > 0 {
368 | return c.maxIdleConns
369 | }
370 | return DefaultMaxIdleConns
371 | }
372 |
373 | func (c *Client) getHCPeriod() time.Duration {
374 | if c.nodeHCPeriod > 0 {
375 | return c.nodeHCPeriod
376 | }
377 | return DefaultNodeHealthCheckPeriod
378 | }
379 |
380 | func (c *Client) getRBPeriod() time.Duration {
381 | if c.nodeRBPeriod > 0 {
382 | return c.nodeRBPeriod
383 | }
384 | return DefaultRebuildingNodePeriod
385 | }
386 |
387 | // ConnectTimeoutError is the error type used when it takes
388 | // too long to connect to the desired host. This level of
389 | // detail can generally be ignored.
390 | type ConnectTimeoutError struct {
391 | Addr net.Addr
392 | }
393 |
394 | func (cte *ConnectTimeoutError) Error() string {
395 | return "connect timeout to " + cte.Addr.String()
396 | }
397 |
398 | func (c *Client) dial(addr net.Addr) (net.Conn, error) {
399 | if c.netTimeout() > 0 {
400 | nc, err := c.nw.dialTimeout(addr.Network(), addr.String(), c.netTimeout())
401 | if err != nil {
402 | var ne net.Error
403 | if errors.As(err, &ne) && ne.Timeout() {
404 | return nil, &ConnectTimeoutError{addr}
405 | }
406 | return nil, err
407 | }
408 | return nc, nil
409 | }
410 | return c.nw.dial(addr.Network(), addr.String())
411 | }
412 |
413 | func (c *Client) getConnForNode(node any) (*conn, error) {
414 | addr, ok := node.(net.Addr)
415 | if !ok {
416 | return nil, ErrInvalidAddr
417 | }
418 | cn, err := c.getFreeConn(addr)
419 | if err != nil {
420 | return nil, err
421 | }
422 |
423 | return cn, nil
424 | }
425 |
426 | // Store is a wrote the provided item with expiration.
427 | func (c *Client) Store(storeMode StoreMode, key string, exp uint32, body []byte) (_ *Response, err error) {
428 | timer := time.Now()
429 | defer c.writeMethodDiagnostics("Store", timer, &err)
430 |
431 | if !legalKey(key) {
432 | return nil, ErrMalformedKey
433 | }
434 |
435 | node, find := c.hr.Get(key)
436 | if !find {
437 | return nil, ErrNoServers
438 | }
439 |
440 | cn, err := c.getConnForNode(node)
441 | if err != nil {
442 | return nil, err
443 | }
444 | return c.store(cn, storeMode.Resolve(), key, exp, c.getOpaque(), body)
445 | }
446 |
447 | func (c *Client) store(cn *conn, opcode OpCode, key string, exp, opaque uint32, body []byte) (*Response, error) {
448 | req := &Request{
449 | Opcode: opcode,
450 | Key: []byte(key),
451 | Opaque: opaque,
452 | Body: body,
453 | }
454 | req.prepareExtras(exp, 0, 0)
455 | return c.send(cn, req)
456 | }
457 |
458 | func (c *Client) send(cn *conn, req *Request) (resp *Response, err error) {
459 | defer cn.condRelease(&err)
460 | _, err = transmitRequest(cn.wrtBuf, req)
461 | if err != nil {
462 | cn.healthy = false
463 | return
464 | }
465 |
466 | if err = cn.wrtBuf.Flush(); err != nil {
467 | return nil, err
468 | }
469 |
470 | resp, _, err = getResponse(cn.rc, cn.hdrBuf)
471 | cn.healthy = !isFatal(err)
472 | return resp, err
473 | }
474 |
475 | // Get is return an item for provided key.
476 | func (c *Client) Get(key string) (_ *Response, err error) {
477 | timer := time.Now()
478 | defer c.writeMethodDiagnostics("Get", timer, &err)
479 |
480 | if !legalKey(key) {
481 | return nil, ErrMalformedKey
482 | }
483 |
484 | node, find := c.hr.Get(key)
485 | if !find {
486 | return nil, ErrNoServers
487 | }
488 |
489 | cn, err := c.getConnForNode(node)
490 | if err != nil {
491 | return nil, err
492 | }
493 |
494 | req := &Request{
495 | Opcode: GET,
496 | Opaque: c.getOpaque(),
497 | Key: []byte(key),
498 | }
499 | req.prepareExtras(0, 0, 0)
500 |
501 | return c.send(cn, req)
502 | }
503 |
504 | // Delete is a deletes the element with the provided key.
505 | // If the element does not exist, an ErrCacheMiss error is returned.
506 | func (c *Client) Delete(key string) (_ *Response, err error) {
507 | timer := time.Now()
508 | defer c.writeMethodDiagnostics("Delete", timer, &err)
509 |
510 | if !legalKey(key) {
511 | return nil, ErrMalformedKey
512 | }
513 |
514 | node, find := c.hr.Get(key)
515 | if !find {
516 | return nil, ErrNoServers
517 | }
518 |
519 | cn, err := c.getConnForNode(node)
520 | if err != nil {
521 | return nil, err
522 | }
523 |
524 | req := &Request{
525 | Opcode: DELETE,
526 | Opaque: c.getOpaque(),
527 | Key: []byte(key),
528 | }
529 | req.prepareExtras(0, 0, 0)
530 |
531 | return c.send(cn, req)
532 | }
533 |
534 | // Delta is an atomically increments/decrements value by delta. The return value is
535 | // the new value after being incremented/decrements or an error.
536 | func (c *Client) Delta(deltaMode DeltaMode, key string, delta, initial uint64, exp uint32) (newValue uint64, err error) {
537 | timer := time.Now()
538 | defer c.writeMethodDiagnostics("Delta", timer, &err)
539 |
540 | if !legalKey(key) {
541 | return 0, ErrMalformedKey
542 | }
543 |
544 | node, find := c.hr.Get(key)
545 | if !find {
546 | return 0, ErrNoServers
547 | }
548 |
549 | cn, err := c.getConnForNode(node)
550 | if err != nil {
551 | return 0, err
552 | }
553 |
554 | req := &Request{
555 | Opcode: deltaMode.Resolve(),
556 | Key: []byte(key),
557 | }
558 | req.prepareExtras(exp, delta, initial)
559 |
560 | resp, err := c.send(cn, req)
561 | if err != nil {
562 | return 0, err
563 | }
564 |
565 | return binary.BigEndian.Uint64(resp.Body), nil
566 | }
567 |
568 | // Append is an appends/prepends the given item to the existing item, if a value already
569 | // exists for its key. ErrNotStored is returned if that condition is not met.
570 | func (c *Client) Append(appendMode AppendMode, key string, data []byte) (_ *Response, err error) {
571 | timer := time.Now()
572 | defer c.writeMethodDiagnostics("Append", timer, &err)
573 |
574 | if !legalKey(key) {
575 | return nil, ErrMalformedKey
576 | }
577 |
578 | node, find := c.hr.Get(key)
579 | if !find {
580 | return nil, ErrNoServers
581 | }
582 |
583 | cn, err := c.getConnForNode(node)
584 | if err != nil {
585 | return nil, err
586 | }
587 |
588 | req := &Request{
589 | Opcode: appendMode.Resolve(),
590 | Opaque: c.getOpaque(),
591 | Key: []byte(key),
592 | Body: data,
593 | }
594 | req.prepareExtras(0, 0, 0)
595 |
596 | return c.send(cn, req)
597 | }
598 |
599 | // FlushAll is a deletes all items in the cache.
600 | func (c *Client) FlushAll(exp uint32) (err error) {
601 | timerMethod := time.Now()
602 | defer c.writeMethodDiagnostics("FlushAll", timerMethod, &err)
603 |
604 | var (
605 | wg sync.WaitGroup
606 | mu sync.Mutex
607 | multiErr error
608 |
609 | nodes = c.hr.GetAllNodes()
610 | )
611 |
612 | addToMultiErr := func(e error) {
613 | mu.Lock()
614 | defer mu.Unlock()
615 | multiErr = errors.Join(multiErr, e)
616 | }
617 |
618 | for _, node := range nodes {
619 | wg.Add(1)
620 | go func(node any) {
621 | defer wg.Done()
622 |
623 | var cn *conn
624 | cn, err = c.getConnForNode(node)
625 | if err != nil {
626 | addToMultiErr(err)
627 | return
628 | }
629 | defer cn.condRelease(&err)
630 |
631 | req := &Request{
632 | Opcode: FLUSH,
633 | }
634 | req.prepareExtras(exp, 0, 0)
635 |
636 | _, err = transmitRequest(cn.wrtBuf, req)
637 | if err != nil {
638 | cn.healthy = false
639 | return
640 | }
641 |
642 | if err = cn.wrtBuf.Flush(); err != nil {
643 | logger.Errorf("%s. %s", ErrServerError.Error(), err.Error())
644 | return
645 | }
646 |
647 | _, _, err = getResponse(cn.rc, cn.hdrBuf)
648 | if err != nil {
649 | if isFatal(err) {
650 | cn.healthy = false
651 | return
652 | }
653 | addToMultiErr(err)
654 | }
655 | }(node)
656 | }
657 |
658 | wg.Wait()
659 |
660 | return multiErr
661 | }
662 |
663 | // MultiGet is a batch version of Get. The returned map from keys to
664 | // items may have fewer elements than the input slice, due to memcached
665 | // cache misses. Each key must be at most 250 bytes in length.
666 | // If no error is returned, the returned map will also be non-nil.
667 | func (c *Client) MultiGet(keys []string) (_ map[string][]byte, err error) {
668 | var (
669 | wg sync.WaitGroup
670 | mu sync.Mutex
671 |
672 | ret = make(map[string][]byte, len(keys))
673 | )
674 | if len(keys) == 0 {
675 | return ret, nil
676 | }
677 |
678 | timerMethod := time.Now()
679 | defer c.writeMethodDiagnostics("MultiGet", timerMethod, &err)
680 |
681 | if len(keys) == 1 {
682 | var res *Response
683 | res, err = c.Get(keys[0])
684 | if res != nil {
685 | if res.Status == SUCCESS {
686 | ret[keys[0]] = res.Body
687 | } else if res.Status == KEY_ENOENT {
688 | // MultiGet never returns a ENOENT
689 | err = nil
690 | }
691 | }
692 | return ret, err
693 | }
694 |
695 | var (
696 | once sync.Once
697 | singleError error
698 | )
699 |
700 | addToRet := func(key string, body []byte) {
701 | mu.Lock()
702 | defer mu.Unlock()
703 | ret[key] = body
704 | }
705 |
706 | nodes, err := getNodesForKeys(c.hr, keys)
707 | if err != nil {
708 | return ret, err
709 | }
710 |
711 | for node, ks := range nodes {
712 | wg.Add(1)
713 | go func(node any, keys []string) {
714 | defer wg.Done()
715 |
716 | var cnErr error
717 |
718 | cn, nErr := c.getConnForNode(node)
719 | if nErr != nil {
720 | once.Do(func() {
721 | singleError = nErr
722 | })
723 | return
724 | }
725 | defer cn.condRelease(&cnErr)
726 |
727 | idToKey := make(map[uint32]string, len(keys))
728 |
729 | for _, key := range keys {
730 | opaqueGet := c.getOpaque()
731 | req := &Request{
732 | Opcode: GETQ,
733 | Opaque: opaqueGet,
734 | Key: []byte(key),
735 | }
736 | req.prepareExtras(0, 0, 0)
737 |
738 | _, cnErr = transmitRequest(cn.wrtBuf, req)
739 | if cnErr != nil {
740 | cn.healthy = false
741 | return
742 | }
743 |
744 | idToKey[opaqueGet] = key
745 | }
746 |
747 | opaqueNOOP := c.getOpaque()
748 | req := &Request{
749 | Opcode: NOOP,
750 | Opaque: opaqueNOOP,
751 | }
752 | req.prepareExtras(0, 0, 0)
753 |
754 | _, cnErr = transmitRequest(cn.wrtBuf, req)
755 | if cnErr != nil {
756 | cn.healthy = false
757 | return
758 | }
759 |
760 | if cnErr = cn.wrtBuf.Flush(); err != nil {
761 | logger.Errorf("%s. %s", ErrServerError.Error(), cnErr.Error())
762 | return
763 | }
764 |
765 | for {
766 | var resp *Response
767 | resp, _, cnErr = getResponse(cn.rc, cn.hdrBuf)
768 | if isFatal(cnErr) {
769 | cn.healthy = false
770 | return
771 | }
772 |
773 | if resp.Opcode == NOOP && resp.Opaque == opaqueNOOP {
774 | break
775 | }
776 |
777 | if key, ok := idToKey[resp.Opaque]; ok && cnErr == nil {
778 | addToRet(key, resp.Body)
779 | }
780 | }
781 | }(node, ks)
782 | }
783 |
784 | wg.Wait()
785 |
786 | return ret, singleError
787 | }
788 |
789 | // MultiStore is a batch version of Store.
790 | // Writes the provided items with expiration.
791 | func (c *Client) MultiStore(storeMode StoreMode, items map[string][]byte, exp uint32) (err error) {
792 | if len(items) == 0 {
793 | return nil
794 | }
795 |
796 | timerMethod := time.Now()
797 | defer c.writeMethodDiagnostics("MultiStore", timerMethod, &err)
798 |
799 | var (
800 | wg sync.WaitGroup
801 | muMErr sync.Mutex
802 | multiErr error
803 | )
804 |
805 | addToMultiErr := func(e error) {
806 | muMErr.Lock()
807 | defer muMErr.Unlock()
808 | multiErr = errors.Join(multiErr, e)
809 | }
810 |
811 | var muItems sync.RWMutex
812 | safeGetItems := func(key string) []byte {
813 | muItems.RLock()
814 | defer muItems.RUnlock()
815 | return items[key]
816 | }
817 |
818 | quietCode := storeMode.Resolve().changeOnQuiet(SETQ)
819 |
820 | keys := maps.Keys(items)
821 | nodes, err := getNodesForKeys(c.hr, keys)
822 | if err != nil {
823 | return err
824 | }
825 |
826 | for node, ks := range nodes {
827 | wg.Add(1)
828 | go func(node any, keys []string, exp uint32) {
829 | defer wg.Done()
830 |
831 | var cnErr error
832 |
833 | cn, nErr := c.getConnForNode(node)
834 | if nErr != nil {
835 | addToMultiErr(nErr)
836 | return
837 | }
838 | defer cn.condRelease(&cnErr)
839 |
840 | idToKey := make(map[uint32]string, len(keys))
841 |
842 | for _, key := range keys {
843 | opaqueStore := c.getOpaque()
844 | req := &Request{
845 | Opcode: quietCode,
846 | Opaque: opaqueStore,
847 | Key: []byte(key),
848 | Body: safeGetItems(key),
849 | }
850 | req.prepareExtras(exp, 0, 0)
851 |
852 | _, cnErr = transmitRequest(cn.wrtBuf, req)
853 | if cnErr != nil {
854 | cn.healthy = false
855 | return
856 | }
857 |
858 | idToKey[opaqueStore] = key
859 | }
860 |
861 | opaqueNOOP := c.getOpaque()
862 | req := &Request{
863 | Opcode: NOOP,
864 | Opaque: opaqueNOOP,
865 | }
866 | req.prepareExtras(0, 0, 0)
867 |
868 | _, cnErr = transmitRequest(cn.wrtBuf, req)
869 | if cnErr != nil {
870 | cn.healthy = false
871 | return
872 | }
873 |
874 | if cnErr = cn.wrtBuf.Flush(); err != nil {
875 | logger.Errorf("%s. %s", ErrServerError.Error(), cnErr.Error())
876 | return
877 | }
878 |
879 | for {
880 | var resp *Response
881 | resp, _, cnErr = getResponse(cn.rc, cn.hdrBuf)
882 | if isFatal(cnErr) {
883 | cn.healthy = false
884 | return
885 | }
886 |
887 | if resp.Opcode == NOOP && resp.Opaque == opaqueNOOP {
888 | break
889 | }
890 |
891 | if key, ok := idToKey[resp.Opaque]; ok {
892 | if resp.Status != SUCCESS {
893 | addToMultiErr(fmt.Errorf("%w. Error for key - %s", cnErr, key))
894 | }
895 | }
896 | }
897 | }(node, ks, exp)
898 | }
899 |
900 | wg.Wait()
901 |
902 | return multiErr
903 | }
904 |
905 | // MultiDelete is a batch version of Delete.
906 | // Deletes the items with the provided keys.
907 | // If there is a key in the provided keys that is missing in the cache,
908 | // the ErrCacheMiss error is ignored.
909 | func (c *Client) MultiDelete(keys []string) (err error) {
910 | if len(keys) == 0 {
911 | return nil
912 | }
913 |
914 | timerMethod := time.Now()
915 | defer c.writeMethodDiagnostics("MultiDelete", timerMethod, &err)
916 |
917 | var (
918 | wg sync.WaitGroup
919 | mu sync.Mutex
920 | multiErr error
921 | )
922 |
923 | addToMultiErr := func(e error) {
924 | mu.Lock()
925 | defer mu.Unlock()
926 | multiErr = errors.Join(multiErr, e)
927 | }
928 |
929 | nodes, err := getNodesForKeys(c.hr, keys)
930 | if err != nil {
931 | return err
932 | }
933 |
934 | for node, ks := range nodes {
935 | wg.Add(1)
936 | go func(node any, keys []string) {
937 | defer wg.Done()
938 |
939 | var cnErr error
940 |
941 | cn, nErr := c.getConnForNode(node)
942 | if nErr != nil {
943 | addToMultiErr(nErr)
944 | return
945 | }
946 | defer cn.condRelease(&cnErr)
947 |
948 | idToKey := make(map[uint32]string, len(keys))
949 |
950 | for _, key := range keys {
951 | opaqueDel := c.getOpaque()
952 | req := &Request{
953 | Opcode: DELETEQ,
954 | Opaque: opaqueDel,
955 | Key: []byte(key),
956 | }
957 | req.prepareExtras(0, 0, 0)
958 |
959 | _, cnErr = transmitRequest(cn.wrtBuf, req)
960 | if cnErr != nil {
961 | cn.healthy = false
962 | return
963 | }
964 |
965 | idToKey[opaqueDel] = key
966 | }
967 |
968 | opaqueNOOP := c.getOpaque()
969 | req := &Request{
970 | Opcode: NOOP,
971 | Opaque: opaqueNOOP,
972 | }
973 | req.prepareExtras(0, 0, 0)
974 |
975 | _, cnErr = transmitRequest(cn.wrtBuf, req)
976 | if cnErr != nil {
977 | cn.healthy = false
978 | return
979 | }
980 |
981 | if cnErr = cn.wrtBuf.Flush(); err != nil {
982 | logger.Errorf("%s. %s", ErrServerError.Error(), cnErr.Error())
983 | return
984 | }
985 |
986 | for {
987 | var resp *Response
988 | resp, _, cnErr = getResponse(cn.rc, cn.hdrBuf)
989 | if isFatal(cnErr) {
990 | cn.healthy = false
991 | return
992 | }
993 |
994 | if resp.Opcode == NOOP && resp.Opaque == opaqueNOOP {
995 | break
996 | }
997 |
998 | if key, ok := idToKey[resp.Opaque]; ok {
999 | if resp.Status != SUCCESS && resp.Status != KEY_ENOENT {
1000 | addToMultiErr(fmt.Errorf("%w. Error for key - %s", cnErr, key))
1001 | }
1002 | }
1003 | }
1004 | }(node, ks)
1005 | }
1006 |
1007 | wg.Wait()
1008 |
1009 | return multiErr
1010 | }
1011 |
1012 | // CloseAllConns is close all opened connection per shards.
1013 | // Once closed, resources should be released.
1014 | func (c *Client) CloseAllConns() {
1015 | c.fmu.Lock()
1016 | defer c.fmu.Unlock()
1017 |
1018 | for addr, connPool := range c.freeConns {
1019 | connPool.Destroy()
1020 | delete(c.freeConns, addr)
1021 | }
1022 | }
1023 |
1024 | // CloseAvailableConnsInAllShardPools - removes the specified number of connections from the pools of all shards.
1025 | func (c *Client) CloseAvailableConnsInAllShardPools(numOfClose int) int {
1026 | var closed int
1027 |
1028 | c.fmu.Lock()
1029 | defer c.fmu.Unlock()
1030 |
1031 | for _, p := range c.freeConns {
1032 | for i := 0; i < numOfClose; i++ {
1033 | if connRaw, ok := p.Pop(); ok {
1034 | p.Close(connRaw)
1035 | closed++
1036 | }
1037 | }
1038 | }
1039 |
1040 | return closed
1041 | }
1042 |
1043 | func (c *Client) writeMethodDiagnostics(methodName string, timer time.Time, err *error) {
1044 | if methodName == "" || c.disableMemcachedDiagnostic {
1045 | return
1046 | }
1047 |
1048 | observeMethodDurationSeconds(methodName, time.Since(timer).Seconds(), *err == nil)
1049 | }
1050 |
1051 | func (c *Client) authenticate(cn *conn) (ok bool) {
1052 | req := &Request{
1053 | Key: []byte(SaslMechanism),
1054 | Body: c.authData,
1055 | }
1056 |
1057 | req.Opcode = SASL_AUTH
1058 | _, err := transmitRequest(cn.wrtBuf, req)
1059 | if err != nil {
1060 | return
1061 | }
1062 |
1063 | if err = cn.wrtBuf.Flush(); err != nil {
1064 | return
1065 | }
1066 |
1067 | resp, _, err := getResponse(cn.rc, cn.hdrBuf)
1068 | if err == nil {
1069 | return true
1070 | }
1071 | if err != nil && resp.Status != FURTHER_AUTH {
1072 | logger.Errorf("%s: Error from sasl auth - %v", libPrefix, resp)
1073 | return
1074 | }
1075 |
1076 | req.Opcode = SASL_STEP
1077 | _, err = transmitRequest(cn.wrtBuf, req)
1078 | if err != nil {
1079 | return
1080 | }
1081 |
1082 | resp, _, err = getResponse(cn.rc, cn.hdrBuf)
1083 | if err != nil {
1084 | logger.Errorf("%s: Error from sasl step - %v", libPrefix, resp)
1085 | return
1086 | }
1087 |
1088 | if err = cn.wrtBuf.Flush(); err != nil {
1089 | return
1090 | }
1091 |
1092 | return true
1093 | }
1094 |
1095 | func legalKey(key string) bool {
1096 | if len(key) > 250 {
1097 | return false
1098 | }
1099 | for i := 0; i < len(key); i++ {
1100 | if key[i] <= ' ' || key[i] == 0x7f {
1101 | return false
1102 | }
1103 | }
1104 | return true
1105 | }
1106 |
1107 | // getNodesForKeys return a map where key is a node and value is a suitable keys
1108 | func getNodesForKeys(hr consistenthash.ConsistentHash, keys []string) (map[any][]string, error) {
1109 | resp := make(map[any][]string, hr.GetNodesCount())
1110 |
1111 | for _, key := range keys {
1112 | if !legalKey(key) {
1113 | return nil, fmt.Errorf("%w. Invalid key - %v", ErrMalformedKey, key)
1114 | }
1115 | if node, found := hr.Get(key); found {
1116 | resp[node] = append(resp[node], key)
1117 | }
1118 | }
1119 |
1120 | return resp, nil
1121 | }
1122 |
--------------------------------------------------------------------------------
/memcached/memcached_test.go:
--------------------------------------------------------------------------------
1 | // nolint
2 | package memcached
3 |
4 | import (
5 | "bufio"
6 | "bytes"
7 | "errors"
8 | "fmt"
9 | "io/ioutil"
10 | "net"
11 | "reflect"
12 | "strconv"
13 | "sync"
14 | "testing"
15 | "time"
16 |
17 | "github.com/stretchr/testify/assert"
18 | "github.com/stretchr/testify/require"
19 | "golang.org/x/exp/maps"
20 |
21 | "github.com/aliexpressru/gomemcached/consistenthash"
22 | "github.com/aliexpressru/gomemcached/utils"
23 | )
24 |
25 | func TestTransmitReq(t *testing.T) {
26 | b := bytes.NewBuffer([]byte{})
27 | buf := bufio.NewWriter(b)
28 |
29 | req := Request{
30 | Opcode: SET,
31 | Cas: 938424885,
32 | Opaque: 7242,
33 | Extras: []byte{},
34 | Key: []byte("somekey"),
35 | Body: []byte("somevalue"),
36 | }
37 |
38 | // Verify nil transmit is OK
39 | _, err := transmitRequest(nil, &req)
40 | if !errors.Is(err, ErrNoServers) {
41 | t.Errorf("Expected errNoConn with no conn, got %v", err)
42 | }
43 |
44 | _, err = transmitRequest(buf, &req)
45 | if err != nil {
46 | t.Fatalf("Error transmitting request: %v", err)
47 | }
48 |
49 | buf.Flush()
50 |
51 | expected := []byte{
52 | REQ_MAGIC, byte(SET),
53 | 0x0, 0x7, // length of key
54 | 0x0, // extra length
55 | 0x0, // reserved
56 | 0x0, 0x0, // reserved
57 | 0x0, 0x0, 0x0, 0x10, // Length of value
58 | 0x0, 0x0, 0x1c, 0x4a, // opaque
59 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS
60 | 's', 'o', 'm', 'e', 'k', 'e', 'y',
61 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e',
62 | }
63 |
64 | if len(b.Bytes()) != req.Size() {
65 | t.Fatalf("Expected %v bytes, got %v", req.Size(),
66 | len(b.Bytes()))
67 | }
68 |
69 | if !reflect.DeepEqual(b.Bytes(), expected) {
70 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v",
71 | expected, b.Bytes())
72 | }
73 | }
74 |
75 | func BenchmarkTransmitReq(b *testing.B) {
76 | bout := bytes.NewBuffer([]byte{})
77 |
78 | req := Request{
79 | Opcode: SET,
80 | Cas: 938424885,
81 | Opaque: 7242,
82 | Extras: []byte{},
83 | Key: []byte("somekey"),
84 | Body: []byte("somevalue"),
85 | }
86 |
87 | b.SetBytes(int64(req.Size()))
88 |
89 | for i := 0; i < b.N; i++ {
90 | bout.Reset()
91 | buf := bufio.NewWriterSize(bout, req.Size()*2)
92 | _, err := transmitRequest(buf, &req)
93 | if err != nil {
94 | b.Fatalf("Error transmitting request: %v", err)
95 | }
96 | }
97 | }
98 |
99 | func BenchmarkTransmitReqLarge(b *testing.B) {
100 | bout := bytes.NewBuffer([]byte{})
101 |
102 | req := Request{
103 | Opcode: SET,
104 | Cas: 938424885,
105 | Opaque: 7242,
106 | Extras: []byte{},
107 | Key: []byte("somekey"),
108 | Body: make([]byte, 24*1024),
109 | }
110 |
111 | b.SetBytes(int64(req.Size()))
112 |
113 | for i := 0; i < b.N; i++ {
114 | bout.Reset()
115 | buf := bufio.NewWriterSize(bout, req.Size()*2)
116 | _, err := transmitRequest(buf, &req)
117 | if err != nil {
118 | b.Fatalf("Error transmitting request: %v", err)
119 | }
120 | }
121 | }
122 |
123 | func BenchmarkTransmitReqNull(b *testing.B) {
124 | req := Request{
125 | Opcode: SET,
126 | Cas: 938424885,
127 | Opaque: 7242,
128 | Extras: []byte{},
129 | Key: []byte("somekey"),
130 | Body: []byte("somevalue"),
131 | }
132 |
133 | b.SetBytes(int64(req.Size()))
134 |
135 | for i := 0; i < b.N; i++ {
136 | _, err := transmitRequest(ioutil.Discard, &req)
137 | if err != nil {
138 | b.Fatalf("Error transmitting request: %v", err)
139 | }
140 | }
141 | }
142 |
143 | /*
144 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|
145 | +---------------+---------------+---------------+---------------+
146 | 0| 0x81 | 0x00 | 0x00 | 0x00 |
147 | +---------------+---------------+---------------+---------------+
148 | 4| 0x04 | 0x00 | 0x00 | 0x00 |
149 | +---------------+---------------+---------------+---------------+
150 | 8| 0x00 | 0x00 | 0x00 | 0x09 |
151 | +---------------+---------------+---------------+---------------+
152 | 12| 0x00 | 0x00 | 0x00 | 0x00 |
153 | +---------------+---------------+---------------+---------------+
154 | 16| 0x00 | 0x00 | 0x00 | 0x00 |
155 | +---------------+---------------+---------------+---------------+
156 | 20| 0x00 | 0x00 | 0x00 | 0x01 |
157 | +---------------+---------------+---------------+---------------+
158 | 24| 0xde | 0xad | 0xbe | 0xef |
159 | +---------------+---------------+---------------+---------------+
160 | 28| 0x57 ('W') | 0x6f ('o') | 0x72 ('r') | 0x6c ('l') |
161 | +---------------+---------------+---------------+---------------+
162 | 32| 0x64 ('d') |
163 | +---------------+
164 |
165 | Field (offset) (value)
166 | Magic (0) : 0x81
167 | Opcode (1) : 0x00
168 | Key length (2,3) : 0x0000
169 | Extra length (4) : 0x04
170 | Data type (5) : 0x00
171 | Status (6,7) : 0x0000
172 | Total body (8-11) : 0x00000009
173 | Opaque (12-15): 0x00000000
174 | CAS (16-23): 0x0000000000000001
175 | Extras :
176 | Flags (24-27): 0xdeadbeef
177 | Key : None
178 | Value (28-32): The textual string "World"
179 |
180 | */
181 |
182 | func TestDecodeSpecSample(t *testing.T) {
183 | data := []byte{
184 | 0x81, 0x00, 0x00, 0x00, // 0
185 | 0x04, 0x00, 0x00, 0x00, // 4
186 | 0x00, 0x00, 0x00, 0x09, // 8
187 | 0x00, 0x00, 0x00, 0x00, // 12
188 | 0x00, 0x00, 0x00, 0x00, // 16
189 | 0x00, 0x00, 0x00, 0x01, // 20
190 | 0xde, 0xad, 0xbe, 0xef, // 24
191 | 0x57, 0x6f, 0x72, 0x6c, // 28
192 | 0x64, // 32
193 | }
194 |
195 | buf := make([]byte, HDR_LEN)
196 | res, _, err := getResponse(bytes.NewReader(data), buf)
197 | if err != nil {
198 | t.Fatalf("Error parsing response: %v", err)
199 | }
200 |
201 | expected := &Response{
202 | Opcode: GET,
203 | Status: 0,
204 | Opaque: 0,
205 | Cas: 1,
206 | Extras: []byte{0xde, 0xad, 0xbe, 0xef},
207 | Body: []byte("World"),
208 | }
209 |
210 | if !reflect.DeepEqual(res, expected) {
211 | t.Fatalf("Expected\n%#v -- got --\n%#v", expected, res)
212 | }
213 | assert.Nil(t, UnwrapMemcachedError(err), "UnwrapMemcachedError: should be return nil for success getResponse")
214 | }
215 |
216 | func TestNilReader(t *testing.T) {
217 | res, _, err := getResponse(nil, nil)
218 | if !errors.Is(err, ErrNoServers) {
219 | t.Fatalf("Expected error reading from nil, got %#v", res)
220 | }
221 | }
222 |
223 | func TestNilConfig(t *testing.T) {
224 | mcl, err := InitFromEnv()
225 | assert.Nil(t, mcl, "InitFromEnv without config should be return nil client")
226 | assert.ErrorIs(t, err, ErrNotConfigured, "InitFromEnv without config should be return error == ErrNotConfigured")
227 | }
228 |
229 | func TestErrWrap(t *testing.T) {
230 | type args struct {
231 | resp *Response
232 | }
233 | tests := []struct {
234 | name string
235 | args args
236 | wantErr error
237 | }{
238 | {
239 | name: ENOMEM.String(),
240 | args: args{resp: &Response{
241 | Status: ENOMEM,
242 | }},
243 | wantErr: ErrServerError,
244 | },
245 | {
246 | name: TMPFAIL.String(),
247 | args: args{resp: &Response{
248 | Status: TMPFAIL,
249 | }},
250 | wantErr: ErrServerNotAvailable,
251 | },
252 | {
253 | name: UNKNOWN_COMMAND.String(),
254 | args: args{resp: &Response{
255 | Status: UNKNOWN_COMMAND,
256 | }},
257 | wantErr: ErrUnknownCommand,
258 | },
259 | }
260 | for _, tt := range tests {
261 | t.Run(tt.name, func(t *testing.T) {
262 | wrapErr := wrapMemcachedResp(tt.args.resp)
263 | require.ErrorIs(t, wrapErr, tt.wantErr, "wrapMemcachedResp wrap error not equal expected")
264 | })
265 | }
266 | }
267 |
268 | func TestDecode(t *testing.T) {
269 | data := []byte{
270 | RES_MAGIC, byte(SET),
271 | 0x0, 0x7, // length of key
272 | 0x0, // extra length
273 | 0x0, // reserved
274 | 0x6, 0x2e, // status
275 | 0x0, 0x0, 0x0, 0x10, // Length of value
276 | 0x0, 0x0, 0x1c, 0x4a, // opaque
277 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS
278 | 's', 'o', 'm', 'e', 'k', 'e', 'y',
279 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e',
280 | }
281 |
282 | buf := make([]byte, HDR_LEN)
283 | res, _, _ := getResponse(bytes.NewReader(data), buf)
284 |
285 | expected := &Response{
286 | Opcode: SET,
287 | Status: 1582,
288 | Opaque: 7242,
289 | Cas: 938424885,
290 | Extras: nil,
291 | Key: []byte("somekey"),
292 | Body: []byte("somevalue"),
293 | }
294 |
295 | if !reflect.DeepEqual(res, expected) {
296 | t.Fatalf("Expected\n%#v -- got --\n%#v", expected, res)
297 | }
298 | }
299 |
300 | func BenchmarkDecodeResponse(b *testing.B) {
301 | data := []byte{
302 | RES_MAGIC, byte(SET),
303 | 0x0, 0x7, // length of key
304 | 0x0, // extra length
305 | 0x0, // reserved
306 | 0x6, 0x2e, // status
307 | 0x0, 0x0, 0x0, 0x10, // Length of value
308 | 0x0, 0x0, 0x1c, 0x4a, // opaque
309 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS
310 | 's', 'o', 'm', 'e', 'k', 'e', 'y',
311 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e',
312 | }
313 | buf := make([]byte, HDR_LEN)
314 | b.SetBytes(int64(len(buf)))
315 |
316 | for i := 0; i < b.N; i++ {
317 | getResponse(bytes.NewReader(data), buf)
318 | }
319 | }
320 |
321 | const localhostTCPAddr = "localhost:11211"
322 |
323 | func TestLocalhost(t *testing.T) {
324 | t.Parallel()
325 | c, err := net.Dial("tcp", localhostTCPAddr)
326 | if err != nil {
327 | t.Skipf("skipping test; no server running at %s", localhostTCPAddr)
328 | }
329 | req := Request{
330 | Opcode: VERSION,
331 | }
332 |
333 | _, err = transmitRequest(c, &req)
334 | if err != nil {
335 | t.Errorf("Expected errNoConn with no conn, got %v", err)
336 | }
337 |
338 | buf := make([]byte, HDR_LEN)
339 | resp, _, err := getResponse(c, buf)
340 | if err != nil {
341 | t.Fatalf("Error transmitting request: %v", err)
342 | }
343 |
344 | if resp.Status != SUCCESS {
345 | t.Errorf("Expected SUCCESS, got %v", resp.Status)
346 | }
347 | if err = c.Close(); err != nil {
348 | t.Fatalf("Error with close connection: %v", err)
349 | }
350 |
351 | _, err = newForTests("invalidServerAddr")
352 | require.ErrorIs(t, err, ErrInvalidAddr)
353 |
354 | mc, err := newForTests(localhostTCPAddr)
355 | if err != nil {
356 | t.Fatalf("failed to create new client: %v", err)
357 | }
358 | t.Cleanup(mc.CloseAllConns)
359 | testWithClient(t, mc)
360 | }
361 |
362 | func testWithClient(t *testing.T, c *Client) {
363 | resp, err := c.Store(Set, "bigdata", 0, make([]byte, MaxBodyLen+1))
364 | assert.ErrorIsf(t, err, ErrDataSizeExceedsLimit, "Store: body > MaxBodyLen, want error ErrDataSizeExceedsLimit")
365 | unwrapResp := UnwrapMemcachedError(err)
366 | if !reflect.DeepEqual(resp, unwrapResp) {
367 | t.Fatalf("Expected\n%#v -- got --\n%#v", resp, unwrapResp)
368 | }
369 |
370 | // multi
371 | err = c.MultiStore(Set, map[string][]byte{}, 0)
372 | assert.Nil(t, err, "MultiStore with 0 items should have no errors")
373 | items, err := c.MultiGet([]string{})
374 | assert.Nil(t, err, "MultiGet with 0 keys should have no errors")
375 | assert.Empty(t, items, "MultiGet with 0 keys should return empty map")
376 | err = c.MultiDelete([]string{})
377 | assert.Nil(t, err, "MultiDelete with 0 keys should have no errors")
378 |
379 | // Set
380 | _, err = c.Store(Set, "foo", 0, []byte("fooval-fromset1"))
381 | assert.Nilf(t, err, "first set(foo): %v", err)
382 | _, err = c.Store(Set, "foo", 0, []byte("fooval-fromset2"))
383 | assert.Nilf(t, err, "second set(foo): %v", err)
384 | // Add
385 | _, err = c.Store(Add, "foo", 0, []byte("fooval-fromset3"))
386 | assert.ErrorIsf(t, err, ErrNotStored, "Add with exist key - %s, want error - ErrNotStored, have - %v", "foo", err)
387 |
388 | // Get
389 | resp, err = c.Get("foo")
390 | assert.Nilf(t, err, "get(foo): %v", err)
391 | // assert.Equalf(t, []byte("foo"), resp.Key, "get(foo) Key = %s, want foo", string(resp.Key)) only for GETK
392 | assert.Equalf(t, []byte("fooval-fromset2"), resp.Body, "get(foo) Body = %s, want fooval-fromset2", string(resp.Body))
393 | err = wrapMemcachedResp(resp)
394 | assert.Nil(t, err, "Get: wrapped success resp should be nil")
395 |
396 | // Get and set a unicode key
397 | quxKey := "Hello_世界"
398 | _, err = c.Store(Set, quxKey, 0, []byte("hello world"))
399 | assert.Nilf(t, err, "first set(Hello_世界): %v", err)
400 | resp, err = c.Get(quxKey)
401 | assert.Nilf(t, err, "get(Hello_世界): %v", err)
402 | // assert.Equalf(t, quxKey, string(resp.Key), "get(Hello_世界) Key = %q, want Hello_世界", quxKey) only for GETK
403 | assert.Equalf(t, "hello world", string(resp.Body), "get(Hello_世界) Value = %q, want hello world", string(resp.Body))
404 |
405 | // Set malformed keys
406 | _, err = c.Store(Set, "foo bar", 0, []byte("foobarval"))
407 | assert.ErrorIsf(t, err, ErrMalformedKey, "set(foo bar) should return ErrMalformedKey instead of %v", err)
408 | _, err = c.Store(Set, "foo"+string(rune(0x7f)), 0, []byte("foobarval"))
409 | assert.ErrorIsf(t, err, ErrMalformedKey, "set(foo<0x7f>) should return ErrMalformedKey instead of %v", err)
410 |
411 | // Append
412 | _, err = c.Append(Append, "append", []byte("appendval"))
413 | assert.ErrorIsf(t, err, ErrNotStored, "first append(append) want ErrNotStored, got %v", err)
414 |
415 | _, err = c.Store(Set, "append", 0, []byte("appendval"))
416 | assert.Nilf(t, err, "Set for append have error - %v", err)
417 | _, err = c.Append(Append, "append", []byte("1"))
418 | assert.Nilf(t, err, "second append(append): %v", err)
419 | appended, err := c.Get("append")
420 | assert.Nilf(t, err, "after append(append): %v", err)
421 | assert.Equalf(t, fmt.Sprintf("%s%s", "appendval", "1"), string(appended.Body),
422 | "Append: want=append1, got=%s", string(appended.Body))
423 |
424 | // Prepend
425 | _, err = c.Append(Prepend, "prepend", []byte("prependval"))
426 | assert.ErrorIsf(t, err, ErrNotStored, "first prepend(prepend) want ErrNotStored, got %v", err)
427 |
428 | _, err = c.Store(Set, "prepend", 0, []byte("prependval"))
429 | assert.Nilf(t, err, "Set for prepend have error - %v", err)
430 | _, err = c.Append(Prepend, "prepend", []byte("1"))
431 | assert.Nilf(t, err, "second prepend(prepend): %v", err)
432 | prepend, err := c.Get("prepend")
433 | assert.Nilf(t, err, "after prepend(prepend): %v", err)
434 | assert.Equalf(t, fmt.Sprintf("%s%s", "1", "prependval"), string(prepend.Body),
435 | "Prepend: want=1prependval, got=%s", string(prepend.Body))
436 |
437 | // Replace
438 | _, err = c.Store(Replace, "baz", 0, []byte("bazvalue"))
439 | assert.ErrorIsf(t, err, ErrCacheMiss, "expected replace(baz) to return ErrCacheMiss, got %v", err)
440 | _, err = c.Store(Set, "baz", 0, []byte("bazvalue"))
441 | assert.Nilf(t, err, "Set for Replace have error - %v", err)
442 | resp, err = c.Store(Replace, "baz", 0, []byte("42"))
443 | assert.Nilf(t, err, "Replace have error - %v", err)
444 | resp, err = c.Get("baz")
445 | assert.Nilf(t, err, "Get for Replace have error - %v", err)
446 | assert.Equalf(t, "42", string(resp.Body), "Resp after replaces want - 42, have - %s", string(resp.Body))
447 |
448 | // Incr/Decr
449 | _, err = c.Store(Set, "num", 0, []byte("42"))
450 | assert.Nilf(t, err, "Set for Increment have error - %v", err)
451 | n, err := c.Delta(Increment, "num", 8, 0, 0)
452 | assert.Nilf(t, err, "Increment num + 8: %v", err)
453 | assert.Equalf(t, 50, int(n), "Increment num + 8: want=50, got=%d", n)
454 | n, err = c.Delta(Decrement, "num", 49, 0, 0)
455 | assert.Nilf(t, err, "Decrement: %v", err)
456 | assert.Equalf(t, 1, int(n), "Decrement 49: want=1, got=%d", n)
457 | _, err = c.Delete("num")
458 | assert.Nilf(t, err, "Delete for Increment/Decrement have error - %v", err)
459 | n, err = c.Delta(Increment, "num", 1, 10, 0)
460 | assert.Nilf(t, err, "Increment with initial value have error - %v", err)
461 | assert.Equalf(t, 10, int(n), "Increment with initial value 10: want=10, got=%d", n)
462 | n, err = c.Delta(Decrement, "num", 2, 0, 0)
463 | assert.Nilf(t, err, "Increment with initial value have error - %v", err)
464 | assert.Equalf(t, 8, int(n), "Increment with initial value 1: want=8, got=%d", n)
465 | const fakeDeltaMode = DeltaMode(42)
466 | n, err = c.Delta(fakeDeltaMode, "num", 2, 0, 0)
467 | assert.Nilf(t, err, "Increment with fakeDeltaMode have error - %v", err)
468 |
469 | _, err = c.Store(Set, "num", 0, []byte("not-numeric"))
470 | assert.Nilf(t, err, "Set for Increment non-numeric value have error - %v", err)
471 | _, err = c.Delta(Increment, "num", 1, 0, 0)
472 | assert.ErrorIs(t, err, ErrInvalidArguments, "Increment not-numeric value")
473 |
474 | // Delete
475 | _, err = c.Delete("foo")
476 | assert.Nilf(t, err, "Delete: %v", err)
477 | _, err = c.Get("foo")
478 | assert.ErrorIsf(t, err, ErrCacheMiss, "post-Delete want ErrCacheMiss, got %v", err)
479 |
480 | testExpireWithClient(t, c)
481 |
482 | // MutliGet
483 | // Create some test items.
484 | keys := []string{"foo", "bar", "gopher", "42"}
485 | input := make(map[string][]byte, len(keys))
486 |
487 | addKeys := func() {
488 | for i, key := range keys {
489 | body := []byte(key + strconv.Itoa(i))
490 | _, err = c.Store(Set, key, 0, body)
491 | assert.Nilf(t, err, "Store for MutliGet have error - %v", err)
492 | input[key] = body
493 | }
494 | }
495 |
496 | checkKeyOnExist := func(method string, input map[string][]byte, output map[string][]byte) {
497 | for key, reqBody := range input {
498 | if respBody, ok := output[key]; ok {
499 | assert.Equalf(t, reqBody, respBody, "%s. Request and response body not equal, have - %v, want - %v", method, respBody, reqBody)
500 | } else {
501 | t.Errorf("%s. Don't found requset key %v in response", method, key)
502 | }
503 | }
504 | }
505 |
506 | _, err = c.MultiGet(append(keys, invalidKey))
507 | assert.ErrorIsf(t, err, ErrMalformedKey, "MultiGet: invalid key, want error ErrMalformedKey")
508 |
509 | addKeys()
510 | output, err := c.MultiGet(keys)
511 | assert.Nilf(t, err, "MultiGet have error: %v", err)
512 | if len(input) != len(output) {
513 | t.Errorf("want %d items after MultiGet, have %d", len(input), len(output))
514 | } else {
515 | checkKeyOnExist("MultiGet", input, output)
516 | }
517 |
518 | // remove one key from cache
519 | _, err = c.Delete(keys[0])
520 | assert.Nilf(t, err, "Delete for MultiGet have error: %v", err)
521 | output, err = c.MultiGet(keys)
522 | assert.Nilf(t, err, "MultiGet after delete one elem have error: %v", err)
523 | if len(input)-1 != len(output) {
524 | t.Errorf("want %d items after MultiStore, have %d", len(input)-1, len(output))
525 | }
526 |
527 | // MutliStore
528 | inputMStore := map[string][]byte{
529 | "foo42": []byte("bar"),
530 | "hello": []byte("world"),
531 | "go": []byte("gopher"),
532 | }
533 | inputMStoreExp := map[string][]byte{
534 | "exp": []byte("needDelete"),
535 | }
536 | inputExp := uint32(1)
537 |
538 | err = c.MultiStore(Set, inputMStore, 0)
539 | assert.Nilf(t, err, "MultiStore have error: %v", err)
540 | err = c.MultiStore(Set, inputMStoreExp, inputExp)
541 | assert.Nilf(t, err, "MultiStore with exp have error: %v", err)
542 |
543 | time.Sleep(time.Second)
544 | keyWithExp := maps.Keys(inputMStoreExp)[0]
545 | _, err = c.Get(keyWithExp)
546 | assert.ErrorIsf(t, err, ErrCacheMiss, "Get for item with 1 sec experetion setted in MultiStore. want - %v, have - %v", ErrCacheMiss, err)
547 |
548 | keysInputMStore := maps.Keys(inputMStore)
549 | outputMStoreOne, err := c.Get(keysInputMStore[0])
550 | assert.Nilf(t, err, "Get for MultiStore have error: %v", err)
551 | assert.NotNil(t, outputMStoreOne.Body, "Get after MultiStore gets item without body")
552 | outputMStore, err := c.MultiGet(keysInputMStore)
553 | assert.Nilf(t, err, "MultiGet for MultiStore have error: %v", err)
554 | checkKeyOnExist("MultiStore", inputMStore, outputMStore)
555 |
556 | singleMStore, err := c.MultiGet([]string{keysInputMStore[0]})
557 | assert.Nilf(t, err, "MultiGet with 1 item have error: %v", err)
558 | for key, body := range singleMStore {
559 | assert.Equal(t, keysInputMStore[0], key, "MultiGet with 1 item not equals keys")
560 | assert.Equal(t, inputMStore[key], body, "MultiGet with 1 item not equals body")
561 | }
562 |
563 | // Test Flush All
564 | err = c.FlushAll(0)
565 | assert.Nilf(t, err, "FlushAll: %v", err)
566 | _, err = c.Get("bar")
567 | assert.ErrorIsf(t, err, ErrCacheMiss, "post-FlushAll want ErrCacheMiss, got %v", err)
568 | }
569 |
570 | func testExpireWithClient(t *testing.T, c *Client) {
571 | if testing.Short() {
572 | t.Log("Skipping testing memcached Touch with testing in Short mode")
573 | return
574 | }
575 |
576 | const secondsToExpiry = uint32(1)
577 |
578 | _, err := c.Store(Set, "foo", secondsToExpiry, []byte("fooval"))
579 | assert.Nilf(t, err, "Store(Set) with expire have error - %v", err)
580 | _, err = c.Store(Add, "bar", secondsToExpiry, []byte("barval"))
581 | assert.Nilf(t, err, "Store(Add) with expire have error - %v", err)
582 |
583 | time.Sleep(time.Second)
584 |
585 | _, err = c.Get("foo")
586 | assert.ErrorIsf(t, err, ErrCacheMiss, "Get for expire item - %v", err)
587 |
588 | _, err = c.Get("bar")
589 | assert.ErrorIsf(t, err, ErrCacheMiss, "Get for expire item - %v", err)
590 | }
591 |
592 | func TestLocalhost_FlushAll_MultiDelete(t *testing.T) {
593 | c, err := net.Dial("tcp", localhostTCPAddr)
594 | if err != nil {
595 | t.Skipf("skipping test; no server running at %s", localhostTCPAddr)
596 | }
597 | req := Request{
598 | Opcode: VERSION,
599 | }
600 |
601 | _, err = transmitRequest(c, &req)
602 | if err != nil {
603 | t.Errorf("Expected errNoConn with no conn, got %v", err)
604 | }
605 |
606 | buf := make([]byte, HDR_LEN)
607 | resp, _, err := getResponse(c, buf)
608 | if err != nil {
609 | t.Fatalf("Error transmitting request: %v", err)
610 | }
611 |
612 | if resp.Status != SUCCESS {
613 | t.Errorf("Expected SUCCESS, got %v", resp.Status)
614 | }
615 | if err = c.Close(); err != nil {
616 | t.Fatalf("Error with close connection: %v", err)
617 | }
618 |
619 | mc, err := newForTests(localhostTCPAddr)
620 | if err != nil {
621 | t.Fatalf("failed to create new client: %v", err)
622 | }
623 | t.Cleanup(mc.CloseAllConns)
624 |
625 | keys := []string{"foo", "bar", "gopher", "42"}
626 |
627 | addKeys := func() {
628 | for i, key := range keys {
629 | _, err = mc.Store(Set, key, 0, []byte(key+strconv.Itoa(i)))
630 | assert.Nil(t, err, fmt.Sprintf("Fail to Store item with key - %s", key))
631 | }
632 | }
633 |
634 | checkKeyOnExist := func(meth string) {
635 | for _, key := range keys {
636 | _, err = mc.Get(key)
637 | assert.ErrorIsf(t, err, ErrCacheMiss, "Get item after %s. want - %v, have - %v", meth, ErrCacheMiss, err)
638 | }
639 | }
640 |
641 | addKeys()
642 | err = mc.MultiDelete(append(keys, "fake"))
643 | assert.Nil(t, err, "MultiDelete")
644 | checkKeyOnExist("MultiDelete")
645 |
646 | addKeys()
647 | err = mc.FlushAll(0)
648 | assert.Nil(t, err, "FlushAll")
649 | checkKeyOnExist("FlushAll")
650 | }
651 |
652 | func TestClient_CloseAvailableConnsInAllShardPools(t *testing.T) {
653 | _, err := net.Dial("tcp", localhostTCPAddr)
654 | if err != nil {
655 | t.Skipf("skipping test; no server running at %s", localhostTCPAddr)
656 | }
657 | mc, err := newForTests(localhostTCPAddr)
658 | assert.Nilf(t, err, "failed to create new client: %v", err)
659 | t.Cleanup(mc.CloseAllConns)
660 |
661 | // for create conns in pool
662 | var wg sync.WaitGroup
663 | wg.Add(1)
664 | go func() {
665 | defer wg.Done()
666 | _, err := mc.Store(Set, "foo1", 0, []byte("bar"))
667 | assert.Nilf(t, err, "Set foo1: %v", err)
668 | }()
669 | wg.Add(1)
670 | go func() {
671 | defer wg.Done()
672 | _, err := mc.Store(Set, "foo2", 0, []byte("bar"))
673 | assert.Nilf(t, err, "Set foo2: %v", err)
674 | }()
675 | wg.Add(1)
676 | go func() {
677 | defer wg.Done()
678 | _, err := mc.Store(Set, "foo3", 0, []byte("bar"))
679 | assert.Nilf(t, err, "Set foo3: %v", err)
680 | }()
681 | wg.Add(1)
682 | go func() {
683 | defer wg.Done()
684 | _, err := mc.Store(Set, "foo4", 0, []byte("bar"))
685 | assert.Nilf(t, err, "Set foo4: %v", err)
686 | }()
687 |
688 | wg.Wait()
689 |
690 | addr, err := utils.AddrRepr(localhostTCPAddr)
691 | assert.Nilf(t, err, "AddrRepr: %v", err)
692 |
693 | pool, ok := mc.safeGetFreeConn(addr)
694 | assert.Truef(t, ok, "Get from freeConns not found pool for %s", addr.String())
695 |
696 | l := pool.Len()
697 |
698 | numOfClose := 1
699 | c := mc.CloseAvailableConnsInAllShardPools(numOfClose)
700 | assert.Equal(t, numOfClose, c, "Request for closed not equal actual")
701 |
702 | assert.Equalf(t, l-numOfClose, pool.Len(), "Resulting pool len not equal expected number")
703 | }
704 |
705 | func TestConn(t *testing.T) {
706 | c, err := net.DialTimeout("tcp", localhostTCPAddr, time.Second)
707 | if err != nil {
708 | t.Skipf("skipping test; no server running at %s", localhostTCPAddr)
709 | }
710 | req := Request{
711 | Opcode: VERSION,
712 | }
713 |
714 | n, err := transmitRequest(c, &req)
715 | if err != nil {
716 | t.Errorf("Expected errNoConn with no conn, got %v", err)
717 | }
718 |
719 | buf := make([]byte, HDR_LEN)
720 | resp, _, err := getResponse(c, buf)
721 | if err != nil {
722 | t.Fatalf("Error transmitting request: %v", err)
723 | }
724 |
725 | if n != len(buf) {
726 | t.Errorf("write bytes - %d != read bytes - %d\n", n, len(buf))
727 | }
728 | if resp.Status != SUCCESS {
729 | t.Errorf("Expected SUCCESS, got %v", resp.Status)
730 | }
731 | if err = c.Close(); err != nil {
732 | t.Fatalf("Error with close connection: %v", err)
733 | }
734 | }
735 |
736 | func TestClient_Getters(t *testing.T) {
737 | type fields struct {
738 | timeout time.Duration
739 | maxIdleConns int
740 | nodeHCPeriod time.Duration
741 | nodeRBPeriod time.Duration
742 | }
743 | tests := []struct {
744 | name string
745 | fields fields
746 | wantTimeout time.Duration
747 | wantMaxIdleConns int
748 | wantNodeHCPeriod time.Duration
749 | wantNodeRBPeriod time.Duration
750 | }{
751 | {
752 | name: "Default",
753 | fields: fields{},
754 | wantTimeout: DefaultTimeout,
755 | wantMaxIdleConns: DefaultMaxIdleConns,
756 | wantNodeHCPeriod: DefaultNodeHealthCheckPeriod,
757 | wantNodeRBPeriod: DefaultRebuildingNodePeriod,
758 | },
759 | {
760 | name: "Custom",
761 | fields: fields{
762 | timeout: 5 * time.Second,
763 | maxIdleConns: 50,
764 | nodeHCPeriod: time.Second,
765 | nodeRBPeriod: time.Second,
766 | },
767 | wantTimeout: 5 * time.Second,
768 | wantMaxIdleConns: 50,
769 | wantNodeHCPeriod: time.Second,
770 | wantNodeRBPeriod: time.Second,
771 | },
772 | }
773 | for _, tt := range tests {
774 | t.Run(tt.name, func(t *testing.T) {
775 | c := &Client{
776 | timeout: tt.fields.timeout,
777 | maxIdleConns: tt.fields.maxIdleConns,
778 | nodeHCPeriod: tt.fields.nodeHCPeriod,
779 | nodeRBPeriod: tt.fields.nodeRBPeriod,
780 | }
781 | assert.Equalf(t, tt.wantTimeout, c.netTimeout(), "netTimeout()")
782 | assert.Equalf(t, tt.wantMaxIdleConns, c.getMaxIdleConns(), "getMaxIdleConns()")
783 | assert.Equalf(t, tt.wantNodeHCPeriod, c.getHCPeriod(), "getHCPeriod()")
784 | assert.Equalf(t, tt.wantNodeRBPeriod, c.getRBPeriod(), "getRBPeriod()")
785 | })
786 | }
787 | }
788 |
789 | func TestMethodsErrors(t *testing.T) {
790 | c := &Client{
791 | hr: consistenthash.NewHashRing(),
792 | disableMemcachedDiagnostic: true,
793 | }
794 |
795 | // invalid key
796 | _, err := c.Store(Set, invalidKey, 0, []byte("foo"))
797 | assert.ErrorIsf(t, err, ErrMalformedKey, "Store: invalid key, want error ErrMalformedKey")
798 | _, err = c.Get(invalidKey)
799 | assert.ErrorIsf(t, err, ErrMalformedKey, "Get: invalid key, want error ErrMalformedKey")
800 | _, err = c.Delete(invalidKey)
801 | assert.ErrorIsf(t, err, ErrMalformedKey, "Delete: invalid key, want error ErrMalformedKey")
802 | _, err = c.Delta(Increment, invalidKey, 1, 0, 0)
803 | assert.ErrorIsf(t, err, ErrMalformedKey, "Delta: invalid key, want error ErrMalformedKey")
804 | _, err = c.Append(Append, invalidKey, []byte("foo"))
805 | assert.ErrorIsf(t, err, ErrMalformedKey, "Append: invalid key, want error ErrMalformedKey")
806 | _, err = c.MultiGet([]string{invalidKey, "foo", "bar"})
807 | assert.ErrorIsf(t, err, ErrMalformedKey, "MultiGet: invalid key, want error ErrMalformedKey")
808 | err = c.MultiDelete([]string{invalidKey, "foo", "bar"})
809 | assert.ErrorIsf(t, err, ErrMalformedKey, "MultiDelete: invalid key, want error ErrMalformedKey")
810 | err = c.MultiStore(Set, map[string][]byte{"foo": []byte("bar"), invalidKey: []byte("data")}, 0)
811 | assert.ErrorIsf(t, err, ErrMalformedKey, "MultiDelete: invalid key, want error ErrMalformedKey")
812 |
813 | // empty hash ring
814 | _, err = c.Store(Set, "store", 0, []byte("foo"))
815 | assert.ErrorIsf(t, err, ErrNoServers, "Store: with empty hash ring, want error ErrNoServers")
816 | _, err = c.Get("get")
817 | assert.ErrorIsf(t, err, ErrNoServers, "Get: with empty hash ring, want error ErrNoServers")
818 | _, err = c.Delete("delete")
819 | assert.ErrorIsf(t, err, ErrNoServers, "Delete: with empty hash ring, want error ErrNoServers")
820 | _, err = c.Delta(Increment, "deltaInc", 1, 0, 0)
821 | assert.ErrorIsf(t, err, ErrNoServers, "Delta: with empty hash ring, want error ErrNoServers")
822 | _, err = c.Append(Append, "append", []byte("foo"))
823 | assert.ErrorIsf(t, err, ErrNoServers, "Append: with empty hash ring, want error ErrNoServers")
824 |
825 | // add invalid node
826 | c.hr.Add("node1")
827 |
828 | // invalid node
829 | _, err = c.Store(Set, "store", 0, []byte("foo"))
830 | assert.ErrorIsf(t, err, ErrInvalidAddr, "Store: invalid node, want error ErrInvalidAddr")
831 | _, err = c.Get("get")
832 | assert.ErrorIsf(t, err, ErrInvalidAddr, "Get: invalid node, want error ErrInvalidAddr")
833 | _, err = c.Delete("delete")
834 | assert.ErrorIsf(t, err, ErrInvalidAddr, "Delete: invalid node, want error ErrInvalidAddr")
835 | _, err = c.Delta(Increment, "deltaInc", 1, 0, 0)
836 | assert.ErrorIsf(t, err, ErrInvalidAddr, "Delta: invalid node, want error ErrInvalidAddr")
837 | _, err = c.Append(Append, "append", []byte("foo"))
838 | assert.ErrorIsf(t, err, ErrInvalidAddr, "Append: invalid node, want error ErrInvalidAddr")
839 | _, err = c.MultiGet([]string{"gopher", "foo", "bar"})
840 | assert.ErrorIsf(t, err, ErrInvalidAddr, "MutliGet: invalid node, want error ErrInvalidAddr")
841 | err = c.MultiDelete([]string{"gopher", "foo", "bar"})
842 | assert.ErrorIsf(t, err, ErrInvalidAddr, "MutliDelete: invalid node, want error ErrInvalidAddr")
843 | err = c.MultiStore(Set, map[string][]byte{"foo": []byte("bar"), "data": []byte("data")}, 0)
844 | assert.ErrorIsf(t, err, ErrInvalidAddr, "MutliStore: invalid node, want error ErrInvalidAddr")
845 |
846 | var (
847 | mockNetworkHeadlessErr = new(MockNetworkOperations)
848 |
849 | expectedErr = errors.New("mocked dial error")
850 |
851 | headlessServiceAddress = "example.com"
852 | )
853 | mockNetworkHeadlessErr.On("LookupHost", headlessServiceAddress).Return(nil, expectedErr)
854 |
855 | op := &options{
856 | Client: Client{
857 | nw: &network{lookupHost: mockNetworkHeadlessErr.LookupHost},
858 | cfg: &config{HeadlessServiceAddress: headlessServiceAddress},
859 | },
860 | }
861 |
862 | _, err = newFromConfig(op)
863 | assert.ErrorIs(t, err, ErrInvalidAddr)
864 |
865 | mockNetworkNodeErr := new(MockNetworkOperations)
866 | mockNetworkNodeErr.On("LookupHost", headlessServiceAddress).Return([]string{"wrongNode"}, nil)
867 |
868 | op = &options{
869 | Client: Client{
870 | nw: &network{lookupHost: mockNetworkNodeErr.LookupHost},
871 | cfg: &config{HeadlessServiceAddress: headlessServiceAddress},
872 | },
873 | }
874 |
875 | _, err = newFromConfig(op)
876 | assert.ErrorIs(t, err, ErrInvalidAddr)
877 | }
878 |
879 | const invalidKey = `Loremipsumdolorsitamet,consecteturadipiscingelit.Velelitvoluptateeleifendquisproidentnonfeugaitiriureliberminimveniamillumcupiditataliquid,nihiltefeugiatlobortiseleifendnibhproidenttationatoptionesseconsectetuerdeserunt.Gubergrenveroidsolutaquis.Dignissimlobortisloremveroenimrebumconsetetur.`
880 |
--------------------------------------------------------------------------------
/memcached/metrics.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "github.com/prometheus/client_golang/prometheus"
5 | )
6 |
7 | const (
8 | methodNameLabel = "method_name"
9 | isSuccessfulLabel = "is_successful"
10 | )
11 |
12 | var (
13 | methodDurationSeconds = func() *prometheus.HistogramVec {
14 | return prometheus.NewHistogramVec(prometheus.HistogramOpts{
15 | Namespace: "",
16 | Name: "gomemcached_method_duration_seconds",
17 | Help: "counts the execution time of successful and failed gomemcached methods",
18 | Buckets: []float64{
19 | 0.0005, 0.001, 0.005, 0.007, 0.015, 0.05, 0.1, 0.2, 0.5, 1,
20 | },
21 | }, []string{
22 | methodNameLabel,
23 | isSuccessfulLabel,
24 | })
25 | }()
26 | )
27 |
28 | // observeMultiMethodDurationSeconds is observing the duration of a method.
29 | func observeMethodDurationSeconds(methodName string, duration float64, isSuccessful bool) {
30 | flag := "0"
31 | if isSuccessful {
32 | flag = "1"
33 | }
34 |
35 | methodDurationSeconds.
36 | WithLabelValues(methodName, flag).
37 | Observe(duration)
38 | }
39 |
--------------------------------------------------------------------------------
/memcached/metrics_test.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "testing"
5 | "time"
6 |
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func Test_observeMethodDurationSeconds(t *testing.T) {
11 | type args struct {
12 | methodName string
13 | duration float64
14 | isSuccessful bool
15 | }
16 | tests := []struct {
17 | name string
18 | args args
19 | }{
20 | {
21 | name: "60 true",
22 | args: args{
23 | methodName: "TestMeth",
24 | duration: 60 * time.Millisecond.Seconds(),
25 | isSuccessful: true,
26 | },
27 | },
28 | {
29 | name: "15 true",
30 | args: args{
31 | methodName: "TestMeth",
32 | duration: 15 * time.Millisecond.Seconds(),
33 | isSuccessful: true,
34 | },
35 | },
36 | {
37 | name: "39 true",
38 | args: args{
39 | methodName: "TestMeth",
40 | duration: 39 * time.Millisecond.Seconds(),
41 | isSuccessful: true,
42 | },
43 | },
44 | {
45 | name: "100 false",
46 | args: args{
47 | methodName: "TestMeth",
48 | duration: 100 * time.Millisecond.Seconds(),
49 | isSuccessful: false,
50 | },
51 | },
52 | {
53 | name: "66 true",
54 | args: args{
55 | methodName: "TestMeth",
56 | duration: 66 * time.Millisecond.Seconds(),
57 | isSuccessful: true,
58 | },
59 | },
60 | {
61 | name: "11 false",
62 | args: args{
63 | methodName: "TestMeth",
64 | duration: 11 * time.Millisecond.Seconds(),
65 | isSuccessful: false,
66 | },
67 | },
68 | }
69 | for _, tt := range tests {
70 | t.Run(tt.name, func(t *testing.T) {
71 | observeMethodDurationSeconds(tt.args.methodName, tt.args.duration, tt.args.isSuccessful)
72 |
73 | var success = "0"
74 | if tt.args.isSuccessful {
75 | success = "1"
76 | }
77 |
78 | _, err := methodDurationSeconds.GetMetricWith(map[string]string{methodNameLabel: tt.args.methodName, isSuccessfulLabel: success})
79 | assert.Nil(t, err, "GetMetricWith: returned error is not nil - %v", err)
80 | })
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/memcached/node_provider.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "errors"
5 | "net"
6 | "slices"
7 | "strconv"
8 | "sync"
9 | "time"
10 |
11 | "golang.org/x/exp/maps"
12 |
13 | "github.com/aliexpressru/gomemcached/logger"
14 | "github.com/aliexpressru/gomemcached/utils"
15 | )
16 |
17 | func (c *Client) initNodesProvider() {
18 | var (
19 | periodHC = c.getHCPeriod()
20 | tHC = time.NewTimer(periodHC)
21 |
22 | periodRB = c.getRBPeriod()
23 | tRB = time.NewTimer(periodRB)
24 | )
25 |
26 | if c.deadNodes == nil {
27 | c.deadNodes = make(map[string]struct{})
28 | }
29 |
30 | go func() {
31 | for {
32 | select {
33 | case <-tHC.C:
34 | c.checkNodesHealth()
35 | tHC.Reset(periodHC)
36 | case <-c.ctx.Done():
37 | tHC.Stop()
38 | return
39 | }
40 | }
41 | }()
42 | go func() {
43 | for {
44 | select {
45 | case <-tRB.C:
46 | c.rebuildNodes()
47 | tRB.Reset(periodRB)
48 | case <-c.ctx.Done():
49 | tRB.Stop()
50 | return
51 | }
52 | }
53 | }()
54 | }
55 |
56 | func (c *Client) checkNodesHealth() {
57 | currentNodes, err := getNodes(c.nw.lookupHost, c.cfg)
58 | if err != nil {
59 | logger.Warnf("%s: Error occurred while checking nodes health, getNodes error - %s", libPrefix, err.Error())
60 | return
61 | }
62 |
63 | recheckDeadNodes := func(node any) {
64 | sNode := utils.Repr(node)
65 | if !slices.Contains(currentNodes, sNode) {
66 | c.safeRemoveFromDeadNodes(sNode)
67 | return
68 | }
69 |
70 | if c.nodeIsDead(node) {
71 | c.safeAddToDeadNodes(sNode)
72 | } else {
73 | c.safeRemoveFromDeadNodes(sNode)
74 | }
75 | }
76 |
77 | wg := sync.WaitGroup{}
78 | for node := range c.safeGetDeadNodes() {
79 | wg.Add(1)
80 | go func(n string) {
81 | defer wg.Done()
82 | recheckDeadNodes(n)
83 | }(node)
84 | }
85 | wg.Wait()
86 |
87 | ringNodes := c.hr.GetAllNodes()
88 | for node := range c.safeGetDeadNodes() {
89 | ringNodes = slices.DeleteFunc(ringNodes, func(a any) bool { return utils.Repr(a) == node })
90 | }
91 |
92 | for _, node := range ringNodes {
93 | wg.Add(1)
94 | go func(n any) {
95 | defer wg.Done()
96 | if c.nodeIsDead(n) {
97 | sNode := utils.Repr(n)
98 | c.safeAddToDeadNodes(sNode)
99 | }
100 | }(node)
101 | }
102 |
103 | wg.Wait()
104 |
105 | deadNodes := c.safeGetDeadNodes()
106 | if len(deadNodes) != 0 {
107 | nodes := maps.Keys(deadNodes)
108 |
109 | logger.Warnf("%s: Dead nodes - %s", libPrefix, nodes)
110 |
111 | for _, node := range nodes {
112 | addr, cErr := utils.AddrRepr(node)
113 | if cErr != nil {
114 | continue
115 | }
116 | c.hr.Remove(addr)
117 | c.removeFromFreeConns(addr)
118 | }
119 | }
120 | }
121 |
122 | func (c *Client) rebuildNodes() {
123 | currentNodes, err := getNodes(c.nw.lookupHost, c.cfg)
124 | if err != nil {
125 | logger.Warnf("%s: Error occurred while rebuild nodes health, getNodes error - %s", libPrefix, err.Error())
126 | return
127 | }
128 | slices.Sort(currentNodes)
129 |
130 | for node := range c.safeGetDeadNodes() {
131 | currentNodes = slices.DeleteFunc(currentNodes, func(a string) bool { return a == node })
132 | }
133 |
134 | var nodesInRing []string
135 | for _, node := range c.hr.GetAllNodes() {
136 | nodesInRing = append(nodesInRing, utils.Repr(node))
137 | }
138 | slices.Sort(nodesInRing)
139 |
140 | var nodesToAdd []string
141 | for _, node := range currentNodes {
142 | if _, ok := slices.BinarySearch(nodesInRing, node); !ok {
143 | nodesToAdd = append(nodesToAdd, node)
144 | }
145 | }
146 |
147 | var nodesToRemove []string
148 | for _, node := range nodesInRing {
149 | if _, ok := slices.BinarySearch(currentNodes, node); !ok {
150 | nodesToRemove = append(nodesToRemove, node)
151 | }
152 | }
153 |
154 | if len(nodesToAdd) != 0 {
155 | for _, node := range nodesToAdd {
156 | addr, cErr := utils.AddrRepr(node)
157 | if cErr != nil {
158 | continue
159 | }
160 | c.hr.Add(addr)
161 | }
162 | }
163 |
164 | if len(nodesToRemove) != 0 {
165 | for _, node := range nodesToRemove {
166 | addr, cErr := utils.AddrRepr(node)
167 | if cErr != nil {
168 | continue
169 | }
170 | c.hr.Remove(addr)
171 | }
172 | }
173 |
174 | if !c.disableRefreshConns {
175 | _ = c.CloseAvailableConnsInAllShardPools(DefaultOfNumberConnsToDestroyPerRBPeriod)
176 | }
177 | }
178 |
179 | func (c *Client) nodeIsDead(node any) bool {
180 | addr, err := utils.AddrRepr(utils.Repr(node))
181 | if err != nil {
182 | return true
183 | }
184 |
185 | var (
186 | countRetry uint8
187 | cn net.Conn
188 | )
189 |
190 | for {
191 | cn, err = c.dial(addr)
192 | if err != nil {
193 | var tErr *ConnectTimeoutError
194 | if errors.As(err, &tErr) {
195 | if countRetry < DefaultRetryCountForConn {
196 | countRetry++
197 | continue
198 | }
199 | logger.Errorf("%s. Node health check failed. error - %s, with timeout - %d",
200 | ErrServerError.Error(), err.Error(), c.netTimeout(),
201 | )
202 | return true
203 | } else {
204 | logger.Errorf("%s. %s", ErrServerError.Error(), err.Error())
205 | return true
206 | }
207 | }
208 | _ = cn.Close()
209 | break
210 | }
211 |
212 | return false
213 | }
214 |
215 | func (c *Client) safeGetDeadNodes() map[string]struct{} {
216 | c.dmu.RLock()
217 | defer c.dmu.RUnlock()
218 | return maps.Clone(c.deadNodes)
219 | }
220 |
221 | func (c *Client) safeAddToDeadNodes(node string) {
222 | c.dmu.Lock()
223 | defer c.dmu.Unlock()
224 | c.deadNodes[node] = struct{}{}
225 | }
226 |
227 | func (c *Client) safeRemoveFromDeadNodes(node string) {
228 | c.dmu.Lock()
229 | defer c.dmu.Unlock()
230 | delete(c.deadNodes, node)
231 | }
232 |
233 | func getNodes(lookup func(host string) (addrs []string, err error), cfg *config) ([]string, error) {
234 | if cfg != nil {
235 | if cfg.HeadlessServiceAddress != "" {
236 | nodes, err := lookup(cfg.HeadlessServiceAddress)
237 | if err != nil {
238 | return nil, err
239 | }
240 |
241 | nodesWithHost := make([]string, len(nodes))
242 | for i := range nodes {
243 | nodesWithHost[i] = net.JoinHostPort(nodes[i], strconv.Itoa(cfg.MemcachedPort))
244 | }
245 |
246 | return nodesWithHost, nil
247 | } else if len(cfg.Servers) != 0 {
248 | for _, s := range cfg.Servers {
249 | _, _, err := net.SplitHostPort(s)
250 | if err != nil {
251 | return nil, err
252 | }
253 | }
254 | return cfg.Servers, nil
255 | }
256 | }
257 |
258 | return []string{}, nil
259 | }
260 |
--------------------------------------------------------------------------------
/memcached/node_provider_test.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net"
7 | "slices"
8 | "sync"
9 | "testing"
10 | "time"
11 |
12 | "github.com/stretchr/testify/assert"
13 | "github.com/stretchr/testify/mock"
14 | "github.com/stretchr/testify/require"
15 |
16 | "github.com/aliexpressru/gomemcached/consistenthash"
17 | "github.com/aliexpressru/gomemcached/logger"
18 | "github.com/aliexpressru/gomemcached/utils"
19 | )
20 |
21 | func Test_getNodes(t *testing.T) {
22 | type args struct {
23 | cfg *config
24 | mock *network
25 | }
26 | tests := []struct {
27 | name string
28 | args args
29 | want []string
30 | wantErr assert.ErrorAssertionFunc
31 | }{
32 | {
33 | name: "Servers",
34 | args: args{
35 | mock: &network{lookupHost: func(host string) (addrs []string, err error) {
36 | return []string{"server1:11211", "server2:11211"}, nil
37 | }},
38 | cfg: &config{
39 | Servers: []string{"server1:11211", "server2:11211"},
40 | }},
41 | want: []string{"server1:11211", "server2:11211"},
42 | wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
43 | if err != nil {
44 | t.Errorf("getNodes have error - %v", err)
45 | return false
46 | }
47 | return true
48 | },
49 | },
50 | {
51 | name: "Headless",
52 | args: args{
53 | mock: &network{lookupHost: func(host string) (addrs []string, err error) {
54 | return []string{"93.184.216.34", "123.323.32.11"}, nil
55 | }},
56 | cfg: &config{
57 | HeadlessServiceAddress: "example.com",
58 | MemcachedPort: 11211,
59 | }},
60 | want: []string{"93.184.216.34:11211", "123.323.32.11:11211"},
61 | wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
62 | if err != nil {
63 | t.Errorf("getNodes have error - %v", err)
64 | return false
65 | }
66 | return true
67 | },
68 | },
69 | {
70 | name: "config nil",
71 | args: args{
72 | mock: &network{lookupHost: func(_ string) (_ []string, _ error) {
73 | return
74 | }},
75 | cfg: nil},
76 | want: []string{},
77 | wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
78 | if err != nil {
79 | t.Errorf("getNodes have error - %v", err)
80 | return false
81 | }
82 | return true
83 | },
84 | },
85 | {
86 | name: "error headless",
87 | args: args{
88 | mock: &network{lookupHost: func(host string) (addrs []string, err error) {
89 | return nil, &net.DNSError{
90 | Err: "no such host",
91 | Name: "fakeaddress.r",
92 | }
93 | }},
94 | cfg: &config{HeadlessServiceAddress: "fakeaddress.r"}},
95 | want: nil,
96 | wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
97 | if err != nil {
98 | dnsError := new(net.DNSError)
99 | assert.ErrorAs(t, err, &dnsError, "Error should be as net.DNSError")
100 | return true
101 | }
102 | return false
103 | },
104 | },
105 | {
106 | name: "error servers",
107 | args: args{
108 | mock: new(network),
109 | cfg: &config{Servers: []string{"localhost:1234", "fakeaddress.r", "localhost"}}},
110 | want: nil,
111 | wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
112 | if err != nil {
113 | return true
114 | }
115 | t.Errorf("getNodes dot't have error")
116 | return false
117 | },
118 | },
119 | }
120 | for _, tt := range tests {
121 | t.Run(tt.name, func(t *testing.T) {
122 | got, err := getNodes(tt.args.mock.lookupHost, tt.args.cfg)
123 | if !tt.wantErr(t, err) {
124 | return
125 | }
126 | assert.Equalf(t, tt.want, got, "getNodes(%v)", tt.args.cfg)
127 | })
128 | }
129 | }
130 |
131 | func Test_safeGetDeadNodes(t *testing.T) {
132 | client := &Client{
133 | deadNodes: map[string]struct{}{
134 | "node1": {},
135 | "node2": {},
136 | },
137 | }
138 |
139 | wg := sync.WaitGroup{}
140 | wg.Add(2)
141 | go func() {
142 | defer wg.Done()
143 | deadNodes := client.safeGetDeadNodes()
144 |
145 | expectedDeadNodes := map[string]struct{}{
146 | "node1": {},
147 | "node2": {},
148 | }
149 | assert.Equal(t, expectedDeadNodes, deadNodes)
150 | }()
151 | go func() {
152 | defer wg.Done()
153 | deadNodes := client.safeGetDeadNodes()
154 |
155 | expectedDeadNodes := map[string]struct{}{
156 | "node1": {},
157 | "node2": {},
158 | }
159 | assert.Equal(t, expectedDeadNodes, deadNodes)
160 | }()
161 |
162 | wg.Wait()
163 | }
164 |
165 | func Test_safeAddToDeadNodes(t *testing.T) {
166 | client := &Client{
167 | deadNodes: map[string]struct{}{},
168 | }
169 |
170 | wg := sync.WaitGroup{}
171 | wg.Add(2)
172 | go func() {
173 | defer wg.Done()
174 | client.safeAddToDeadNodes("node1")
175 | }()
176 | go func() {
177 | defer wg.Done()
178 | client.safeAddToDeadNodes("node2")
179 | }()
180 |
181 | wg.Wait()
182 |
183 | assert.Contains(t, client.deadNodes, "node1")
184 | assert.Contains(t, client.deadNodes, "node2")
185 | }
186 |
187 | func Test_safeRemoveFromDeadNodes(t *testing.T) {
188 | client := &Client{
189 | deadNodes: map[string]struct{}{
190 | "node1": {},
191 | "node2": {},
192 | },
193 | }
194 |
195 | wg := sync.WaitGroup{}
196 | wg.Add(2)
197 | go func() {
198 | defer wg.Done()
199 | client.safeRemoveFromDeadNodes("node1")
200 | }()
201 | go func() {
202 | defer wg.Done()
203 | client.safeRemoveFromDeadNodes("node2")
204 | }()
205 |
206 | wg.Wait()
207 |
208 | assert.NotContains(t, client.deadNodes, "node1")
209 | assert.NotContains(t, client.deadNodes, "node2")
210 | }
211 |
212 | func Test_nodeIsDead(t *testing.T) {
213 | logger.DisableLogger()
214 | addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
215 |
216 | mockNetworkError := new(MockNetworkOperations)
217 | client := &Client{nw: &network{
218 | dialTimeout: mockNetworkError.DialTimeout,
219 | }}
220 |
221 | assert.True(t, client.nodeIsDead("wrongarrd.r"), "nodeIsDead: wrong addr should be return true")
222 |
223 | expectedErr := errors.New("mocked dial error")
224 |
225 | mockNetworkError.On("DialTimeout", addr.Network(), addr.String(), client.netTimeout()).Return(nil, expectedErr)
226 |
227 | result := client.nodeIsDead(addr)
228 |
229 | assert.True(t, result)
230 |
231 | mockNetworkError.AssertCalled(t, "DialTimeout", addr.Network(), addr.String(), client.netTimeout())
232 |
233 | mockNetworkRetry := new(MockNetworkOperations)
234 | client = &Client{nw: &network{
235 | dialTimeout: mockNetworkRetry.DialTimeout,
236 | }}
237 |
238 | expectedErr = &ConnectTimeoutError{addr}
239 |
240 | mockNetworkRetry.On("DialTimeout", addr.Network(), addr.String(), client.netTimeout()).Return(nil, expectedErr)
241 | result = client.nodeIsDead(addr)
242 |
243 | assert.True(t, result)
244 |
245 | // int(DefaultRetryCountForConn)+1 - the default number of retries plus the first execution.
246 | mockNetworkRetry.AssertNumberOfCalls(t, "DialTimeout", int(DefaultRetryCountForConn)+1)
247 |
248 | mockNetworkSuccess := new(MockNetworkOperations)
249 | client = &Client{nw: &network{
250 | dialTimeout: mockNetworkSuccess.DialTimeout,
251 | }}
252 |
253 | mockNetworkSuccess.On("DialTimeout", addr.Network(), addr.String(), client.netTimeout()).Return(&FakeConn{}, nil)
254 |
255 | result = client.nodeIsDead(addr)
256 |
257 | assert.False(t, result)
258 |
259 | mockNetworkSuccess.AssertCalled(t, "DialTimeout", addr.Network(), addr.String(), client.netTimeout())
260 | }
261 |
262 | func Test_initNodesProvider(t *testing.T) {
263 | var (
264 | mockNetworkErr = new(MockNetworkOperations)
265 |
266 | period = 10 * time.Millisecond
267 |
268 | ctx, cancel = context.WithCancel(context.TODO())
269 | expectedErr = errors.New("mocked dial error")
270 | )
271 | cl := &Client{
272 | ctx: ctx,
273 | nw: &network{
274 | dial: mockNetworkErr.Dial,
275 | lookupHost: mockNetworkErr.LookupHost,
276 | },
277 | cfg: &config{
278 | HeadlessServiceAddress: "example.com",
279 | },
280 | nodeHCPeriod: period,
281 | nodeRBPeriod: period,
282 | }
283 |
284 | mockNetworkErr.On("LookupHost", cl.cfg.HeadlessServiceAddress).Return(nil, expectedErr)
285 | mockNetworkErr.On("Dial", mock.Anything, mock.Anything).Return(&FakeConn{}, nil)
286 |
287 | cl.initNodesProvider()
288 |
289 | mockNetworkErr.AssertNotCalled(t, "Dial")
290 |
291 | <-time.After(2 * period)
292 | cancel()
293 |
294 | mockNetworkErr.AssertCalled(t, "LookupHost", cl.cfg.HeadlessServiceAddress)
295 | }
296 |
297 | func Test_checkNodesHealth(t *testing.T) {
298 | var (
299 | mockNetworkErr = new(MockNetworkOperations)
300 |
301 | expectedErr = errors.New("mocked dial error")
302 | )
303 | cl := &Client{
304 | nw: &network{
305 | dial: mockNetworkErr.Dial,
306 | lookupHost: mockNetworkErr.LookupHost,
307 | },
308 | cfg: &config{
309 | HeadlessServiceAddress: "example.com",
310 | },
311 | }
312 |
313 | mockNetworkErr.On("LookupHost", cl.cfg.HeadlessServiceAddress).Return(nil, expectedErr)
314 | mockNetworkErr.On("Dial", mock.Anything, mock.Anything).Return(&FakeConn{}, nil)
315 |
316 | cl.checkNodesHealth()
317 |
318 | mockNetworkErr.AssertNotCalled(t, "Dial")
319 | mockNetworkErr.AssertNumberOfCalls(t, "LookupHost", 1)
320 |
321 | var (
322 | currentNodes = []string{"127.0.0.1:12345", "127.0.0.2:12345", "127.0.0.3:12345", "127.0.0.4:12345", "127.0.0.5:12345"}
323 | alreadyDeadNodes = []string{"127.0.0.4:12345", "127.0.0.5:12345"}
324 | disableNodes = []string{"127.0.0.6:12345"}
325 |
326 | mockNetwork = new(MockNetworkOperations)
327 | )
328 |
329 | cl = &Client{
330 | hr: consistenthash.NewHashRing(),
331 | timeout: -1,
332 | nw: &network{
333 | dial: mockNetwork.Dial,
334 | lookupHost: mockNetwork.LookupHost,
335 | },
336 | cfg: &config{
337 | Servers: currentNodes,
338 | },
339 | }
340 |
341 | mockNetwork.On("Dial", "tcp", "127.0.0.2:12345").Return(nil, expectedErr).Once()
342 | mockNetwork.On("Dial", "tcp", "127.0.0.4:12345").Return(nil, expectedErr).Once()
343 | mockNetwork.On("Dial", mock.Anything, mock.Anything).Return(&FakeConn{}, nil)
344 |
345 | for _, node := range currentNodes {
346 | addr, _ := utils.AddrRepr(node)
347 | cl.hr.Add(addr)
348 | }
349 | cl.deadNodes = make(map[string]struct{})
350 | for _, node := range alreadyDeadNodes {
351 | cl.deadNodes[node] = struct{}{}
352 | }
353 | for _, node := range disableNodes {
354 | cl.deadNodes[node] = struct{}{}
355 | }
356 |
357 | cl.checkNodesHealth()
358 |
359 | assert.Equal(t, 3, len(cl.hr.GetAllNodes()))
360 | assert.Equal(t, 2, len(cl.deadNodes))
361 | }
362 |
363 | func Test_rebuildNodes(t *testing.T) {
364 | var (
365 | mockNetworkErr = new(MockNetworkOperations)
366 |
367 | expectedErr = errors.New("mocked dial error")
368 | )
369 | cl := &Client{
370 | nw: &network{
371 | dial: mockNetworkErr.Dial,
372 | lookupHost: mockNetworkErr.LookupHost,
373 | },
374 | cfg: &config{
375 | HeadlessServiceAddress: "example.com",
376 | },
377 | }
378 |
379 | mockNetworkErr.On("LookupHost", cl.cfg.HeadlessServiceAddress).Return(nil, expectedErr)
380 | mockNetworkErr.On("Dial", mock.Anything, mock.Anything).Return(&FakeConn{}, nil)
381 |
382 | cl.rebuildNodes()
383 |
384 | mockNetworkErr.AssertNotCalled(t, "Dial")
385 | mockNetworkErr.AssertNumberOfCalls(t, "LookupHost", 1)
386 |
387 | var (
388 | currentNodes = []string{"127.0.0.1:12345", "127.0.0.2:12345", "127.0.0.3:12345", "127.0.0.4:12345", "127.0.0.5:12345"}
389 | alreadyDeadNodes = []string{"127.0.0.4:12345", "127.0.0.2:12345"}
390 | expectedNodesInRing = []string{"127.0.0.1:12345", "127.0.0.3:12345", "127.0.0.5:12345"}
391 |
392 | mockNetwork = new(MockNetworkOperations)
393 | )
394 | cl = &Client{
395 | ctx: context.TODO(),
396 | nw: &network{
397 | dial: mockNetwork.Dial,
398 | lookupHost: mockNetwork.LookupHost,
399 | },
400 | cfg: &config{
401 | Servers: currentNodes,
402 | },
403 | timeout: -1,
404 | maxIdleConns: 1,
405 | hr: consistenthash.NewHashRing(),
406 | }
407 |
408 | mockNetwork.On("LookupHost", cl.cfg.Servers).Return(currentNodes, nil)
409 | mockNetwork.On("Dial", mock.Anything, mock.Anything).Return(&FakeConn{}, nil)
410 |
411 | cl.deadNodes = make(map[string]struct{})
412 | for _, node := range alreadyDeadNodes {
413 | cl.deadNodes[node] = struct{}{}
414 | }
415 | // len(currentNodes)-1 simulates the absence of one node in the hash ring.
416 | for i := 0; i < len(currentNodes)-1; i++ {
417 | addr, _ := utils.AddrRepr(currentNodes[i])
418 | cl.hr.Add(addr)
419 | }
420 |
421 | // simulating the borrowing of a connection from the connection pool by taking a connection and putting it back.
422 | // for test CloseAvailableConnsInAllShardPools
423 | for i := 0; i < len(currentNodes)-1; i++ {
424 | node, ok := cl.hr.Get(currentNodes[i])
425 | require.Truef(t, ok, "Not found node (%s) in hash ring", currentNodes[i])
426 | cn, err := cl.getConnForNode(node)
427 | require.Nil(t, err, "getConnForNode try get conn")
428 | cn.condRelease(new(error))
429 | }
430 |
431 | cl.rebuildNodes()
432 |
433 | assert.Equal(t, 3, cl.hr.GetNodesCount())
434 |
435 | var actualNodesInRing []string
436 | for _, node := range cl.hr.GetAllNodes() {
437 | actualNodesInRing = append(actualNodesInRing, node.(net.Addr).String())
438 | }
439 |
440 | slices.Sort(actualNodesInRing)
441 | slices.Sort(expectedNodesInRing)
442 | assert.Equal(t, expectedNodesInRing, actualNodesInRing)
443 |
444 | for _, pool := range cl.freeConns {
445 | assert.Equal(t, 0, pool.Len())
446 | }
447 | }
448 |
449 | type MockNetworkOperations struct {
450 | mock.Mock
451 | }
452 |
453 | func (m *MockNetworkOperations) Dial(network, address string) (net.Conn, error) {
454 | args := m.Called(network, address)
455 | if args.Get(0) == nil {
456 | return nil, args.Error(1)
457 | }
458 | return args.Get(0).(net.Conn), args.Error(1)
459 | }
460 |
461 | func (m *MockNetworkOperations) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
462 | args := m.Called(network, address, timeout)
463 | if args.Get(0) == nil {
464 | return nil, args.Error(1)
465 | }
466 | return args.Get(0).(net.Conn), args.Error(1)
467 | }
468 |
469 | func (m *MockNetworkOperations) LookupHost(host string) ([]string, error) {
470 | args := m.Called(host)
471 | if args.Get(0) == nil {
472 | return nil, args.Error(1)
473 | }
474 | return args.Get(0).([]string), args.Error(1)
475 | }
476 |
477 | type FakeConn struct {
478 | net.TCPConn
479 | }
480 |
481 | func (f *FakeConn) Read(_ []byte) (n int, err error) {
482 | return 0, nil
483 | }
484 |
485 | func (f *FakeConn) Write(_ []byte) (n int, err error) {
486 | return 0, nil
487 | }
488 |
489 | func (f *FakeConn) Close() error {
490 | return nil
491 | }
492 |
--------------------------------------------------------------------------------
/memcached/options.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/aliexpressru/gomemcached/consistenthash"
7 | )
8 |
9 | type options struct {
10 | Client
11 | disableLogger bool
12 | }
13 |
14 | type Option func(*options)
15 |
16 | // WithMaxIdleConns is sets a custom value of open connections per address.
17 | // By default, DefaultMaxIdleConns will be used.
18 | func WithMaxIdleConns(num int) Option {
19 | return func(o *options) {
20 | o.Client.maxIdleConns = num
21 | }
22 | }
23 |
24 | // WithTimeout is sets custom timeout for connections.
25 | // By default, DefaultTimeout will be used.
26 | func WithTimeout(tm time.Duration) Option {
27 | return func(o *options) {
28 | o.Client.timeout = tm
29 | }
30 | }
31 |
32 | // WithCustomHashRing for setup use consistenthash.NewCustomHashRing
33 | func WithCustomHashRing(hr *consistenthash.HashRing) Option {
34 | return func(o *options) {
35 | o.Client.hr = hr
36 | }
37 | }
38 |
39 | // WithPeriodForNodeHealthCheck is sets a custom frequency for health checker of physical nodes.
40 | // By default, DefaultNodeHealthCheckPeriod will be used.
41 | func WithPeriodForNodeHealthCheck(t time.Duration) Option {
42 | return func(o *options) {
43 | o.Client.nodeHCPeriod = t
44 | }
45 | }
46 |
47 | // WithPeriodForRebuildingNodes is sets a custom frequency for resharding and checking for dead nodes.
48 | // By default, DefaultRebuildingNodePeriod will be used.
49 | func WithPeriodForRebuildingNodes(t time.Duration) Option {
50 | return func(o *options) {
51 | o.Client.nodeRBPeriod = t
52 | }
53 | }
54 |
55 | // WithDisableNodeProvider is disabled node health cheek and rebuild nodes for hash ring
56 | func WithDisableNodeProvider() Option {
57 | return func(o *options) {
58 | o.Client.disableNodeProvider = true
59 | }
60 | }
61 |
62 | // WithDisableRefreshConnsInPool is disabled auto close some connections in pool in NodeProvider.
63 | // This is done to refresh connections in the pool.
64 | func WithDisableRefreshConnsInPool() Option {
65 | return func(o *options) {
66 | o.Client.disableRefreshConns = true
67 | }
68 | }
69 |
70 | // WithDisableMemcachedDiagnostic is disabled write library metrics.
71 | //
72 | // gomemcached_method_duration_seconds
73 | func WithDisableMemcachedDiagnostic() Option {
74 | return func(o *options) {
75 | o.Client.disableMemcachedDiagnostic = true
76 | }
77 | }
78 |
79 | // WithDisableLogger is disabled internal library logs.
80 | func WithDisableLogger() Option {
81 | return func(o *options) {
82 | o.disableLogger = true
83 | }
84 | }
85 |
86 | // WithAuthentication is turn on authenticate for memcached
87 | func WithAuthentication(user, pass string) Option {
88 | return func(o *options) {
89 | o.Client.authEnable = true
90 | o.Client.authData = prepareAuthData(user, pass)
91 | }
92 | }
93 |
--------------------------------------------------------------------------------
/memcached/options_test.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "os"
5 | "testing"
6 | "time"
7 |
8 | "github.com/stretchr/testify/assert"
9 |
10 | "github.com/aliexpressru/gomemcached/consistenthash"
11 | "github.com/aliexpressru/gomemcached/logger"
12 | )
13 |
14 | func TestWithOptions(t *testing.T) {
15 | os.Setenv("MEMCACHED_SERVERS", "localhost:11211")
16 |
17 | hMcl, _ := InitFromEnv()
18 | assert.NotNil(t, hMcl.hr, "InitFromEnv: hash ring is nil")
19 |
20 | const (
21 | maxIdleConns = 10
22 | disable = true
23 | enable
24 | authUser = "admin"
25 | authPass = "password"
26 | timeout = 5 * time.Second
27 | period = time.Second
28 | )
29 |
30 | hr := consistenthash.NewCustomHashRing(1, nil)
31 | mcl, _ := InitFromEnv(
32 | WithMaxIdleConns(maxIdleConns),
33 | WithTimeout(timeout),
34 | WithCustomHashRing(hr),
35 | WithPeriodForNodeHealthCheck(period),
36 | WithPeriodForRebuildingNodes(period),
37 | WithDisableNodeProvider(),
38 | WithDisableRefreshConnsInPool(),
39 | WithDisableMemcachedDiagnostic(),
40 | WithAuthentication(authUser, authPass),
41 | WithDisableLogger(),
42 | )
43 | t.Cleanup(func() {
44 | mcl.CloseAllConns()
45 | })
46 |
47 | assert.Equal(t, maxIdleConns, mcl.maxIdleConns, "WithMaxIdleConns should set maxIdleConns")
48 | assert.Equal(t, timeout, mcl.timeout, "WithTimeout should set timeout")
49 | assert.Equal(t, hr, mcl.hr, "WithCustomHashRing should set hr")
50 | assert.Equal(t, period, mcl.nodeHCPeriod, "WithPeriodForNodeHealthCheck should set period")
51 | assert.Equal(t, period, mcl.nodeRBPeriod, "WithPeriodForRebuildingNodes should set period")
52 | assert.Equal(t, disable, mcl.disableNodeProvider, "WithDisableNodeProvider should set disable")
53 | assert.Equal(t, disable, mcl.disableRefreshConns, "WithDisableRefreshConnsInPool should set disable")
54 | assert.Equal(t, disable, mcl.disableMemcachedDiagnostic, "WithDisableMemcachedDiagnostic should set disable")
55 | assert.Equal(t, enable, mcl.authEnable, "WithAuthentication should set enable")
56 | assert.Equal(t, disable, logger.LoggerIsDisable(), "WithDisableLogger should set disable")
57 | }
58 |
--------------------------------------------------------------------------------
/memcached/requests.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "encoding/binary"
5 | "fmt"
6 | "io"
7 | )
8 |
9 | const (
10 | // MaxBodyLen a maximum reasonable body length to expect.
11 | // Anything larger than this will result in an error.
12 | MaxBodyLen = int(22 * 1e6) // 22 MB
13 |
14 | BUF_LEN = 256
15 |
16 | // reserved always 0
17 | reserved8 = uint8(0)
18 | reserved16 = uint16(0)
19 | )
20 |
21 | // Request a Memcached request
22 | type Request struct {
23 | // The command being issued
24 | Opcode OpCode
25 | // The CAS (if applicable, or 0)
26 | Cas uint64
27 | // An opaque value to be returned with this request
28 | Opaque uint32
29 | // Command extras, key, and body
30 | Extras, Key, Body []byte
31 | }
32 |
33 | // Size is a number of bytes this request requires.
34 | func (r *Request) Size() int {
35 | return HDR_LEN + len(r.Extras) + len(r.Key) + len(r.Body)
36 | }
37 |
38 | // String is a debugging string representation of this request
39 | func (r Request) String() string {
40 | return fmt.Sprintf("{Request opcode=%s, bodylen=%d, key='%s'}",
41 | r.Opcode, len(r.Body), r.Key)
42 | }
43 |
44 | func (r *Request) fillHeaderBytes(data []byte) int {
45 | /*
46 | Byte/ 0 | 1 | 2 | 3 |
47 | / | | | |
48 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|
49 | +---------------+---------------+---------------+---------------+
50 | 0| Magic | Opcode | Key length |
51 | +---------------+---------------+---------------+---------------+
52 | 4| Extras length | Data type | vbucket id |
53 | +---------------+---------------+---------------+---------------+
54 | 8| Total body length |
55 | +---------------+---------------+---------------+---------------+
56 | 12| Opaque |
57 | +---------------+---------------+---------------+---------------+
58 | 16| CAS |
59 | | |
60 | +---------------+---------------+---------------+---------------+
61 | Total 24 bytes
62 | */
63 |
64 | pos := 0
65 | data[pos] /*0x00*/ = REQ_MAGIC
66 | pos++ // 1
67 | data[pos] /*0x01*/ = byte(r.Opcode)
68 | pos++ // 2
69 | binary.BigEndian.PutUint16(data[pos:pos+2] /*0x02 - 0x03*/, uint16(len(r.Key)))
70 |
71 | pos += 2 // 4
72 | data[pos] /*0x04*/ = byte(len(r.Extras))
73 |
74 | pos++ // 5
75 | data[pos] /*0x05*/ = reserved8
76 |
77 | pos++ // 6
78 | binary.BigEndian.PutUint16(data[pos:pos+2] /*0x06*/, reserved16)
79 |
80 | pos += 2 // 8
81 | binary.BigEndian.PutUint32(data[pos:pos+4] /*0x08 - 0x09 - 0x0a - 0x0b*/, uint32(len(r.Body)+len(r.Key)+len(r.Extras)))
82 |
83 | pos += 4 // 12
84 | binary.BigEndian.PutUint32(data[pos:pos+4] /*0x0c - 0x0d - 0x0e - 0x0f*/, r.Opaque)
85 |
86 | pos += 4 // 16
87 | if r.Cas != 0 {
88 | binary.BigEndian.PutUint64(data[pos:pos+8] /*0x10 - 0x11 - 0x12 - 0x13 - 0x14 - 0x16 - 0x17*/, r.Cas)
89 | }
90 |
91 | pos += 8 // 24
92 | if len(r.Extras) > 0 {
93 | copy(data[pos:pos+len(r.Extras)], r.Extras)
94 | pos += len(r.Extras)
95 | }
96 |
97 | if len(r.Key) > 0 {
98 | copy(data[pos:pos+len(r.Key)], r.Key)
99 | pos += len(r.Key)
100 | }
101 | return pos
102 | }
103 |
104 | // HeaderBytes is a wire representation of the header (with the extras and key)
105 | func (r *Request) HeaderBytes() []byte {
106 | data := make([]byte, HDR_LEN+len(r.Extras)+len(r.Key))
107 |
108 | r.fillHeaderBytes(data)
109 |
110 | return data
111 | }
112 |
113 | // Bytes is a wire representation of this request.
114 | func (r *Request) Bytes() []byte {
115 | data := make([]byte, r.Size())
116 |
117 | pos := r.fillHeaderBytes(data)
118 |
119 | if len(r.Body) > 0 {
120 | copy(data[pos:pos+len(r.Body)], r.Body)
121 | }
122 |
123 | return data
124 | }
125 |
126 | // Transmit is send this request message across a writer.
127 | func (r *Request) Transmit(w io.Writer) (n int, err error) {
128 | if len(r.Body) < BODY_LEN {
129 | n, err = w.Write(r.Bytes())
130 | } else {
131 | n, err = w.Write(r.HeaderBytes())
132 | if err == nil {
133 | m := 0
134 | m, err = w.Write(r.Body)
135 | n += m
136 | }
137 | }
138 | return
139 | }
140 |
141 | // Receive a fill this Request with the data from this reader.
142 | func (r *Request) Receive(rd io.Reader, hdrBytes []byte) (int, error) {
143 | /*
144 | Byte/ 0 | 1 | 2 | 3 |
145 | / | | | |
146 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|
147 | +---------------+---------------+---------------+---------------+
148 | 0| Magic | Opcode | Key length |
149 | +---------------+---------------+---------------+---------------+
150 | 4| Extras length | Data type | vbucket id |
151 | +---------------+---------------+---------------+---------------+
152 | 8| Total body length |
153 | +---------------+---------------+---------------+---------------+
154 | 12| Opaque |
155 | +---------------+---------------+---------------+---------------+
156 | 16| CAS |
157 | | |
158 | +---------------+---------------+---------------+---------------+
159 | Total 24 bytes
160 | */
161 |
162 | if len(hdrBytes) < HDR_LEN {
163 | hdrBytes = make([]byte, HDR_LEN)
164 | }
165 |
166 | n, err := io.ReadFull(rd, hdrBytes)
167 | if err != nil {
168 | return n, err
169 | }
170 |
171 | if hdrBytes[0] != RES_MAGIC && hdrBytes[0] != REQ_MAGIC {
172 | return n, fmt.Errorf("bad magic: 0x%02x", hdrBytes[0])
173 | }
174 | r.Opcode = OpCode(hdrBytes[1])
175 |
176 | klen := int(binary.BigEndian.Uint16(hdrBytes[2:]))
177 | elen := int(hdrBytes[4])
178 | bodyLen := int(binary.BigEndian.Uint32(hdrBytes[8:]) - uint32(klen) - uint32(elen))
179 | if bodyLen > MaxBodyLen {
180 | return n, fmt.Errorf("%d is too big (max %d)",
181 | bodyLen, MaxBodyLen)
182 | }
183 | r.Opaque = binary.BigEndian.Uint32(hdrBytes[12:])
184 | r.Cas = binary.BigEndian.Uint64(hdrBytes[16:])
185 |
186 | buf := make([]byte, klen+elen+bodyLen)
187 | m, err := io.ReadFull(rd, buf)
188 | n += m
189 | if err == nil {
190 | if elen > 0 {
191 | r.Extras = buf[0:elen]
192 | }
193 | if klen > 0 {
194 | r.Key = buf[elen : klen+elen]
195 | }
196 | if klen+elen > 0 {
197 | r.Body = buf[klen+elen:]
198 | }
199 | }
200 |
201 | return n, err
202 | }
203 |
204 | // prepareExtras fills Extras depending on OpCode for Request
205 | func (r *Request) prepareExtras(expiration uint32, delta uint64, initVal uint64) {
206 | switch r.Opcode {
207 | case DELETE, DELETEQ, QUIT, QUITQ, NOOP, VERSION, APPEND, APPENDQ, PREPEND, PREPENDQ, STAT, GET, GETQ, GETK, GETKQ: // MUST NOT have extras
208 | case SET, SETQ, ADD, ADDQ, REPLACE, REPLACEQ:
209 | /*
210 | Byte/ 0 | 1 | 2 | 3 |
211 | / | | | |
212 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|
213 | +---------------+---------------+---------------+---------------+
214 | 0| Flags |
215 | +---------------+---------------+---------------+---------------+
216 | 4| Expiration |
217 | +---------------+---------------+---------------+---------------+
218 | Total 8 bytes
219 | */
220 |
221 | r.Extras = make([]byte, 8)
222 | // flags always is 0
223 | binary.BigEndian.PutUint32(r.Extras[:4], uint32(0))
224 | binary.BigEndian.PutUint32(r.Extras[4:], expiration)
225 | case INCREMENT, INCREMENTQ, DECREMENT, DECREMENTQ:
226 | /*
227 |
228 | Byte/ 0 | 1 | 2 | 3 |
229 | / | | | |
230 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|
231 | +---------------+---------------+---------------+---------------+
232 | 0| Amount to add / subtract (delta) |
233 | | |
234 | +---------------+---------------+---------------+---------------+
235 | 8| Initial value |
236 | | |
237 | +---------------+---------------+---------------+---------------+
238 | 16| Expiration |
239 | +---------------+---------------+---------------+---------------+
240 | Total 20 bytes
241 | */
242 | r.Extras = make([]byte, 20)
243 | binary.BigEndian.PutUint64(r.Extras[:8], delta)
244 | binary.BigEndian.PutUint64(r.Extras[8:], initVal)
245 | binary.BigEndian.PutUint32(r.Extras[16:], expiration)
246 | case FLUSH, FLUSHQ:
247 | /*
248 | Byte/ 0 | 1 | 2 | 3 |
249 | / | | | |
250 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|
251 | +---------------+---------------+---------------+---------------+
252 | 0| Expiration |
253 | +---------------+---------------+---------------+---------------+
254 | Total 4 bytes
255 | */
256 | r.Extras = make([]byte, 4)
257 | binary.BigEndian.PutUint32(r.Extras, expiration)
258 | }
259 | }
260 |
261 | type StoreMode uint8
262 |
263 | const (
264 | // Add - Store the data, but only if the server does not already hold data for a given key
265 | Add StoreMode = iota
266 | // Set - Store the data, overwrite if already exists
267 | Set
268 | // Replace - Store the data, but only if the server does already hold data for a given key
269 | Replace
270 | )
271 |
272 | func (sm StoreMode) Resolve() OpCode {
273 | switch sm {
274 | case Set:
275 | return SET
276 | case Replace:
277 | return REPLACE
278 | default:
279 | return ADD
280 | }
281 | }
282 |
283 | type DeltaMode uint8
284 |
285 | const (
286 | // Increment - increases the value by the specified amount
287 | Increment DeltaMode = iota
288 | // Decrement - decreases the value by the specified amount
289 | Decrement
290 | )
291 |
292 | func (sm DeltaMode) Resolve() OpCode {
293 | switch sm {
294 | case Increment:
295 | return INCREMENT
296 | case Decrement:
297 | return DECREMENT
298 | default:
299 | return INCREMENT
300 | }
301 | }
302 |
303 | type AppendMode uint8
304 |
305 | const (
306 | // Append - Appends data to the end of an existing key value.
307 | Append AppendMode = iota
308 | // Prepend -Appends data to the beginning of an existing key value.
309 | Prepend
310 | )
311 |
312 | func (sm AppendMode) Resolve() OpCode {
313 | switch sm {
314 | case Append:
315 | return APPEND
316 | default:
317 | return PREPEND
318 | }
319 | }
320 |
--------------------------------------------------------------------------------
/memcached/requests_test.go:
--------------------------------------------------------------------------------
1 | // nolint
2 | package memcached
3 |
4 | import (
5 | "bytes"
6 | "fmt"
7 | "io/ioutil"
8 | "reflect"
9 | "testing"
10 | )
11 |
12 | func TestEncodingRequest(t *testing.T) {
13 | req := Request{
14 | Opcode: SET,
15 | Cas: 938424885,
16 | Opaque: 7242,
17 | Key: []byte("somekey"),
18 | Body: []byte("somevalue"),
19 | }
20 |
21 | got := req.Bytes()
22 |
23 | expected := []byte{
24 | REQ_MAGIC, byte(SET),
25 | 0x0, 0x7, // length of key
26 | 0x0, // extra length
27 | 0x0, // reserved
28 | 0x0, 0x0, // vbucket
29 | 0x0, 0x0, 0x0, 0x10, // Length of value
30 | 0x0, 0x0, 0x1c, 0x4a, // opaque
31 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS
32 | 's', 'o', 'm', 'e', 'k', 'e', 'y',
33 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e',
34 | }
35 |
36 | if len(got) != req.Size() {
37 | t.Fatalf("Expected %v bytes, got %v", got,
38 | len(got))
39 | }
40 |
41 | if !reflect.DeepEqual(got, expected) {
42 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v",
43 | expected, got)
44 | }
45 |
46 | exp := `{Request opcode=SET, bodylen=9, key='somekey'}`
47 | if req.String() != exp {
48 | t.Errorf("Expected string=%q, got %q", exp, req.String())
49 | }
50 | }
51 |
52 | func TestEncodingRequestWithExtras(t *testing.T) {
53 | req := Request{
54 | Opcode: SET,
55 | Cas: 938424885,
56 | Opaque: 7242,
57 | Extras: []byte{1, 2, 3, 4},
58 | Key: []byte("somekey"),
59 | Body: []byte("somevalue"),
60 | }
61 |
62 | buf := &bytes.Buffer{}
63 | req.Transmit(buf)
64 | got := buf.Bytes()
65 |
66 | expected := []byte{
67 | REQ_MAGIC, byte(SET),
68 | 0x0, 0x7, // length of key
69 | 0x4, // extra length
70 | 0x0, // reserved
71 | 0x0, 0x0, // vbucket
72 | 0x0, 0x0, 0x0, 0x14, // Length of remainder
73 | 0x0, 0x0, 0x1c, 0x4a, // opaque
74 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS
75 | 1, 2, 3, 4, // extras
76 | 's', 'o', 'm', 'e', 'k', 'e', 'y',
77 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e',
78 | }
79 |
80 | if len(got) != req.Size() {
81 | t.Fatalf("Expected %v bytes, got %v", got,
82 | len(got))
83 | }
84 |
85 | if !reflect.DeepEqual(got, expected) {
86 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v",
87 | expected, got)
88 | }
89 | }
90 |
91 | func TestEncodingRequestWithLargeBody(t *testing.T) {
92 | req := Request{
93 | Opcode: SET,
94 | Cas: 938424885,
95 | Opaque: 7242,
96 | Extras: []byte{1, 2, 3, 4},
97 | Key: []byte("somekey"),
98 | Body: make([]byte, 256),
99 | }
100 |
101 | buf := &bytes.Buffer{}
102 | req.Transmit(buf)
103 | got := buf.Bytes()
104 |
105 | expected := append([]byte{
106 | REQ_MAGIC, byte(SET),
107 | 0x0, 0x7, // length of key
108 | 0x4, // extra length
109 | 0x0, // reserved
110 | 0x0, 0x0, // vbucket
111 | 0x0, 0x0, 0x1, 0xb, // Length of remainder
112 | 0x0, 0x0, 0x1c, 0x4a, // opaque
113 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS
114 | 1, 2, 3, 4, // extras
115 | 's', 'o', 'm', 'e', 'k', 'e', 'y',
116 | }, make([]byte, 256)...)
117 |
118 | if len(got) != req.Size() {
119 | t.Fatalf("Expected %v bytes, got %v", got,
120 | len(got))
121 | }
122 |
123 | if !reflect.DeepEqual(got, expected) {
124 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v",
125 | expected, got)
126 | }
127 | }
128 |
129 | func BenchmarkEncodingRequest(b *testing.B) {
130 | req := Request{
131 | Opcode: SET,
132 | Cas: 938424885,
133 | Opaque: 7242,
134 | Key: []byte("somekey"),
135 | Body: []byte("somevalue"),
136 | }
137 |
138 | b.SetBytes(int64(req.Size()))
139 |
140 | for i := 0; i < b.N; i++ {
141 | req.Bytes()
142 | }
143 | }
144 |
145 | func BenchmarkEncodingRequest0CAS(b *testing.B) {
146 | req := Request{
147 | Opcode: SET,
148 | Cas: 0,
149 | Opaque: 7242,
150 | Key: []byte("somekey"),
151 | Body: []byte("somevalue"),
152 | }
153 |
154 | b.SetBytes(int64(req.Size()))
155 |
156 | for i := 0; i < b.N; i++ {
157 | req.Bytes()
158 | }
159 | }
160 |
161 | func BenchmarkEncodingRequest1Extra(b *testing.B) {
162 | req := Request{
163 | Opcode: SET,
164 | Cas: 0,
165 | Opaque: 7242,
166 | Extras: []byte{1},
167 | Key: []byte("somekey"),
168 | Body: []byte("somevalue"),
169 | }
170 |
171 | b.SetBytes(int64(req.Size()))
172 |
173 | for i := 0; i < b.N; i++ {
174 | req.Bytes()
175 | }
176 | }
177 |
178 | func TestRequestTransmit(t *testing.T) {
179 | res := Request{Key: []byte("thekey")}
180 | _, err := res.Transmit(ioutil.Discard)
181 | if err != nil {
182 | t.Errorf("Error sending small request: %v", err)
183 | }
184 |
185 | res.Body = make([]byte, 256)
186 | _, err = res.Transmit(ioutil.Discard)
187 | if err != nil {
188 | t.Errorf("Error sending large request thing: %v", err)
189 | }
190 | }
191 |
192 | func TestReceiveRequest(t *testing.T) {
193 | req := Request{
194 | Opcode: SET,
195 | Cas: 0,
196 | Opaque: 7242,
197 | Extras: []byte{1},
198 | Key: []byte("somekey"),
199 | Body: []byte("somevalue"),
200 | }
201 |
202 | data := req.Bytes()
203 |
204 | req2 := Request{}
205 | n, err := req2.Receive(bytes.NewReader(data), nil)
206 | if err != nil {
207 | t.Fatalf("Error receiving: %v", err)
208 | }
209 | if len(data) != n {
210 | t.Errorf("Expected to read %v bytes, read %v", len(data), n)
211 | }
212 |
213 | if !reflect.DeepEqual(req, req2) {
214 | t.Fatalf("Expected %#v == %#v", req, req2)
215 | }
216 | }
217 |
218 | func TestReceiveRequestNoContent(t *testing.T) {
219 | req := Request{
220 | Opcode: SET,
221 | Cas: 0,
222 | Opaque: 7242,
223 | }
224 |
225 | data := req.Bytes()
226 |
227 | req2 := Request{}
228 | n, err := req2.Receive(bytes.NewReader(data), nil)
229 | if err != nil {
230 | t.Fatalf("Error receiving: %v", err)
231 | }
232 | if len(data) != n {
233 | t.Errorf("Expected to read %v bytes, read %v", len(data), n)
234 | }
235 |
236 | if fmt.Sprintf("%#v", req) != fmt.Sprintf("%#v", req2) {
237 | t.Fatalf("Expected %#v == %#v", req, req2)
238 | }
239 | }
240 |
241 | func TestReceiveRequestShortHdr(t *testing.T) {
242 | req := Request{}
243 | n, err := req.Receive(bytes.NewReader([]byte{1, 2, 3}), nil)
244 | if err == nil {
245 | t.Errorf("Expected error, got %#v", req)
246 | }
247 | if n != 3 {
248 | t.Errorf("Expected to have read 3 bytes, read %v", n)
249 | }
250 | }
251 |
252 | func TestReceiveRequestShortBody(t *testing.T) {
253 | req := Request{
254 | Opcode: SET,
255 | Cas: 0,
256 | Opaque: 7242,
257 | Extras: []byte{1},
258 | Key: []byte("somekey"),
259 | Body: []byte("somevalue"),
260 | }
261 |
262 | data := req.Bytes()
263 |
264 | req2 := Request{}
265 | n, err := req2.Receive(bytes.NewReader(data[:len(data)-3]), nil)
266 | if err == nil {
267 | t.Errorf("Expected error, got %#v", req2)
268 | }
269 | if n != len(data)-3 {
270 | t.Errorf("Expected to have read %v bytes, read %v", len(data)-3, n)
271 | }
272 | }
273 |
274 | func TestReceiveRequestBadMagic(t *testing.T) {
275 | req := Request{
276 | Opcode: SET,
277 | Cas: 0,
278 | Opaque: 7242,
279 | Extras: []byte{1},
280 | Key: []byte("somekey"),
281 | Body: []byte("somevalue"),
282 | }
283 |
284 | data := req.Bytes()
285 | data[0] = 0x83
286 |
287 | req2 := Request{}
288 | _, err := req2.Receive(bytes.NewReader(data), nil)
289 | if err == nil {
290 | t.Fatalf("Expected error, got %#v", req2)
291 | }
292 | }
293 |
294 | func TestReceiveRequestLongBody(t *testing.T) {
295 | req := Request{
296 | Opcode: SET,
297 | Cas: 0,
298 | Opaque: 7242,
299 | Extras: []byte{1},
300 | Key: []byte("somekey"),
301 | Body: make([]byte, MaxBodyLen+5),
302 | }
303 |
304 | data := req.Bytes()
305 |
306 | req2 := Request{}
307 | _, err := req2.Receive(bytes.NewReader(data), nil)
308 | if err == nil {
309 | t.Fatalf("Expected error, got %#v", req2)
310 | }
311 | }
312 |
313 | func BenchmarkReceiveRequest(b *testing.B) {
314 | req := Request{
315 | Opcode: SET,
316 | Cas: 0,
317 | Opaque: 7242,
318 | Extras: []byte{1},
319 | Key: []byte("somekey"),
320 | Body: []byte("somevalue"),
321 | }
322 |
323 | data := req.Bytes()
324 | data[0] = REQ_MAGIC
325 | rdr := bytes.NewReader(data)
326 |
327 | b.SetBytes(int64(len(data)))
328 |
329 | b.ResetTimer()
330 | buf := make([]byte, HDR_LEN)
331 | for i := 0; i < b.N; i++ {
332 | req2 := Request{}
333 | rdr.Seek(0, 0)
334 | _, err := req2.Receive(rdr, buf)
335 | if err != nil {
336 | b.Fatalf("Error receiving: %v", err)
337 | }
338 | }
339 | }
340 |
341 | func BenchmarkReceiveRequestNoBuf(b *testing.B) {
342 | req := Request{
343 | Opcode: SET,
344 | Cas: 0,
345 | Opaque: 7242,
346 | Extras: []byte{1},
347 | Key: []byte("somekey"),
348 | Body: []byte("somevalue"),
349 | }
350 |
351 | data := req.Bytes()
352 | data[0] = REQ_MAGIC
353 | rdr := bytes.NewReader(data)
354 |
355 | b.SetBytes(int64(len(data)))
356 |
357 | b.ResetTimer()
358 | for i := 0; i < b.N; i++ {
359 | req2 := Request{}
360 | rdr.Seek(0, 0)
361 | _, err := req2.Receive(rdr, nil)
362 | if err != nil {
363 | b.Fatalf("Error receiving: %v", err)
364 | }
365 | }
366 | }
367 |
368 | func TestRequest_prepareExtras(t *testing.T) {
369 | type fields struct {
370 | Opcode OpCode
371 | }
372 | type args struct {
373 | expiration uint32
374 | delta uint64
375 | initVal uint64
376 | }
377 | tests := []struct {
378 | name string
379 | fields fields
380 | args args
381 | expect []byte
382 | }{
383 | {
384 | name: "GET must not have extras",
385 | fields: fields{
386 | Opcode: GET,
387 | },
388 | args: args{
389 | expiration: 256,
390 | delta: 1,
391 | initVal: 1,
392 | },
393 | expect: nil,
394 | },
395 | {
396 | name: "SET",
397 | fields: fields{
398 | Opcode: SET,
399 | },
400 | args: args{
401 | expiration: 256,
402 | delta: 1,
403 | initVal: 1,
404 | },
405 | expect: []byte{
406 | 0x00, 0x00, 0x00, 0x00,
407 | 0x00, 0x00, 0x01, 0x00,
408 | },
409 | },
410 | {
411 | name: "INCREMENT",
412 | fields: fields{
413 | Opcode: INCREMENT,
414 | },
415 | args: args{
416 | expiration: 256,
417 | delta: 1,
418 | initVal: 42,
419 | },
420 | expect: []byte{
421 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
422 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2a,
423 | 0x00, 0x00, 0x01, 0x00,
424 | },
425 | },
426 | {
427 | name: "FLUSH",
428 | fields: fields{
429 | Opcode: FLUSH,
430 | },
431 | args: args{
432 | expiration: 256,
433 | delta: 0,
434 | initVal: 0,
435 | },
436 | expect: []byte{
437 | 0x00, 0x00, 0x01, 0x00,
438 | },
439 | },
440 | }
441 | for _, tt := range tests {
442 | t.Run(tt.name, func(t *testing.T) {
443 | r := &Request{
444 | Opcode: tt.fields.Opcode,
445 | }
446 | r.prepareExtras(tt.args.expiration, tt.args.delta, tt.args.initVal)
447 |
448 | if !bytes.Equal(r.Extras, tt.expect) {
449 | t.Fatalf("Expected %#v == %#v", r.Extras, tt.expect)
450 | }
451 | })
452 | }
453 | }
454 |
--------------------------------------------------------------------------------
/memcached/responses.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "encoding/binary"
5 | "fmt"
6 | "io"
7 | )
8 |
9 | // Response is a memcached response
10 | type Response struct {
11 | // The command opcode of the command that sent the request
12 | Opcode OpCode
13 | // The status of the response
14 | Status Status
15 | // The opaque sent in the request
16 | Opaque uint32
17 | // The CAS identifier (if applicable)
18 | Cas uint64
19 | // Extras, key, and body for this response
20 | Extras, Key, Body []byte
21 | }
22 |
23 | // String a debugging string representation of this response
24 | func (r Response) String() string {
25 | return fmt.Sprintf("{Response status=%v keylen=%d, extralen=%d, bodylen=%d}",
26 | r.Status, len(r.Key), len(r.Extras), len(r.Body))
27 | }
28 |
29 | // Error - Response as an error.
30 | func (r *Response) Error() string {
31 | return fmt.Sprintf("Response status=%v, opcode=%v, opaque=%v, msg: %s",
32 | r.Status, r.Opcode, r.Opaque, string(r.Body))
33 | }
34 |
35 | // isFatal return false if this error isn't believed to be fatal to a connection.
36 | func isFatal(e error) bool {
37 | if e == nil {
38 | return false
39 | }
40 | switch errStatus(e) {
41 | case KEY_ENOENT, KEY_EEXISTS, NOT_STORED, TMPFAIL, AUTHFAIL:
42 | return false
43 | }
44 | return true
45 | }
46 |
47 | // Size is a number of bytes this response consumes on the wire.
48 | func (r *Response) Size() int {
49 | return HDR_LEN + len(r.Extras) + len(r.Key) + len(r.Body)
50 | }
51 |
52 | func (r *Response) fillHeaderBytes(data []byte) int {
53 | /*
54 | Byte/ 0 | 1 | 2 | 3 |
55 | / | | | |
56 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|
57 | +---------------+---------------+---------------+---------------+
58 | 0| Magic | Opcode | Key Length |
59 | +---------------+---------------+---------------+---------------+
60 | 4| Extras length | Data type | Status |
61 | +---------------+---------------+---------------+---------------+
62 | 8| Total body length |
63 | +---------------+---------------+---------------+---------------+
64 | 12| Opaque |
65 | +---------------+---------------+---------------+---------------+
66 | 16| CAS |
67 | | |
68 | +---------------+---------------+---------------+---------------+
69 | Total 24 bytes
70 | */
71 |
72 | pos := 0
73 | data[pos] /*0x00*/ = RES_MAGIC
74 | pos++ // 1
75 | data[pos] /*0x01*/ = byte(r.Opcode)
76 | pos++ // 2
77 | binary.BigEndian.PutUint16(data[pos:pos+2] /*0x02 - 0x03*/, uint16(len(r.Key)))
78 |
79 | pos += 2 // 4
80 | data[pos] /*0x04*/ = byte(len(r.Extras))
81 |
82 | pos++ // 5
83 | data[pos] /*0x05*/ = reserved8
84 |
85 | pos++ // 6
86 | binary.BigEndian.PutUint16(data[pos:pos+2] /*0x06*/, uint16(r.Status))
87 |
88 | pos += 2 // 8
89 | binary.BigEndian.PutUint32(data[pos:pos+4] /*0x08 - 0x09 - 0x0a - 0x0b*/, uint32(len(r.Body)+len(r.Key)+len(r.Extras)))
90 |
91 | pos += 4 // 12
92 | binary.BigEndian.PutUint32(data[pos:pos+4] /*0x0c - 0x0d - 0x0e - 0x0f*/, r.Opaque)
93 |
94 | pos += 4 // 16
95 | if r.Cas != 0 {
96 | binary.BigEndian.PutUint64(data[pos:pos+8] /*0x10 - 0x11 - 0x12 - 0x13 - 0x14 - 0x16 - 0x17*/, r.Cas)
97 | }
98 |
99 | pos += 8 // 24
100 | if len(r.Extras) > 0 {
101 | copy(data[pos:pos+len(r.Extras)], r.Extras)
102 | pos += len(r.Extras)
103 | }
104 |
105 | if len(r.Key) > 0 {
106 | copy(data[pos:pos+len(r.Key)], r.Key)
107 | pos += len(r.Key)
108 | }
109 |
110 | return pos
111 | }
112 |
113 | // HeaderBytes get just the header bytes for this response.
114 | func (r *Response) HeaderBytes() []byte {
115 | data := make([]byte, HDR_LEN+len(r.Extras)+len(r.Key))
116 |
117 | r.fillHeaderBytes(data)
118 |
119 | return data
120 | }
121 |
122 | // Bytes the actual bytes transmitted for this response.
123 | func (r *Response) Bytes() []byte {
124 | data := make([]byte, r.Size())
125 |
126 | pos := r.fillHeaderBytes(data)
127 |
128 | copy(data[pos:pos+len(r.Body)], r.Body)
129 |
130 | return data
131 | }
132 |
133 | // Transmit send this response message across a writer.
134 | func (r *Response) Transmit(w io.Writer) (n int, err error) {
135 | if len(r.Body) < BODY_LEN {
136 | n, err = w.Write(r.Bytes())
137 | } else {
138 | n, err = w.Write(r.HeaderBytes())
139 | if err == nil {
140 | m := 0
141 | m, err = w.Write(r.Body)
142 | n += m
143 | }
144 | }
145 | return
146 | }
147 |
148 | // Receive - fill this Response with the data from this reader.
149 | func (r *Response) Receive(rd io.Reader, hdrBytes []byte) (int, error) {
150 | /*
151 | Byte/ 0 | 1 | 2 | 3 |
152 | / | | | |
153 | |0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|0 1 2 3 4 5 6 7|
154 | +---------------+---------------+---------------+---------------+
155 | 0| Magic | Opcode | Key Length |
156 | +---------------+---------------+---------------+---------------+
157 | 4| Extras length | Data type | Status |
158 | +---------------+---------------+---------------+---------------+
159 | 8| Total body length |
160 | +---------------+---------------+---------------+---------------+
161 | 12| Opaque |
162 | +---------------+---------------+---------------+---------------+
163 | 16| CAS |
164 | | |
165 | +---------------+---------------+---------------+---------------+
166 | Total 24 bytes
167 | */
168 |
169 | if len(hdrBytes) < HDR_LEN {
170 | hdrBytes = make([]byte, HDR_LEN)
171 | }
172 |
173 | n, err := io.ReadFull(rd, hdrBytes)
174 | if err != nil {
175 | return n, err
176 | }
177 |
178 | if hdrBytes[0] != RES_MAGIC && hdrBytes[0] != REQ_MAGIC {
179 | return n, fmt.Errorf("Bad magic: 0x%02x", hdrBytes[0])
180 | }
181 |
182 | klen := int(binary.BigEndian.Uint16(hdrBytes[2:4]))
183 | elen := int(hdrBytes[4])
184 |
185 | r.Opcode = OpCode(hdrBytes[1])
186 | r.Status = Status(binary.BigEndian.Uint16(hdrBytes[6:8]))
187 | r.Opaque = binary.BigEndian.Uint32(hdrBytes[12:16])
188 | r.Cas = binary.BigEndian.Uint64(hdrBytes[16:24])
189 |
190 | bodyLen := int(binary.BigEndian.Uint32(hdrBytes[8:12])) - (klen + elen)
191 |
192 | buf := make([]byte, klen+elen+bodyLen)
193 | m, err := io.ReadFull(rd, buf)
194 | if err == nil {
195 | if elen > 0 {
196 | r.Extras = buf[0:elen]
197 | }
198 | if klen > 0 {
199 | r.Key = buf[elen : klen+elen]
200 | }
201 | if bodyLen > 0 {
202 | r.Body = buf[klen+elen:]
203 | }
204 | }
205 |
206 | return n + m, err
207 | }
208 |
--------------------------------------------------------------------------------
/memcached/responses_test.go:
--------------------------------------------------------------------------------
1 | // nolint
2 | package memcached
3 |
4 | import (
5 | "bytes"
6 | "errors"
7 | "fmt"
8 | "io/ioutil"
9 | "reflect"
10 | "testing"
11 | )
12 |
13 | func TestEncodingResponse(t *testing.T) {
14 | req := Response{
15 | Opcode: SET,
16 | Status: 1582,
17 | Opaque: 7242,
18 | Cas: 938424885,
19 | Key: []byte("somekey"),
20 | Body: []byte("somevalue"),
21 | }
22 |
23 | got := req.Bytes()
24 |
25 | expected := []byte{
26 | RES_MAGIC, byte(SET),
27 | 0x0, 0x7, // length of key
28 | 0x0, // extra length
29 | 0x0, // reserved16
30 | 0x6, 0x2e, // status
31 | 0x0, 0x0, 0x0, 0x10, // Length of value
32 | 0x0, 0x0, 0x1c, 0x4a, // opaque
33 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS
34 | 's', 'o', 'm', 'e', 'k', 'e', 'y',
35 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e',
36 | }
37 |
38 | if len(got) != req.Size() {
39 | t.Fatalf("Expected %v bytes, got %v", got,
40 | len(got))
41 | }
42 |
43 | if !reflect.DeepEqual(got, expected) {
44 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v",
45 | expected, got)
46 | }
47 |
48 | exp := `{Response status=0x62e keylen=7, extralen=0, bodylen=9}`
49 | if req.String() != exp {
50 | t.Errorf("Expected string=%q, got %q", exp, req.String())
51 | }
52 |
53 | exp = `Response status=0x62e, opcode=SET, opaque=7242, msg: somevalue`
54 | if req.Error() != exp {
55 | t.Errorf("Expected string=%q, got %q", exp, req.Error())
56 | }
57 | }
58 |
59 | func TestEncodingResponseWithExtras(t *testing.T) {
60 | res := Response{
61 | Opcode: SET,
62 | Status: 1582,
63 | Opaque: 7242,
64 | Cas: 938424885,
65 | Extras: []byte{1, 2, 3, 4},
66 | Key: []byte("somekey"),
67 | Body: []byte("somevalue"),
68 | }
69 |
70 | buf := &bytes.Buffer{}
71 | res.Transmit(buf)
72 | got := buf.Bytes()
73 |
74 | expected := []byte{
75 | RES_MAGIC, byte(SET),
76 | 0x0, 0x7, // length of key
77 | 0x4, // extra length
78 | 0x0, // reserved
79 | 0x6, 0x2e, // status
80 | 0x0, 0x0, 0x0, 0x14, // Length of remainder
81 | 0x0, 0x0, 0x1c, 0x4a, // opaque
82 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS
83 | 1, 2, 3, 4, // extras
84 | 's', 'o', 'm', 'e', 'k', 'e', 'y',
85 | 's', 'o', 'm', 'e', 'v', 'a', 'l', 'u', 'e',
86 | }
87 |
88 | if len(got) != res.Size() {
89 | t.Fatalf("Expected %v bytes, got %v", got,
90 | len(got))
91 | }
92 |
93 | if !reflect.DeepEqual(got, expected) {
94 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v",
95 | expected, got)
96 | }
97 | }
98 |
99 | func TestEncodingResponseWithLargeBody(t *testing.T) {
100 | res := Response{
101 | Opcode: SET,
102 | Status: 1582,
103 | Opaque: 7242,
104 | Cas: 938424885,
105 | Extras: []byte{1, 2, 3, 4},
106 | Key: []byte("somekey"),
107 | Body: make([]byte, 256),
108 | }
109 |
110 | buf := &bytes.Buffer{}
111 | res.Transmit(buf)
112 | got := buf.Bytes()
113 |
114 | expected := append([]byte{
115 | RES_MAGIC, byte(SET),
116 | 0x0, 0x7, // length of key
117 | 0x4, // extra length
118 | 0x0, // reserved
119 | 0x6, 0x2e, // status
120 | 0x0, 0x0, 0x1, 0xb, // Length of remainder
121 | 0x0, 0x0, 0x1c, 0x4a, // opaque
122 | 0x0, 0x0, 0x0, 0x0, 0x37, 0xef, 0x3a, 0x35, // CAS
123 | 1, 2, 3, 4, // extras
124 | 's', 'o', 'm', 'e', 'k', 'e', 'y',
125 | }, make([]byte, 256)...)
126 |
127 | if len(got) != res.Size() {
128 | t.Fatalf("Expected %v bytes, got %v", got,
129 | len(got))
130 | }
131 |
132 | if !reflect.DeepEqual(got, expected) {
133 | t.Fatalf("Expected:\n%#v\n -- got -- \n%#v",
134 | expected, got)
135 | }
136 | }
137 |
138 | func BenchmarkEncodingResponse(b *testing.B) {
139 | req := Response{
140 | Opcode: SET,
141 | Status: 1582,
142 | Opaque: 7242,
143 | Cas: 938424885,
144 | Extras: []byte{},
145 | Key: []byte("somekey"),
146 | Body: []byte("somevalue"),
147 | }
148 |
149 | b.SetBytes(int64(req.Size()))
150 |
151 | for i := 0; i < b.N; i++ {
152 | req.Bytes()
153 | }
154 | }
155 |
156 | func BenchmarkEncodingResponseLarge(b *testing.B) {
157 | req := Response{
158 | Opcode: SET,
159 | Status: 1582,
160 | Opaque: 7242,
161 | Cas: 938424885,
162 | Extras: []byte{},
163 | Key: []byte("somekey"),
164 | Body: make([]byte, 24*1024),
165 | }
166 |
167 | b.SetBytes(int64(req.Size()))
168 |
169 | for i := 0; i < b.N; i++ {
170 | req.Bytes()
171 | }
172 | }
173 |
174 | func TestIsNotFound(t *testing.T) {
175 | tests := []struct {
176 | e error
177 | is bool
178 | }{
179 | {nil, false},
180 | {errors.New("something"), false},
181 | {&Response{}, false},
182 | {&Response{Status: KEY_ENOENT}, true},
183 | }
184 |
185 | for i, x := range tests {
186 | if isNotFound(x.e) != x.is {
187 | t.Errorf("Expected %v for %#v (%v)", x.is, x.e, i)
188 | }
189 | }
190 | }
191 |
192 | func TestIsFatal(t *testing.T) {
193 | tests := []struct {
194 | e error
195 | is bool
196 | }{
197 | {nil, false},
198 | {errors.New("something"), true},
199 | {&Response{}, true},
200 | {&Response{Status: KEY_ENOENT}, false},
201 | {&Response{Status: EINVAL}, true},
202 | {&Response{Status: TMPFAIL}, false},
203 | }
204 |
205 | for i, x := range tests {
206 | if isFatal(x.e) != x.is {
207 | t.Errorf("Expected %v for %#v (%v)", x.is, x.e, i)
208 | }
209 | }
210 | }
211 |
212 | func TestResponseTransmit(t *testing.T) {
213 | res := Response{Key: []byte("thekey")}
214 | _, err := res.Transmit(ioutil.Discard)
215 | if err != nil {
216 | t.Errorf("Error sending small response: %v", err)
217 | }
218 |
219 | res.Body = make([]byte, 256)
220 | _, err = res.Transmit(ioutil.Discard)
221 | if err != nil {
222 | t.Errorf("Error sending large response thing: %v", err)
223 | }
224 | }
225 |
226 | func TestReceiveResponse(t *testing.T) {
227 | res := Response{
228 | Opcode: SET,
229 | Status: 74,
230 | Opaque: 7242,
231 | Extras: []byte{1},
232 | Key: []byte("somekey"),
233 | Body: []byte("somevalue"),
234 | }
235 |
236 | data := res.Bytes()
237 |
238 | res2 := Response{}
239 | _, err := res2.Receive(bytes.NewReader(data), nil)
240 | if err != nil {
241 | t.Fatalf("Error receiving: %v", err)
242 | }
243 |
244 | if !reflect.DeepEqual(res, res2) {
245 | t.Fatalf("Expected %#v == %#v", res, res2)
246 | }
247 | }
248 |
249 | func TestReceiveResponseBadMagic(t *testing.T) {
250 | res := Response{
251 | Opcode: SET,
252 | Status: 74,
253 | Opaque: 7242,
254 | Extras: []byte{1},
255 | Key: []byte("somekey"),
256 | Body: []byte("somevalue"),
257 | }
258 |
259 | data := res.Bytes()
260 | data[0] = 0x13
261 |
262 | res2 := Response{}
263 | _, err := res2.Receive(bytes.NewReader(data), nil)
264 | if err == nil {
265 | t.Fatalf("Expected error, got: %#v", res2)
266 | }
267 | }
268 |
269 | func TestReceiveResponseShortHeader(t *testing.T) {
270 | res := Response{
271 | Opcode: SET,
272 | Status: 74,
273 | Opaque: 7242,
274 | Extras: []byte{1},
275 | Key: []byte("somekey"),
276 | Body: []byte("somevalue"),
277 | }
278 |
279 | data := res.Bytes()
280 | data[0] = 0x13
281 |
282 | res2 := Response{}
283 | _, err := res2.Receive(bytes.NewReader(data[:13]), nil)
284 | if err == nil {
285 | t.Fatalf("Expected error, got: %#v", res2)
286 | }
287 | }
288 |
289 | func TestReceiveResponseShortBody(t *testing.T) {
290 | res := Response{
291 | Opcode: SET,
292 | Status: 74,
293 | Opaque: 7242,
294 | Extras: []byte{1},
295 | Key: []byte("somekey"),
296 | Body: []byte("somevalue"),
297 | }
298 |
299 | data := res.Bytes()
300 | data[0] = 0x13
301 |
302 | res2 := Response{}
303 | _, err := res2.Receive(bytes.NewReader(data[:len(data)-3]), nil)
304 | if err == nil {
305 | t.Fatalf("Expected error, got: %#v", res2)
306 | }
307 | }
308 |
309 | func TestReceiveResponseWithBuffer(t *testing.T) {
310 | res := Response{
311 | Opcode: SET,
312 | Status: 74,
313 | Opaque: 7242,
314 | Extras: []byte{1},
315 | Key: []byte("somekey"),
316 | Body: []byte("somevalue"),
317 | }
318 |
319 | data := res.Bytes()
320 |
321 | res2 := Response{}
322 | buf := make([]byte, HDR_LEN)
323 | _, err := res2.Receive(bytes.NewReader(data), buf)
324 | if err != nil {
325 | t.Fatalf("Error receiving: %v", err)
326 | }
327 |
328 | if !reflect.DeepEqual(res, res2) {
329 | t.Fatalf("Expected %#v == %#v", res, res2)
330 | }
331 | }
332 |
333 | func TestReceiveResponseNoContent(t *testing.T) {
334 | res := Response{
335 | Opcode: SET,
336 | Status: 74,
337 | Opaque: 7242,
338 | }
339 |
340 | data := res.Bytes()
341 |
342 | res2 := Response{}
343 | _, err := res2.Receive(bytes.NewReader(data), nil)
344 | if err != nil {
345 | t.Fatalf("Error receiving: %v", err)
346 | }
347 |
348 | // Can't use reflect here because []byte{} != nil, though they
349 | // look the same.
350 | if fmt.Sprintf("%#v", res) != fmt.Sprintf("%#v", res2) {
351 | t.Fatalf("Expected %#v == %#v", res, res2)
352 | }
353 | }
354 |
355 | func BenchmarkReceiveResponse(b *testing.B) {
356 | req := Response{
357 | Opcode: SET,
358 | Status: 183,
359 | Cas: 0,
360 | Opaque: 7242,
361 | Extras: []byte{1},
362 | Key: []byte("somekey"),
363 | Body: []byte("somevalue"),
364 | }
365 |
366 | data := req.Bytes()
367 | rdr := bytes.NewReader(data)
368 |
369 | b.SetBytes(int64(len(data)))
370 |
371 | b.ResetTimer()
372 | buf := make([]byte, HDR_LEN)
373 | for i := 0; i < b.N; i++ {
374 | res2 := Response{}
375 | rdr.Seek(0, 0)
376 | res2.Receive(rdr, buf)
377 | }
378 | }
379 |
380 | func BenchmarkReceiveResponseNoBuf(b *testing.B) {
381 | req := Response{
382 | Opcode: SET,
383 | Status: 183,
384 | Cas: 0,
385 | Opaque: 7242,
386 | Extras: []byte{1},
387 | Key: []byte("somekey"),
388 | Body: []byte("somevalue"),
389 | }
390 |
391 | data := req.Bytes()
392 | rdr := bytes.NewReader(data)
393 |
394 | b.SetBytes(int64(len(data)))
395 |
396 | b.ResetTimer()
397 | for i := 0; i < b.N; i++ {
398 | res2 := Response{}
399 | rdr.Seek(0, 0)
400 | res2.Receive(rdr, nil)
401 | }
402 | }
403 |
404 | func isNotFound(e error) bool {
405 | return errStatus(e) == KEY_ENOENT
406 | }
407 |
--------------------------------------------------------------------------------
/memcached/transport.go:
--------------------------------------------------------------------------------
1 | package memcached
2 |
3 | import (
4 | "errors"
5 | "io"
6 | )
7 |
8 | // UnwrapMemcachedError converts memcached errors to normal responses.
9 | //
10 | // If the error is a memcached response, declare the error to be nil
11 | // so a client can handle the status without worrying about whether it
12 | // indicates success or failure.
13 | func UnwrapMemcachedError(err error) *Response {
14 | var res *Response
15 | if errors.As(err, &res) {
16 | return res
17 | }
18 | return nil
19 | }
20 |
21 | func getResponse(s io.Reader, hdrBytes []byte) (rv *Response, n int, err error) {
22 | if s == nil {
23 | return nil, 0, ErrNoServers
24 | }
25 |
26 | rv = &Response{}
27 | n, err = rv.Receive(s, hdrBytes)
28 | if err == nil && rv.Status != SUCCESS {
29 | err = wrapMemcachedResp(rv)
30 | }
31 | return rv, n, err
32 | }
33 |
34 | func transmitRequest(o io.Writer, req *Request) (int, error) {
35 | if o == nil {
36 | return 0, ErrNoServers
37 | }
38 | n, err := req.Transmit(o)
39 | return n, err
40 | }
41 |
--------------------------------------------------------------------------------
/pool/pool.go:
--------------------------------------------------------------------------------
1 | package pool
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "time"
7 |
8 | "golang.org/x/sync/semaphore"
9 | )
10 |
11 | const token int64 = 1
12 |
13 | var (
14 | ErrClosedPool = fmt.Errorf("pool is closed")
15 | ErrNewFuncNil = fmt.Errorf("newFunc for pool is nil, can not create connection")
16 | ErrAcquireTimeout = fmt.Errorf("timeout for Acquire from the pool. Need to increase the maxCap for pool")
17 | )
18 |
19 | var _ ConnPool = (*Pool)(nil)
20 |
21 | type ConnPool interface {
22 | Get() (any, error)
23 | Pop() (any, bool)
24 | Put(v any)
25 | Destroy()
26 | Len() int
27 | Close(v any)
28 | }
29 |
30 | // Pool common connection pool
31 | type Pool struct {
32 | ctx context.Context
33 |
34 | // newConn are functions for creating new connections if maxCap is not reached.
35 | newConn func() (any, error)
36 | // closeConn is a function for graceful closed connections.
37 | closeConn func(any)
38 |
39 | // sema is a semaphore implementation for control a max capacity of pool
40 | sema *semaphore.Weighted
41 | // aqSemaTimeout is an amount of time to acquire conn from pool
42 | aqSemaTimeout time.Duration
43 |
44 | // store is a chan with connections.
45 | store chan any
46 | // storeClose is a flag indicating that store is closed.
47 | storeClose chan struct{}
48 | // maxCap is maximum of total connections used
49 | maxCap int32
50 | }
51 |
52 | // New create a pool with capacity
53 | func New(ctx context.Context, maxCap int32, acquireSemaTimeout time.Duration, newFunc func() (any, error), closeFunc func(any)) *Pool {
54 | if maxCap <= 0 {
55 | panic("invalid memcached maxCap")
56 | }
57 |
58 | return &Pool{
59 | ctx: ctx,
60 | newConn: newFunc,
61 | closeConn: closeFunc,
62 | sema: semaphore.NewWeighted(int64(maxCap)),
63 | aqSemaTimeout: acquireSemaTimeout,
64 | store: make(chan any, maxCap),
65 | storeClose: make(chan struct{}),
66 | maxCap: maxCap,
67 | }
68 | }
69 |
70 | // Len returns current connections in pool
71 | func (p *Pool) Len() int {
72 | return len(p.store)
73 | }
74 |
75 | // Get returns a conn from store or create one
76 | func (p *Pool) Get() (any, error) {
77 | var aqTimeout bool
78 |
79 | for {
80 | select {
81 | case v, ok := <-p.store:
82 | if ok {
83 | return v, nil
84 | }
85 | return nil, ErrClosedPool
86 | default:
87 | if aqTimeout {
88 | return nil, ErrAcquireTimeout
89 | }
90 | if cn, timeout, err := p.create(); timeout {
91 | // last try get conn after timeout
92 | aqTimeout = true
93 | continue
94 | } else {
95 | return cn, err
96 | }
97 | }
98 | }
99 | }
100 |
101 | // Pop return available conn without block
102 | func (p *Pool) Pop() (any, bool) {
103 | if p.isClosed() {
104 | return nil, false
105 | }
106 |
107 | select {
108 | case v, ok := <-p.store:
109 | return v, ok
110 | default:
111 | return nil, false
112 | }
113 | }
114 |
115 | // Put set back conn into store again
116 | func (p *Pool) Put(v any) {
117 | if p.isClosed() {
118 | return
119 | }
120 | select {
121 | case p.store <- v:
122 | default:
123 | }
124 | }
125 |
126 | // Destroy close all connections and deactivate the pool
127 | func (p *Pool) Destroy() {
128 | if p.isClosed() {
129 | // pool already destroyed
130 | return
131 | }
132 |
133 | close(p.storeClose)
134 | close(p.store)
135 | for v := range p.store {
136 | p.close(v)
137 | }
138 | }
139 |
140 | // Close is closed a connection
141 | func (p *Pool) Close(v any) {
142 | p.close(v)
143 | }
144 |
145 | func (p *Pool) create() (any, bool, error) {
146 | ctx, cancel := context.WithTimeout(p.ctx, p.aqSemaTimeout)
147 | defer cancel()
148 |
149 | if err := p.sema.Acquire(ctx, token); err != nil {
150 | return nil, true, nil
151 | }
152 |
153 | if p.isClosed() {
154 | p.sema.Release(token)
155 | return nil, false, ErrClosedPool
156 | }
157 |
158 | if p.newConn == nil {
159 | return nil, false, ErrNewFuncNil
160 | }
161 | cn, err := p.newConn()
162 | if err != nil {
163 | p.sema.Release(token)
164 | return nil, false, err
165 | }
166 | return cn, false, nil
167 | }
168 |
169 | func (p *Pool) close(v any) {
170 | p.sema.Release(token)
171 | if p.closeConn != nil {
172 | p.closeConn(v)
173 | }
174 | }
175 |
176 | func (p *Pool) isClosed() bool {
177 | select {
178 | case <-p.storeClose:
179 | return true
180 | default:
181 | return false
182 | }
183 | }
184 |
--------------------------------------------------------------------------------
/pool/pool_test.go:
--------------------------------------------------------------------------------
1 | package pool
2 |
3 | import (
4 | "context"
5 | "math/rand"
6 | "net/http"
7 | "sync"
8 | "sync/atomic"
9 | "testing"
10 | "time"
11 |
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | const defaultSocketPoolingTimeout = 50 * time.Millisecond
16 |
17 | type testConnection struct{}
18 |
19 | func newTestConnection() (any, error) {
20 | return &testConnection{}, nil
21 | }
22 |
23 | func newTestConnectionWithErr() (any, error) {
24 | return nil, http.ErrHandlerTimeout
25 | }
26 |
27 | func closeTestConnection(any) {
28 | // Do nothing
29 | }
30 |
31 | func TestPool(t *testing.T) {
32 | assert.Panics(t, func() {
33 | _ = New(context.TODO(), 0, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection)
34 | }, "was expected panic")
35 |
36 | defer func() {
37 | if pErr := recover(); pErr != nil {
38 | t.Fatalf("pool have panic - %v", pErr)
39 | }
40 | }()
41 |
42 | p := New(context.TODO(), 2, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection)
43 | defer p.Destroy()
44 |
45 | _, ok := p.Pop()
46 | assert.False(t, ok, "Pop return ok != false for empty pool")
47 |
48 | assert.Equalf(t, 0, p.Len(), "Expected pool length to be 0, got %d", p.Len())
49 |
50 | conn, err := p.Get()
51 | assert.Nilf(t, err, "Get from empty pool have error - %v", err)
52 |
53 | assert.Equalf(t, 0, p.Len(), "Expected pool length to be 0 after getting a connection, got %d", p.Len())
54 |
55 | p.Put(conn)
56 | assert.Equalf(t, 1, p.Len(), "Expected pool length to be 1 after putting back a connection, got %d", p.Len())
57 |
58 | _, ok = p.Pop()
59 | assert.True(t, ok, "Pop return ok != true for non-empty pool")
60 |
61 | conn, err = p.Get()
62 | assert.Nilf(t, err, "Get from pool have error - %v", err)
63 |
64 | assert.Equalf(t, 0, p.Len(), "Expected pool length to be 0 after getting a connection from the pool, got %d", p.Len())
65 |
66 | p.Put(conn)
67 | p.Destroy()
68 | assert.Equalf(t, 0, p.Len(), "Expected pool length to be 0 after destroying the pool, got %d", p.Len())
69 |
70 | _, err = p.Get()
71 | assert.ErrorIsf(t, err, ErrClosedPool, "Expected to get an error when getting from a destroyed pool, got %v", err)
72 |
73 | p.Put(conn)
74 | assert.ErrorIsf(t, err, ErrClosedPool, "Expected to put an error when putting a destroyed pool, got %v", err)
75 | }
76 |
77 | func TestPoolConcurrency(t *testing.T) {
78 | p := New(context.TODO(), 10, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection)
79 | defer p.Destroy()
80 |
81 | var wg sync.WaitGroup
82 | for i := 0; i < 10; i++ {
83 | wg.Add(1)
84 | go func() {
85 | defer wg.Done()
86 | conn, err := p.Get()
87 | assert.Nilf(t, err, "Get have error %v", err)
88 | <-time.After(5 * time.Millisecond)
89 | p.Put(conn)
90 | }()
91 | }
92 | wg.Wait()
93 |
94 | assert.Equalf(t, 10, p.Len(), "Expected pool length to be 10, got %d", p.Len())
95 | }
96 |
97 | func TestCountConns(t *testing.T) {
98 | const count = 300
99 | p := New(context.TODO(), int32(count), defaultSocketPoolingTimeout, newTestConnection, closeTestConnection)
100 |
101 | conn := atomic.Int32{}
102 | wg1 := sync.WaitGroup{}
103 |
104 | wg1.Add(3)
105 | go func() {
106 | defer wg1.Done()
107 | for i := 0; i < count/3; i++ {
108 | _, pErr := p.Get()
109 | conn.Add(1)
110 | assert.Nilf(t, pErr, "Get have error - %v", pErr)
111 | }
112 | }()
113 | go func() {
114 | defer wg1.Done()
115 | for i := 0; i < count/3; i++ {
116 | _, pErr := p.Get()
117 | conn.Add(1)
118 | assert.Nilf(t, pErr, "Get have error - %v", pErr)
119 | }
120 | }()
121 | go func() {
122 | defer wg1.Done()
123 | for i := 0; i < count/3; i++ {
124 | _, pErr := p.Get()
125 | conn.Add(1)
126 | assert.Nilf(t, pErr, "Get have error - %v", pErr)
127 | }
128 | }()
129 |
130 | wg1.Wait()
131 |
132 | assert.Equalf(t, conn.Load(), int32(count), "Not equal init and received conns. have - %d, expacted - %d ", conn.Load(), int32(count))
133 |
134 | for i := 0; i < int(conn.Load()); i++ {
135 | p.Put(testConnection{})
136 | }
137 |
138 | // p.store is full, over-conn
139 | p.Put(testConnection{})
140 |
141 | wg1.Add(2)
142 |
143 | go func() {
144 | defer wg1.Done()
145 | p.Destroy()
146 | }()
147 | go func() {
148 | defer wg1.Done()
149 | p.Destroy()
150 | }()
151 | wg1.Wait()
152 |
153 | cn, err := p.Get()
154 | assert.Nil(t, cn, "Get: after method Destroy, pool is closed and should return cn == nil")
155 | assert.ErrorIs(t, err, ErrClosedPool, "Get: after method Destroy, pool is closed, want error ErrClosedPool")
156 |
157 | p2 := New(context.TODO(), count, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection)
158 |
159 | var (
160 | mu sync.RWMutex
161 | conns []any
162 | wg2 sync.WaitGroup
163 | )
164 |
165 | addToSl := func(c any) {
166 | mu.Lock()
167 | defer mu.Unlock()
168 | conns = append(conns, c)
169 | }
170 |
171 | getFromSl := func() any {
172 | mu.Lock()
173 | defer mu.Unlock()
174 | return conns[len(conns)-1]
175 | }
176 |
177 | getSlLen := func() int {
178 | mu.RLock()
179 | defer mu.RUnlock()
180 | return len(conns)
181 | }
182 |
183 | wg2.Add(1)
184 | go func() {
185 | defer wg2.Done()
186 | for i := 0; i < count/2; i++ {
187 | c, gErr := p2.Get()
188 | assert.Nilf(t, gErr, "Get have error")
189 | //nolint:gosec
190 | if rand.Int()%2 == 0 {
191 | addToSl(c)
192 | } else {
193 | p2.Put(c)
194 | }
195 | }
196 | }()
197 | wg2.Add(1)
198 | go func() {
199 | defer wg2.Done()
200 | for i := 0; i < count/2; i++ {
201 | c, gErr := p2.Get()
202 | assert.Nilf(t, gErr, "Get have error")
203 | //nolint:gosec
204 | if rand.Int()%2 == 0 {
205 | addToSl(c)
206 | } else {
207 | p2.Put(c)
208 | }
209 | }
210 | }()
211 |
212 | wg2.Add(1)
213 | go func() {
214 | defer wg2.Done()
215 | <-time.After(200 * time.Millisecond)
216 | c, gErr := p2.Get()
217 | assert.Nilf(t, gErr, "Get with full cap have error")
218 | addToSl(c)
219 | }()
220 |
221 | wg2.Wait()
222 |
223 | for i := 0; i < getSlLen(); i++ {
224 | go func() {
225 | p2.Put(getFromSl())
226 | }()
227 | }
228 |
229 | p3 := New(context.TODO(), 1, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection)
230 |
231 | // maxConns is full
232 | _, _ = p3.Get()
233 |
234 | cn, err = p3.Get()
235 | assert.Nil(t, cn, "Get: after a timeout, it should return cn == nil")
236 | assert.ErrorIsf(t, ErrAcquireTimeout, err, "Get: after a timeout, it should return ErrAcquireTimeout")
237 |
238 | _, ok := p3.Pop()
239 | assert.False(t, ok, "Pop: pool with empty pool it should return false for second arg")
240 |
241 | p3.Destroy()
242 | cn, ok = p3.Pop()
243 | assert.Nil(t, cn, "Pop: after method Destroy, pool is closed and should return cn == nil")
244 | assert.False(t, ok, "Pop: after method Destroy, pool is closed and should return false for second arg")
245 |
246 | p4 := New(context.TODO(), 1, defaultSocketPoolingTimeout, newTestConnectionWithErr, closeTestConnection)
247 |
248 | cn, err = p4.Get()
249 | assert.Nil(t, cn, "Get: create new conn returned an error, conn should be nil")
250 | assert.ErrorIs(t, err, http.ErrHandlerTimeout, "Get: error should be equal - http.ErrHandlerTimeout")
251 |
252 | p5 := New(context.TODO(), 1, time.Second, nil, nil)
253 |
254 | cn, err = p5.Get()
255 | assert.Nil(t, cn, "Get: newFunc equal nil, conn should be nil")
256 | assert.ErrorIs(t, err, ErrNewFuncNil, "Get: error should be equal ErrNewFuncNil")
257 |
258 | p6 := New(context.TODO(), 1, defaultSocketPoolingTimeout, newTestConnection, closeTestConnection)
259 | bcn, err := p6.Get()
260 | assert.NotNil(t, bcn, "Get: conn cannot be nil")
261 | assert.Nil(t, err, "Get: error should be nil")
262 | wg := sync.WaitGroup{}
263 | wg.Add(1)
264 | go func() {
265 | defer wg.Done()
266 | cn, err = p6.Get()
267 | assert.Nil(t, cn, "Get: conn should be nil")
268 | assert.ErrorIs(t, err, ErrClosedPool, "Get: error should be equal ErrClosedPool")
269 | }()
270 | wg.Add(1)
271 | go func() {
272 | defer wg.Done()
273 | <-time.After(10 * time.Millisecond)
274 | p6.Close(bcn)
275 | p6.Destroy()
276 | }()
277 |
278 | wg.Wait()
279 | }
280 |
--------------------------------------------------------------------------------
/utils/addr.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "net"
5 | "strings"
6 | )
7 |
8 | // staticAddr caches the Network() and String() values from any net.Addr.
9 | type staticAddr struct {
10 | ntw, str string
11 | }
12 |
13 | func newStaticAddr(a net.Addr) net.Addr {
14 | return &staticAddr{
15 | ntw: a.Network(),
16 | str: a.String(),
17 | }
18 | }
19 |
20 | func (s *staticAddr) Network() string { return s.ntw }
21 | func (s *staticAddr) String() string { return s.str }
22 |
23 | // AddrRepr a string representation of the server address implements net.Addr
24 | func AddrRepr(server string) (net.Addr, error) {
25 | var nAddr net.Addr
26 | if strings.Contains(server, "/") {
27 | addr, _ := net.ResolveUnixAddr("unix", server)
28 | nAddr = newStaticAddr(addr)
29 | } else {
30 | tcpAddr, err := net.ResolveTCPAddr("tcp", server)
31 | if err != nil {
32 | return nil, err
33 | }
34 | nAddr = newStaticAddr(tcpAddr)
35 | }
36 |
37 | return nAddr, nil
38 | }
39 |
--------------------------------------------------------------------------------
/utils/addr_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "fmt"
5 | "net"
6 | "testing"
7 |
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestNewStaticAddr(t *testing.T) {
12 | tcpAddr := &net.TCPAddr{
13 | IP: net.IPv4(127, 0, 0, 1),
14 | Port: 8080,
15 | }
16 | staticAddr := newStaticAddr(tcpAddr)
17 | if staticAddr.Network() != tcpAddr.Network() {
18 | t.Errorf("Expected Network() to be %s, got %s", tcpAddr.Network(), staticAddr.Network())
19 | }
20 | if staticAddr.String() != tcpAddr.String() {
21 | t.Errorf("Expected String() to be %s, got %s", tcpAddr.String(), staticAddr.String())
22 | }
23 | }
24 |
25 | func TestAddrRepr(t *testing.T) {
26 | type args struct {
27 | server string
28 | }
29 | tests := []struct {
30 | name string
31 | args args
32 | want net.Addr
33 | wantErr bool
34 | }{
35 | {
36 | name: "invalid address",
37 | args: args{server: "invalid-address"},
38 | want: nil,
39 | wantErr: true,
40 | },
41 | {
42 | name: "unix",
43 | args: args{server: "/var/unix.sock"},
44 | want: &staticAddr{
45 | ntw: "unix",
46 | str: "/var/unix.sock",
47 | },
48 | wantErr: false,
49 | },
50 | {
51 | name: "tcp",
52 | args: args{server: "127.0.0.1:8080"},
53 | want: &staticAddr{
54 | ntw: "tcp",
55 | str: "127.0.0.1:8080",
56 | },
57 | wantErr: false,
58 | },
59 | }
60 | for _, tt := range tests {
61 | t.Run(tt.name, func(t *testing.T) {
62 | got, err := AddrRepr(tt.args.server)
63 | if tt.wantErr {
64 | assert.NotNilf(t, err, fmt.Sprintf("AddrRepr(%v) Expected an error, got nil", tt.args.server))
65 | }
66 | assert.Equalf(t, tt.want, got, "AddrRepr(%v)", tt.args.server)
67 | })
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/utils/math.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import "math"
4 |
5 | const epsilon = 1e-6
6 |
7 | // CalcEntropy calculates the entropy of m.
8 | func CalcEntropy(m map[any]int) float64 {
9 | if len(m) == 0 || len(m) == 1 {
10 | return 1
11 | }
12 |
13 | var entropy float64
14 | var total int
15 | for _, v := range m {
16 | total += v
17 | }
18 |
19 | for _, v := range m {
20 | proba := float64(v) / float64(total)
21 | if proba < epsilon {
22 | proba = epsilon
23 | }
24 | entropy -= proba * math.Log2(proba)
25 | }
26 |
27 | return entropy / math.Log2(float64(len(m)))
28 | }
29 |
--------------------------------------------------------------------------------
/utils/math_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "math"
5 | "testing"
6 | )
7 |
8 | func TestCalcEntropyWithEmptyMap(t *testing.T) {
9 | m := make(map[any]int)
10 | result := CalcEntropy(m)
11 | expected := 1.0
12 |
13 | if math.Abs(result-expected) > epsilon {
14 | t.Errorf("Expected: %f, Got: %f", expected, result)
15 | }
16 | }
17 |
18 | func TestCalcEntropyWithSingleValue(t *testing.T) {
19 | m := map[any]int{"A": 5}
20 | result := CalcEntropy(m)
21 | expected := 1.0
22 |
23 | if math.Abs(result-expected) > epsilon {
24 | t.Errorf("Expected: %f, Got: %f", expected, result)
25 | }
26 | }
27 |
28 | func TestCalcEntropyWithMultipleValues(t *testing.T) {
29 | m := map[any]int{"A": 5, "B": 5, "C": 5}
30 | result := CalcEntropy(m)
31 | expected := 1.0
32 |
33 | if math.Abs(result-expected) > epsilon {
34 | t.Errorf("Expected: %f, Got: %f", expected, result)
35 | }
36 | }
37 |
38 | func TestCalcEntropyWithDifferentProbabilities(t *testing.T) {
39 | m := map[any]int{"A": 1, "B": 3, "C": 6}
40 | result := CalcEntropy(m)
41 |
42 | total := 0
43 | probabilities := make(map[any]float64)
44 | for _, count := range m {
45 | total += count
46 | }
47 | for key, count := range m {
48 | probabilities[key] = float64(count) / float64(total)
49 | }
50 |
51 | entropy := 0.0
52 | for _, p := range probabilities {
53 | entropy -= p * math.Log2(p)
54 | }
55 | expected := entropy / math.Log2(float64(len(m)))
56 |
57 | if math.Abs(result-expected) > epsilon {
58 | t.Errorf("Expected: %f, Got: %f", expected, result)
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/utils/stringer.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | "strconv"
7 | )
8 |
9 | // Repr returns the string representation of v.
10 | func Repr(v any) string {
11 | if v == nil {
12 | return ""
13 | }
14 |
15 | // if func (v *Type) String() string, we can't use Elem()
16 | switch vt := v.(type) {
17 | case fmt.Stringer:
18 | return vt.String()
19 | }
20 |
21 | val := reflect.ValueOf(v)
22 | for val.Kind() == reflect.Ptr && !val.IsNil() {
23 | val = val.Elem()
24 | }
25 |
26 | return reprOfValue(val)
27 | }
28 |
29 | func reprOfValue(val reflect.Value) string {
30 | switch vt := val.Interface().(type) {
31 | case bool:
32 | return strconv.FormatBool(vt)
33 | case error:
34 | return vt.Error()
35 | case float32:
36 | return strconv.FormatFloat(float64(vt), 'f', -1, 32)
37 | case float64:
38 | return strconv.FormatFloat(vt, 'f', -1, 64)
39 | case fmt.Stringer:
40 | return vt.String()
41 | case int:
42 | return strconv.Itoa(vt)
43 | case int8:
44 | return strconv.Itoa(int(vt))
45 | case int16:
46 | return strconv.Itoa(int(vt))
47 | case int32:
48 | return strconv.Itoa(int(vt))
49 | case int64:
50 | return strconv.FormatInt(vt, 10)
51 | case string:
52 | return vt
53 | case uint:
54 | return strconv.FormatUint(uint64(vt), 10)
55 | case uint8:
56 | return strconv.FormatUint(uint64(vt), 10)
57 | case uint16:
58 | return strconv.FormatUint(uint64(vt), 10)
59 | case uint32:
60 | return strconv.FormatUint(uint64(vt), 10)
61 | case uint64:
62 | return strconv.FormatUint(vt, 10)
63 | case []byte:
64 | return string(vt)
65 | default:
66 | return fmt.Sprint(val.Interface())
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/utils/stringer_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "reflect"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func TestRepr(t *testing.T) {
13 | var (
14 | f32 float32 = 1.1
15 | f64 = 2.2
16 | i8 int8 = 1
17 | i16 int16 = 2
18 | i32 int32 = 3
19 | i64 int64 = 4
20 | u8 uint8 = 5
21 | u16 uint16 = 6
22 | u32 uint32 = 7
23 | u64 uint64 = 8
24 | )
25 | tests := []struct {
26 | v any
27 | expect string
28 | }{
29 | {
30 | nil,
31 | "",
32 | },
33 | {
34 | mockStringable{},
35 | "mocked",
36 | },
37 | {
38 | new(mockStringable),
39 | "mocked",
40 | },
41 | {
42 | newMockPtr(),
43 | "mockptr",
44 | },
45 | {
46 | &mockOpacity{
47 | val: 1,
48 | },
49 | "{1}",
50 | },
51 | {
52 | true,
53 | "true",
54 | },
55 | {
56 | false,
57 | "false",
58 | },
59 | {
60 | f32,
61 | "1.1",
62 | },
63 | {
64 | f64,
65 | "2.2",
66 | },
67 | {
68 | i8,
69 | "1",
70 | },
71 | {
72 | i16,
73 | "2",
74 | },
75 | {
76 | i32,
77 | "3",
78 | },
79 | {
80 | i64,
81 | "4",
82 | },
83 | {
84 | u8,
85 | "5",
86 | },
87 | {
88 | u16,
89 | "6",
90 | },
91 | {
92 | u32,
93 | "7",
94 | },
95 | {
96 | u64,
97 | "8",
98 | },
99 | {
100 | []byte(`abcd`),
101 | "abcd",
102 | },
103 | {
104 | mockOpacity{val: 1},
105 | "{1}",
106 | },
107 | }
108 |
109 | for _, test := range tests {
110 | t.Run(test.expect, func(t *testing.T) {
111 | assert.Equal(t, test.expect, Repr(test.v))
112 | })
113 | }
114 | }
115 |
116 | func TestReprOfValue(t *testing.T) {
117 | t.Run("error", func(t *testing.T) {
118 | assert.Equal(t, "error", reprOfValue(reflect.ValueOf(errors.New("error"))))
119 | })
120 |
121 | t.Run("stringer", func(t *testing.T) {
122 | assert.Equal(t, "1.23", reprOfValue(reflect.ValueOf(json.Number("1.23"))))
123 | })
124 |
125 | t.Run("int", func(t *testing.T) {
126 | assert.Equal(t, "1", reprOfValue(reflect.ValueOf(1)))
127 | })
128 |
129 | t.Run("int", func(t *testing.T) {
130 | assert.Equal(t, "1", reprOfValue(reflect.ValueOf("1")))
131 | })
132 |
133 | t.Run("int", func(t *testing.T) {
134 | assert.Equal(t, "1", reprOfValue(reflect.ValueOf(uint(1))))
135 | })
136 | }
137 |
138 | type mockStringable struct{}
139 |
140 | func (m mockStringable) String() string {
141 | return "mocked"
142 | }
143 |
144 | type mockPtr struct{}
145 |
146 | func newMockPtr() *mockPtr {
147 | return new(mockPtr)
148 | }
149 |
150 | func (m *mockPtr) String() string {
151 | return "mockptr"
152 | }
153 |
154 | type mockOpacity struct {
155 | val int
156 | }
157 |
--------------------------------------------------------------------------------