├── .travis.yml ├── httpdown.go ├── httpdown_example └── main.go ├── httpdown_test.go ├── license └── readme.md /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.6 5 | 6 | before_install: 7 | - go get -v golang.org/x/tools/cmd/vet 8 | - go get -v golang.org/x/tools/cmd/cover 9 | - go get -v github.com/golang/lint/golint 10 | 11 | install: 12 | - go install -race -v std 13 | - go get -race -t -v ./... 14 | - go install -race -v ./... 15 | 16 | script: 17 | - go vet ./... 18 | - $HOME/gopath/bin/golint . 19 | - go test -cpu=2 -race -v ./... 20 | - go test -cpu=2 -covermode=atomic -coverprofile=coverage.txt ./ 21 | 22 | after_success: 23 | - bash <(curl -s https://codecov.io/bash) 24 | -------------------------------------------------------------------------------- /httpdown.go: -------------------------------------------------------------------------------- 1 | // Package httpdown provides http.ConnState enabled graceful termination of 2 | // http.Server. 3 | package httpdown 4 | 5 | import ( 6 | "crypto/tls" 7 | "fmt" 8 | "net" 9 | "net/http" 10 | "os" 11 | "os/signal" 12 | "sync" 13 | "syscall" 14 | "time" 15 | 16 | "github.com/facebookgo/clock" 17 | "github.com/facebookgo/stats" 18 | ) 19 | 20 | const ( 21 | defaultStopTimeout = time.Minute 22 | defaultKillTimeout = time.Minute 23 | ) 24 | 25 | // A Server allows encapsulates the process of accepting new connections and 26 | // serving them, and gracefully shutting down the listener without dropping 27 | // active connections. 28 | type Server interface { 29 | // Wait waits for the serving loop to finish. This will happen when Stop is 30 | // called, at which point it returns no error, or if there is an error in the 31 | // serving loop. You must call Wait after calling Serve or ListenAndServe. 32 | Wait() error 33 | 34 | // Stop stops the listener. It will block until all connections have been 35 | // closed. 36 | Stop() error 37 | } 38 | 39 | // HTTP defines the configuration for serving a http.Server. Multiple calls to 40 | // Serve or ListenAndServe can be made on the same HTTP instance. The default 41 | // timeouts of 1 minute each result in a maximum of 2 minutes before a Stop() 42 | // returns. 43 | type HTTP struct { 44 | // StopTimeout is the duration before we begin force closing connections. 45 | // Defaults to 1 minute. 46 | StopTimeout time.Duration 47 | 48 | // KillTimeout is the duration before which we completely give up and abort 49 | // even though we still have connected clients. This is useful when a large 50 | // number of client connections exist and closing them can take a long time. 51 | // Note, this is in addition to the StopTimeout. Defaults to 1 minute. 52 | KillTimeout time.Duration 53 | 54 | // Stats is optional. If provided, it will be used to record various metrics. 55 | Stats stats.Client 56 | 57 | // Clock allows for testing timing related functionality. Do not specify this 58 | // in production code. 59 | Clock clock.Clock 60 | } 61 | 62 | // Serve provides the low-level API which is useful if you're creating your own 63 | // net.Listener. 64 | func (h HTTP) Serve(s *http.Server, l net.Listener) Server { 65 | stopTimeout := h.StopTimeout 66 | if stopTimeout == 0 { 67 | stopTimeout = defaultStopTimeout 68 | } 69 | killTimeout := h.KillTimeout 70 | if killTimeout == 0 { 71 | killTimeout = defaultKillTimeout 72 | } 73 | klock := h.Clock 74 | if klock == nil { 75 | klock = clock.New() 76 | } 77 | 78 | ss := &server{ 79 | stopTimeout: stopTimeout, 80 | killTimeout: killTimeout, 81 | stats: h.Stats, 82 | clock: klock, 83 | oldConnState: s.ConnState, 84 | listener: l, 85 | server: s, 86 | serveDone: make(chan struct{}), 87 | serveErr: make(chan error, 1), 88 | new: make(chan net.Conn), 89 | active: make(chan net.Conn), 90 | idle: make(chan net.Conn), 91 | closed: make(chan net.Conn), 92 | stop: make(chan chan struct{}), 93 | kill: make(chan chan struct{}), 94 | } 95 | s.ConnState = ss.connState 96 | go ss.manage() 97 | go ss.serve() 98 | return ss 99 | } 100 | 101 | // ListenAndServe returns a Server for the given http.Server. It is equivalent 102 | // to ListenAndServe from the standard library, but returns immediately. 103 | // Requests will be accepted in a background goroutine. If the http.Server has 104 | // a non-nil TLSConfig, a TLS enabled listener will be setup. 105 | func (h HTTP) ListenAndServe(s *http.Server) (Server, error) { 106 | addr := s.Addr 107 | if addr == "" { 108 | if s.TLSConfig == nil { 109 | addr = ":http" 110 | } else { 111 | addr = ":https" 112 | } 113 | } 114 | l, err := net.Listen("tcp", addr) 115 | if err != nil { 116 | stats.BumpSum(h.Stats, "listen.error", 1) 117 | return nil, err 118 | } 119 | if s.TLSConfig != nil { 120 | l = tls.NewListener(l, s.TLSConfig) 121 | } 122 | return h.Serve(s, l), nil 123 | } 124 | 125 | // server manages the serving process and allows for gracefully stopping it. 126 | type server struct { 127 | stopTimeout time.Duration 128 | killTimeout time.Duration 129 | stats stats.Client 130 | clock clock.Clock 131 | 132 | oldConnState func(net.Conn, http.ConnState) 133 | server *http.Server 134 | serveDone chan struct{} 135 | serveErr chan error 136 | listener net.Listener 137 | 138 | new chan net.Conn 139 | active chan net.Conn 140 | idle chan net.Conn 141 | closed chan net.Conn 142 | stop chan chan struct{} 143 | kill chan chan struct{} 144 | 145 | stopOnce sync.Once 146 | stopErr error 147 | } 148 | 149 | func (s *server) connState(c net.Conn, cs http.ConnState) { 150 | if s.oldConnState != nil { 151 | s.oldConnState(c, cs) 152 | } 153 | 154 | switch cs { 155 | case http.StateNew: 156 | s.new <- c 157 | case http.StateActive: 158 | s.active <- c 159 | case http.StateIdle: 160 | s.idle <- c 161 | case http.StateHijacked, http.StateClosed: 162 | s.closed <- c 163 | } 164 | } 165 | 166 | func (s *server) manage() { 167 | defer func() { 168 | close(s.new) 169 | close(s.active) 170 | close(s.idle) 171 | close(s.closed) 172 | close(s.stop) 173 | close(s.kill) 174 | }() 175 | 176 | var stopDone chan struct{} 177 | 178 | conns := map[net.Conn]http.ConnState{} 179 | var countNew, countActive, countIdle float64 180 | 181 | // decConn decrements the count associated with the current state of the 182 | // given connection. 183 | decConn := func(c net.Conn) { 184 | switch conns[c] { 185 | default: 186 | panic(fmt.Errorf("unknown existing connection: %s", c)) 187 | case http.StateNew: 188 | countNew-- 189 | case http.StateActive: 190 | countActive-- 191 | case http.StateIdle: 192 | countIdle-- 193 | } 194 | } 195 | 196 | // setup a ticker to report various values every minute. if we don't have a 197 | // Stats implementation provided, we Stop it so it never ticks. 198 | statsTicker := s.clock.Ticker(time.Minute) 199 | if s.stats == nil { 200 | statsTicker.Stop() 201 | } 202 | 203 | for { 204 | select { 205 | case <-statsTicker.C: 206 | // we'll only get here when s.stats is not nil 207 | s.stats.BumpAvg("http-state.new", countNew) 208 | s.stats.BumpAvg("http-state.active", countActive) 209 | s.stats.BumpAvg("http-state.idle", countIdle) 210 | s.stats.BumpAvg("http-state.total", countNew+countActive+countIdle) 211 | case c := <-s.new: 212 | conns[c] = http.StateNew 213 | countNew++ 214 | case c := <-s.active: 215 | decConn(c) 216 | countActive++ 217 | 218 | conns[c] = http.StateActive 219 | case c := <-s.idle: 220 | decConn(c) 221 | countIdle++ 222 | 223 | conns[c] = http.StateIdle 224 | 225 | // if we're already stopping, close it 226 | if stopDone != nil { 227 | c.Close() 228 | } 229 | case c := <-s.closed: 230 | stats.BumpSum(s.stats, "conn.closed", 1) 231 | decConn(c) 232 | delete(conns, c) 233 | 234 | // if we're waiting to stop and are all empty, we just closed the last 235 | // connection and we're done. 236 | if stopDone != nil && len(conns) == 0 { 237 | close(stopDone) 238 | return 239 | } 240 | case stopDone = <-s.stop: 241 | // if we're already all empty, we're already done 242 | if len(conns) == 0 { 243 | close(stopDone) 244 | return 245 | } 246 | 247 | // close current idle connections right away 248 | for c, cs := range conns { 249 | if cs == http.StateIdle { 250 | c.Close() 251 | } 252 | } 253 | 254 | // continue the loop and wait for all the ConnState updates which will 255 | // eventually close(stopDone) and return from this goroutine. 256 | 257 | case killDone := <-s.kill: 258 | // force close all connections 259 | stats.BumpSum(s.stats, "kill.conn.count", float64(len(conns))) 260 | for c := range conns { 261 | c.Close() 262 | } 263 | 264 | // don't block the kill. 265 | close(killDone) 266 | 267 | // continue the loop and we wait for all the ConnState updates and will 268 | // return from this goroutine when we're all done. otherwise we'll try to 269 | // send those ConnState updates on closed channels. 270 | 271 | } 272 | } 273 | } 274 | 275 | func (s *server) serve() { 276 | stats.BumpSum(s.stats, "serve", 1) 277 | s.serveErr <- s.server.Serve(s.listener) 278 | close(s.serveDone) 279 | close(s.serveErr) 280 | } 281 | 282 | func (s *server) Wait() error { 283 | if err := <-s.serveErr; !isUseOfClosedError(err) { 284 | return err 285 | } 286 | return nil 287 | } 288 | 289 | func (s *server) Stop() error { 290 | s.stopOnce.Do(func() { 291 | defer stats.BumpTime(s.stats, "stop.time").End() 292 | stats.BumpSum(s.stats, "stop", 1) 293 | 294 | // first disable keep-alive for new connections 295 | s.server.SetKeepAlivesEnabled(false) 296 | 297 | // then close the listener so new connections can't connect come thru 298 | closeErr := s.listener.Close() 299 | <-s.serveDone 300 | 301 | // then trigger the background goroutine to stop and wait for it 302 | stopDone := make(chan struct{}) 303 | s.stop <- stopDone 304 | 305 | // wait for stop 306 | select { 307 | case <-stopDone: 308 | case <-s.clock.After(s.stopTimeout): 309 | defer stats.BumpTime(s.stats, "kill.time").End() 310 | stats.BumpSum(s.stats, "kill", 1) 311 | 312 | // stop timed out, wait for kill 313 | killDone := make(chan struct{}) 314 | s.kill <- killDone 315 | select { 316 | case <-killDone: 317 | case <-s.clock.After(s.killTimeout): 318 | // kill timed out, give up 319 | stats.BumpSum(s.stats, "kill.timeout", 1) 320 | } 321 | } 322 | 323 | if closeErr != nil && !isUseOfClosedError(closeErr) { 324 | stats.BumpSum(s.stats, "listener.close.error", 1) 325 | s.stopErr = closeErr 326 | } 327 | }) 328 | return s.stopErr 329 | } 330 | 331 | func isUseOfClosedError(err error) bool { 332 | if err == nil { 333 | return false 334 | } 335 | if opErr, ok := err.(*net.OpError); ok { 336 | err = opErr.Err 337 | } 338 | return err.Error() == "use of closed network connection" 339 | } 340 | 341 | // ListenAndServe is a convenience function to serve and wait for a SIGTERM 342 | // or SIGINT before shutting down. 343 | func ListenAndServe(s *http.Server, hd *HTTP) error { 344 | if hd == nil { 345 | hd = &HTTP{} 346 | } 347 | hs, err := hd.ListenAndServe(s) 348 | if err != nil { 349 | return err 350 | } 351 | 352 | waiterr := make(chan error, 1) 353 | go func() { 354 | defer close(waiterr) 355 | waiterr <- hs.Wait() 356 | }() 357 | 358 | signals := make(chan os.Signal, 10) 359 | signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) 360 | 361 | select { 362 | case err := <-waiterr: 363 | if err != nil { 364 | return err 365 | } 366 | case <-signals: 367 | signal.Stop(signals) 368 | if err := hs.Stop(); err != nil { 369 | return err 370 | } 371 | if err := <-waiterr; err != nil { 372 | return err 373 | } 374 | } 375 | return nil 376 | } 377 | -------------------------------------------------------------------------------- /httpdown_example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "net/http" 7 | "os" 8 | "time" 9 | 10 | "github.com/facebookgo/httpdown" 11 | ) 12 | 13 | func handler(w http.ResponseWriter, r *http.Request) { 14 | duration, err := time.ParseDuration(r.FormValue("duration")) 15 | if err != nil { 16 | http.Error(w, err.Error(), 400) 17 | return 18 | } 19 | fmt.Fprintf(w, "going to sleep %s with pid %d\n", duration, os.Getpid()) 20 | w.(http.Flusher).Flush() 21 | time.Sleep(duration) 22 | fmt.Fprintf(w, "slept %s with pid %d\n", duration, os.Getpid()) 23 | } 24 | 25 | func main() { 26 | server := &http.Server{ 27 | Addr: "127.0.0.1:8080", 28 | Handler: http.HandlerFunc(handler), 29 | } 30 | hd := &httpdown.HTTP{ 31 | StopTimeout: 10 * time.Second, 32 | KillTimeout: 1 * time.Second, 33 | } 34 | 35 | flag.StringVar(&server.Addr, "addr", server.Addr, "http address") 36 | flag.DurationVar(&hd.StopTimeout, "stop-timeout", hd.StopTimeout, "stop timeout") 37 | flag.DurationVar(&hd.KillTimeout, "kill-timeout", hd.KillTimeout, "kill timeout") 38 | flag.Parse() 39 | 40 | if err := httpdown.ListenAndServe(server, hd); err != nil { 41 | panic(err) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /httpdown_test.go: -------------------------------------------------------------------------------- 1 | package httpdown_test 2 | 3 | import ( 4 | "bytes" 5 | "crypto/tls" 6 | "errors" 7 | "fmt" 8 | "io/ioutil" 9 | "net" 10 | "net/http" 11 | "os" 12 | "regexp" 13 | "sync" 14 | "sync/atomic" 15 | "testing" 16 | "time" 17 | 18 | "github.com/facebookgo/clock" 19 | "github.com/facebookgo/ensure" 20 | "github.com/facebookgo/freeport" 21 | "github.com/facebookgo/httpdown" 22 | "github.com/facebookgo/stats" 23 | ) 24 | 25 | type onCloseListener struct { 26 | net.Listener 27 | mutex sync.Mutex 28 | onClose chan struct{} 29 | } 30 | 31 | func (o *onCloseListener) Close() error { 32 | // Listener is closed twice, once by Grace, and once by the http library, so 33 | // we guard against a double close of the chan. 34 | defer func() { 35 | o.mutex.Lock() 36 | defer o.mutex.Unlock() 37 | if o.onClose != nil { 38 | close(o.onClose) 39 | o.onClose = nil 40 | } 41 | }() 42 | return o.Listener.Close() 43 | } 44 | 45 | func NewOnCloseListener(l net.Listener) (net.Listener, chan struct{}) { 46 | c := make(chan struct{}) 47 | return &onCloseListener{Listener: l, onClose: c}, c 48 | } 49 | 50 | type closeErrListener struct { 51 | net.Listener 52 | err error 53 | } 54 | 55 | func (c *closeErrListener) Close() error { 56 | c.Listener.Close() 57 | return c.err 58 | } 59 | 60 | type acceptErrListener struct { 61 | net.Listener 62 | err chan error 63 | } 64 | 65 | func (c *acceptErrListener) Accept() (net.Conn, error) { 66 | return nil, <-c.err 67 | } 68 | 69 | type closeErrConn struct { 70 | net.Conn 71 | unblockClose chan chan struct{} 72 | } 73 | 74 | func (c *closeErrConn) Close() error { 75 | ch := <-c.unblockClose 76 | 77 | // Close gets called multiple times, but only the first one gets this ch 78 | if ch != nil { 79 | defer close(ch) 80 | } 81 | 82 | return c.Conn.Close() 83 | } 84 | 85 | type closeErrConnListener struct { 86 | net.Listener 87 | unblockClose chan chan struct{} 88 | } 89 | 90 | func (l *closeErrConnListener) Accept() (net.Conn, error) { 91 | c, err := l.Listener.Accept() 92 | if err != nil { 93 | return c, err 94 | } 95 | return &closeErrConn{Conn: c, unblockClose: l.unblockClose}, nil 96 | } 97 | 98 | func TestHTTPStopWithNoRequest(t *testing.T) { 99 | t.Parallel() 100 | listener, err := net.Listen("tcp", "127.0.0.1:0") 101 | ensure.Nil(t, err) 102 | 103 | statsDone := make(chan struct{}, 2) 104 | hc := &stats.HookClient{ 105 | BumpSumHook: func(key string, val float64) { 106 | if key == "serve" && val == 1 { 107 | statsDone <- struct{}{} 108 | } 109 | if key == "stop" && val == 1 { 110 | statsDone <- struct{}{} 111 | } 112 | }, 113 | } 114 | 115 | server := &http.Server{} 116 | down := &httpdown.HTTP{Stats: hc} 117 | s := down.Serve(server, listener) 118 | ensure.Nil(t, s.Stop()) 119 | <-statsDone 120 | <-statsDone 121 | } 122 | 123 | func TestHTTPStopWithFinishedRequest(t *testing.T) { 124 | t.Parallel() 125 | hello := []byte("hello") 126 | fin := make(chan struct{}) 127 | okHandler := func(w http.ResponseWriter, r *http.Request) { 128 | defer close(fin) 129 | w.Write(hello) 130 | } 131 | 132 | listener, err := net.Listen("tcp", "127.0.0.1:0") 133 | ensure.Nil(t, err) 134 | server := &http.Server{Handler: http.HandlerFunc(okHandler)} 135 | transport := &http.Transport{} 136 | client := &http.Client{Transport: transport} 137 | down := &httpdown.HTTP{} 138 | s := down.Serve(server, listener) 139 | res, err := client.Get(fmt.Sprintf("http://%s/", listener.Addr().String())) 140 | ensure.Nil(t, err) 141 | actualBody, err := ioutil.ReadAll(res.Body) 142 | ensure.Nil(t, err) 143 | ensure.DeepEqual(t, actualBody, hello) 144 | ensure.Nil(t, res.Body.Close()) 145 | 146 | // At this point the request is finished, and the connection should be alive 147 | // but idle (because we have keep alive enabled by default in our Transport). 148 | ensure.Nil(t, s.Stop()) 149 | <-fin 150 | 151 | ensure.Nil(t, s.Wait()) 152 | } 153 | 154 | func TestHTTPStopWithActiveRequest(t *testing.T) { 155 | t.Parallel() 156 | const count = 10000 157 | hello := []byte("hello") 158 | finOkHandler := make(chan struct{}) 159 | okHandler := func(w http.ResponseWriter, r *http.Request) { 160 | defer close(finOkHandler) 161 | w.WriteHeader(200) 162 | for i := 0; i < count; i++ { 163 | w.Write(hello) 164 | } 165 | } 166 | 167 | listener, err := net.Listen("tcp", "127.0.0.1:0") 168 | ensure.Nil(t, err) 169 | server := &http.Server{Handler: http.HandlerFunc(okHandler)} 170 | transport := &http.Transport{} 171 | client := &http.Client{Transport: transport} 172 | down := &httpdown.HTTP{} 173 | s := down.Serve(server, listener) 174 | res, err := client.Get(fmt.Sprintf("http://%s/", listener.Addr().String())) 175 | ensure.Nil(t, err) 176 | 177 | finStop := make(chan struct{}) 178 | go func() { 179 | defer close(finStop) 180 | ensure.Nil(t, s.Stop()) 181 | }() 182 | 183 | actualBody, err := ioutil.ReadAll(res.Body) 184 | ensure.Nil(t, err) 185 | ensure.DeepEqual(t, actualBody, bytes.Repeat(hello, count)) 186 | ensure.Nil(t, res.Body.Close()) 187 | <-finOkHandler 188 | <-finStop 189 | } 190 | 191 | func TestNewRequestAfterStop(t *testing.T) { 192 | t.Parallel() 193 | const count = 10000 194 | hello := []byte("hello") 195 | finOkHandler := make(chan struct{}) 196 | unblockOkHandler := make(chan struct{}) 197 | okHandler := func(w http.ResponseWriter, r *http.Request) { 198 | defer close(finOkHandler) 199 | w.WriteHeader(200) 200 | const diff = 500 201 | for i := 0; i < count-diff; i++ { 202 | w.Write(hello) 203 | } 204 | <-unblockOkHandler 205 | for i := 0; i < diff; i++ { 206 | w.Write(hello) 207 | } 208 | } 209 | 210 | listener, err := net.Listen("tcp", "127.0.0.1:0") 211 | listener, onClose := NewOnCloseListener(listener) 212 | ensure.Nil(t, err) 213 | server := &http.Server{Handler: http.HandlerFunc(okHandler)} 214 | transport := &http.Transport{} 215 | client := &http.Client{Transport: transport} 216 | down := &httpdown.HTTP{} 217 | s := down.Serve(server, listener) 218 | res, err := client.Get(fmt.Sprintf("http://%s/", listener.Addr().String())) 219 | ensure.Nil(t, err) 220 | 221 | finStop := make(chan struct{}) 222 | go func() { 223 | defer close(finStop) 224 | ensure.Nil(t, s.Stop()) 225 | }() 226 | 227 | // Wait until the listener is closed. 228 | <-onClose 229 | 230 | // Now the next request should not be able to connect as the listener is 231 | // now closed. 232 | _, err = client.Get(fmt.Sprintf("http://%s/", listener.Addr().String())) 233 | 234 | // We should just get "connection refused" here, but sometimes, very rarely, 235 | // we get a "connection reset" instead. Unclear why this happens. 236 | ensure.Err(t, err, regexp.MustCompile("(connection refused|connection reset by peer)$")) 237 | 238 | // Unblock the handler and ensure we finish writing the rest of the body 239 | // successfully. 240 | close(unblockOkHandler) 241 | actualBody, err := ioutil.ReadAll(res.Body) 242 | ensure.Nil(t, err) 243 | ensure.DeepEqual(t, actualBody, bytes.Repeat(hello, count)) 244 | ensure.Nil(t, res.Body.Close()) 245 | <-finOkHandler 246 | <-finStop 247 | } 248 | 249 | func TestHTTPListenerCloseError(t *testing.T) { 250 | t.Parallel() 251 | expectedError := errors.New("foo") 252 | listener, err := net.Listen("tcp", "127.0.0.1:0") 253 | listener = &closeErrListener{Listener: listener, err: expectedError} 254 | ensure.Nil(t, err) 255 | server := &http.Server{} 256 | down := &httpdown.HTTP{} 257 | s := down.Serve(server, listener) 258 | ensure.DeepEqual(t, s.Stop(), expectedError) 259 | } 260 | 261 | func TestHTTPServeError(t *testing.T) { 262 | t.Parallel() 263 | expectedError := errors.New("foo") 264 | listener, err := net.Listen("tcp", "127.0.0.1:0") 265 | errChan := make(chan error) 266 | listener = &acceptErrListener{Listener: listener, err: errChan} 267 | ensure.Nil(t, err) 268 | server := &http.Server{} 269 | down := &httpdown.HTTP{} 270 | s := down.Serve(server, listener) 271 | errChan <- expectedError 272 | ensure.DeepEqual(t, s.Wait(), expectedError) 273 | ensure.Nil(t, s.Stop()) 274 | } 275 | 276 | func TestHTTPWithinStopTimeout(t *testing.T) { 277 | t.Parallel() 278 | hello := []byte("hello") 279 | finOkHandler := make(chan struct{}) 280 | okHandler := func(w http.ResponseWriter, r *http.Request) { 281 | defer close(finOkHandler) 282 | w.WriteHeader(200) 283 | w.Write(hello) 284 | } 285 | 286 | listener, err := net.Listen("tcp", "127.0.0.1:0") 287 | ensure.Nil(t, err) 288 | server := &http.Server{Handler: http.HandlerFunc(okHandler)} 289 | transport := &http.Transport{} 290 | client := &http.Client{Transport: transport} 291 | down := &httpdown.HTTP{StopTimeout: time.Minute} 292 | s := down.Serve(server, listener) 293 | res, err := client.Get(fmt.Sprintf("http://%s/", listener.Addr().String())) 294 | ensure.Nil(t, err) 295 | 296 | finStop := make(chan struct{}) 297 | go func() { 298 | defer close(finStop) 299 | ensure.Nil(t, s.Stop()) 300 | }() 301 | 302 | actualBody, err := ioutil.ReadAll(res.Body) 303 | ensure.Nil(t, err) 304 | ensure.DeepEqual(t, actualBody, hello) 305 | ensure.Nil(t, res.Body.Close()) 306 | <-finOkHandler 307 | <-finStop 308 | } 309 | 310 | func TestHTTPStopTimeoutMissed(t *testing.T) { 311 | t.Parallel() 312 | 313 | klock := clock.NewMock() 314 | 315 | const count = 10000 316 | hello := []byte("hello") 317 | finOkHandler := make(chan struct{}) 318 | unblockOkHandler := make(chan struct{}) 319 | okHandler := func(w http.ResponseWriter, r *http.Request) { 320 | defer close(finOkHandler) 321 | w.Header().Set("Content-Length", fmt.Sprint(len(hello)*count)) 322 | w.WriteHeader(200) 323 | for i := 0; i < count/2; i++ { 324 | w.Write(hello) 325 | } 326 | <-unblockOkHandler 327 | for i := 0; i < count/2; i++ { 328 | w.Write(hello) 329 | } 330 | } 331 | 332 | listener, err := net.Listen("tcp", "127.0.0.1:0") 333 | ensure.Nil(t, err) 334 | server := &http.Server{Handler: http.HandlerFunc(okHandler)} 335 | transport := &http.Transport{} 336 | client := &http.Client{Transport: transport} 337 | down := &httpdown.HTTP{ 338 | StopTimeout: time.Minute, 339 | Clock: klock, 340 | } 341 | s := down.Serve(server, listener) 342 | res, err := client.Get(fmt.Sprintf("http://%s/", listener.Addr().String())) 343 | ensure.Nil(t, err) 344 | 345 | finStop := make(chan struct{}) 346 | go func() { 347 | defer close(finStop) 348 | ensure.Nil(t, s.Stop()) 349 | }() 350 | 351 | klock.Wait(clock.Calls{After: 1}) // wait for Stop to call After 352 | klock.Add(down.StopTimeout) 353 | 354 | _, err = ioutil.ReadAll(res.Body) 355 | ensure.Err(t, err, regexp.MustCompile("^unexpected EOF$")) 356 | ensure.Nil(t, res.Body.Close()) 357 | close(unblockOkHandler) 358 | <-finOkHandler 359 | <-finStop 360 | } 361 | 362 | func TestHTTPKillTimeout(t *testing.T) { 363 | t.Parallel() 364 | 365 | klock := clock.NewMock() 366 | 367 | statsDone := make(chan struct{}, 1) 368 | hc := &stats.HookClient{ 369 | BumpSumHook: func(key string, val float64) { 370 | if key == "kill" && val == 1 { 371 | statsDone <- struct{}{} 372 | } 373 | }, 374 | } 375 | 376 | const count = 10000 377 | hello := []byte("hello") 378 | finOkHandler := make(chan struct{}) 379 | unblockOkHandler := make(chan struct{}) 380 | okHandler := func(w http.ResponseWriter, r *http.Request) { 381 | defer close(finOkHandler) 382 | w.Header().Set("Content-Length", fmt.Sprint(len(hello)*count)) 383 | w.WriteHeader(200) 384 | for i := 0; i < count/2; i++ { 385 | w.Write(hello) 386 | } 387 | <-unblockOkHandler 388 | for i := 0; i < count/2; i++ { 389 | w.Write(hello) 390 | } 391 | } 392 | 393 | listener, err := net.Listen("tcp", "127.0.0.1:0") 394 | ensure.Nil(t, err) 395 | server := &http.Server{Handler: http.HandlerFunc(okHandler)} 396 | transport := &http.Transport{} 397 | client := &http.Client{Transport: transport} 398 | down := &httpdown.HTTP{ 399 | StopTimeout: time.Minute, 400 | KillTimeout: time.Minute, 401 | Stats: hc, 402 | Clock: klock, 403 | } 404 | s := down.Serve(server, listener) 405 | res, err := client.Get(fmt.Sprintf("http://%s/", listener.Addr().String())) 406 | ensure.Nil(t, err) 407 | 408 | finStop := make(chan struct{}) 409 | go func() { 410 | defer close(finStop) 411 | ensure.Nil(t, s.Stop()) 412 | }() 413 | 414 | klock.Wait(clock.Calls{After: 1}) // wait for Stop to call After 415 | klock.Add(down.StopTimeout) 416 | 417 | _, err = ioutil.ReadAll(res.Body) 418 | ensure.Err(t, err, regexp.MustCompile("^unexpected EOF$")) 419 | ensure.Nil(t, res.Body.Close()) 420 | close(unblockOkHandler) 421 | <-finOkHandler 422 | <-finStop 423 | <-statsDone 424 | } 425 | 426 | func TestHTTPKillTimeoutMissed(t *testing.T) { 427 | t.Parallel() 428 | 429 | klock := clock.NewMock() 430 | 431 | statsDone := make(chan struct{}, 1) 432 | hc := &stats.HookClient{ 433 | BumpSumHook: func(key string, val float64) { 434 | if key == "kill.timeout" && val == 1 { 435 | statsDone <- struct{}{} 436 | } 437 | }, 438 | } 439 | 440 | const count = 10000 441 | hello := []byte("hello") 442 | finOkHandler := make(chan struct{}) 443 | unblockOkHandler := make(chan struct{}) 444 | okHandler := func(w http.ResponseWriter, r *http.Request) { 445 | defer close(finOkHandler) 446 | w.Header().Set("Content-Length", fmt.Sprint(len(hello)*count)) 447 | w.WriteHeader(200) 448 | for i := 0; i < count/2; i++ { 449 | w.Write(hello) 450 | } 451 | <-unblockOkHandler 452 | for i := 0; i < count/2; i++ { 453 | w.Write(hello) 454 | } 455 | } 456 | 457 | listener, err := net.Listen("tcp", "127.0.0.1:0") 458 | ensure.Nil(t, err) 459 | unblockConnClose := make(chan chan struct{}, 1) 460 | listener = &closeErrConnListener{ 461 | Listener: listener, 462 | unblockClose: unblockConnClose, 463 | } 464 | 465 | server := &http.Server{Handler: http.HandlerFunc(okHandler)} 466 | transport := &http.Transport{} 467 | client := &http.Client{Transport: transport} 468 | down := &httpdown.HTTP{ 469 | StopTimeout: time.Minute, 470 | KillTimeout: time.Minute, 471 | Stats: hc, 472 | Clock: klock, 473 | } 474 | s := down.Serve(server, listener) 475 | res, err := client.Get(fmt.Sprintf("http://%s/", listener.Addr().String())) 476 | ensure.Nil(t, err) 477 | 478 | // Start the Stop process. 479 | finStop := make(chan struct{}) 480 | go func() { 481 | defer close(finStop) 482 | ensure.Nil(t, s.Stop()) 483 | }() 484 | 485 | klock.Wait(clock.Calls{After: 1}) // wait for Stop to call After 486 | klock.Add(down.StopTimeout) // trigger stop timeout 487 | klock.Wait(clock.Calls{After: 2}) // wait for Kill to call After 488 | klock.Add(down.KillTimeout) // trigger kill timeout 489 | 490 | // We hit both the StopTimeout & the KillTimeout. 491 | <-finStop 492 | 493 | // Then we unblock the Close, so we get an unexpected EOF since we close 494 | // before we finish writing the response. 495 | connCloseDone := make(chan struct{}) 496 | unblockConnClose <- connCloseDone 497 | <-connCloseDone 498 | close(unblockConnClose) 499 | 500 | // Then we unblock the handler which tries to write the rest of the data. 501 | close(unblockOkHandler) 502 | 503 | _, err = ioutil.ReadAll(res.Body) 504 | ensure.Err(t, err, regexp.MustCompile("^unexpected EOF$")) 505 | ensure.Nil(t, res.Body.Close()) 506 | <-finOkHandler 507 | <-statsDone 508 | } 509 | 510 | func TestDoubleStop(t *testing.T) { 511 | t.Parallel() 512 | listener, err := net.Listen("tcp", "127.0.0.1:0") 513 | ensure.Nil(t, err) 514 | server := &http.Server{} 515 | down := &httpdown.HTTP{} 516 | s := down.Serve(server, listener) 517 | ensure.Nil(t, s.Stop()) 518 | ensure.Nil(t, s.Stop()) 519 | } 520 | 521 | func TestExistingConnState(t *testing.T) { 522 | t.Parallel() 523 | hello := []byte("hello") 524 | fin := make(chan struct{}) 525 | okHandler := func(w http.ResponseWriter, r *http.Request) { 526 | defer close(fin) 527 | w.Write(hello) 528 | } 529 | 530 | var called int32 531 | listener, err := net.Listen("tcp", "127.0.0.1:0") 532 | ensure.Nil(t, err) 533 | server := &http.Server{ 534 | Handler: http.HandlerFunc(okHandler), 535 | ConnState: func(c net.Conn, s http.ConnState) { 536 | atomic.AddInt32(&called, 1) 537 | }, 538 | } 539 | transport := &http.Transport{} 540 | client := &http.Client{Transport: transport} 541 | down := &httpdown.HTTP{} 542 | s := down.Serve(server, listener) 543 | res, err := client.Get(fmt.Sprintf("http://%s/", listener.Addr().String())) 544 | ensure.Nil(t, err) 545 | actualBody, err := ioutil.ReadAll(res.Body) 546 | ensure.Nil(t, err) 547 | ensure.DeepEqual(t, actualBody, hello) 548 | ensure.Nil(t, res.Body.Close()) 549 | 550 | ensure.Nil(t, s.Stop()) 551 | <-fin 552 | 553 | ensure.True(t, atomic.LoadInt32(&called) > 0) 554 | } 555 | 556 | func TestHTTPDefaultListenError(t *testing.T) { 557 | if os.Getuid() == 0 { 558 | t.Skip("cant run this test as root") 559 | } 560 | 561 | statsDone := make(chan struct{}, 1) 562 | hc := &stats.HookClient{ 563 | BumpSumHook: func(key string, val float64) { 564 | if key == "listen.error" && val == 1 { 565 | statsDone <- struct{}{} 566 | } 567 | }, 568 | } 569 | 570 | t.Parallel() 571 | down := &httpdown.HTTP{Stats: hc} 572 | _, err := down.ListenAndServe(&http.Server{}) 573 | ensure.Err(t, err, regexp.MustCompile("listen tcp :80: bind: permission denied")) 574 | <-statsDone 575 | } 576 | 577 | func TestHTTPSDefaultListenError(t *testing.T) { 578 | if os.Getuid() == 0 { 579 | t.Skip("cant run this test as root") 580 | } 581 | t.Parallel() 582 | 583 | cert, err := tls.X509KeyPair(localhostCert, localhostKey) 584 | if err != nil { 585 | t.Fatalf("error loading cert: %v", err) 586 | } 587 | 588 | down := &httpdown.HTTP{} 589 | _, err = down.ListenAndServe(&http.Server{ 590 | TLSConfig: &tls.Config{ 591 | NextProtos: []string{"http/1.1"}, 592 | Certificates: []tls.Certificate{cert}, 593 | }, 594 | }) 595 | ensure.Err(t, err, regexp.MustCompile("listen tcp :443: bind: permission denied")) 596 | } 597 | 598 | func TestTLS(t *testing.T) { 599 | t.Parallel() 600 | port, err := freeport.Get() 601 | ensure.Nil(t, err) 602 | 603 | cert, err := tls.X509KeyPair(localhostCert, localhostKey) 604 | if err != nil { 605 | t.Fatalf("error loading cert: %v", err) 606 | } 607 | const count = 10000 608 | hello := []byte("hello") 609 | finOkHandler := make(chan struct{}) 610 | okHandler := func(w http.ResponseWriter, r *http.Request) { 611 | defer close(finOkHandler) 612 | w.WriteHeader(200) 613 | for i := 0; i < count; i++ { 614 | w.Write(hello) 615 | } 616 | } 617 | 618 | server := &http.Server{ 619 | Addr: fmt.Sprintf("0.0.0.0:%d", port), 620 | Handler: http.HandlerFunc(okHandler), 621 | TLSConfig: &tls.Config{ 622 | NextProtos: []string{"http/1.1"}, 623 | Certificates: []tls.Certificate{cert}, 624 | }, 625 | } 626 | transport := &http.Transport{ 627 | TLSClientConfig: &tls.Config{ 628 | InsecureSkipVerify: true, 629 | }, 630 | } 631 | client := &http.Client{Transport: transport} 632 | down := &httpdown.HTTP{} 633 | s, err := down.ListenAndServe(server) 634 | ensure.Nil(t, err) 635 | res, err := client.Get(fmt.Sprintf("https://%s/", server.Addr)) 636 | ensure.Nil(t, err) 637 | 638 | finStop := make(chan struct{}) 639 | go func() { 640 | defer close(finStop) 641 | ensure.Nil(t, s.Stop()) 642 | }() 643 | 644 | actualBody, err := ioutil.ReadAll(res.Body) 645 | ensure.Nil(t, err) 646 | ensure.DeepEqual(t, actualBody, bytes.Repeat(hello, count)) 647 | ensure.Nil(t, res.Body.Close()) 648 | <-finOkHandler 649 | <-finStop 650 | } 651 | 652 | // localhostCert is a PEM-encoded TLS cert with SAN IPs 653 | // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end 654 | // of ASN.1 time). 655 | // generated from src/pkg/crypto/tls: 656 | // go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h 657 | var localhostCert = []byte(`-----BEGIN CERTIFICATE----- 658 | MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD 659 | bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj 660 | bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBALyCfqwwip8BvTKgVKGdmjZTU8DD 661 | ndR+WALmFPIRqn89bOU3s30olKiqYEju/SFoEvMyFRT/TWEhXHDaufThqaMCAwEA 662 | AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud 663 | EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA 664 | AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAr/09uy108p51rheIOSnz4zgduyTl 665 | M+4AmRo8/U1twEZLgfAGG/GZjREv2y4mCEUIM3HebCAqlA5jpRg76Rf8jw== 666 | -----END CERTIFICATE-----`) 667 | 668 | // localhostKey is the private key for localhostCert. 669 | var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- 670 | MIIBOQIBAAJBALyCfqwwip8BvTKgVKGdmjZTU8DDndR+WALmFPIRqn89bOU3s30o 671 | lKiqYEju/SFoEvMyFRT/TWEhXHDaufThqaMCAwEAAQJAPXuWUxTV8XyAt8VhNQER 672 | LgzJcUKb9JVsoS1nwXgPksXnPDKnL9ax8VERrdNr+nZbj2Q9cDSXBUovfdtehcdP 673 | qQIhAO48ZsPylbTrmtjDEKiHT2Ik04rLotZYS2U873J6I7WlAiEAypDjYxXyafv/ 674 | Yo1pm9onwcetQKMW8CS3AjuV9Axzj6cCIEx2Il19fEMG4zny0WPlmbrcKvD/DpJQ 675 | 4FHrzsYlIVTpAiAas7S1uAvneqd0l02HlN9OxQKKlbUNXNme+rnOnOGS2wIgS0jW 676 | zl1jvrOSJeP1PpAHohWz6LOhEr8uvltWkN6x3vE= 677 | -----END RSA PRIVATE KEY-----`) 678 | -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2013-present, Facebook, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | httpdown [![Build Status](https://secure.travis-ci.org/facebookgo/httpdown.png)](https://travis-ci.org/facebookgo/httpdown) 2 | ======== 3 | 4 | Documentation: https://godoc.org/github.com/facebookgo/httpdown 5 | 6 | Package httpdown provides a library that makes it easy to build a HTTP server 7 | that can be shutdown gracefully (that is, without dropping any connections). 8 | 9 | If you want graceful restart and not just graceful shutdown, look at the 10 | [grace](https://github.com/facebookgo/grace) package which uses this package 11 | underneath but also provides graceful restart. 12 | 13 | Usage 14 | ----- 15 | 16 | Demo HTTP Server with graceful termination: 17 | https://github.com/facebookgo/httpdown/blob/master/httpdown_example/main.go 18 | 19 | 1. Install the demo application 20 | 21 | go get github.com/facebookgo/httpdown/httpdown_example 22 | 23 | 1. Start it in the first terminal 24 | 25 | httpdown_example 26 | 27 | This will output something like: 28 | 29 | 2014/11/18 21:57:50 serving on http://127.0.0.1:8080/ with pid 17 30 | 31 | 1. In a second terminal start a slow HTTP request 32 | 33 | curl 'http://localhost:8080/?duration=20s' 34 | 35 | 1. In a third terminal trigger a graceful shutdown (using the pid from your output): 36 | 37 | kill -TERM 17 38 | 39 | This will demonstrate that the slow request was served before the server was 40 | shutdown. You could also have used `Ctrl-C` instead of `kill` as the example 41 | application triggers graceful shutdown on TERM or INT signals. 42 | --------------------------------------------------------------------------------