├── .gitignore ├── LICENSE ├── README.md ├── pool.go └── pool_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 ProcessOut 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # grpc-go-pool 2 | 3 | [![GoDoc](https://godoc.org/github.com/processout/grpc-go-pool?status.svg)](https://godoc.org/github.com/processout/grpc-go-pool) 4 | 5 | This package aims to provide an easy to use and lightweight GRPC connection pool. 6 | 7 | Please note that the goal isn't to replicate the client-side load-balancing feature of the official grpc package: the goal is rather to have multiple connections established to one endpoint (which can be server-side load-balanced). 8 | -------------------------------------------------------------------------------- /pool.go: -------------------------------------------------------------------------------- 1 | // Package grpcpool provides a pool of grpc clients 2 | package grpcpool 3 | 4 | import ( 5 | "context" 6 | "errors" 7 | "sync" 8 | "time" 9 | 10 | "google.golang.org/grpc" 11 | ) 12 | 13 | var ( 14 | // ErrClosed is the error when the client pool is closed 15 | ErrClosed = errors.New("grpc pool: client pool is closed") 16 | // ErrTimeout is the error when the client pool timed out 17 | ErrTimeout = errors.New("grpc pool: client pool timed out") 18 | // ErrAlreadyClosed is the error when the client conn was already closed 19 | ErrAlreadyClosed = errors.New("grpc pool: the connection was already closed") 20 | // ErrFullPool is the error when the pool is already full 21 | ErrFullPool = errors.New("grpc pool: closing a ClientConn into a full pool") 22 | ) 23 | 24 | // Factory is a function type creating a grpc client 25 | type Factory func() (*grpc.ClientConn, error) 26 | 27 | // FactoryWithContext is a function type creating a grpc client 28 | // that accepts the context parameter that could be passed from 29 | // Get or NewWithContext method. 30 | type FactoryWithContext func(context.Context) (*grpc.ClientConn, error) 31 | 32 | // Pool is the grpc client pool 33 | type Pool struct { 34 | clients chan ClientConn 35 | factory FactoryWithContext 36 | idleTimeout time.Duration 37 | maxLifeDuration time.Duration 38 | mu sync.RWMutex 39 | } 40 | 41 | // ClientConn is the wrapper for a grpc client conn 42 | type ClientConn struct { 43 | *grpc.ClientConn 44 | pool *Pool 45 | timeUsed time.Time 46 | timeInitiated time.Time 47 | unhealthy bool 48 | } 49 | 50 | // New creates a new clients pool with the given initial and maximum capacity, 51 | // and the timeout for the idle clients. Returns an error if the initial 52 | // clients could not be created 53 | func New(factory Factory, init, capacity int, idleTimeout time.Duration, 54 | maxLifeDuration ...time.Duration) (*Pool, error) { 55 | return NewWithContext(context.Background(), func(ctx context.Context) (*grpc.ClientConn, error) { return factory() }, 56 | init, capacity, idleTimeout, maxLifeDuration...) 57 | } 58 | 59 | // NewWithContext creates a new clients pool with the given initial and maximum 60 | // capacity, and the timeout for the idle clients. The context parameter would 61 | // be passed to the factory method during initialization. Returns an error if the 62 | // initial clients could not be created. 63 | func NewWithContext(ctx context.Context, factory FactoryWithContext, init, capacity int, idleTimeout time.Duration, 64 | maxLifeDuration ...time.Duration) (*Pool, error) { 65 | 66 | if capacity <= 0 { 67 | capacity = 1 68 | } 69 | if init < 0 { 70 | init = 0 71 | } 72 | if init > capacity { 73 | init = capacity 74 | } 75 | p := &Pool{ 76 | clients: make(chan ClientConn, capacity), 77 | factory: factory, 78 | idleTimeout: idleTimeout, 79 | } 80 | if len(maxLifeDuration) > 0 { 81 | p.maxLifeDuration = maxLifeDuration[0] 82 | } 83 | for i := 0; i < init; i++ { 84 | c, err := factory(ctx) 85 | if err != nil { 86 | return nil, err 87 | } 88 | 89 | p.clients <- ClientConn{ 90 | ClientConn: c, 91 | pool: p, 92 | timeUsed: time.Now(), 93 | timeInitiated: time.Now(), 94 | } 95 | } 96 | // Fill the rest of the pool with empty clients 97 | for i := 0; i < capacity-init; i++ { 98 | p.clients <- ClientConn{ 99 | pool: p, 100 | } 101 | } 102 | return p, nil 103 | } 104 | 105 | func (p *Pool) getClients() chan ClientConn { 106 | p.mu.RLock() 107 | defer p.mu.RUnlock() 108 | 109 | return p.clients 110 | } 111 | 112 | // Close empties the pool calling Close on all its clients. 113 | // You can call Close while there are outstanding clients. 114 | // The pool channel is then closed, and Get will not be allowed anymore 115 | func (p *Pool) Close() { 116 | p.mu.Lock() 117 | clients := p.clients 118 | p.clients = nil 119 | p.mu.Unlock() 120 | 121 | if clients == nil { 122 | return 123 | } 124 | 125 | close(clients) 126 | for client := range clients { 127 | if client.ClientConn == nil { 128 | continue 129 | } 130 | client.ClientConn.Close() 131 | } 132 | } 133 | 134 | // IsClosed returns true if the client pool is closed. 135 | func (p *Pool) IsClosed() bool { 136 | return p == nil || p.getClients() == nil 137 | } 138 | 139 | // Get will return the next available client. If capacity 140 | // has not been reached, it will create a new one using the factory. Otherwise, 141 | // it will wait till the next client becomes available or a timeout. 142 | // A timeout of 0 is an indefinite wait 143 | func (p *Pool) Get(ctx context.Context) (*ClientConn, error) { 144 | clients := p.getClients() 145 | if clients == nil { 146 | return nil, ErrClosed 147 | } 148 | 149 | wrapper := ClientConn{ 150 | pool: p, 151 | } 152 | select { 153 | case wrapper = <-clients: 154 | // All good 155 | case <-ctx.Done(): 156 | return nil, ErrTimeout // it would better returns ctx.Err() 157 | } 158 | 159 | // If the wrapper was idle too long, close the connection and create a new 160 | // one. It's safe to assume that there isn't any newer client as the client 161 | // we fetched is the first in the channel 162 | idleTimeout := p.idleTimeout 163 | if wrapper.ClientConn != nil && idleTimeout > 0 && 164 | wrapper.timeUsed.Add(idleTimeout).Before(time.Now()) { 165 | 166 | wrapper.ClientConn.Close() 167 | wrapper.ClientConn = nil 168 | } 169 | 170 | var err error 171 | if wrapper.ClientConn == nil { 172 | wrapper.ClientConn, err = p.factory(ctx) 173 | if err != nil { 174 | // If there was an error, we want to put back a placeholder 175 | // client in the channel 176 | clients <- ClientConn{ 177 | pool: p, 178 | } 179 | } 180 | // This is a new connection, reset its initiated time 181 | wrapper.timeInitiated = time.Now() 182 | } 183 | 184 | return &wrapper, err 185 | } 186 | 187 | // Unhealthy marks the client conn as unhealthy, so that the connection 188 | // gets reset when closed 189 | func (c *ClientConn) Unhealthy() { 190 | c.unhealthy = true 191 | } 192 | 193 | // Close returns a ClientConn to the pool. It is safe to call multiple time, 194 | // but will return an error after first time 195 | func (c *ClientConn) Close() error { 196 | if c == nil { 197 | return nil 198 | } 199 | if c.ClientConn == nil { 200 | return ErrAlreadyClosed 201 | } 202 | if c.pool.IsClosed() { 203 | return ErrClosed 204 | } 205 | // If the wrapper connection has become too old, we want to recycle it. To 206 | // clarify the logic: if the sum of the initialization time and the max 207 | // duration is before Now(), it means the initialization is so old adding 208 | // the maximum duration couldn't put in the future. This sum therefore 209 | // corresponds to the cut-off point: if it's in the future we still have 210 | // time, if it's in the past it's too old 211 | maxDuration := c.pool.maxLifeDuration 212 | if maxDuration > 0 && c.timeInitiated.Add(maxDuration).Before(time.Now()) { 213 | c.Unhealthy() 214 | } 215 | 216 | // We're cloning the wrapper so we can set ClientConn to nil in the one 217 | // used by the user 218 | wrapper := ClientConn{ 219 | pool: c.pool, 220 | ClientConn: c.ClientConn, 221 | timeUsed: time.Now(), 222 | } 223 | if c.unhealthy { 224 | wrapper.ClientConn.Close() 225 | wrapper.ClientConn = nil 226 | } else { 227 | wrapper.timeInitiated = c.timeInitiated 228 | } 229 | select { 230 | case c.pool.clients <- wrapper: 231 | // All good 232 | default: 233 | return ErrFullPool 234 | } 235 | 236 | c.ClientConn = nil // Mark as closed 237 | return nil 238 | } 239 | 240 | // Capacity returns the capacity 241 | func (p *Pool) Capacity() int { 242 | if p.IsClosed() { 243 | return 0 244 | } 245 | return cap(p.clients) 246 | } 247 | 248 | // Available returns the number of currently unused clients 249 | func (p *Pool) Available() int { 250 | if p.IsClosed() { 251 | return 0 252 | } 253 | return len(p.clients) 254 | } 255 | -------------------------------------------------------------------------------- /pool_test.go: -------------------------------------------------------------------------------- 1 | package grpcpool 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "google.golang.org/grpc" 9 | "google.golang.org/grpc/connectivity" 10 | ) 11 | 12 | func TestNew(t *testing.T) { 13 | p, err := New(func() (*grpc.ClientConn, error) { 14 | return grpc.Dial("example.com", grpc.WithInsecure()) 15 | }, 1, 3, 0) 16 | if err != nil { 17 | t.Errorf("The pool returned an error: %s", err.Error()) 18 | } 19 | if a := p.Available(); a != 3 { 20 | t.Errorf("The pool available was %d but should be 3", a) 21 | } 22 | if a := p.Capacity(); a != 3 { 23 | t.Errorf("The pool capacity was %d but should be 3", a) 24 | } 25 | 26 | // Get a client 27 | client, err := p.Get(context.Background()) 28 | if err != nil { 29 | t.Errorf("Get returned an error: %s", err.Error()) 30 | } 31 | if client == nil { 32 | t.Error("client was nil") 33 | } 34 | if a := p.Available(); a != 2 { 35 | t.Errorf("The pool available was %d but should be 2", a) 36 | } 37 | if a := p.Capacity(); a != 3 { 38 | t.Errorf("The pool capacity was %d but should be 3", a) 39 | } 40 | 41 | // Return the client 42 | err = client.Close() 43 | if err != nil { 44 | t.Errorf("Close returned an error: %s", err.Error()) 45 | } 46 | if a := p.Available(); a != 3 { 47 | t.Errorf("The pool available was %d but should be 3", a) 48 | } 49 | if a := p.Capacity(); a != 3 { 50 | t.Errorf("The pool capacity was %d but should be 3", a) 51 | } 52 | 53 | // Attempt to return the client again 54 | err = client.Close() 55 | if err != ErrAlreadyClosed { 56 | t.Errorf("Expected error \"%s\" but got \"%s\"", 57 | ErrAlreadyClosed.Error(), err.Error()) 58 | } 59 | 60 | // Take 3 clients 61 | cl1, err1 := p.Get(context.Background()) 62 | cl2, err2 := p.Get(context.Background()) 63 | cl3, err3 := p.Get(context.Background()) 64 | if err1 != nil { 65 | t.Errorf("Err1 was not nil: %s", err1.Error()) 66 | } 67 | if err2 != nil { 68 | t.Errorf("Err2 was not nil: %s", err2.Error()) 69 | } 70 | if err3 != nil { 71 | t.Errorf("Err3 was not nil: %s", err3.Error()) 72 | } 73 | 74 | if a := p.Available(); a != 0 { 75 | t.Errorf("The pool available was %d but should be 0", a) 76 | } 77 | if a := p.Capacity(); a != 3 { 78 | t.Errorf("The pool capacity was %d but should be 3", a) 79 | } 80 | 81 | // Returning all of them 82 | err1 = cl1.Close() 83 | if err1 != nil { 84 | t.Errorf("Close returned an error: %s", err1.Error()) 85 | } 86 | err2 = cl2.Close() 87 | if err2 != nil { 88 | t.Errorf("Close returned an error: %s", err2.Error()) 89 | } 90 | err3 = cl3.Close() 91 | if err3 != nil { 92 | t.Errorf("Close returned an error: %s", err3.Error()) 93 | } 94 | } 95 | 96 | func TestTimeout(t *testing.T) { 97 | p, err := New(func() (*grpc.ClientConn, error) { 98 | return grpc.Dial("example.com", grpc.WithInsecure()) 99 | }, 1, 1, 0) 100 | if err != nil { 101 | t.Errorf("The pool returned an error: %s", err.Error()) 102 | } 103 | 104 | _, err = p.Get(context.Background()) 105 | if err != nil { 106 | t.Errorf("Get returned an error: %s", err.Error()) 107 | } 108 | if a := p.Available(); a != 0 { 109 | t.Errorf("The pool available was %d but expected 0", a) 110 | } 111 | 112 | // We want to fetch a second one, with a timeout. If the timeout was 113 | // ommitted, the pool would wait indefinitely as it'd wait for another 114 | // client to get back into the queue 115 | ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond)) 116 | _, err2 := p.Get(ctx) 117 | if err2 != ErrTimeout { 118 | t.Errorf("Expected error \"%s\" but got \"%s\"", ErrTimeout, err2.Error()) 119 | } 120 | } 121 | 122 | func TestMaxLifeDuration(t *testing.T) { 123 | p, err := New(func() (*grpc.ClientConn, error) { 124 | return grpc.Dial("example.com", grpc.WithInsecure()) 125 | }, 1, 1, 0, 1) 126 | if err != nil { 127 | t.Errorf("The pool returned an error: %s", err.Error()) 128 | } 129 | 130 | c, err := p.Get(context.Background()) 131 | if err != nil { 132 | t.Errorf("Get returned an error: %s", err.Error()) 133 | } 134 | 135 | // The max life of the connection was very low (1ns), so when we close 136 | // the connection it should get marked as unhealthy 137 | if err := c.Close(); err != nil { 138 | t.Errorf("Close returned an error: %s", err.Error()) 139 | } 140 | if !c.unhealthy { 141 | t.Errorf("the connection should've been marked as unhealthy") 142 | } 143 | 144 | // Let's also make sure we don't prematurely close the connection 145 | count := 0 146 | p, err = New(func() (*grpc.ClientConn, error) { 147 | count++ 148 | return grpc.Dial("example.com", grpc.WithInsecure()) 149 | }, 1, 1, 0, time.Minute) 150 | if err != nil { 151 | t.Errorf("The pool returned an error: %s", err.Error()) 152 | } 153 | 154 | for i := 0; i < 3; i++ { 155 | c, err = p.Get(context.Background()) 156 | if err != nil { 157 | t.Errorf("Get returned an error: %s", err.Error()) 158 | } 159 | 160 | // The max life of the connection is high, so when we close 161 | // the connection it shouldn't be marked as unhealthy 162 | if err := c.Close(); err != nil { 163 | t.Errorf("Close returned an error: %s", err.Error()) 164 | } 165 | if c.unhealthy { 166 | t.Errorf("the connection shouldn't have been marked as unhealthy") 167 | } 168 | } 169 | 170 | // Count should have been 1 as dial function should only have been called once 171 | if count > 1 { 172 | t.Errorf("Dial function has been called multiple times") 173 | } 174 | 175 | } 176 | 177 | func TestPoolClose(t *testing.T) { 178 | p, err := New(func() (*grpc.ClientConn, error) { 179 | return grpc.Dial("example.com", grpc.WithInsecure()) 180 | }, 1, 1, 0) 181 | if err != nil { 182 | t.Errorf("The pool returned an error: %s", err.Error()) 183 | } 184 | 185 | c, err := p.Get(context.Background()) 186 | if err != nil { 187 | t.Errorf("Get returned an error: %s", err.Error()) 188 | } 189 | 190 | cc := c.ClientConn 191 | if err := c.Close(); err != nil { 192 | t.Errorf("Close returned an error: %s", err.Error()) 193 | } 194 | 195 | // Close pool should close all underlying gRPC client connections 196 | p.Close() 197 | 198 | if cc.GetState() != connectivity.Shutdown { 199 | t.Errorf("Returned connection was not closed, underlying connection is not in shutdown state") 200 | } 201 | } 202 | 203 | func TestContextCancelation(t *testing.T) { 204 | ctx, cancel := context.WithCancel(context.Background()) 205 | cancel() 206 | 207 | _, err := NewWithContext(ctx, func(ctx context.Context) (*grpc.ClientConn, error) { 208 | select { 209 | case <-ctx.Done(): 210 | return nil, ctx.Err() 211 | 212 | default: 213 | return grpc.Dial("example.com", grpc.WithInsecure()) 214 | } 215 | 216 | }, 1, 1, 0) 217 | 218 | if err != context.Canceled { 219 | t.Errorf("Returned error was not context.Canceled, but the context did cancel before the invocation") 220 | } 221 | } 222 | func TestContextTimeout(t *testing.T) { 223 | ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) 224 | defer cancel() 225 | 226 | _, err := NewWithContext(ctx, func(ctx context.Context) (*grpc.ClientConn, error) { 227 | select { 228 | case <-ctx.Done(): 229 | return nil, ctx.Err() 230 | 231 | // wait for the deadline to pass 232 | case <-time.After(time.Millisecond): 233 | return grpc.Dial("example.com", grpc.WithInsecure()) 234 | } 235 | 236 | }, 1, 1, 0) 237 | 238 | if err != context.DeadlineExceeded { 239 | t.Errorf("Returned error was not context.DeadlineExceeded, but the context was timed out before the initialization") 240 | } 241 | } 242 | 243 | func TestGetContextTimeout(t *testing.T) { 244 | p, err := New(func() (*grpc.ClientConn, error) { 245 | return grpc.Dial("example.com", grpc.WithInsecure()) 246 | }, 1, 1, 0) 247 | 248 | if err != nil { 249 | t.Errorf("The pool returned an error: %s", err.Error()) 250 | } 251 | 252 | // keep busy the available conn 253 | _, _ = p.Get(context.Background()) 254 | 255 | ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) 256 | defer cancel() 257 | 258 | // wait for the deadline to pass 259 | time.Sleep(time.Millisecond) 260 | _, err = p.Get(ctx) 261 | if err != ErrTimeout { // it should be context.DeadlineExceeded 262 | t.Errorf("Returned error was not ErrTimeout, but the context was timed out before the Get invocation") 263 | } 264 | } 265 | 266 | func TestGetContextFactoryTimeout(t *testing.T) { 267 | p, err := NewWithContext(context.Background(), func(ctx context.Context) (*grpc.ClientConn, error) { 268 | select { 269 | case <-ctx.Done(): 270 | return nil, ctx.Err() 271 | 272 | // wait for the deadline to pass 273 | case <-time.After(time.Millisecond): 274 | return grpc.Dial("example.com", grpc.WithInsecure()) 275 | } 276 | 277 | }, 1, 1, 0) 278 | 279 | if err != nil { 280 | t.Errorf("The pool returned an error: %s", err.Error()) 281 | } 282 | 283 | // mark as unhealty the available conn 284 | c, err := p.Get(context.Background()) 285 | if err != nil { 286 | t.Errorf("Get returned an error: %s", err.Error()) 287 | } 288 | c.Unhealthy() 289 | c.Close() 290 | 291 | ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond) 292 | defer cancel() 293 | 294 | _, err = p.Get(ctx) 295 | if err != context.DeadlineExceeded { 296 | t.Errorf("Returned error was not context.DeadlineExceeded, but the context was timed out before the Get invocation") 297 | } 298 | } 299 | --------------------------------------------------------------------------------