├── .idea ├── .gitignore ├── modules.xml ├── pool.iml └── vcs.xml ├── README.md ├── client.go ├── conn.go ├── default.go ├── example └── main.go ├── go.mod ├── go.sum ├── log.go ├── option.go └── pool.go /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/pool.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Connection Pool 2 | 3 | > Is a collection of connections to a client connection. 4 | 5 | 6 | ## Usage 7 | 8 | ```go 9 | package main 10 | 11 | import ( 12 | "context" 13 | "github.com/bean-du/pool" 14 | "log" 15 | "os" 16 | "os/signal" 17 | "time" 18 | ) 19 | 20 | func main() { 21 | 22 | sig := make(chan os.Signal, 1) 23 | signal.Notify(sig, os.Interrupt, os.Kill) 24 | 25 | // init a pool with options 26 | client := pool.NewClient( 27 | pool.WebsocketDialer("ws://127.0.0.1:8081/ws"), 28 | // log instance 29 | nil, 30 | // set pool size 31 | pool.WithPoolSize(50), 32 | // set write func default is tcp writer 33 | pool.WithWriteFunc(pool.WsWriter), 34 | // set read func, must be set 35 | pool.WithReadFunc(pool.WebsocketReadFunc(dataHandleFunc)), 36 | // set min idle connections 37 | pool.WithMinIdleConns(10), 38 | // set idle check duration 39 | pool.WithIdleCheckFrequency(time.Second*10), 40 | ) 41 | 42 | for i := 0; i < 10; i++ { 43 | go func() { 44 | if err := client.Send(context.Background(), []byte("hello")); err != nil { 45 | log.Println(err) 46 | } 47 | }() 48 | } 49 | 50 | select { 51 | case <-sig: 52 | client.Close() 53 | } 54 | } 55 | 56 | func dataHandleFunc(p []byte) { 57 | go func() { 58 | log.Println(string(p)) 59 | }() 60 | } 61 | 62 | ``` -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | type Client struct { 8 | p Pooler 9 | opts *Options 10 | } 11 | 12 | func NewClient(dialFunc DialFunc, log Logger, o ...Option) *Client { 13 | opts := &Options{Dialer: dialFunc} 14 | for _, fn := range o { 15 | fn(opts) 16 | } 17 | 18 | p := NewConnPool(opts, log) 19 | return &Client{ 20 | p: p, 21 | opts: opts, 22 | } 23 | } 24 | 25 | func (c *Client) Send(ctx context.Context, data []byte) error { 26 | var ( 27 | conn *Conn 28 | err error 29 | ) 30 | if conn, err = c.p.Get(ctx); err != nil { 31 | return err 32 | } 33 | defer c.p.Put(ctx, conn) 34 | 35 | if c.opts.WriteFunc != nil { 36 | err = conn.WithWriter(ctx, 0, c.opts.WriteFunc(data)) 37 | } else { 38 | _, err = conn.Write(data) 39 | } 40 | return err 41 | } 42 | 43 | func (c *Client) Close() error { 44 | return c.p.Close() 45 | } 46 | 47 | type PoolStats Stats 48 | 49 | // PoolStats returns connection pool stats. 50 | func (c *Client) PoolStats() *PoolStats { 51 | stats := c.p.Stats() 52 | return (*PoolStats)(stats) 53 | } 54 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "io" 7 | "net" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | ) 12 | 13 | var noDeadline = time.Time{} 14 | 15 | type Conn struct { 16 | netConn net.Conn 17 | writer *bufio.Writer 18 | reader *bufio.Reader 19 | Inited bool // 是否完成初始化 20 | pooled bool // 是否放进连接池 21 | createdAt time.Time // 创建时间 22 | usedAt int64 // 使用时间 23 | ioLock sync.Mutex 24 | } 25 | 26 | func NewConn(netConn net.Conn) *Conn { 27 | conn := &Conn{ 28 | netConn: netConn, 29 | writer: bufio.NewWriter(netConn), 30 | reader: bufio.NewReader(netConn), 31 | createdAt: time.Now(), 32 | } 33 | 34 | conn.SetUseAt(time.Now()) 35 | return conn 36 | } 37 | 38 | func (c *Conn) UsedAt() time.Time { 39 | unix := atomic.LoadInt64(&c.usedAt) 40 | return time.Unix(unix, 0) 41 | } 42 | 43 | func (c *Conn) SetUseAt(now time.Time) { 44 | atomic.StoreInt64(&c.usedAt, now.Unix()) 45 | } 46 | 47 | func (c *Conn) SetNetConn(netConn net.Conn) { 48 | c.netConn = netConn 49 | c.reader = bufio.NewReader(netConn) 50 | c.writer = bufio.NewWriter(netConn) 51 | } 52 | 53 | func (c *Conn) Write(b []byte) (int, error) { 54 | return c.netConn.Write(b) 55 | } 56 | 57 | func (c *Conn) RemoteAddr() net.Addr { 58 | if c.netConn != nil { 59 | return c.netConn.RemoteAddr() 60 | } 61 | return nil 62 | } 63 | 64 | func (c *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd net.Conn) error) error { 65 | if timeout != 0 { 66 | if err := c.netConn.SetReadDeadline(c.deadline(ctx, timeout)); err != nil { 67 | return err 68 | } 69 | } 70 | return fn(c.netConn) 71 | } 72 | 73 | func (c *Conn) WithWriter(ctx context.Context, timeout time.Duration, fn func(writer io.Writer) error) error { 74 | if timeout != 0 { 75 | if err := c.netConn.SetWriteDeadline(c.deadline(ctx, timeout)); err != nil { 76 | return err 77 | } 78 | } 79 | 80 | if c.writer.Buffered() > 0 { 81 | c.writer.Reset(c.netConn) 82 | } 83 | 84 | if err := fn(c.writer); err != nil { 85 | return err 86 | } 87 | 88 | return c.writer.Flush() 89 | } 90 | 91 | func (c *Conn) Close() error { 92 | return c.netConn.Close() 93 | } 94 | 95 | func (c *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { 96 | tm := time.Now() 97 | c.SetUseAt(tm) 98 | 99 | if timeout > 0 { 100 | tm.Add(timeout) 101 | } 102 | 103 | if ctx != nil { 104 | deadline, ok := ctx.Deadline() 105 | if ok { 106 | if timeout == 0 { 107 | return deadline 108 | } 109 | 110 | if deadline.Before(tm) { 111 | return deadline 112 | } 113 | return tm 114 | } 115 | } 116 | 117 | if timeout > 0 { 118 | return tm 119 | } 120 | return noDeadline 121 | } 122 | -------------------------------------------------------------------------------- /default.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "context" 5 | "github.com/gobwas/ws" 6 | "github.com/gobwas/ws/wsutil" 7 | "io" 8 | "net" 9 | ) 10 | 11 | func WebsocketDialer(dialAddr string) DialFunc { 12 | return func(ctx context.Context) (net.Conn, error) { 13 | c, _, _, err := ws.Dial(ctx, dialAddr) 14 | return c, err 15 | } 16 | } 17 | 18 | func TcpDialer(addr string) DialFunc { 19 | return func(ctx context.Context) (net.Conn, error) { 20 | return net.Dial("tcp", addr) 21 | } 22 | } 23 | 24 | func WsWriter(data []byte) func(writer io.Writer) error { 25 | return func(w io.Writer) error { 26 | writer := wsutil.NewWriter(w, ws.StateClientSide, ws.OpText) 27 | if _, err := writer.Write(data); err != nil { 28 | return err 29 | } 30 | return writer.Flush() 31 | } 32 | } 33 | 34 | func TcpWriter(p []byte) func(writer io.Writer) error { 35 | return func(w io.Writer) error { 36 | _, err := w.Write(p) 37 | return err 38 | } 39 | } 40 | 41 | func WebsocketReadFunc(fn func(p []byte)) ReadFunc { 42 | return func(conn net.Conn) error { 43 | h, r, err := wsutil.NextReader(conn, ws.StateClientSide) 44 | if err != nil { 45 | return err 46 | } 47 | if h.OpCode.IsControl() { 48 | return wsutil.ControlFrameHandler(conn, ws.StateClientSide)(h, r) 49 | } 50 | 51 | data := make([]byte, h.Length) 52 | if _, err = r.Read(data); err != nil && err != io.EOF { 53 | return err 54 | } 55 | 56 | fn(data) 57 | return nil 58 | } 59 | } 60 | 61 | func TcpReadFunc(fn func(p []byte)) ReadFunc { 62 | return func(conn net.Conn) error { 63 | data := make([]byte, 1024) 64 | n, err := conn.Read(data) 65 | if err != nil && err != io.EOF { 66 | return err 67 | } 68 | fn(data[:n]) 69 | return nil 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "github.com/bean-du/pool" 6 | "log" 7 | "os" 8 | "os/signal" 9 | "time" 10 | ) 11 | 12 | func main() { 13 | 14 | sig := make(chan os.Signal, 1) 15 | signal.Notify(sig, os.Interrupt, os.Kill) 16 | 17 | // init a pool with options 18 | client := pool.NewClient( 19 | pool.WebsocketDialer("ws://127.0.0.1:8081/ws"), 20 | // log instance 21 | nil, 22 | // set pool size 23 | pool.WithPoolSize(50), 24 | // set write func default is tcp writer 25 | pool.WithWriteFunc(pool.WsWriter), 26 | // set read func, must be set 27 | pool.WithReadFunc(pool.WebsocketReadFunc(dataHandleFunc)), 28 | // set min idle connections 29 | pool.WithMinIdleConns(10), 30 | 31 | pool.WithPoolTimeout(5*time.Second), 32 | // set idle check duration 33 | pool.WithIdleCheckFrequency(time.Second*10), 34 | ) 35 | 36 | for i := 0; i < 10; i++ { 37 | go func() { 38 | if err := client.Send(context.Background(), []byte("hello")); err != nil { 39 | log.Println(err) 40 | } 41 | }() 42 | } 43 | 44 | select { 45 | case <-sig: 46 | client.Close() 47 | } 48 | } 49 | 50 | func dataHandleFunc(p []byte) { 51 | go func() { 52 | log.Println(string(p)) 53 | }() 54 | } 55 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/bean-du/pool 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/gobwas/ws v1.1.0 7 | github.com/mailru/easygo v0.0.0-20190618140210-3c14a0dc985f 8 | ) 9 | 10 | require ( 11 | github.com/gobwas/httphead v0.1.0 // indirect 12 | github.com/gobwas/pool v0.2.1 // indirect 13 | golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d // indirect 14 | ) 15 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= 2 | github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= 3 | github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= 4 | github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= 5 | github.com/gobwas/ws v1.1.0 h1:7RFti/xnNkMJnrK7D1yQ/iCIB5OrrY/54/H930kIbHA= 6 | github.com/gobwas/ws v1.1.0/go.mod h1:nzvNcVha5eUziGrbxFCo6qFIojQHjJV5cLYIbezhfL0= 7 | github.com/mailru/easygo v0.0.0-20190618140210-3c14a0dc985f h1:4+gHs0jJFJ06bfN8PshnM6cHcxGjRUVRLo5jndDiKRQ= 8 | github.com/mailru/easygo v0.0.0-20190618140210-3c14a0dc985f/go.mod h1:tHCZHV8b2A90ObojrEAzY0Lb03gxUxjDHr5IJyAh4ew= 9 | golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 10 | golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d h1:/m5NbqQelATgoSPVC2Z23sR4kVNokFwDDyWh/3rGY+I= 11 | golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 12 | -------------------------------------------------------------------------------- /log.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | type Fields map[string]interface{} 4 | 5 | type Logger interface { 6 | Trace(args ...interface{}) 7 | Tracef(format string, args ...interface{}) 8 | Debug(args ...interface{}) 9 | Debugf(format string, args ...interface{}) 10 | Info(args ...interface{}) 11 | Infof(format string, args ...interface{}) 12 | Warn(args ...interface{}) 13 | Warnf(format string, args ...interface{}) 14 | Error(args ...interface{}) 15 | Errorf(format string, args ...interface{}) 16 | Panic(args ...interface{}) 17 | Panicf(format string, args ...interface{}) 18 | Fatal(args ...interface{}) 19 | Fatalf(format string, args ...interface{}) 20 | } 21 | -------------------------------------------------------------------------------- /option.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net" 7 | "time" 8 | ) 9 | 10 | type DialFunc func(context.Context) (net.Conn, error) 11 | 12 | type ReadFunc func(conn net.Conn) error 13 | 14 | type WriteFunc func(p []byte) func(w io.Writer) error 15 | 16 | type ReceiveHandler func(context.Context, []byte) 17 | 18 | type KeepAliveFunc func(conn net.Conn) 19 | 20 | type Option func(o *Options) 21 | 22 | type Options struct { 23 | Dialer DialFunc 24 | OnClose func(*Conn) error 25 | ReceiveHandler ReceiveHandler 26 | ReadFunc ReadFunc 27 | Keepalive KeepAliveFunc 28 | WriteFunc WriteFunc 29 | 30 | PoolFIFO bool 31 | PoolSize int 32 | MinIdleConns int 33 | MaxConnAge time.Duration 34 | PoolTimeout time.Duration 35 | IdleTimeout time.Duration 36 | IdleCheckFrequency time.Duration 37 | } 38 | 39 | func WithReadFunc(fn ReadFunc) Option { 40 | return func(o *Options) { 41 | o.ReadFunc = fn 42 | } 43 | } 44 | 45 | func WithWriteFunc(fn WriteFunc) Option { 46 | return func(o *Options) { 47 | o.WriteFunc = fn 48 | } 49 | } 50 | 51 | func WithKeepAlive(fn KeepAliveFunc) Option { 52 | return func(o *Options) { 53 | o.Keepalive = fn 54 | } 55 | } 56 | 57 | func WithReceiveHandle(fn ReceiveHandler) Option { 58 | return func(o *Options) { 59 | o.ReceiveHandler = fn 60 | } 61 | } 62 | 63 | func WithPoolFIFO(b bool) Option { 64 | return func(o *Options) { 65 | o.PoolFIFO = b 66 | } 67 | } 68 | 69 | func WithPoolSize(i int) Option { 70 | return func(o *Options) { 71 | o.PoolSize = i 72 | } 73 | } 74 | 75 | func WithMinIdleConns(i int) Option { 76 | return func(o *Options) { 77 | o.MinIdleConns = i 78 | } 79 | } 80 | func WithMaxConnAge(d time.Duration) Option { 81 | return func(o *Options) { 82 | o.MaxConnAge = d 83 | } 84 | } 85 | 86 | func WithPoolTimeout(d time.Duration) Option { 87 | return func(o *Options) { 88 | o.PoolTimeout = d 89 | } 90 | } 91 | 92 | func WithIdleTimeout(d time.Duration) Option { 93 | return func(o *Options) { 94 | o.IdleTimeout = d 95 | } 96 | } 97 | 98 | func WithIdleCheckFrequency(d time.Duration) Option { 99 | return func(o *Options) { 100 | o.IdleCheckFrequency = d 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /pool.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net" 8 | "sync" 9 | "sync/atomic" 10 | "syscall" 11 | "time" 12 | 13 | "github.com/mailru/easygo/netpoll" 14 | ) 15 | 16 | var ( 17 | // ErrClosed performs any operation on the closed client will return this error. 18 | ErrClosed = errors.New("connection is closed") 19 | 20 | // ErrPoolTimeout timed out waiting to get a connection from the connection pool. 21 | ErrPoolTimeout = errors.New("connection pool timeout") 22 | 23 | errUnexpectedRead = errors.New("unexpected read from socket") 24 | ) 25 | 26 | var timers = sync.Pool{ 27 | New: func() interface{} { 28 | t := time.NewTimer(time.Hour) 29 | t.Stop() 30 | return t 31 | }, 32 | } 33 | 34 | // Stats contains pool state information and accumulated stats. 35 | type Stats struct { 36 | Hits uint32 // number of times free connection was found in the pool 37 | Misses uint32 // number of times free connection was NOT found in the pool 38 | Timeouts uint32 // number of times a wait timeout occurred 39 | 40 | TotalConns uint32 // number of total connections in the pool 41 | IdleConns uint32 // number of idle connections in the pool 42 | StaleConns uint32 // number of stale connections removed from the pool 43 | } 44 | 45 | type Pooler interface { 46 | NewConn(context.Context) (*Conn, error) 47 | CloseConn(*Conn) error 48 | 49 | Get(context.Context) (*Conn, error) 50 | Put(context.Context, *Conn) 51 | Remove(context.Context, *Conn, error) 52 | 53 | Len() int 54 | IdleLen() int 55 | Stats() *Stats 56 | 57 | Close() error 58 | } 59 | 60 | type lastDialErrorWrap struct { 61 | err error 62 | } 63 | 64 | type ConnPool struct { 65 | opt *Options 66 | 67 | poller netpoll.Poller 68 | registerMu sync.Mutex 69 | registeredDesc map[string]*netpoll.Desc 70 | dialErrorsNum uint32 // atomic 71 | 72 | lastDialError atomic.Value 73 | 74 | queue chan struct{} 75 | 76 | connsMu sync.Mutex 77 | conns []*Conn 78 | idleConns []*Conn 79 | poolSize int 80 | idleConnsLen int 81 | receiver chan []byte 82 | 83 | stats Stats 84 | 85 | _closed uint32 // atomic 86 | closedCh chan struct{} 87 | 88 | log Logger 89 | } 90 | 91 | var _ Pooler = (*ConnPool)(nil) 92 | 93 | func NewConnPool(opt *Options, logger Logger) *ConnPool { 94 | p := &ConnPool{ 95 | opt: opt, 96 | 97 | queue: make(chan struct{}, opt.PoolSize), 98 | conns: make([]*Conn, 0, opt.PoolSize), 99 | closedCh: make(chan struct{}), 100 | receiver: make(chan []byte, 1), 101 | idleConns: make([]*Conn, 0, opt.PoolSize), 102 | registeredDesc: make(map[string]*netpoll.Desc), 103 | log: logger, 104 | } 105 | 106 | p.connsMu.Lock() 107 | go p.checkMinIdleConns() 108 | p.connsMu.Unlock() 109 | 110 | poller, _ := netpoll.New(nil) 111 | p.poller = poller 112 | 113 | if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { 114 | go p.reaper(opt.IdleCheckFrequency) 115 | } 116 | 117 | if p.opt.ReceiveHandler != nil { 118 | go p.handleReceiveDate() 119 | } 120 | 121 | return p 122 | } 123 | 124 | func (p *ConnPool) handleReceiveDate() { 125 | if p.opt.ReadFunc != nil { 126 | return 127 | } 128 | for { 129 | select { 130 | case <-p.closedCh: 131 | return 132 | case data := <-p.receiver: 133 | p.opt.ReceiveHandler(context.Background(), data) 134 | } 135 | } 136 | } 137 | 138 | func (p *ConnPool) checkMinIdleConns() { 139 | if p.opt.MinIdleConns == 0 { 140 | return 141 | } 142 | for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { 143 | p.poolSize++ 144 | p.idleConnsLen++ 145 | 146 | go func() { 147 | err := p.addIdleConn() 148 | if err != nil && err != ErrClosed { 149 | p.connsMu.Lock() 150 | p.poolSize-- 151 | p.idleConnsLen-- 152 | p.connsMu.Unlock() 153 | } 154 | }() 155 | } 156 | } 157 | 158 | func (p *ConnPool) addIdleConn() error { 159 | cn, err := p.dialConn(context.TODO(), true) 160 | if err != nil { 161 | p.log.Error("addIdleConn error:", err) 162 | return err 163 | } 164 | 165 | p.connsMu.Lock() 166 | defer p.connsMu.Unlock() 167 | 168 | // It is not allowed to add new connections to the closed connection pool. 169 | if p.closed() { 170 | _ = cn.Close() 171 | return ErrClosed 172 | } 173 | 174 | p.conns = append(p.conns, cn) 175 | p.idleConns = append(p.idleConns, cn) 176 | return nil 177 | } 178 | 179 | func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { 180 | return p.newConn(ctx, false) 181 | } 182 | 183 | func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { 184 | cn, err := p.dialConn(ctx, pooled) 185 | if err != nil { 186 | p.log.Error("dialConn error:", err) 187 | return nil, err 188 | } 189 | 190 | p.connsMu.Lock() 191 | defer p.connsMu.Unlock() 192 | 193 | // It is not allowed to add new connections to the closed connection pool. 194 | if p.closed() { 195 | _ = cn.Close() 196 | return nil, ErrClosed 197 | } 198 | 199 | p.conns = append(p.conns, cn) 200 | if pooled { 201 | // If pool is full remove the cn on next Put. 202 | if p.poolSize >= p.opt.PoolSize { 203 | cn.pooled = false 204 | } else { 205 | p.poolSize++ 206 | } 207 | } 208 | 209 | return cn, nil 210 | } 211 | 212 | func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { 213 | if p.closed() { 214 | return nil, ErrClosed 215 | } 216 | 217 | if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) { 218 | return nil, p.getLastDialError() 219 | } 220 | 221 | netConn, err := p.opt.Dialer(ctx) 222 | if err != nil { 223 | p.setLastDialError(err) 224 | if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { 225 | go p.tryDial() 226 | } 227 | return nil, err 228 | } 229 | cn := NewConn(netConn) 230 | cn.pooled = pooled 231 | 232 | if err = p.addPoller(ctx, cn); err != nil { 233 | p.log.Error("addPoller error:", err) 234 | return nil, err 235 | } 236 | 237 | if p.opt.Keepalive != nil { 238 | p.opt.Keepalive(cn.netConn) 239 | } 240 | 241 | return cn, err 242 | } 243 | 244 | func (p *ConnPool) addPoller(ctx context.Context, cn *Conn) error { 245 | var ( 246 | err error 247 | desc = netpoll.Must(netpoll.HandleRead(cn.netConn)) 248 | ) 249 | 250 | addr := cn.netConn.LocalAddr().String() 251 | p.registerMu.Lock() 252 | p.registeredDesc[addr] = desc 253 | p.registerMu.Unlock() 254 | 255 | eventHandle := func(event netpoll.Event) { 256 | if event&(netpoll.EventHup|netpoll.EventReadHup) != 0 { 257 | p.log.Error("connection closed by peer") 258 | err = p.closeConn(cn) 259 | if err != nil { 260 | p.log.Error("closeConn error:", err) 261 | } 262 | return 263 | } 264 | 265 | if err = cn.WithReader(ctx, 0, p.opt.ReadFunc); err != nil { 266 | err = p.poller.Stop(desc) 267 | if err != nil { 268 | p.log.Error("poller Stop error:", err) 269 | } 270 | return 271 | } 272 | } 273 | return p.poller.Start(desc, eventHandle) 274 | } 275 | 276 | func (p *ConnPool) tryDial() { 277 | for { 278 | if p.closed() { 279 | return 280 | } 281 | 282 | conn, err := p.opt.Dialer(context.Background()) 283 | if err != nil { 284 | p.setLastDialError(err) 285 | time.Sleep(time.Second) 286 | continue 287 | } 288 | 289 | atomic.StoreUint32(&p.dialErrorsNum, 0) 290 | _ = conn.Close() 291 | return 292 | } 293 | } 294 | 295 | func (p *ConnPool) setLastDialError(err error) { 296 | p.lastDialError.Store(&lastDialErrorWrap{err: err}) 297 | } 298 | 299 | func (p *ConnPool) getLastDialError() error { 300 | err, _ := p.lastDialError.Load().(*lastDialErrorWrap) 301 | if err != nil { 302 | return err.err 303 | } 304 | return nil 305 | } 306 | 307 | // Get returns existed connection from the pool or creates a new one. 308 | func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { 309 | if p.closed() { 310 | return nil, ErrClosed 311 | } 312 | 313 | if err := p.waitTurn(ctx); err != nil { 314 | return nil, err 315 | } 316 | 317 | for { 318 | p.connsMu.Lock() 319 | cn, err := p.popIdle() 320 | p.connsMu.Unlock() 321 | 322 | if err != nil { 323 | return nil, err 324 | } 325 | 326 | if cn == nil { 327 | break 328 | } 329 | 330 | if p.isStaleConn(cn) { 331 | _ = p.CloseConn(cn) 332 | continue 333 | } 334 | 335 | atomic.AddUint32(&p.stats.Hits, 1) 336 | return cn, nil 337 | } 338 | 339 | atomic.AddUint32(&p.stats.Misses, 1) 340 | 341 | newcn, err := p.newConn(ctx, true) 342 | if err != nil { 343 | p.freeTurn() 344 | return nil, err 345 | } 346 | 347 | return newcn, nil 348 | } 349 | 350 | func (p *ConnPool) getTurn() { 351 | p.queue <- struct{}{} 352 | } 353 | 354 | func (p *ConnPool) waitTurn(ctx context.Context) error { 355 | select { 356 | case <-ctx.Done(): 357 | return ctx.Err() 358 | default: 359 | } 360 | 361 | select { 362 | case p.queue <- struct{}{}: 363 | return nil 364 | default: 365 | } 366 | 367 | timer := timers.Get().(*time.Timer) 368 | timer.Reset(p.opt.PoolTimeout) 369 | 370 | select { 371 | case <-ctx.Done(): 372 | if !timer.Stop() { 373 | <-timer.C 374 | } 375 | timers.Put(timer) 376 | return ctx.Err() 377 | case p.queue <- struct{}{}: 378 | if !timer.Stop() { 379 | <-timer.C 380 | } 381 | timers.Put(timer) 382 | return nil 383 | case <-timer.C: 384 | timers.Put(timer) 385 | atomic.AddUint32(&p.stats.Timeouts, 1) 386 | return ErrPoolTimeout 387 | } 388 | } 389 | 390 | func (p *ConnPool) freeTurn() { 391 | <-p.queue 392 | } 393 | 394 | func (p *ConnPool) popIdle() (*Conn, error) { 395 | if p.closed() { 396 | return nil, ErrClosed 397 | } 398 | n := len(p.idleConns) 399 | if n == 0 { 400 | return nil, nil 401 | } 402 | 403 | var cn *Conn 404 | if p.opt.PoolFIFO { 405 | cn = p.idleConns[0] 406 | copy(p.idleConns, p.idleConns[1:]) 407 | p.idleConns = p.idleConns[:n-1] 408 | } else { 409 | idx := n - 1 410 | cn = p.idleConns[idx] 411 | p.idleConns = p.idleConns[:idx] 412 | } 413 | p.idleConnsLen-- 414 | p.checkMinIdleConns() 415 | return cn, nil 416 | } 417 | 418 | func (p *ConnPool) Put(ctx context.Context, cn *Conn) { 419 | if cn.reader.Buffered() > 0 { 420 | p.log.Error("Conn has unread data") 421 | p.Remove(ctx, cn, errors.New("conn is in a bad sate")) 422 | return 423 | } 424 | 425 | if !cn.pooled { 426 | p.Remove(ctx, cn, nil) 427 | return 428 | } 429 | 430 | p.connsMu.Lock() 431 | p.idleConns = append(p.idleConns, cn) 432 | p.idleConnsLen++ 433 | p.connsMu.Unlock() 434 | p.freeTurn() 435 | } 436 | 437 | func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { 438 | p.removeConnWithLock(cn) 439 | p.freeTurn() 440 | _ = p.closeConn(cn) 441 | } 442 | 443 | func (p *ConnPool) CloseConn(cn *Conn) error { 444 | p.removeConnWithLock(cn) 445 | return p.closeConn(cn) 446 | } 447 | 448 | func (p *ConnPool) removeConnWithLock(cn *Conn) { 449 | p.connsMu.Lock() 450 | p.removeConn(cn) 451 | p.connsMu.Unlock() 452 | } 453 | 454 | func (p *ConnPool) removeConn(cn *Conn) { 455 | for i, c := range p.conns { 456 | if c == cn { 457 | p.conns = append(p.conns[:i], p.conns[i+1:]...) 458 | if cn.pooled { 459 | p.poolSize-- 460 | p.checkMinIdleConns() 461 | } 462 | return 463 | } 464 | } 465 | } 466 | 467 | func (p *ConnPool) closeConn(cn *Conn) error { 468 | var ( 469 | ok bool 470 | err error 471 | desc *netpoll.Desc 472 | ) 473 | if p.opt.OnClose != nil { 474 | _ = p.opt.OnClose(cn) 475 | } 476 | 477 | p.registerMu.Lock() 478 | defer p.registerMu.Unlock() 479 | if desc, ok = p.registeredDesc[cn.netConn.LocalAddr().String()]; ok { 480 | if err = p.poller.Stop(desc); err != nil { 481 | p.log.Error("netpoll stop error:", err) 482 | } 483 | 484 | delete(p.registeredDesc, cn.netConn.LocalAddr().String()) 485 | } 486 | 487 | return cn.Close() 488 | } 489 | 490 | // Len returns total number of connections. 491 | func (p *ConnPool) Len() int { 492 | p.connsMu.Lock() 493 | n := len(p.conns) 494 | p.connsMu.Unlock() 495 | return n 496 | } 497 | 498 | // IdleLen returns number of idle connections. 499 | func (p *ConnPool) IdleLen() int { 500 | p.connsMu.Lock() 501 | n := p.idleConnsLen 502 | p.connsMu.Unlock() 503 | return n 504 | } 505 | 506 | func (p *ConnPool) Stats() *Stats { 507 | idleLen := p.IdleLen() 508 | return &Stats{ 509 | Hits: atomic.LoadUint32(&p.stats.Hits), 510 | Misses: atomic.LoadUint32(&p.stats.Misses), 511 | Timeouts: atomic.LoadUint32(&p.stats.Timeouts), 512 | 513 | TotalConns: uint32(p.Len()), 514 | IdleConns: uint32(idleLen), 515 | StaleConns: atomic.LoadUint32(&p.stats.StaleConns), 516 | } 517 | } 518 | 519 | func (p *ConnPool) closed() bool { 520 | return atomic.LoadUint32(&p._closed) == 1 521 | } 522 | 523 | func (p *ConnPool) Filter(fn func(*Conn) bool) error { 524 | p.connsMu.Lock() 525 | defer p.connsMu.Unlock() 526 | 527 | var firstErr error 528 | for _, cn := range p.conns { 529 | if fn(cn) { 530 | if err := p.closeConn(cn); err != nil && firstErr == nil { 531 | firstErr = err 532 | } 533 | } 534 | } 535 | return firstErr 536 | } 537 | 538 | func (p *ConnPool) Close() error { 539 | if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) { 540 | return ErrClosed 541 | } 542 | close(p.closedCh) 543 | 544 | var firstErr error 545 | p.connsMu.Lock() 546 | for _, cn := range p.conns { 547 | if err := p.closeConn(cn); err != nil && firstErr == nil { 548 | firstErr = err 549 | } 550 | } 551 | p.conns = nil 552 | p.poolSize = 0 553 | p.idleConns = nil 554 | p.idleConnsLen = 0 555 | p.connsMu.Unlock() 556 | 557 | return firstErr 558 | } 559 | 560 | func (p *ConnPool) reaper(frequency time.Duration) { 561 | ticker := time.NewTicker(frequency) 562 | defer ticker.Stop() 563 | 564 | for { 565 | select { 566 | case <-ticker.C: 567 | // It is possible that ticker and closedCh arrive together, 568 | // and select pseudo-randomly pick ticker case, we double 569 | // check here to prevent being executed after closed. 570 | if p.closed() { 571 | return 572 | } 573 | _, err := p.ReapStaleConns() 574 | if err != nil { 575 | p.log.Error("ReapStaleConns failed:", err) 576 | continue 577 | } 578 | case <-p.closedCh: 579 | return 580 | } 581 | } 582 | } 583 | 584 | func (p *ConnPool) ReapStaleConns() (int, error) { 585 | var n int 586 | for { 587 | p.getTurn() 588 | 589 | p.connsMu.Lock() 590 | cn := p.reapStaleConn() 591 | p.connsMu.Unlock() 592 | 593 | p.freeTurn() 594 | 595 | if cn != nil { 596 | _ = p.closeConn(cn) 597 | n++ 598 | } else { 599 | break 600 | } 601 | } 602 | atomic.AddUint32(&p.stats.StaleConns, uint32(n)) 603 | return n, nil 604 | } 605 | 606 | func (p *ConnPool) reapStaleConn() *Conn { 607 | if len(p.idleConns) == 0 { 608 | return nil 609 | } 610 | 611 | cn := p.idleConns[0] 612 | if !p.isStaleConn(cn) { 613 | return nil 614 | } 615 | 616 | p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) 617 | p.idleConnsLen-- 618 | p.removeConn(cn) 619 | 620 | return cn 621 | } 622 | 623 | func (p *ConnPool) isStaleConn(cn *Conn) bool { 624 | if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { 625 | return connCheck(cn.netConn) != nil 626 | } 627 | 628 | now := time.Now() 629 | if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout { 630 | return true 631 | } 632 | if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge { 633 | return true 634 | } 635 | 636 | return connCheck(cn.netConn) != nil 637 | } 638 | 639 | func connCheck(conn net.Conn) error { 640 | // Reset previous timeout. 641 | _ = conn.SetDeadline(time.Time{}) 642 | 643 | sysConn, ok := conn.(syscall.Conn) 644 | if !ok { 645 | return nil 646 | } 647 | rawConn, err := sysConn.SyscallConn() 648 | if err != nil { 649 | return err 650 | } 651 | 652 | var sysErr error 653 | err = rawConn.Read(func(fd uintptr) bool { 654 | var buf [1]byte 655 | n, err := syscall.Read(int(fd), buf[:]) 656 | switch { 657 | case n == 0 && err == nil: 658 | sysErr = io.EOF 659 | case n > 0: 660 | sysErr = errUnexpectedRead 661 | case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: 662 | sysErr = nil 663 | default: 664 | sysErr = err 665 | } 666 | return true 667 | }) 668 | if err != nil { 669 | return err 670 | } 671 | 672 | return sysErr 673 | } 674 | --------------------------------------------------------------------------------