├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── graceful.go ├── graceful_test.go ├── http2_test.go ├── keepalive_listener.go ├── limit_listen.go ├── signal.go ├── signal_appengine.go ├── test-fixtures ├── cert.crt └── key.pem └── tests └── main.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | go: 4 | - 1.7 5 | - 1.6.2 6 | - 1.5.4 7 | - 1.4.3 8 | - 1.3.3 9 | before_install: 10 | - go get github.com/mattn/goveralls 11 | - go get golang.org/x/tools/cmd/cover 12 | script: 13 | - $HOME/gopath/bin/goveralls -service=travis-ci 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Tyler Bunnell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | graceful [![GoDoc](https://godoc.org/github.com/tylerb/graceful?status.png)](http://godoc.org/github.com/tylerb/graceful) [![Build Status](https://travis-ci.org/tylerb/graceful.svg?branch=master)](https://travis-ci.org/tylerb/graceful) [![Coverage Status](https://coveralls.io/repos/tylerb/graceful/badge.svg)](https://coveralls.io/r/tylerb/graceful) [![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/tylerb/graceful?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) 2 | ======== 3 | 4 | Graceful is a Go 1.3+ package enabling graceful shutdown of http.Handler servers. 5 | 6 | ## Using Go 1.8? 7 | 8 | If you are using Go 1.8, you may not need to use this library! Consider using `http.Server`'s built-in [Shutdown()](https://golang.org/pkg/net/http/#Server.Shutdown) 9 | method for graceful shutdowns. 10 | 11 | ## Installation 12 | 13 | To install, simply execute: 14 | 15 | ``` 16 | go get gopkg.in/tylerb/graceful.v1 17 | ``` 18 | 19 | I am using [gopkg.in](http://labix.org/gopkg.in) to control releases. 20 | 21 | ## Usage 22 | 23 | Using Graceful is easy. Simply create your http.Handler and pass it to the `Run` function: 24 | 25 | ```go 26 | package main 27 | 28 | import ( 29 | "gopkg.in/tylerb/graceful.v1" 30 | "net/http" 31 | "fmt" 32 | "time" 33 | ) 34 | 35 | func main() { 36 | mux := http.NewServeMux() 37 | mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { 38 | fmt.Fprintf(w, "Welcome to the home page!") 39 | }) 40 | 41 | graceful.Run(":3001",10*time.Second,mux) 42 | } 43 | ``` 44 | 45 | Another example, using [Negroni](https://github.com/codegangsta/negroni), functions in much the same manner: 46 | 47 | ```go 48 | package main 49 | 50 | import ( 51 | "github.com/codegangsta/negroni" 52 | "gopkg.in/tylerb/graceful.v1" 53 | "net/http" 54 | "fmt" 55 | "time" 56 | ) 57 | 58 | func main() { 59 | mux := http.NewServeMux() 60 | mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { 61 | fmt.Fprintf(w, "Welcome to the home page!") 62 | }) 63 | 64 | n := negroni.Classic() 65 | n.UseHandler(mux) 66 | //n.Run(":3000") 67 | graceful.Run(":3001",10*time.Second,n) 68 | } 69 | ``` 70 | 71 | In addition to Run there are the http.Server counterparts ListenAndServe, ListenAndServeTLS and Serve, which allow you to configure HTTPS, custom timeouts and error handling. 72 | Graceful may also be used by instantiating its Server type directly, which embeds an http.Server: 73 | 74 | ```go 75 | mux := // ... 76 | 77 | srv := &graceful.Server{ 78 | Timeout: 10 * time.Second, 79 | 80 | Server: &http.Server{ 81 | Addr: ":1234", 82 | Handler: mux, 83 | }, 84 | } 85 | 86 | srv.ListenAndServe() 87 | ``` 88 | 89 | This form allows you to set the ConnState callback, which works in the same way as in http.Server: 90 | 91 | ```go 92 | mux := // ... 93 | 94 | srv := &graceful.Server{ 95 | Timeout: 10 * time.Second, 96 | 97 | ConnState: func(conn net.Conn, state http.ConnState) { 98 | // conn has a new state 99 | }, 100 | 101 | Server: &http.Server{ 102 | Addr: ":1234", 103 | Handler: mux, 104 | }, 105 | } 106 | 107 | srv.ListenAndServe() 108 | ``` 109 | 110 | ## Behaviour 111 | 112 | When Graceful is sent a SIGINT or SIGTERM (possibly from ^C or a kill command), it: 113 | 114 | 1. Disables keepalive connections. 115 | 2. Closes the listening socket, allowing another process to listen on that port immediately. 116 | 3. Starts a timer of `timeout` duration to give active requests a chance to finish. 117 | 4. When timeout expires, closes all active connections. 118 | 5. Closes the `stopChan`, waking up any blocking goroutines. 119 | 6. Returns from the function, allowing the server to terminate. 120 | 121 | ## Notes 122 | 123 | If the `timeout` argument to `Run` is 0, the server never times out, allowing all active requests to complete. 124 | 125 | If you wish to stop the server in some way other than an OS signal, you may call the `Stop()` function. 126 | This function stops the server, gracefully, using the new timeout value you provide. The `StopChan()` function 127 | returns a channel on which you can block while waiting for the server to stop. This channel will be closed when 128 | the server is stopped, allowing your execution to proceed. Multiple goroutines can block on this channel at the 129 | same time and all will be signalled when stopping is complete. 130 | 131 | ### Important things to note when setting `timeout` to 0: 132 | 133 | If you set the `timeout` to `0`, it waits for all connections to the server to disconnect before shutting down. 134 | This means that even though requests over a connection have finished, it is possible for the client to hold the 135 | connection open and block the server from shutting down indefinitely. 136 | 137 | This is especially evident when graceful is used to run HTTP/2 servers. Clients like Chrome and Firefox have been 138 | observed to hold onto the open connection indefinitely over HTTP/2, preventing the server from shutting down. In 139 | addition, there is also the risk of malicious clients holding and keeping the connection alive. 140 | 141 | It is understandable that sometimes, you might want to wait for the client indefinitely because they might be 142 | uploading large files. In these type of cases, it is recommended that you set a reasonable timeout to kill the 143 | connection, and have the client perform resumable uploads. For example, the client can divide the file into chunks 144 | and reupload chunks that were in transit when the connection was terminated. 145 | 146 | ## Contributing 147 | 148 | If you would like to contribute, please: 149 | 150 | 1. Create a GitHub issue regarding the contribution. Features and bugs should be discussed beforehand. 151 | 2. Fork the repository. 152 | 3. Create a pull request with your solution. This pull request should reference and close the issues (Fix #2). 153 | 154 | All pull requests should: 155 | 156 | 1. Pass [gometalinter -t .](https://github.com/alecthomas/gometalinter) with no warnings. 157 | 2. Be `go fmt` formatted. 158 | -------------------------------------------------------------------------------- /graceful.go: -------------------------------------------------------------------------------- 1 | package graceful 2 | 3 | import ( 4 | "crypto/tls" 5 | "log" 6 | "net" 7 | "net/http" 8 | "os" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | // Server wraps an http.Server with graceful connection handling. 14 | // It may be used directly in the same way as http.Server, or may 15 | // be constructed with the global functions in this package. 16 | // 17 | // Example: 18 | // srv := &graceful.Server{ 19 | // Timeout: 5 * time.Second, 20 | // Server: &http.Server{Addr: ":1234", Handler: handler}, 21 | // } 22 | // srv.ListenAndServe() 23 | type Server struct { 24 | *http.Server 25 | 26 | // Timeout is the duration to allow outstanding requests to survive 27 | // before forcefully terminating them. 28 | Timeout time.Duration 29 | 30 | // Limit the number of outstanding requests 31 | ListenLimit int 32 | 33 | // TCPKeepAlive sets the TCP keep-alive timeouts on accepted 34 | // connections. It prunes dead TCP connections ( e.g. closing 35 | // laptop mid-download) 36 | TCPKeepAlive time.Duration 37 | 38 | // ConnState specifies an optional callback function that is 39 | // called when a client connection changes state. This is a proxy 40 | // to the underlying http.Server's ConnState, and the original 41 | // must not be set directly. 42 | ConnState func(net.Conn, http.ConnState) 43 | 44 | // BeforeShutdown is an optional callback function that is called 45 | // before the listener is closed. Returns true if shutdown is allowed 46 | BeforeShutdown func() bool 47 | 48 | // ShutdownInitiated is an optional callback function that is called 49 | // when shutdown is initiated. It can be used to notify the client 50 | // side of long lived connections (e.g. websockets) to reconnect. 51 | ShutdownInitiated func() 52 | 53 | // NoSignalHandling prevents graceful from automatically shutting down 54 | // on SIGINT and SIGTERM. If set to true, you must shut down the server 55 | // manually with Stop(). 56 | NoSignalHandling bool 57 | 58 | // Logger used to notify of errors on startup and on stop. 59 | Logger *log.Logger 60 | 61 | // LogFunc can be assigned with a logging function of your choice, allowing 62 | // you to use whatever logging approach you would like 63 | LogFunc func(format string, args ...interface{}) 64 | 65 | // Interrupted is true if the server is handling a SIGINT or SIGTERM 66 | // signal and is thus shutting down. 67 | Interrupted bool 68 | 69 | // interrupt signals the listener to stop serving connections, 70 | // and the server to shut down. 71 | interrupt chan os.Signal 72 | 73 | // stopLock is used to protect against concurrent calls to Stop 74 | stopLock sync.Mutex 75 | 76 | // stopChan is the channel on which callers may block while waiting for 77 | // the server to stop. 78 | stopChan chan struct{} 79 | 80 | // chanLock is used to protect access to the various channel constructors. 81 | chanLock sync.RWMutex 82 | 83 | // connections holds all connections managed by graceful 84 | connections map[net.Conn]struct{} 85 | 86 | // idleConnections holds all idle connections managed by graceful 87 | idleConnections map[net.Conn]struct{} 88 | } 89 | 90 | // Run serves the http.Handler with graceful shutdown enabled. 91 | // 92 | // timeout is the duration to wait until killing active requests and stopping the server. 93 | // If timeout is 0, the server never times out. It waits for all active requests to finish. 94 | func Run(addr string, timeout time.Duration, n http.Handler) { 95 | srv := &Server{ 96 | Timeout: timeout, 97 | TCPKeepAlive: 3 * time.Minute, 98 | Server: &http.Server{Addr: addr, Handler: n}, 99 | // Logger: DefaultLogger(), 100 | } 101 | 102 | if err := srv.ListenAndServe(); err != nil { 103 | if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") { 104 | srv.logf("%s", err) 105 | os.Exit(1) 106 | } 107 | } 108 | 109 | } 110 | 111 | // RunWithErr is an alternative version of Run function which can return error. 112 | // 113 | // Unlike Run this version will not exit the program if an error is encountered but will 114 | // return it instead. 115 | func RunWithErr(addr string, timeout time.Duration, n http.Handler) error { 116 | srv := &Server{ 117 | Timeout: timeout, 118 | TCPKeepAlive: 3 * time.Minute, 119 | Server: &http.Server{Addr: addr, Handler: n}, 120 | Logger: DefaultLogger(), 121 | } 122 | 123 | return srv.ListenAndServe() 124 | } 125 | 126 | // ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled. 127 | // 128 | // timeout is the duration to wait until killing active requests and stopping the server. 129 | // If timeout is 0, the server never times out. It waits for all active requests to finish. 130 | func ListenAndServe(server *http.Server, timeout time.Duration) error { 131 | srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()} 132 | return srv.ListenAndServe() 133 | } 134 | 135 | // ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled. 136 | func (srv *Server) ListenAndServe() error { 137 | // Create the listener so we can control their lifetime 138 | addr := srv.Addr 139 | if addr == "" { 140 | addr = ":http" 141 | } 142 | conn, err := srv.newTCPListener(addr) 143 | if err != nil { 144 | return err 145 | } 146 | 147 | return srv.Serve(conn) 148 | } 149 | 150 | // ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled. 151 | // 152 | // timeout is the duration to wait until killing active requests and stopping the server. 153 | // If timeout is 0, the server never times out. It waits for all active requests to finish. 154 | func ListenAndServeTLS(server *http.Server, certFile, keyFile string, timeout time.Duration) error { 155 | srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()} 156 | return srv.ListenAndServeTLS(certFile, keyFile) 157 | } 158 | 159 | // ListenTLS is a convenience method that creates an https listener using the 160 | // provided cert and key files. Use this method if you need access to the 161 | // listener object directly. When ready, pass it to the Serve method. 162 | func (srv *Server) ListenTLS(certFile, keyFile string) (net.Listener, error) { 163 | // Create the listener ourselves so we can control its lifetime 164 | addr := srv.Addr 165 | if addr == "" { 166 | addr = ":https" 167 | } 168 | 169 | config := &tls.Config{} 170 | if srv.TLSConfig != nil { 171 | *config = *srv.TLSConfig 172 | } 173 | 174 | var err error 175 | if certFile != "" && keyFile != "" { 176 | config.Certificates = make([]tls.Certificate, 1) 177 | config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) 178 | if err != nil { 179 | return nil, err 180 | } 181 | } 182 | 183 | // Enable http2 184 | enableHTTP2ForTLSConfig(config) 185 | 186 | conn, err := srv.newTCPListener(addr) 187 | if err != nil { 188 | return nil, err 189 | } 190 | 191 | srv.TLSConfig = config 192 | 193 | tlsListener := tls.NewListener(conn, config) 194 | return tlsListener, nil 195 | } 196 | 197 | // Enable HTTP2ForTLSConfig explicitly enables http/2 for a TLS Config. This is due to changes in Go 1.7 where 198 | // http servers are no longer automatically configured to enable http/2 if the server's TLSConfig is set. 199 | // See https://github.com/golang/go/issues/15908 200 | func enableHTTP2ForTLSConfig(t *tls.Config) { 201 | 202 | if TLSConfigHasHTTP2Enabled(t) { 203 | return 204 | } 205 | 206 | t.NextProtos = append(t.NextProtos, "h2") 207 | } 208 | 209 | // TLSConfigHasHTTP2Enabled checks to see if a given TLS Config has http2 enabled. 210 | func TLSConfigHasHTTP2Enabled(t *tls.Config) bool { 211 | for _, value := range t.NextProtos { 212 | if value == "h2" { 213 | return true 214 | } 215 | } 216 | return false 217 | } 218 | 219 | // ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled. 220 | func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { 221 | l, err := srv.ListenTLS(certFile, keyFile) 222 | if err != nil { 223 | return err 224 | } 225 | 226 | return srv.Serve(l) 227 | } 228 | 229 | // ListenAndServeTLSConfig can be used with an existing TLS config and is equivalent to 230 | // http.Server.ListenAndServeTLS with graceful shutdown enabled, 231 | func (srv *Server) ListenAndServeTLSConfig(config *tls.Config) error { 232 | addr := srv.Addr 233 | if addr == "" { 234 | addr = ":https" 235 | } 236 | 237 | conn, err := srv.newTCPListener(addr) 238 | if err != nil { 239 | return err 240 | } 241 | 242 | srv.TLSConfig = config 243 | 244 | tlsListener := tls.NewListener(conn, config) 245 | return srv.Serve(tlsListener) 246 | } 247 | 248 | // Serve is equivalent to http.Server.Serve with graceful shutdown enabled. 249 | // 250 | // timeout is the duration to wait until killing active requests and stopping the server. 251 | // If timeout is 0, the server never times out. It waits for all active requests to finish. 252 | func Serve(server *http.Server, l net.Listener, timeout time.Duration) error { 253 | srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()} 254 | 255 | return srv.Serve(l) 256 | } 257 | 258 | // Serve is equivalent to http.Server.Serve with graceful shutdown enabled. 259 | func (srv *Server) Serve(listener net.Listener) error { 260 | 261 | if srv.ListenLimit != 0 { 262 | listener = LimitListener(listener, srv.ListenLimit) 263 | } 264 | 265 | // Make our stopchan 266 | srv.StopChan() 267 | 268 | // Track connection state 269 | add := make(chan net.Conn) 270 | idle := make(chan net.Conn) 271 | active := make(chan net.Conn) 272 | remove := make(chan net.Conn) 273 | 274 | srv.Server.ConnState = func(conn net.Conn, state http.ConnState) { 275 | switch state { 276 | case http.StateNew: 277 | add <- conn 278 | case http.StateActive: 279 | active <- conn 280 | case http.StateIdle: 281 | idle <- conn 282 | case http.StateClosed, http.StateHijacked: 283 | remove <- conn 284 | } 285 | 286 | srv.stopLock.Lock() 287 | defer srv.stopLock.Unlock() 288 | 289 | if srv.ConnState != nil { 290 | srv.ConnState(conn, state) 291 | } 292 | } 293 | 294 | // Manage open connections 295 | shutdown := make(chan chan struct{}) 296 | kill := make(chan struct{}) 297 | go srv.manageConnections(add, idle, active, remove, shutdown, kill) 298 | 299 | interrupt := srv.interruptChan() 300 | // Set up the interrupt handler 301 | if !srv.NoSignalHandling { 302 | signalNotify(interrupt) 303 | } 304 | quitting := make(chan struct{}) 305 | go srv.handleInterrupt(interrupt, quitting, listener) 306 | 307 | // Serve with graceful listener. 308 | // Execution blocks here until listener.Close() is called, above. 309 | err := srv.Server.Serve(listener) 310 | if err != nil { 311 | // If the underlying listening is closed, Serve returns an error 312 | // complaining about listening on a closed socket. This is expected, so 313 | // let's ignore the error if we are the ones who explicitly closed the 314 | // socket. 315 | select { 316 | case <-quitting: 317 | err = nil 318 | default: 319 | } 320 | } 321 | 322 | srv.shutdown(shutdown, kill) 323 | 324 | return err 325 | } 326 | 327 | // Stop instructs the type to halt operations and close 328 | // the stop channel when it is finished. 329 | // 330 | // timeout is grace period for which to wait before shutting 331 | // down the server. The timeout value passed here will override the 332 | // timeout given when constructing the server, as this is an explicit 333 | // command to stop the server. 334 | func (srv *Server) Stop(timeout time.Duration) { 335 | srv.stopLock.Lock() 336 | defer srv.stopLock.Unlock() 337 | 338 | srv.Timeout = timeout 339 | sendSignalInt(srv.interruptChan()) 340 | } 341 | 342 | // StopChan gets the stop channel which will block until 343 | // stopping has completed, at which point it is closed. 344 | // Callers should never close the stop channel. 345 | func (srv *Server) StopChan() <-chan struct{} { 346 | srv.chanLock.Lock() 347 | defer srv.chanLock.Unlock() 348 | 349 | if srv.stopChan == nil { 350 | srv.stopChan = make(chan struct{}) 351 | } 352 | return srv.stopChan 353 | } 354 | 355 | // DefaultLogger returns the logger used by Run, RunWithErr, ListenAndServe, ListenAndServeTLS and Serve. 356 | // The logger outputs to STDERR by default. 357 | func DefaultLogger() *log.Logger { 358 | return log.New(os.Stderr, "[graceful] ", 0) 359 | } 360 | 361 | func (srv *Server) manageConnections(add, idle, active, remove chan net.Conn, shutdown chan chan struct{}, kill chan struct{}) { 362 | var done chan struct{} 363 | srv.connections = map[net.Conn]struct{}{} 364 | srv.idleConnections = map[net.Conn]struct{}{} 365 | for { 366 | select { 367 | case conn := <-add: 368 | srv.connections[conn] = struct{}{} 369 | srv.idleConnections[conn] = struct{}{} // Newly-added connections are considered idle until they become active. 370 | case conn := <-idle: 371 | srv.idleConnections[conn] = struct{}{} 372 | case conn := <-active: 373 | delete(srv.idleConnections, conn) 374 | case conn := <-remove: 375 | delete(srv.connections, conn) 376 | delete(srv.idleConnections, conn) 377 | if done != nil && len(srv.connections) == 0 { 378 | done <- struct{}{} 379 | return 380 | } 381 | case done = <-shutdown: 382 | if len(srv.connections) == 0 && len(srv.idleConnections) == 0 { 383 | done <- struct{}{} 384 | return 385 | } 386 | // a shutdown request has been received. if we have open idle 387 | // connections, we must close all of them now. this prevents idle 388 | // connections from holding the server open while waiting for them to 389 | // hit their idle timeout. 390 | for k := range srv.idleConnections { 391 | if err := k.Close(); err != nil { 392 | srv.logf("[ERROR] %s", err) 393 | } 394 | } 395 | case <-kill: 396 | srv.stopLock.Lock() 397 | defer srv.stopLock.Unlock() 398 | 399 | srv.Server.ConnState = nil 400 | for k := range srv.connections { 401 | if err := k.Close(); err != nil { 402 | srv.logf("[ERROR] %s", err) 403 | } 404 | } 405 | return 406 | } 407 | } 408 | } 409 | 410 | func (srv *Server) interruptChan() chan os.Signal { 411 | srv.chanLock.Lock() 412 | defer srv.chanLock.Unlock() 413 | 414 | if srv.interrupt == nil { 415 | srv.interrupt = make(chan os.Signal, 1) 416 | } 417 | 418 | return srv.interrupt 419 | } 420 | 421 | func (srv *Server) handleInterrupt(interrupt chan os.Signal, quitting chan struct{}, listener net.Listener) { 422 | for _ = range interrupt { 423 | if srv.Interrupted { 424 | srv.logf("already shutting down") 425 | continue 426 | } 427 | srv.logf("shutdown initiated") 428 | srv.Interrupted = true 429 | if srv.BeforeShutdown != nil { 430 | if !srv.BeforeShutdown() { 431 | srv.Interrupted = false 432 | continue 433 | } 434 | } 435 | 436 | close(quitting) 437 | srv.SetKeepAlivesEnabled(false) 438 | if err := listener.Close(); err != nil { 439 | srv.logf("[ERROR] %s", err) 440 | } 441 | 442 | if srv.ShutdownInitiated != nil { 443 | srv.ShutdownInitiated() 444 | } 445 | } 446 | } 447 | 448 | func (srv *Server) logf(format string, args ...interface{}) { 449 | if srv.LogFunc != nil { 450 | srv.LogFunc(format, args...) 451 | } else if srv.Logger != nil { 452 | srv.Logger.Printf(format, args...) 453 | } 454 | } 455 | 456 | func (srv *Server) shutdown(shutdown chan chan struct{}, kill chan struct{}) { 457 | // Request done notification 458 | done := make(chan struct{}) 459 | shutdown <- done 460 | 461 | srv.stopLock.Lock() 462 | defer srv.stopLock.Unlock() 463 | if srv.Timeout > 0 { 464 | select { 465 | case <-done: 466 | case <-time.After(srv.Timeout): 467 | close(kill) 468 | } 469 | } else { 470 | <-done 471 | } 472 | // Close the stopChan to wake up any blocked goroutines. 473 | srv.chanLock.Lock() 474 | if srv.stopChan != nil { 475 | close(srv.stopChan) 476 | } 477 | srv.chanLock.Unlock() 478 | } 479 | 480 | func (srv *Server) newTCPListener(addr string) (net.Listener, error) { 481 | conn, err := net.Listen("tcp", addr) 482 | if err != nil { 483 | return conn, err 484 | } 485 | if srv.TCPKeepAlive != 0 { 486 | conn = keepAliveListener{conn, srv.TCPKeepAlive} 487 | } 488 | return conn, nil 489 | } 490 | -------------------------------------------------------------------------------- /graceful_test.go: -------------------------------------------------------------------------------- 1 | package graceful 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "log" 8 | "net" 9 | "net/http" 10 | "net/url" 11 | "os" 12 | "reflect" 13 | "strings" 14 | "sync" 15 | "syscall" 16 | "testing" 17 | "time" 18 | ) 19 | 20 | const ( 21 | // The tests will run a test server on this port. 22 | port = 9654 23 | concurrentRequestN = 8 24 | killTime = 500 * time.Millisecond 25 | timeoutTime = 1000 * time.Millisecond 26 | waitTime = 100 * time.Millisecond 27 | ) 28 | 29 | func runQuery(t *testing.T, expected int, shouldErr bool, wg *sync.WaitGroup, once *sync.Once) { 30 | defer wg.Done() 31 | client := http.Client{} 32 | r, err := client.Get(fmt.Sprintf("http://localhost:%d", port)) 33 | if shouldErr && err == nil { 34 | once.Do(func() { 35 | t.Error("Expected an error but none was encountered.") 36 | }) 37 | } else if shouldErr && err != nil { 38 | if checkErr(t, err, once) { 39 | return 40 | } 41 | } 42 | if r != nil && r.StatusCode != expected { 43 | once.Do(func() { 44 | t.Errorf("Incorrect status code on response. Expected %d. Got %d", expected, r.StatusCode) 45 | }) 46 | } else if r == nil { 47 | once.Do(func() { 48 | t.Error("No response when a response was expected.") 49 | }) 50 | } 51 | } 52 | 53 | func checkErr(t *testing.T, err error, once *sync.Once) bool { 54 | if err.(*url.Error).Err == io.EOF { 55 | return true 56 | } 57 | var errno syscall.Errno 58 | switch oe := err.(*url.Error).Err.(type) { 59 | case *net.OpError: 60 | switch e := oe.Err.(type) { 61 | case syscall.Errno: 62 | errno = e 63 | case *os.SyscallError: 64 | errno = e.Err.(syscall.Errno) 65 | } 66 | if errno == syscall.ECONNREFUSED { 67 | return true 68 | } else if err != nil { 69 | once.Do(func() { 70 | t.Error("Error on Get:", err) 71 | }) 72 | } 73 | default: 74 | if strings.Contains(err.Error(), "transport closed before response was received") { 75 | return true 76 | } 77 | if strings.Contains(err.Error(), "server closed connection") { 78 | return true 79 | } 80 | fmt.Printf("unknown err: %s, %#v\n", err, err) 81 | } 82 | return false 83 | } 84 | 85 | func createListener(sleep time.Duration) (*http.Server, net.Listener, error) { 86 | mux := http.NewServeMux() 87 | mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { 88 | time.Sleep(sleep) 89 | rw.WriteHeader(http.StatusOK) 90 | }) 91 | 92 | server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux} 93 | l, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) 94 | return server, l, err 95 | } 96 | 97 | func launchTestQueries(t *testing.T, wg *sync.WaitGroup, c chan os.Signal) { 98 | defer wg.Done() 99 | var once sync.Once 100 | 101 | for i := 0; i < concurrentRequestN; i++ { 102 | wg.Add(1) 103 | go runQuery(t, http.StatusOK, false, wg, &once) 104 | } 105 | 106 | time.Sleep(waitTime) 107 | c <- os.Interrupt 108 | time.Sleep(waitTime) 109 | 110 | for i := 0; i < concurrentRequestN; i++ { 111 | wg.Add(1) 112 | go runQuery(t, 0, true, wg, &once) 113 | } 114 | } 115 | 116 | func TestGracefulRun(t *testing.T) { 117 | var wg sync.WaitGroup 118 | defer wg.Wait() 119 | 120 | c := make(chan os.Signal, 1) 121 | server, l, err := createListener(killTime / 2) 122 | if err != nil { 123 | t.Fatal(err) 124 | } 125 | 126 | wg.Add(1) 127 | go func() { 128 | defer wg.Done() 129 | srv := &Server{Timeout: killTime, Server: server, interrupt: c} 130 | srv.Serve(l) 131 | }() 132 | 133 | wg.Add(1) 134 | go launchTestQueries(t, &wg, c) 135 | } 136 | 137 | func TestGracefulRunLimitKeepAliveListener(t *testing.T) { 138 | var wg sync.WaitGroup 139 | defer wg.Wait() 140 | 141 | c := make(chan os.Signal, 1) 142 | server, l, err := createListener(killTime / 2) 143 | if err != nil { 144 | t.Fatal(err) 145 | } 146 | 147 | wg.Add(1) 148 | go func() { 149 | defer wg.Done() 150 | srv := &Server{ 151 | Timeout: killTime, 152 | ListenLimit: concurrentRequestN, 153 | TCPKeepAlive: 1 * time.Second, 154 | Server: server, 155 | interrupt: c, 156 | } 157 | srv.Serve(l) 158 | }() 159 | 160 | wg.Add(1) 161 | go launchTestQueries(t, &wg, c) 162 | } 163 | 164 | func TestGracefulRunTimesOut(t *testing.T) { 165 | var wg sync.WaitGroup 166 | defer wg.Wait() 167 | 168 | c := make(chan os.Signal, 1) 169 | server, l, err := createListener(killTime * 10) 170 | if err != nil { 171 | t.Fatal(err) 172 | } 173 | 174 | wg.Add(1) 175 | go func() { 176 | defer wg.Done() 177 | srv := &Server{Timeout: killTime, Server: server, interrupt: c} 178 | srv.Serve(l) 179 | }() 180 | 181 | wg.Add(1) 182 | go func() { 183 | defer wg.Done() 184 | var once sync.Once 185 | 186 | for i := 0; i < concurrentRequestN; i++ { 187 | wg.Add(1) 188 | go runQuery(t, 0, true, &wg, &once) 189 | } 190 | 191 | time.Sleep(waitTime) 192 | c <- os.Interrupt 193 | time.Sleep(waitTime) 194 | 195 | for i := 0; i < concurrentRequestN; i++ { 196 | wg.Add(1) 197 | go runQuery(t, 0, true, &wg, &once) 198 | } 199 | }() 200 | } 201 | 202 | func TestGracefulRunDoesntTimeOut(t *testing.T) { 203 | var wg sync.WaitGroup 204 | defer wg.Wait() 205 | 206 | c := make(chan os.Signal, 1) 207 | server, l, err := createListener(killTime * 2) 208 | if err != nil { 209 | t.Fatal(err) 210 | } 211 | 212 | wg.Add(1) 213 | go func() { 214 | defer wg.Done() 215 | srv := &Server{Timeout: 0, Server: server, interrupt: c} 216 | srv.Serve(l) 217 | }() 218 | 219 | wg.Add(1) 220 | go launchTestQueries(t, &wg, c) 221 | } 222 | 223 | func TestGracefulRunDoesntTimeOutAfterConnectionCreated(t *testing.T) { 224 | var wg sync.WaitGroup 225 | defer wg.Wait() 226 | 227 | c := make(chan os.Signal, 1) 228 | server, l, err := createListener(killTime) 229 | if err != nil { 230 | t.Fatal(err) 231 | } 232 | 233 | wg.Add(1) 234 | go func() { 235 | defer wg.Done() 236 | srv := &Server{Timeout: 0, Server: server, interrupt: c} 237 | srv.Serve(l) 238 | }() 239 | time.Sleep(waitTime) 240 | 241 | // Make a sample first request. The connection will be left idle. 242 | resp, err := http.Get(fmt.Sprintf("http://localhost:%d", port)) 243 | if err != nil { 244 | panic(fmt.Sprintf("first request failed: %v", err)) 245 | } 246 | resp.Body.Close() 247 | 248 | wg.Add(1) 249 | go func() { 250 | defer wg.Done() 251 | 252 | // With idle connections improperly handled, the server doesn't wait for this 253 | // to complete and the request fails. It should be allowed to complete successfully. 254 | _, err := http.Get(fmt.Sprintf("http://localhost:%d", port)) 255 | if err != nil { 256 | t.Errorf("Get failed: %v", err) 257 | } 258 | }() 259 | 260 | // Ensure the request goes out 261 | time.Sleep(waitTime) 262 | c <- os.Interrupt 263 | wg.Wait() 264 | } 265 | 266 | func TestGracefulRunNoRequests(t *testing.T) { 267 | var wg sync.WaitGroup 268 | defer wg.Wait() 269 | 270 | c := make(chan os.Signal, 1) 271 | server, l, err := createListener(killTime * 2) 272 | if err != nil { 273 | t.Fatal(err) 274 | } 275 | 276 | wg.Add(1) 277 | go func() { 278 | defer wg.Done() 279 | srv := &Server{Timeout: 0, Server: server, interrupt: c} 280 | srv.Serve(l) 281 | }() 282 | 283 | c <- os.Interrupt 284 | } 285 | 286 | func TestGracefulForwardsConnState(t *testing.T) { 287 | var stateLock sync.Mutex 288 | states := make(map[http.ConnState]int) 289 | connState := func(conn net.Conn, state http.ConnState) { 290 | stateLock.Lock() 291 | states[state]++ 292 | stateLock.Unlock() 293 | } 294 | 295 | var wg sync.WaitGroup 296 | defer wg.Wait() 297 | 298 | expected := map[http.ConnState]int{ 299 | http.StateNew: concurrentRequestN, 300 | http.StateActive: concurrentRequestN, 301 | http.StateClosed: concurrentRequestN, 302 | } 303 | 304 | c := make(chan os.Signal, 1) 305 | server, l, err := createListener(killTime / 2) 306 | if err != nil { 307 | t.Fatal(err) 308 | } 309 | 310 | wg.Add(1) 311 | go func() { 312 | defer wg.Done() 313 | srv := &Server{ 314 | ConnState: connState, 315 | Timeout: killTime, 316 | Server: server, 317 | interrupt: c, 318 | } 319 | srv.Serve(l) 320 | }() 321 | 322 | wg.Add(1) 323 | go launchTestQueries(t, &wg, c) 324 | wg.Wait() 325 | 326 | stateLock.Lock() 327 | if !reflect.DeepEqual(states, expected) { 328 | t.Errorf("Incorrect connection state tracking.\n actual: %v\nexpected: %v\n", states, expected) 329 | } 330 | stateLock.Unlock() 331 | } 332 | 333 | func TestGracefulExplicitStop(t *testing.T) { 334 | server, l, err := createListener(1 * time.Millisecond) 335 | if err != nil { 336 | t.Fatal(err) 337 | } 338 | 339 | srv := &Server{Timeout: killTime, Server: server} 340 | 341 | go func() { 342 | go srv.Serve(l) 343 | time.Sleep(waitTime) 344 | srv.Stop(killTime) 345 | }() 346 | 347 | // block on the stopChan until the server has shut down 348 | select { 349 | case <-srv.StopChan(): 350 | case <-time.After(timeoutTime): 351 | t.Fatal("Timed out while waiting for explicit stop to complete") 352 | } 353 | } 354 | 355 | func TestGracefulExplicitStopOverride(t *testing.T) { 356 | server, l, err := createListener(1 * time.Millisecond) 357 | if err != nil { 358 | t.Fatal(err) 359 | } 360 | 361 | srv := &Server{Timeout: killTime, Server: server} 362 | 363 | go func() { 364 | go srv.Serve(l) 365 | time.Sleep(waitTime) 366 | srv.Stop(killTime / 2) 367 | }() 368 | 369 | // block on the stopChan until the server has shut down 370 | select { 371 | case <-srv.StopChan(): 372 | case <-time.After(killTime): 373 | t.Fatal("Timed out while waiting for explicit stop to complete") 374 | } 375 | } 376 | 377 | func TestBeforeShutdownAndShutdownInitiatedCallbacks(t *testing.T) { 378 | var wg sync.WaitGroup 379 | defer wg.Wait() 380 | 381 | server, l, err := createListener(1 * time.Millisecond) 382 | if err != nil { 383 | t.Fatal(err) 384 | } 385 | 386 | beforeShutdownCalled := make(chan struct{}) 387 | cb1 := func() bool { close(beforeShutdownCalled); return true } 388 | shutdownInitiatedCalled := make(chan struct{}) 389 | cb2 := func() { close(shutdownInitiatedCalled) } 390 | 391 | wg.Add(2) 392 | srv := &Server{Server: server, BeforeShutdown: cb1, ShutdownInitiated: cb2} 393 | go func() { 394 | defer wg.Done() 395 | srv.Serve(l) 396 | }() 397 | go func() { 398 | defer wg.Done() 399 | time.Sleep(waitTime) 400 | srv.Stop(killTime) 401 | }() 402 | 403 | beforeShutdown := false 404 | shutdownInitiated := false 405 | for i := 0; i < 2; i++ { 406 | select { 407 | case <-beforeShutdownCalled: 408 | beforeShutdownCalled = nil 409 | beforeShutdown = true 410 | case <-shutdownInitiatedCalled: 411 | shutdownInitiatedCalled = nil 412 | shutdownInitiated = true 413 | case <-time.After(killTime): 414 | t.Fatal("Timed out while waiting for ShutdownInitiated callback to be called") 415 | } 416 | } 417 | 418 | if !beforeShutdown { 419 | t.Fatal("beforeShutdown should be true") 420 | } 421 | if !shutdownInitiated { 422 | t.Fatal("shutdownInitiated should be true") 423 | } 424 | } 425 | 426 | func TestBeforeShutdownCanceled(t *testing.T) { 427 | var wg sync.WaitGroup 428 | wg.Add(1) 429 | 430 | server, l, err := createListener(1 * time.Millisecond) 431 | if err != nil { 432 | t.Fatal(err) 433 | } 434 | 435 | beforeShutdownCalled := make(chan struct{}) 436 | cb1 := func() bool { close(beforeShutdownCalled); return false } 437 | shutdownInitiatedCalled := make(chan struct{}) 438 | cb2 := func() { close(shutdownInitiatedCalled) } 439 | 440 | srv := &Server{Server: server, BeforeShutdown: cb1, ShutdownInitiated: cb2} 441 | go func() { 442 | srv.Serve(l) 443 | wg.Done() 444 | }() 445 | go func() { 446 | time.Sleep(waitTime) 447 | srv.Stop(killTime) 448 | }() 449 | 450 | beforeShutdown := false 451 | shutdownInitiated := false 452 | timeouted := false 453 | 454 | for i := 0; i < 2; i++ { 455 | select { 456 | case <-beforeShutdownCalled: 457 | beforeShutdownCalled = nil 458 | beforeShutdown = true 459 | case <-shutdownInitiatedCalled: 460 | shutdownInitiatedCalled = nil 461 | shutdownInitiated = true 462 | case <-time.After(killTime): 463 | timeouted = true 464 | } 465 | } 466 | 467 | if !beforeShutdown { 468 | t.Fatal("beforeShutdown should be true") 469 | } 470 | if !timeouted { 471 | t.Fatal("timeouted should be true") 472 | } 473 | if shutdownInitiated { 474 | t.Fatal("shutdownInitiated shouldn't be true") 475 | } 476 | 477 | srv.BeforeShutdown = func() bool { return true } 478 | srv.Stop(killTime) 479 | 480 | wg.Wait() 481 | } 482 | 483 | func hijackingListener(srv *Server) (*http.Server, net.Listener, error) { 484 | mux := http.NewServeMux() 485 | mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { 486 | conn, bufrw, err := rw.(http.Hijacker).Hijack() 487 | if err != nil { 488 | http.Error(rw, "webserver doesn't support hijacking", http.StatusInternalServerError) 489 | return 490 | } 491 | 492 | defer conn.Close() 493 | 494 | bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") 495 | bufrw.Flush() 496 | }) 497 | 498 | server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux} 499 | l, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) 500 | return server, l, err 501 | } 502 | 503 | func TestNotifyClosed(t *testing.T) { 504 | var wg sync.WaitGroup 505 | defer wg.Wait() 506 | 507 | c := make(chan os.Signal, 1) 508 | srv := &Server{Timeout: killTime, interrupt: c} 509 | server, l, err := hijackingListener(srv) 510 | if err != nil { 511 | t.Fatal(err) 512 | } 513 | 514 | srv.Server = server 515 | 516 | wg.Add(1) 517 | go func() { 518 | defer wg.Done() 519 | srv.Serve(l) 520 | }() 521 | 522 | var once sync.Once 523 | for i := 0; i < concurrentRequestN; i++ { 524 | wg.Add(1) 525 | runQuery(t, http.StatusOK, false, &wg, &once) 526 | } 527 | 528 | srv.Stop(0) 529 | 530 | // block on the stopChan until the server has shut down 531 | select { 532 | case <-srv.StopChan(): 533 | case <-time.After(timeoutTime): 534 | t.Fatal("Timed out while waiting for explicit stop to complete") 535 | } 536 | 537 | if len(srv.connections) > 0 { 538 | t.Fatal("hijacked connections should not be managed") 539 | } 540 | 541 | } 542 | 543 | func TestStopDeadlock(t *testing.T) { 544 | var wg sync.WaitGroup 545 | defer wg.Wait() 546 | 547 | c := make(chan struct{}) 548 | server, l, err := createListener(1 * time.Millisecond) 549 | if err != nil { 550 | t.Fatal(err) 551 | } 552 | 553 | srv := &Server{Server: server, NoSignalHandling: true} 554 | 555 | wg.Add(2) 556 | go func() { 557 | defer wg.Done() 558 | time.Sleep(waitTime) 559 | srv.Serve(l) 560 | }() 561 | go func() { 562 | defer wg.Done() 563 | srv.Stop(0) 564 | close(c) 565 | }() 566 | 567 | select { 568 | case <-c: 569 | l.Close() 570 | case <-time.After(timeoutTime): 571 | t.Fatal("Timed out while waiting for explicit stop to complete") 572 | } 573 | } 574 | 575 | // Run with --race 576 | func TestStopRace(t *testing.T) { 577 | server, l, err := createListener(1 * time.Millisecond) 578 | if err != nil { 579 | t.Fatal(err) 580 | } 581 | 582 | srv := &Server{Timeout: killTime, Server: server} 583 | 584 | go func() { 585 | go srv.Serve(l) 586 | srv.Stop(killTime) 587 | }() 588 | srv.Stop(0) 589 | select { 590 | case <-srv.StopChan(): 591 | case <-time.After(timeoutTime): 592 | t.Fatal("Timed out while waiting for explicit stop to complete") 593 | } 594 | } 595 | 596 | func TestInterruptLog(t *testing.T) { 597 | c := make(chan os.Signal, 1) 598 | 599 | server, l, err := createListener(killTime * 10) 600 | if err != nil { 601 | t.Fatal(err) 602 | } 603 | 604 | var buf bytes.Buffer 605 | var tbuf bytes.Buffer 606 | logger := log.New(&buf, "", 0) 607 | expected := log.New(&tbuf, "", 0) 608 | 609 | srv := &Server{Timeout: killTime, Server: server, Logger: logger, interrupt: c} 610 | go func() { srv.Serve(l) }() 611 | 612 | stop := srv.StopChan() 613 | c <- os.Interrupt 614 | expected.Print("shutdown initiated") 615 | 616 | <-stop 617 | 618 | if buf.String() != tbuf.String() { 619 | t.Fatal("shutdown log incorrect - got '" + buf.String() + "'") 620 | } 621 | } 622 | 623 | func TestMultiInterrupts(t *testing.T) { 624 | c := make(chan os.Signal, 1) 625 | 626 | server, l, err := createListener(killTime * 10) 627 | if err != nil { 628 | t.Fatal(err) 629 | } 630 | 631 | var wg sync.WaitGroup 632 | var bu bytes.Buffer 633 | buf := SyncBuffer{&wg, &bu} 634 | var tbuf bytes.Buffer 635 | logger := log.New(&buf, "", 0) 636 | expected := log.New(&tbuf, "", 0) 637 | 638 | srv := &Server{Timeout: killTime, Server: server, Logger: logger, interrupt: c} 639 | go func() { srv.Serve(l) }() 640 | 641 | stop := srv.StopChan() 642 | buf.Add(1 + 10) // Expecting 11 log calls 643 | c <- os.Interrupt 644 | expected.Printf("shutdown initiated") 645 | for i := 0; i < 10; i++ { 646 | c <- os.Interrupt 647 | expected.Printf("already shutting down") 648 | } 649 | 650 | <-stop 651 | 652 | wg.Wait() 653 | bb, bt := buf.Bytes(), tbuf.Bytes() 654 | for i, b := range bb { 655 | if b != bt[i] { 656 | t.Fatal(fmt.Sprintf("shutdown log incorrect - got '%s', expected '%s'", buf.String(), tbuf.String())) 657 | } 658 | } 659 | } 660 | 661 | func TestLogFunc(t *testing.T) { 662 | c := make(chan os.Signal, 1) 663 | 664 | server, l, err := createListener(killTime * 10) 665 | if err != nil { 666 | t.Fatal(err) 667 | } 668 | var called bool 669 | srv := &Server{Timeout: killTime, Server: server, 670 | LogFunc: func(format string, args ...interface{}) { 671 | called = true 672 | }, interrupt: c} 673 | stop := srv.StopChan() 674 | go func() { srv.Serve(l) }() 675 | c <- os.Interrupt 676 | <-stop 677 | 678 | if called != true { 679 | t.Fatal("Expected LogFunc to be called.") 680 | } 681 | } 682 | 683 | // SyncBuffer calls Done on the embedded wait group after each call to Write. 684 | type SyncBuffer struct { 685 | *sync.WaitGroup 686 | *bytes.Buffer 687 | } 688 | 689 | func (buf *SyncBuffer) Write(b []byte) (int, error) { 690 | defer buf.Done() 691 | return buf.Buffer.Write(b) 692 | } 693 | -------------------------------------------------------------------------------- /http2_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.6 2 | 3 | package graceful 4 | 5 | import ( 6 | "crypto/tls" 7 | "fmt" 8 | "net/http" 9 | "os" 10 | "sync" 11 | "testing" 12 | "time" 13 | 14 | "golang.org/x/net/http2" 15 | ) 16 | 17 | func createServer() *http.Server { 18 | mux := http.NewServeMux() 19 | mux.HandleFunc("/", func(rw http.ResponseWriter, r *http.Request) { 20 | rw.WriteHeader(http.StatusOK) 21 | }) 22 | 23 | server := &http.Server{Addr: fmt.Sprintf(":%d", port), Handler: mux} 24 | 25 | return server 26 | } 27 | 28 | func checkIfConnectionToServerIsHTTP2(t *testing.T, wg *sync.WaitGroup, c chan os.Signal) { 29 | 30 | defer wg.Done() 31 | 32 | tr := &http.Transport{ 33 | TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 34 | } 35 | 36 | err := http2.ConfigureTransport(tr) 37 | 38 | if err != nil { 39 | t.Fatal("Unable to upgrade client transport to HTTP/2") 40 | } 41 | 42 | client := http.Client{Transport: tr} 43 | r, err := client.Get(fmt.Sprintf("https://localhost:%d", port)) 44 | 45 | c <- os.Interrupt 46 | 47 | if err != nil { 48 | t.Fatalf("Error encountered while connecting to test server: %s", err) 49 | } 50 | 51 | if !r.ProtoAtLeast(2, 0) { 52 | t.Fatalf("Expected HTTP/2 connection to server, but connection was using %s", r.Proto) 53 | } 54 | } 55 | 56 | func TestHTTP2ListenAndServeTLS(t *testing.T) { 57 | 58 | c := make(chan os.Signal, 1) 59 | 60 | var wg sync.WaitGroup 61 | wg.Add(1) 62 | 63 | server := createServer() 64 | 65 | var srv *Server 66 | go func() { 67 | // set timeout of 0 to test idle connection closing 68 | srv = &Server{Timeout: 0, TCPKeepAlive: 1 * time.Minute, Server: server, interrupt: c} 69 | srv.ListenAndServeTLS("test-fixtures/cert.crt", "test-fixtures/key.pem") 70 | wg.Done() 71 | }() 72 | 73 | time.Sleep(waitTime) // Wait for the server to start 74 | 75 | wg.Add(1) 76 | go checkIfConnectionToServerIsHTTP2(t, &wg, c) 77 | wg.Wait() 78 | 79 | c <- os.Interrupt // kill the server to close idle connections 80 | 81 | // block on the stopChan until the server has shut down 82 | select { 83 | case <-srv.StopChan(): 84 | case <-time.After(killTime * 2): 85 | t.Fatal("Timed out while waiting for explicit stop to complete") 86 | } 87 | 88 | } 89 | 90 | func TestHTTP2ListenAndServeTLSConfig(t *testing.T) { 91 | 92 | c := make(chan os.Signal, 1) 93 | 94 | var wg sync.WaitGroup 95 | 96 | wg.Add(1) 97 | 98 | server2 := createServer() 99 | 100 | go func() { 101 | srv := &Server{Timeout: killTime, TCPKeepAlive: 1 * time.Minute, Server: server2, interrupt: c} 102 | 103 | cert, err := tls.LoadX509KeyPair("test-fixtures/cert.crt", "test-fixtures/key.pem") 104 | 105 | if err != nil { 106 | t.Fatalf("Unexpected error: %s", err) 107 | } 108 | 109 | tlsConf := &tls.Config{ 110 | Certificates: []tls.Certificate{cert}, 111 | NextProtos: []string{"h2"}, // We need to explicitly enable http/2 in Go 1.7+ 112 | } 113 | 114 | tlsConf.BuildNameToCertificate() 115 | 116 | srv.ListenAndServeTLSConfig(tlsConf) 117 | wg.Done() 118 | }() 119 | 120 | time.Sleep(waitTime) // Wait for the server to start 121 | 122 | wg.Add(1) 123 | go checkIfConnectionToServerIsHTTP2(t, &wg, c) 124 | wg.Wait() 125 | } 126 | -------------------------------------------------------------------------------- /keepalive_listener.go: -------------------------------------------------------------------------------- 1 | package graceful 2 | 3 | import ( 4 | "net" 5 | "time" 6 | ) 7 | 8 | type keepAliveConn interface { 9 | SetKeepAlive(bool) error 10 | SetKeepAlivePeriod(d time.Duration) error 11 | } 12 | 13 | // keepAliveListener sets TCP keep-alive timeouts on accepted 14 | // connections. It's used by ListenAndServe and ListenAndServeTLS so 15 | // dead TCP connections (e.g. closing laptop mid-download) eventually 16 | // go away. 17 | type keepAliveListener struct { 18 | net.Listener 19 | keepAlivePeriod time.Duration 20 | } 21 | 22 | func (ln keepAliveListener) Accept() (net.Conn, error) { 23 | c, err := ln.Listener.Accept() 24 | if err != nil { 25 | return nil, err 26 | } 27 | 28 | kac := c.(keepAliveConn) 29 | kac.SetKeepAlive(true) 30 | kac.SetKeepAlivePeriod(ln.keepAlivePeriod) 31 | return c, nil 32 | } 33 | -------------------------------------------------------------------------------- /limit_listen.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The etcd Authors 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package graceful 16 | 17 | import ( 18 | "errors" 19 | "net" 20 | "sync" 21 | "time" 22 | ) 23 | 24 | // ErrNotTCP indicates that network connection is not a TCP connection. 25 | var ErrNotTCP = errors.New("only tcp connections have keepalive") 26 | 27 | // LimitListener returns a Listener that accepts at most n simultaneous 28 | // connections from the provided Listener. 29 | func LimitListener(l net.Listener, n int) net.Listener { 30 | return &limitListener{l, make(chan struct{}, n)} 31 | } 32 | 33 | type limitListener struct { 34 | net.Listener 35 | sem chan struct{} 36 | } 37 | 38 | func (l *limitListener) acquire() { l.sem <- struct{}{} } 39 | func (l *limitListener) release() { <-l.sem } 40 | 41 | func (l *limitListener) Accept() (net.Conn, error) { 42 | l.acquire() 43 | c, err := l.Listener.Accept() 44 | if err != nil { 45 | l.release() 46 | return nil, err 47 | } 48 | return &limitListenerConn{Conn: c, release: l.release}, nil 49 | } 50 | 51 | type limitListenerConn struct { 52 | net.Conn 53 | releaseOnce sync.Once 54 | release func() 55 | } 56 | 57 | func (l *limitListenerConn) Close() error { 58 | err := l.Conn.Close() 59 | l.releaseOnce.Do(l.release) 60 | return err 61 | } 62 | 63 | func (l *limitListenerConn) SetKeepAlive(doKeepAlive bool) error { 64 | tcpc, ok := l.Conn.(*net.TCPConn) 65 | if !ok { 66 | return ErrNotTCP 67 | } 68 | return tcpc.SetKeepAlive(doKeepAlive) 69 | } 70 | 71 | func (l *limitListenerConn) SetKeepAlivePeriod(d time.Duration) error { 72 | tcpc, ok := l.Conn.(*net.TCPConn) 73 | if !ok { 74 | return ErrNotTCP 75 | } 76 | return tcpc.SetKeepAlivePeriod(d) 77 | } 78 | -------------------------------------------------------------------------------- /signal.go: -------------------------------------------------------------------------------- 1 | //+build !appengine 2 | 3 | package graceful 4 | 5 | import ( 6 | "os" 7 | "os/signal" 8 | "syscall" 9 | ) 10 | 11 | func signalNotify(interrupt chan<- os.Signal) { 12 | signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) 13 | } 14 | 15 | func sendSignalInt(interrupt chan<- os.Signal) { 16 | interrupt <- syscall.SIGINT 17 | } 18 | -------------------------------------------------------------------------------- /signal_appengine.go: -------------------------------------------------------------------------------- 1 | //+build appengine 2 | 3 | package graceful 4 | 5 | import "os" 6 | 7 | func signalNotify(interrupt chan<- os.Signal) { 8 | // Does not notify in the case of AppEngine. 9 | } 10 | 11 | func sendSignalInt(interrupt chan<- os.Signal) { 12 | // Does not send in the case of AppEngine. 13 | } 14 | -------------------------------------------------------------------------------- /test-fixtures/cert.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDhTCCAm2gAwIBAgIUDvdWhjUd/JS+E5bxZlmCM+giGHMwDQYJKoZIhvcNAQEL 3 | BQAwHzEdMBsGA1UEAxMUVGVzdCBJbnRlcm1lZGlhdGUgQ0EwHhcNMTYwNjAyMDMy 4 | MjA0WhcNMTkwNjAyMDMyMjM0WjAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwggEiMA0G 5 | CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDDoyMTUK2OSp+XhKRXB/+uO6YAJE/W 6 | 2rzqARahWT6boHZMDhHXRtdwYxWwiUqoxlEeBrEerQ2qPFAqlWkDw8zliE/DWgXg 7 | BiW+Vq5DAn3F1jZ5WskLWr1iP48oK4/l+BXEsDd44MHZFoSZiWlr2Fi4iaIHJE7+ 8 | LGBqPVQXwBYTyc7Jvi3HY8I4/waaAwXoSo8vDPjRiMCD2wlg24Rimocf4goa/2Xs 9 | Z0NU76Uf2jPdsZ5MujjKRqwHDEAjiBq0aPvm6igkNGAGoZ6QYEptO+J4t1oFrbdP 10 | gYRlpqCa3ekr9gc+wg5AO/V9x8/cypbQ8tpwFwvvSYg2TJaUMZ5abc+HAgMBAAGj 11 | gcMwgcAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBQC 12 | R0Y69NLOfFCLRiB5N3uoacILXTAfBgNVHSMEGDAWgBRm0fFHSXtDCVHC8UW7/obv 13 | DLp9tTBJBggrBgEFBQcBAQQ9MDswOQYIKwYBBQUHMAKGLWh0dHA6Ly9sb2NhbGhv 14 | c3Qvc2VsZi1pc3N1ZWQtaW50ZXJtZWRpYXRlLmNydDAUBgNVHREEDTALgglsb2Nh 15 | bGhvc3QwDQYJKoZIhvcNAQELBQADggEBALAf/nowwB0NJ7lGGaoVKhmMHxBEQkd1 16 | K/jBAlJg9Kgmg1IJJ7zLE3SeYF8tGTNYATd4RLmqo1GakrMDaKWNXd74v3p/tWmb 17 | 4vqCh6WzFPHU1dpxDKtbbmaLt9Ije7s6DuQAz9bBXM0mN0vy5F0dORpx/j0h3u1B 18 | j7B5O8kLejPY2w/8pd+QECCb1Q5A6Xx1EEsJpzTlGXO0SBla/oCg+nvirsBGVpWr 19 | bGskAIwG9wNKuGfg4m5u1bL87iX80NemeLtWRWVM+Ry/RhfOokH59/EIFRAXeRz6 20 | gXjIWa0vcXnhW1MOvbD1GFYhO6AJAnDwWes48WfBHysOhq0RycdpGw0= 21 | -----END CERTIFICATE----- 22 | -----BEGIN CERTIFICATE----- 23 | MIIDjTCCAnWgAwIBAgIUMzpit8+j2dWxdk1PdMqGWYalZyIwDQYJKoZIhvcNAQEL 24 | BQAwFzEVMBMGA1UEAxMMVGVzdCBSb290IENBMB4XDTE2MDUyOTEwNDYwMFoXDTMx 25 | MDUyNjEwNDYzMFowHzEdMBsGA1UEAxMUVGVzdCBJbnRlcm1lZGlhdGUgQ0EwggEi 26 | MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDs6kY6mHJWzupq5dsSavPZHuv6 27 | 0E9PczHbujWLuzv7+qbwzcAgfRvaeR0xgvf7q9pjMgJ7/kNANgneWGpwciLgHtiJ 28 | rSHii3RZfWlK4gdbCXya9EmHj8zO+9xGBHM0FrqfqA+IA70SimFcwGPrGHyERsdX 29 | +mqO64Z95yI5uJpoS8OBAUPU8i6xvNLZGmgUEF3CRhDDTYVGcTEtKAPcnnBuZzZU 30 | Ds+DrHf/MC7HHK0/l0auuRz3p+/GFNePGePG+FFbInS/vwHwrkMW2tzBKG41K+gD 31 | GfkTjVU8xBSiMYOiEja6YcJ4GuzEPcmu5LS+6BkLlsIbazDW5IM8p+7+8RKjAgMB 32 | AAGjgcgwgcUwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0O 33 | BBYEFGbR8UdJe0MJUcLxRbv+hu8Mun21MB8GA1UdIwQYMBaAFKmz0h3CW1HBO9uz 34 | uCzg+MNPGZtkMEEGCCsGAQUFBwEBBDUwMzAxBggrBgEFBQcwAoYlaHR0cDovL2xv 35 | Y2FsaG9zdC9zZWxmLWlzc3VlZC1yb290LmNydDAfBgNVHREEGDAWghRUZXN0IElu 36 | dGVybWVkaWF0ZSBDQTANBgkqhkiG9w0BAQsFAAOCAQEAaYVGqHbaE0c9F/kyIMgu 37 | S3HuNn4pBh2EwGcKIlPkDe43hqXjhS/+itmWk75rQz+Rw+acevGoxbpDR38abTIS 38 | RJd9L/3MA644z8F82er3pNjKqvS/vTre/wsvGYwmEM+GrgJw3HUcisc93qLgaWH2 39 | kjky208k9kOuzJDiY45eu9TfSSmjSHSMCtxk8p5wYKDcfVz+uqlBhVEiHGjQIc2E 40 | 66SituusiwgQv/mdtEW7y48EvMGdzxPfLFcvj06B3vTsZaaYyB6GyKwMcaPFvHRr 41 | V0yYaKRZgAh4X6LHlgPJqvIv3gjMdJR55durAO7tI9Pos0o5Lv5WJgi0g0KvMsco 42 | qQ== 43 | -----END CERTIFICATE----- -------------------------------------------------------------------------------- /test-fixtures/key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpQIBAAKCAQEAw6MjE1Ctjkqfl4SkVwf/rjumACRP1tq86gEWoVk+m6B2TA4R 3 | 10bXcGMVsIlKqMZRHgaxHq0NqjxQKpVpA8PM5YhPw1oF4AYlvlauQwJ9xdY2eVrJ 4 | C1q9Yj+PKCuP5fgVxLA3eODB2RaEmYlpa9hYuImiByRO/ixgaj1UF8AWE8nOyb4t 5 | x2PCOP8GmgMF6EqPLwz40YjAg9sJYNuEYpqHH+IKGv9l7GdDVO+lH9oz3bGeTLo4 6 | ykasBwxAI4gatGj75uooJDRgBqGekGBKbTvieLdaBa23T4GEZaagmt3pK/YHPsIO 7 | QDv1fcfP3MqW0PLacBcL70mINkyWlDGeWm3PhwIDAQABAoIBAQC87HWa2XZAyt+D 8 | OpxZT2ghoYiU6nwPR/zXHWX1OnGzaCnVGGEyOz8hUQ5JBMwMYDdFf8DbltJzavsf 9 | pFldQWBE6HXeeLjjtgwM2zg9jdJXkp3YY0tyo5XvouFkMW0s735WCrYHDUUllxFG 10 | E+SyOKK00nSd4PpHiiMxdTgYF286exwOpzjhcJfAkn7oBNeOGc5VLOvcvakrSrdq 11 | OYBAJ25HSVFnSQbeAAsCzBEBZC0WLyB1BQGcidbtEn8sxyGnV8HWjbXY+MJQWHg+ 12 | q2iK+uvO4wtrE/WC6p4Ty44Myh+AB79s35HWKYd4okwKkpI1QdD543TIiZnkNEVI 13 | aS/uH13BAoGBAP/psBxKzIft59hw+U9NscH6N9/ze8iAtOtqsWdER/qXCrlUn8+j 14 | F/xquJR6gDj5GwGBt07asEuoG8CKJMQI0c3AeHF7XBcmUunBStktb9O97Zsp6bNJ 15 | olsrWlM4yvVuCVizEwIYjHrMBOS3YIPErM1LmAyDHmzx3+yz+3+WxRQLAoGBAMO0 16 | MaJDPisMC05pvieHRb91HlsiSrASeMkw1FmHI0b/gcC88mEnuXIze1ySoF6FE7B7 17 | xaEm6Lf5Snl0JgXPDSj6ukd51NdaU2VmpKvDOrvQ5QQE9mXaDkXv/i2B0YkCh+Hy 18 | bkziW1IKnWT2PTRAAEIJQ22oK51MdQnvCdmtsIP1AoGBAKnMiEl9Z9AZDmgSLZls 19 | 17D5MPGrQEp8+43oMOVv7MJcTYVCnPbMJDIbLXV3AnTK9Bw/0TzE5YyNcjyCbHqV 20 | z39RYZkKXMQPbZwj4GHRQA2iS3FUkfeft9X+IeRuHlxSMmlkCAyv9SXVELog4i0L 21 | 5gwhSDWlGh73LbiEgy7Y/tKZAoGBALTiMhYGDMoA4dpiBi3G7AKgH6SgN2QyTo22 22 | oi71pveSZb1dZrHB47fYOadApxV17tLqM6pVqjeRJPLJFfO8gi9kPxSdWMqLZBWP 23 | H5jaY8kAtQxYAd32A8dEoSwylxcJzcpbJvPNLBbSVNPifIN0vEhNA5OxIk7LQkoi 24 | NHqL/WCZAoGAPf3kb9Gw/NkBq4Cn86pQfP/xE0h7zcoNmFtLbdKIjId+DDDOPOeX 25 | 9tm33fZzw0SG4KlRQlsqgzFvm8aDD8rpW17341Z/rWlLo8uHNdRkMvbSabc34vPv 26 | 4lrs0rHSYW06MlqkJBNVraySRz7hmU4+n7YMvNI0Due9mVGmE1NU/vI= 27 | -----END RSA PRIVATE KEY----- -------------------------------------------------------------------------------- /tests/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | 7 | "github.com/urfave/negroni" 8 | "gopkg.in/tylerb/graceful.v1" 9 | ) 10 | 11 | func main() { 12 | 13 | var wg sync.WaitGroup 14 | 15 | wg.Add(3) 16 | go func() { 17 | n := negroni.New() 18 | fmt.Println("Launching server on :3000") 19 | graceful.Run(":3000", 0, n) 20 | fmt.Println("Terminated server on :3000") 21 | wg.Done() 22 | }() 23 | go func() { 24 | n := negroni.New() 25 | fmt.Println("Launching server on :3001") 26 | graceful.Run(":3001", 0, n) 27 | fmt.Println("Terminated server on :3001") 28 | wg.Done() 29 | }() 30 | go func() { 31 | n := negroni.New() 32 | fmt.Println("Launching server on :3002") 33 | graceful.Run(":3002", 0, n) 34 | fmt.Println("Terminated server on :3002") 35 | wg.Done() 36 | }() 37 | fmt.Println("Press ctrl+c. All servers should terminate.") 38 | wg.Wait() 39 | 40 | } 41 | --------------------------------------------------------------------------------