├── .travis.yml ├── gracedemo └── demo.go ├── gracehttp ├── http.go ├── http_test.go └── testbin_test.go ├── gracenet ├── net.go └── net_test.go ├── license └── readme.md /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.7 5 | 6 | before_install: 7 | - go get -v github.com/golang/lint/golint 8 | 9 | install: 10 | - go install -race -v std 11 | - go get -race -t -v ./... 12 | - go install -race -v ./... 13 | 14 | script: 15 | - go vet ./... 16 | - $HOME/gopath/bin/golint . 17 | - go test -cpu=2 -race -v ./... 18 | - go test -cpu=2 -covermode=atomic ./... 19 | -------------------------------------------------------------------------------- /gracedemo/demo.go: -------------------------------------------------------------------------------- 1 | // Command gracedemo implements a demo server showing how to gracefully 2 | // terminate an HTTP server using grace. 3 | package main 4 | 5 | import ( 6 | "flag" 7 | "fmt" 8 | "net/http" 9 | "os" 10 | "time" 11 | 12 | "github.com/facebookgo/grace/gracehttp" 13 | ) 14 | 15 | var ( 16 | address0 = flag.String("a0", ":48567", "Zero address to bind to.") 17 | address1 = flag.String("a1", ":48568", "First address to bind to.") 18 | address2 = flag.String("a2", ":48569", "Second address to bind to.") 19 | now = time.Now() 20 | ) 21 | 22 | func main() { 23 | flag.Parse() 24 | gracehttp.Serve( 25 | &http.Server{Addr: *address0, Handler: newHandler("Zero ")}, 26 | &http.Server{Addr: *address1, Handler: newHandler("First ")}, 27 | &http.Server{Addr: *address2, Handler: newHandler("Second")}, 28 | ) 29 | } 30 | 31 | func newHandler(name string) http.Handler { 32 | mux := http.NewServeMux() 33 | mux.HandleFunc("/sleep/", func(w http.ResponseWriter, r *http.Request) { 34 | duration, err := time.ParseDuration(r.FormValue("duration")) 35 | if err != nil { 36 | http.Error(w, err.Error(), 400) 37 | return 38 | } 39 | time.Sleep(duration) 40 | fmt.Fprintf( 41 | w, 42 | "%s started at %s slept for %d nanoseconds from pid %d.\n", 43 | name, 44 | now, 45 | duration.Nanoseconds(), 46 | os.Getpid(), 47 | ) 48 | }) 49 | return mux 50 | } 51 | -------------------------------------------------------------------------------- /gracehttp/http.go: -------------------------------------------------------------------------------- 1 | // Package gracehttp provides easy to use graceful restart 2 | // functionality for HTTP server. 3 | package gracehttp 4 | 5 | import ( 6 | "bytes" 7 | "crypto/tls" 8 | "fmt" 9 | "log" 10 | "net" 11 | "net/http" 12 | "os" 13 | "os/signal" 14 | "sync" 15 | "syscall" 16 | 17 | "github.com/facebookgo/grace/gracenet" 18 | "github.com/facebookgo/httpdown" 19 | ) 20 | 21 | var ( 22 | logger *log.Logger 23 | didInherit = os.Getenv("LISTEN_FDS") != "" 24 | ppid = os.Getppid() 25 | ) 26 | 27 | type option func(*app) 28 | 29 | // An app contains one or more servers and associated configuration. 30 | type app struct { 31 | servers []*http.Server 32 | http *httpdown.HTTP 33 | net *gracenet.Net 34 | listeners []net.Listener 35 | sds []httpdown.Server 36 | preStartProcess func() error 37 | errors chan error 38 | } 39 | 40 | func newApp(servers []*http.Server) *app { 41 | return &app{ 42 | servers: servers, 43 | http: &httpdown.HTTP{}, 44 | net: &gracenet.Net{}, 45 | listeners: make([]net.Listener, 0, len(servers)), 46 | sds: make([]httpdown.Server, 0, len(servers)), 47 | 48 | preStartProcess: func() error { return nil }, 49 | // 2x num servers for possible Close or Stop errors + 1 for possible 50 | // StartProcess error. 51 | errors: make(chan error, 1+(len(servers)*2)), 52 | } 53 | } 54 | 55 | func (a *app) listen() error { 56 | for _, s := range a.servers { 57 | // TODO: default addresses 58 | l, err := a.net.Listen("tcp", s.Addr) 59 | if err != nil { 60 | return err 61 | } 62 | if s.TLSConfig != nil { 63 | l = tls.NewListener(l, s.TLSConfig) 64 | } 65 | a.listeners = append(a.listeners, l) 66 | } 67 | return nil 68 | } 69 | 70 | func (a *app) serve() { 71 | for i, s := range a.servers { 72 | a.sds = append(a.sds, a.http.Serve(s, a.listeners[i])) 73 | } 74 | } 75 | 76 | func (a *app) wait() { 77 | var wg sync.WaitGroup 78 | wg.Add(len(a.sds) * 2) // Wait & Stop 79 | go a.signalHandler(&wg) 80 | for _, s := range a.sds { 81 | go func(s httpdown.Server) { 82 | defer wg.Done() 83 | if err := s.Wait(); err != nil { 84 | a.errors <- err 85 | } 86 | }(s) 87 | } 88 | wg.Wait() 89 | } 90 | 91 | func (a *app) term(wg *sync.WaitGroup) { 92 | for _, s := range a.sds { 93 | go func(s httpdown.Server) { 94 | defer wg.Done() 95 | if err := s.Stop(); err != nil { 96 | a.errors <- err 97 | } 98 | }(s) 99 | } 100 | } 101 | 102 | func (a *app) signalHandler(wg *sync.WaitGroup) { 103 | ch := make(chan os.Signal, 10) 104 | signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, syscall.SIGUSR2) 105 | for { 106 | sig := <-ch 107 | switch sig { 108 | case syscall.SIGINT, syscall.SIGTERM: 109 | // this ensures a subsequent INT/TERM will trigger standard go behaviour of 110 | // terminating. 111 | signal.Stop(ch) 112 | a.term(wg) 113 | return 114 | case syscall.SIGUSR2: 115 | err := a.preStartProcess() 116 | if err != nil { 117 | a.errors <- err 118 | } 119 | // we only return here if there's an error, otherwise the new process 120 | // will send us a TERM when it's ready to trigger the actual shutdown. 121 | if _, err := a.net.StartProcess(); err != nil { 122 | a.errors <- err 123 | } 124 | } 125 | } 126 | } 127 | 128 | func (a *app) run() error { 129 | // Acquire Listeners 130 | if err := a.listen(); err != nil { 131 | return err 132 | } 133 | 134 | // Some useful logging. 135 | if logger != nil { 136 | if didInherit { 137 | if ppid == 1 { 138 | logger.Printf("Listening on init activated %s", pprintAddr(a.listeners)) 139 | } else { 140 | const msg = "Graceful handoff of %s with new pid %d and old pid %d" 141 | logger.Printf(msg, pprintAddr(a.listeners), os.Getpid(), ppid) 142 | } 143 | } else { 144 | const msg = "Serving %s with pid %d" 145 | logger.Printf(msg, pprintAddr(a.listeners), os.Getpid()) 146 | } 147 | } 148 | 149 | // Start serving. 150 | a.serve() 151 | 152 | // Close the parent if we inherited and it wasn't init that started us. 153 | if didInherit && ppid != 1 { 154 | if err := syscall.Kill(ppid, syscall.SIGTERM); err != nil { 155 | return fmt.Errorf("failed to close parent: %s", err) 156 | } 157 | } 158 | 159 | waitdone := make(chan struct{}) 160 | go func() { 161 | defer close(waitdone) 162 | a.wait() 163 | }() 164 | 165 | select { 166 | case err := <-a.errors: 167 | if err == nil { 168 | panic("unexpected nil error") 169 | } 170 | return err 171 | case <-waitdone: 172 | if logger != nil { 173 | logger.Printf("Exiting pid %d.", os.Getpid()) 174 | } 175 | return nil 176 | } 177 | } 178 | 179 | // ServeWithOptions does the same as Serve, but takes a set of options to 180 | // configure the app struct. 181 | func ServeWithOptions(servers []*http.Server, options ...option) error { 182 | a := newApp(servers) 183 | for _, opt := range options { 184 | opt(a) 185 | } 186 | return a.run() 187 | } 188 | 189 | // Serve will serve the given http.Servers and will monitor for signals 190 | // allowing for graceful termination (SIGTERM) or restart (SIGUSR2). 191 | func Serve(servers ...*http.Server) error { 192 | a := newApp(servers) 193 | return a.run() 194 | } 195 | 196 | // PreStartProcess configures a callback to trigger during graceful restart 197 | // directly before starting the successor process. This allows the current 198 | // process to release holds on resources that the new process will need. 199 | func PreStartProcess(hook func() error) option { 200 | return func(a *app) { 201 | a.preStartProcess = hook 202 | } 203 | } 204 | 205 | // Used for pretty printing addresses. 206 | func pprintAddr(listeners []net.Listener) []byte { 207 | var out bytes.Buffer 208 | for i, l := range listeners { 209 | if i != 0 { 210 | fmt.Fprint(&out, ", ") 211 | } 212 | fmt.Fprint(&out, l.Addr()) 213 | } 214 | return out.Bytes() 215 | } 216 | 217 | // SetLogger sets logger to be able to grab some useful logs 218 | func SetLogger(l *log.Logger) { 219 | logger = l 220 | } 221 | -------------------------------------------------------------------------------- /gracehttp/http_test.go: -------------------------------------------------------------------------------- 1 | package gracehttp_test 2 | 3 | import ( 4 | "bufio" 5 | "crypto/tls" 6 | "encoding/json" 7 | "flag" 8 | "fmt" 9 | "io" 10 | "net" 11 | "net/http" 12 | "os" 13 | "os/exec" 14 | "strconv" 15 | "sync" 16 | "syscall" 17 | "testing" 18 | "time" 19 | 20 | "github.com/facebookgo/freeport" 21 | ) 22 | 23 | const ( 24 | testPreStartProcess = iota 25 | ) 26 | 27 | // Debug logging. 28 | var debugLog = flag.Bool("debug", false, "enable debug logging") 29 | 30 | func debug(format string, a ...interface{}) { 31 | if *debugLog { 32 | println(fmt.Sprintf(format, a...)) 33 | } 34 | } 35 | 36 | // State for the test run. 37 | type harness struct { 38 | T *testing.T // The test instance. 39 | httpAddr string // The address for the http server. 40 | httpsAddr string // The address for the https server. 41 | Process []*os.Process // The server commands, oldest to newest. 42 | ProcessMutex sync.Mutex // The mutex to guard Process manipulation. 43 | RequestWaitGroup sync.WaitGroup // The wait group for the HTTP requests. 44 | newProcess chan bool // A bool is sent on start/restart. 45 | requestCount int 46 | requestCountMutex sync.Mutex 47 | serveOption int 48 | } 49 | 50 | // Find 3 free ports and setup addresses. 51 | func (h *harness) setupAddr() { 52 | port, err := freeport.Get() 53 | if err != nil { 54 | h.T.Fatalf("Failed to find a free port: %s", err) 55 | } 56 | h.httpAddr = fmt.Sprintf("127.0.0.1:%d", port) 57 | 58 | port, err = freeport.Get() 59 | if err != nil { 60 | h.T.Fatalf("Failed to find a free port: %s", err) 61 | } 62 | h.httpsAddr = fmt.Sprintf("127.0.0.1:%d", port) 63 | debug("Addresses %s & %s", h.httpAddr, h.httpsAddr) 64 | } 65 | 66 | // Start a fresh server and wait for pid updates on restart. 67 | func (h *harness) Start() { 68 | h.setupAddr() 69 | cmd := exec.Command(os.Args[0], "-http", h.httpAddr, "-https", h.httpsAddr, "-testOption", strconv.Itoa(h.serveOption)) 70 | stderr, err := cmd.StderrPipe() 71 | if err != nil { 72 | h.T.Fatal(err) 73 | } 74 | go func() { 75 | reader := bufio.NewReader(stderr) 76 | for { 77 | line, isPrefix, err := reader.ReadLine() 78 | if err == io.EOF { 79 | return 80 | } 81 | if err != nil { 82 | println(fmt.Sprintf("Failed to read line from server process: %s", err)) 83 | } 84 | if isPrefix { 85 | println(fmt.Sprintf("Deal with isPrefix for line: %s", line)) 86 | } 87 | res := &response{} 88 | err = json.Unmarshal([]byte(line), res) 89 | if err != nil { 90 | println(fmt.Sprintf("Could not parse json from stderr %s: %s", line, err)) 91 | } 92 | if res.Error != "" { 93 | println(fmt.Sprintf("Got error from process: %v", res)) 94 | } 95 | process, err := os.FindProcess(res.Pid) 96 | if err != nil { 97 | println(fmt.Sprintf("Could not find process with pid: %d", res.Pid)) 98 | } 99 | h.ProcessMutex.Lock() 100 | h.Process = append(h.Process, process) 101 | h.ProcessMutex.Unlock() 102 | h.newProcess <- true 103 | } 104 | }() 105 | err = cmd.Start() 106 | if err != nil { 107 | h.T.Fatalf("Failed to start command: %s", err) 108 | } 109 | <-h.newProcess 110 | } 111 | 112 | // Restart the most recent server. 113 | func (h *harness) Restart() { 114 | err := h.MostRecentProcess().Signal(syscall.SIGUSR2) 115 | if err != nil { 116 | h.T.Fatalf("Failed to send SIGUSR2 and restart process: %s", err) 117 | } 118 | <-h.newProcess 119 | } 120 | 121 | // Graceful termination of the most recent server. 122 | func (h *harness) Stop() { 123 | err := h.MostRecentProcess().Signal(syscall.SIGTERM) 124 | if err != nil { 125 | h.T.Fatalf("Failed to send SIGTERM and stop process: %s", err) 126 | } 127 | } 128 | 129 | // Returns the most recent server process. 130 | func (h *harness) MostRecentProcess() *os.Process { 131 | h.ProcessMutex.Lock() 132 | defer h.ProcessMutex.Unlock() 133 | l := len(h.Process) 134 | if l == 0 { 135 | h.T.Fatalf("Most recent command requested before command was created.") 136 | } 137 | return h.Process[l-1] 138 | } 139 | 140 | // Get the global request count and increment it. 141 | func (h *harness) RequestCount() int { 142 | h.requestCountMutex.Lock() 143 | defer h.requestCountMutex.Unlock() 144 | c := h.requestCount 145 | h.requestCount++ 146 | return c 147 | } 148 | 149 | // Helper for sending a single request. 150 | func (h *harness) SendOne(dialgroup *sync.WaitGroup, url string, pid int) { 151 | defer h.RequestWaitGroup.Done() 152 | count := h.RequestCount() 153 | debug("Send %02d pid=%d url=%s", count, pid, url) 154 | client := &http.Client{ 155 | Transport: &http.Transport{ 156 | Dial: func(network, addr string) (net.Conn, error) { 157 | defer func() { 158 | time.Sleep(50 * time.Millisecond) 159 | dialgroup.Done() 160 | }() 161 | return net.Dial(network, addr) 162 | }, 163 | TLSClientConfig: &tls.Config{ 164 | InsecureSkipVerify: true, 165 | }, 166 | }, 167 | } 168 | r, err := client.Get(url) 169 | if err != nil { 170 | h.T.Fatalf("Failed request %02d to %s pid=%d: %s", count, url, pid, err) 171 | } 172 | debug("Body %02d pid=%d url=%s", count, pid, url) 173 | defer r.Body.Close() 174 | res := &response{} 175 | err = json.NewDecoder(r.Body).Decode(res) 176 | if err != nil { 177 | h.T.Fatalf("Failed to ready decode json response body pid=%d: %s", pid, err) 178 | } 179 | if pid != res.Pid { 180 | for _, old := range h.Process[0 : len(h.Process)-1] { 181 | if res.Pid == old.Pid { 182 | h.T.Logf("Found old pid %d, ignoring the discrepancy", res.Pid) 183 | return 184 | } 185 | } 186 | h.T.Fatalf("Didn't get expected pid %d instead got %d", pid, res.Pid) 187 | } 188 | debug("Done %02d pid=%d url=%s", count, pid, url) 189 | } 190 | 191 | // Send test HTTP request. 192 | func (h *harness) SendRequest() { 193 | pid := h.MostRecentProcess().Pid 194 | httpFastURL := fmt.Sprintf("http://%s/sleep/?duration=0", h.httpAddr) 195 | httpSlowURL := fmt.Sprintf("http://%s/sleep/?duration=2s", h.httpAddr) 196 | httpsFastURL := fmt.Sprintf("https://%s/sleep/?duration=0", h.httpsAddr) 197 | httpsSlowURL := fmt.Sprintf("https://%s/sleep/?duration=2s", h.httpsAddr) 198 | 199 | var dialgroup sync.WaitGroup 200 | h.RequestWaitGroup.Add(4) 201 | dialgroup.Add(4) 202 | go h.SendOne(&dialgroup, httpFastURL, pid) 203 | go h.SendOne(&dialgroup, httpSlowURL, pid) 204 | go h.SendOne(&dialgroup, httpsFastURL, pid) 205 | go h.SendOne(&dialgroup, httpsSlowURL, pid) 206 | debug("Added Requests pid=%d", pid) 207 | dialgroup.Wait() 208 | debug("Dialed Requests pid=%d", pid) 209 | } 210 | 211 | // Wait for everything. 212 | func (h *harness) Wait() { 213 | h.RequestWaitGroup.Wait() 214 | } 215 | 216 | func newHarness(t *testing.T) *harness { 217 | return &harness{ 218 | T: t, 219 | newProcess: make(chan bool), 220 | serveOption: -1, 221 | } 222 | } 223 | 224 | // The main test case. 225 | func TestComplex(t *testing.T) { 226 | t.Parallel() 227 | debug("Started TestComplex") 228 | h := newHarness(t) 229 | debug("Initial Start") 230 | h.Start() 231 | debug("Send Request 1") 232 | h.SendRequest() 233 | debug("Restart 1") 234 | h.Restart() 235 | debug("Send Request 2") 236 | h.SendRequest() 237 | debug("Restart 2") 238 | h.Restart() 239 | debug("Send Request 3") 240 | h.SendRequest() 241 | debug("Stopping") 242 | h.Stop() 243 | debug("Waiting") 244 | h.Wait() 245 | } 246 | 247 | func TestComplexAgain(t *testing.T) { 248 | t.Parallel() 249 | debug("Started TestComplex") 250 | h := newHarness(t) 251 | debug("Initial Start") 252 | h.Start() 253 | debug("Send Request 1") 254 | h.SendRequest() 255 | debug("Restart 1") 256 | h.Restart() 257 | debug("Send Request 2") 258 | h.SendRequest() 259 | debug("Restart 2") 260 | h.Restart() 261 | debug("Send Request 3") 262 | h.SendRequest() 263 | debug("Stopping") 264 | h.Stop() 265 | debug("Waiting") 266 | h.Wait() 267 | } 268 | 269 | func TestPreStartProcess(t *testing.T) { 270 | t.Parallel() 271 | debug("Started TestPreStartProcess") 272 | h := newHarness(t) 273 | h.serveOption = testPreStartProcess 274 | debug("Initial Start") 275 | h.Start() 276 | debug("Send Request 1") 277 | h.SendRequest() 278 | debug("Restart 1") 279 | h.Restart() 280 | debug("Send Request 2") 281 | h.SendRequest() 282 | debug("Restart 2") 283 | h.Restart() 284 | debug("Send Request 3") 285 | h.SendRequest() 286 | debug("Stopping") 287 | h.Stop() 288 | debug("Waiting") 289 | h.Wait() 290 | } 291 | 292 | func TestPreStartProcessAgain(t *testing.T) { 293 | t.Parallel() 294 | debug("Started TestPreStartProcessAgain") 295 | h := newHarness(t) 296 | h.serveOption = testPreStartProcess 297 | debug("Initial Start") 298 | h.Start() 299 | debug("Send Request 1") 300 | h.SendRequest() 301 | debug("Restart 1") 302 | h.Restart() 303 | debug("Send Request 2") 304 | h.SendRequest() 305 | debug("Restart 2") 306 | h.Restart() 307 | debug("Send Request 3") 308 | h.SendRequest() 309 | debug("Stopping") 310 | h.Stop() 311 | debug("Waiting") 312 | h.Wait() 313 | } 314 | -------------------------------------------------------------------------------- /gracehttp/testbin_test.go: -------------------------------------------------------------------------------- 1 | package gracehttp_test 2 | 3 | import ( 4 | "crypto/tls" 5 | "encoding/json" 6 | "flag" 7 | "fmt" 8 | "log" 9 | "net/http" 10 | "os" 11 | "strings" 12 | "sync" 13 | "testing" 14 | "time" 15 | 16 | "github.com/facebookgo/grace/gracehttp" 17 | ) 18 | 19 | const preStartProcessEnv = "GRACEHTTP_PRE_START_PROCESS" 20 | 21 | func TestMain(m *testing.M) { 22 | const ( 23 | testbinKey = "GRACEHTTP_TEST_BIN" 24 | testbinValue = "1" 25 | ) 26 | if os.Getenv(testbinKey) == testbinValue { 27 | testbinMain() 28 | return 29 | } 30 | if err := os.Setenv(testbinKey, testbinValue); err != nil { 31 | panic(err) 32 | } 33 | os.Exit(m.Run()) 34 | } 35 | 36 | type response struct { 37 | Sleep time.Duration 38 | Pid int 39 | Error string `json:",omitempty"` 40 | } 41 | 42 | // Wait for 10 consecutive responses from our own pid. 43 | // 44 | // This prevents flaky tests that arise from the fact that we have the 45 | // perfectly acceptable (read: not a bug) condition where both the new and the 46 | // old servers are accepting requests. In fact the amount of time both are 47 | // accepting at the same time and the number of requests that flip flop between 48 | // them is unbounded and in the hands of the various kernels our code tends to 49 | // run on. 50 | // 51 | // In order to combat this, we wait for 10 successful responses from our own 52 | // pid. This is a somewhat reliable way to ensure the old server isn't 53 | // serving anymore. 54 | func wait(wg *sync.WaitGroup, url string) { 55 | var success int 56 | defer wg.Done() 57 | for { 58 | res, err := http.Get(url) 59 | if err == nil { 60 | // ensure it isn't a response from a previous instance 61 | defer res.Body.Close() 62 | var r response 63 | if err := json.NewDecoder(res.Body).Decode(&r); err != nil { 64 | log.Fatalf("Error decoding json: %s", err) 65 | } 66 | if r.Pid == os.Getpid() { 67 | success++ 68 | if success == 10 { 69 | return 70 | } 71 | continue 72 | } 73 | } else { 74 | success = 0 75 | // we expect connection refused 76 | if !strings.HasSuffix(err.Error(), "connection refused") { 77 | e2 := json.NewEncoder(os.Stderr).Encode(&response{ 78 | Error: err.Error(), 79 | Pid: os.Getpid(), 80 | }) 81 | if e2 != nil { 82 | log.Fatalf("Error writing error json: %s", e2) 83 | } 84 | } 85 | } 86 | } 87 | } 88 | 89 | func httpsServer(addr string) *http.Server { 90 | cert, err := tls.X509KeyPair(localhostCert, localhostKey) 91 | if err != nil { 92 | log.Fatalf("error loading cert: %v", err) 93 | } 94 | return &http.Server{ 95 | Addr: addr, 96 | Handler: newHandler(), 97 | TLSConfig: &tls.Config{ 98 | NextProtos: []string{"http/1.1"}, 99 | Certificates: []tls.Certificate{cert}, 100 | }, 101 | } 102 | } 103 | 104 | func testbinMain() { 105 | var httpAddr, httpsAddr string 106 | var testOption int 107 | flag.StringVar(&httpAddr, "http", ":48560", "http address to bind to") 108 | flag.StringVar(&httpsAddr, "https", ":48561", "https address to bind to") 109 | flag.IntVar(&testOption, "testOption", -1, "which option to test on ServeWithOptions") 110 | flag.Parse() 111 | 112 | // we have self signed certs 113 | http.DefaultTransport = &http.Transport{ 114 | DisableKeepAlives: true, 115 | TLSClientConfig: &tls.Config{ 116 | InsecureSkipVerify: true, 117 | }, 118 | } 119 | 120 | // print json to stderr once we can successfully connect to all three 121 | // addresses. the ensures we only print the line once we're ready to serve. 122 | go func() { 123 | var wg sync.WaitGroup 124 | wg.Add(2) 125 | go wait(&wg, fmt.Sprintf("http://%s/sleep/?duration=1ms", httpAddr)) 126 | go wait(&wg, fmt.Sprintf("https://%s/sleep/?duration=1ms", httpsAddr)) 127 | wg.Wait() 128 | 129 | err := json.NewEncoder(os.Stderr).Encode(&response{Pid: os.Getpid()}) 130 | if err != nil { 131 | log.Fatalf("Error writing startup json: %s", err) 132 | } 133 | }() 134 | 135 | servers := []*http.Server{ 136 | &http.Server{Addr: httpAddr, Handler: newHandler()}, 137 | httpsServer(httpsAddr), 138 | } 139 | 140 | if testOption == -1 { 141 | err := gracehttp.Serve(servers...) 142 | if err != nil { 143 | log.Fatalf("Error in gracehttp.Serve: %s", err) 144 | } 145 | } else { 146 | if testOption == testPreStartProcess { 147 | switch os.Getenv(preStartProcessEnv) { 148 | case "": 149 | err := os.Setenv(preStartProcessEnv, "READY") 150 | if err != nil { 151 | log.Fatalf("testbin (first incarnation) could not set %v to 'ready': %v", preStartProcessEnv, err) 152 | } 153 | case "FIRED": 154 | // all good, reset for next round 155 | err := os.Setenv(preStartProcessEnv, "READY") 156 | if err != nil { 157 | log.Fatalf("testbin (second incarnation) could not reset %v to 'ready': %v", preStartProcessEnv, err) 158 | } 159 | case "READY": 160 | log.Fatalf("failure to update startup hook before new process started") 161 | default: 162 | log.Fatalf("something strange happened with %v: it ended up as %v, which is not '', 'FIRED', or 'READY'", preStartProcessEnv, os.Getenv(preStartProcessEnv)) 163 | } 164 | 165 | err := gracehttp.ServeWithOptions( 166 | servers, 167 | gracehttp.PreStartProcess(func() error { 168 | err := os.Setenv(preStartProcessEnv, "FIRED") 169 | if err != nil { 170 | log.Fatalf("startup hook could not set %v to 'fired': %v", preStartProcessEnv, err) 171 | } 172 | return nil 173 | }), 174 | ) 175 | if err != nil { 176 | log.Fatalf("Error in gracehttp.Serve: %s", err) 177 | } 178 | } 179 | } 180 | } 181 | 182 | func newHandler() http.Handler { 183 | mux := http.NewServeMux() 184 | mux.HandleFunc("/sleep/", func(w http.ResponseWriter, r *http.Request) { 185 | duration, err := time.ParseDuration(r.FormValue("duration")) 186 | if err != nil { 187 | http.Error(w, err.Error(), 400) 188 | } 189 | time.Sleep(duration) 190 | err = json.NewEncoder(w).Encode(&response{ 191 | Sleep: duration, 192 | Pid: os.Getpid(), 193 | }) 194 | if err != nil { 195 | log.Fatalf("Error encoding json: %s", err) 196 | } 197 | }) 198 | return mux 199 | } 200 | 201 | // localhostCert is a PEM-encoded TLS cert with SAN IPs 202 | // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end 203 | // of ASN.1 time). 204 | // generated from src/pkg/crypto/tls: 205 | // go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h 206 | var localhostCert = []byte(`-----BEGIN CERTIFICATE----- 207 | MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD 208 | bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj 209 | bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBALyCfqwwip8BvTKgVKGdmjZTU8DD 210 | ndR+WALmFPIRqn89bOU3s30olKiqYEju/SFoEvMyFRT/TWEhXHDaufThqaMCAwEA 211 | AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud 212 | EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA 213 | AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAr/09uy108p51rheIOSnz4zgduyTl 214 | M+4AmRo8/U1twEZLgfAGG/GZjREv2y4mCEUIM3HebCAqlA5jpRg76Rf8jw== 215 | -----END CERTIFICATE-----`) 216 | 217 | // localhostKey is the private key for localhostCert. 218 | var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- 219 | MIIBOQIBAAJBALyCfqwwip8BvTKgVKGdmjZTU8DDndR+WALmFPIRqn89bOU3s30o 220 | lKiqYEju/SFoEvMyFRT/TWEhXHDaufThqaMCAwEAAQJAPXuWUxTV8XyAt8VhNQER 221 | LgzJcUKb9JVsoS1nwXgPksXnPDKnL9ax8VERrdNr+nZbj2Q9cDSXBUovfdtehcdP 222 | qQIhAO48ZsPylbTrmtjDEKiHT2Ik04rLotZYS2U873J6I7WlAiEAypDjYxXyafv/ 223 | Yo1pm9onwcetQKMW8CS3AjuV9Axzj6cCIEx2Il19fEMG4zny0WPlmbrcKvD/DpJQ 224 | 4FHrzsYlIVTpAiAas7S1uAvneqd0l02HlN9OxQKKlbUNXNme+rnOnOGS2wIgS0jW 225 | zl1jvrOSJeP1PpAHohWz6LOhEr8uvltWkN6x3vE= 226 | -----END RSA PRIVATE KEY-----`) 227 | -------------------------------------------------------------------------------- /gracenet/net.go: -------------------------------------------------------------------------------- 1 | // Package gracenet provides a family of Listen functions that either open a 2 | // fresh connection or provide an inherited connection from when the process 3 | // was started. The behave like their counterparts in the net package, but 4 | // transparently provide support for graceful restarts without dropping 5 | // connections. This is provided in a systemd socket activation compatible form 6 | // to allow using socket activation. 7 | // 8 | // BUG: Doesn't handle closing of listeners. 9 | package gracenet 10 | 11 | import ( 12 | "fmt" 13 | "net" 14 | "os" 15 | "os/exec" 16 | "strconv" 17 | "strings" 18 | "sync" 19 | ) 20 | 21 | const ( 22 | // Used to indicate a graceful restart in the new process. 23 | envCountKey = "LISTEN_FDS" 24 | envCountKeyPrefix = envCountKey + "=" 25 | ) 26 | 27 | // In order to keep the working directory the same as when we started we record 28 | // it at startup. 29 | var originalWD, _ = os.Getwd() 30 | 31 | // Net provides the family of Listen functions and maintains the associated 32 | // state. Typically you will have only once instance of Net per application. 33 | type Net struct { 34 | inherited []net.Listener 35 | active []net.Listener 36 | mutex sync.Mutex 37 | inheritOnce sync.Once 38 | 39 | // used in tests to override the default behavior of starting from fd 3. 40 | fdStart int 41 | } 42 | 43 | func (n *Net) inherit() error { 44 | var retErr error 45 | n.inheritOnce.Do(func() { 46 | n.mutex.Lock() 47 | defer n.mutex.Unlock() 48 | countStr := os.Getenv(envCountKey) 49 | if countStr == "" { 50 | return 51 | } 52 | count, err := strconv.Atoi(countStr) 53 | if err != nil { 54 | retErr = fmt.Errorf("found invalid count value: %s=%s", envCountKey, countStr) 55 | return 56 | } 57 | 58 | // In tests this may be overridden. 59 | fdStart := n.fdStart 60 | if fdStart == 0 { 61 | // In normal operations if we are inheriting, the listeners will begin at 62 | // fd 3. 63 | fdStart = 3 64 | } 65 | 66 | for i := fdStart; i < fdStart+count; i++ { 67 | file := os.NewFile(uintptr(i), "listener") 68 | l, err := net.FileListener(file) 69 | if err != nil { 70 | file.Close() 71 | retErr = fmt.Errorf("error inheriting socket fd %d: %s", i, err) 72 | return 73 | } 74 | if err := file.Close(); err != nil { 75 | retErr = fmt.Errorf("error closing inherited socket fd %d: %s", i, err) 76 | return 77 | } 78 | n.inherited = append(n.inherited, l) 79 | } 80 | }) 81 | return retErr 82 | } 83 | 84 | // Listen announces on the local network address laddr. The network net must be 85 | // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It 86 | // returns an inherited net.Listener for the matching network and address, or 87 | // creates a new one using net.Listen. 88 | func (n *Net) Listen(nett, laddr string) (net.Listener, error) { 89 | switch nett { 90 | default: 91 | return nil, net.UnknownNetworkError(nett) 92 | case "tcp", "tcp4", "tcp6": 93 | addr, err := net.ResolveTCPAddr(nett, laddr) 94 | if err != nil { 95 | return nil, err 96 | } 97 | return n.ListenTCP(nett, addr) 98 | case "unix", "unixpacket", "invalid_unix_net_for_test": 99 | addr, err := net.ResolveUnixAddr(nett, laddr) 100 | if err != nil { 101 | return nil, err 102 | } 103 | return n.ListenUnix(nett, addr) 104 | } 105 | } 106 | 107 | // ListenTCP announces on the local network address laddr. The network net must 108 | // be: "tcp", "tcp4" or "tcp6". It returns an inherited net.Listener for the 109 | // matching network and address, or creates a new one using net.ListenTCP. 110 | func (n *Net) ListenTCP(nett string, laddr *net.TCPAddr) (*net.TCPListener, error) { 111 | if err := n.inherit(); err != nil { 112 | return nil, err 113 | } 114 | 115 | n.mutex.Lock() 116 | defer n.mutex.Unlock() 117 | 118 | // look for an inherited listener 119 | for i, l := range n.inherited { 120 | if l == nil { // we nil used inherited listeners 121 | continue 122 | } 123 | if isSameAddr(l.Addr(), laddr) { 124 | n.inherited[i] = nil 125 | n.active = append(n.active, l) 126 | return l.(*net.TCPListener), nil 127 | } 128 | } 129 | 130 | // make a fresh listener 131 | l, err := net.ListenTCP(nett, laddr) 132 | if err != nil { 133 | return nil, err 134 | } 135 | n.active = append(n.active, l) 136 | return l, nil 137 | } 138 | 139 | // ListenUnix announces on the local network address laddr. The network net 140 | // must be a: "unix" or "unixpacket". It returns an inherited net.Listener for 141 | // the matching network and address, or creates a new one using net.ListenUnix. 142 | func (n *Net) ListenUnix(nett string, laddr *net.UnixAddr) (*net.UnixListener, error) { 143 | if err := n.inherit(); err != nil { 144 | return nil, err 145 | } 146 | 147 | n.mutex.Lock() 148 | defer n.mutex.Unlock() 149 | 150 | // look for an inherited listener 151 | for i, l := range n.inherited { 152 | if l == nil { // we nil used inherited listeners 153 | continue 154 | } 155 | if isSameAddr(l.Addr(), laddr) { 156 | n.inherited[i] = nil 157 | n.active = append(n.active, l) 158 | return l.(*net.UnixListener), nil 159 | } 160 | } 161 | 162 | // make a fresh listener 163 | l, err := net.ListenUnix(nett, laddr) 164 | if err != nil { 165 | return nil, err 166 | } 167 | n.active = append(n.active, l) 168 | return l, nil 169 | } 170 | 171 | // activeListeners returns a snapshot copy of the active listeners. 172 | func (n *Net) activeListeners() ([]net.Listener, error) { 173 | n.mutex.Lock() 174 | defer n.mutex.Unlock() 175 | ls := make([]net.Listener, len(n.active)) 176 | copy(ls, n.active) 177 | return ls, nil 178 | } 179 | 180 | func isSameAddr(a1, a2 net.Addr) bool { 181 | if a1.Network() != a2.Network() { 182 | return false 183 | } 184 | a1s := a1.String() 185 | a2s := a2.String() 186 | if a1s == a2s { 187 | return true 188 | } 189 | 190 | // This allows for ipv6 vs ipv4 local addresses to compare as equal. This 191 | // scenario is common when listening on localhost. 192 | const ipv6prefix = "[::]" 193 | a1s = strings.TrimPrefix(a1s, ipv6prefix) 194 | a2s = strings.TrimPrefix(a2s, ipv6prefix) 195 | const ipv4prefix = "0.0.0.0" 196 | a1s = strings.TrimPrefix(a1s, ipv4prefix) 197 | a2s = strings.TrimPrefix(a2s, ipv4prefix) 198 | return a1s == a2s 199 | } 200 | 201 | // StartProcess starts a new process passing it the active listeners. It 202 | // doesn't fork, but starts a new process using the same environment and 203 | // arguments as when it was originally started. This allows for a newly 204 | // deployed binary to be started. It returns the pid of the newly started 205 | // process when successful. 206 | func (n *Net) StartProcess() (int, error) { 207 | listeners, err := n.activeListeners() 208 | if err != nil { 209 | return 0, err 210 | } 211 | 212 | // Extract the fds from the listeners. 213 | files := make([]*os.File, len(listeners)) 214 | for i, l := range listeners { 215 | files[i], err = l.(filer).File() 216 | if err != nil { 217 | return 0, err 218 | } 219 | defer files[i].Close() 220 | } 221 | 222 | // Use the original binary location. This works with symlinks such that if 223 | // the file it points to has been changed we will use the updated symlink. 224 | argv0, err := exec.LookPath(os.Args[0]) 225 | if err != nil { 226 | return 0, err 227 | } 228 | 229 | // Pass on the environment and replace the old count key with the new one. 230 | var env []string 231 | for _, v := range os.Environ() { 232 | if !strings.HasPrefix(v, envCountKeyPrefix) { 233 | env = append(env, v) 234 | } 235 | } 236 | env = append(env, fmt.Sprintf("%s%d", envCountKeyPrefix, len(listeners))) 237 | 238 | allFiles := append([]*os.File{os.Stdin, os.Stdout, os.Stderr}, files...) 239 | process, err := os.StartProcess(argv0, os.Args, &os.ProcAttr{ 240 | Dir: originalWD, 241 | Env: env, 242 | Files: allFiles, 243 | }) 244 | if err != nil { 245 | return 0, err 246 | } 247 | return process.Pid, nil 248 | } 249 | 250 | type filer interface { 251 | File() (*os.File, error) 252 | } 253 | -------------------------------------------------------------------------------- /gracenet/net_test.go: -------------------------------------------------------------------------------- 1 | package gracenet 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "net" 7 | "os" 8 | "path/filepath" 9 | "regexp" 10 | "syscall" 11 | "testing" 12 | 13 | "github.com/facebookgo/ensure" 14 | "github.com/facebookgo/freeport" 15 | ) 16 | 17 | func TestEmptyCountEnvVariable(t *testing.T) { 18 | var n Net 19 | os.Setenv(envCountKey, "") 20 | ensure.Nil(t, n.inherit()) 21 | } 22 | 23 | func TestZeroCountEnvVariable(t *testing.T) { 24 | var n Net 25 | os.Setenv(envCountKey, "0") 26 | ensure.Nil(t, n.inherit()) 27 | } 28 | 29 | func TestInvalidCountEnvVariable(t *testing.T) { 30 | var n Net 31 | os.Setenv(envCountKey, "a") 32 | expected := regexp.MustCompile("^found invalid count value: LISTEN_FDS=a$") 33 | ensure.Err(t, n.inherit(), expected) 34 | } 35 | 36 | func TestInvalidFileInherit(t *testing.T) { 37 | var n Net 38 | tmpfile, err := ioutil.TempFile("", "TestInvalidFileInherit-") 39 | ensure.Nil(t, err) 40 | defer os.Remove(tmpfile.Name()) 41 | n.fdStart = dup(t, int(tmpfile.Fd())) 42 | os.Setenv(envCountKey, "1") 43 | ensure.Err(t, n.inherit(), regexp.MustCompile("^error inheriting socket fd")) 44 | ensure.DeepEqual(t, len(n.inherited), 0) 45 | ensure.Nil(t, tmpfile.Close()) 46 | } 47 | 48 | func TestInheritErrorOnListenTCPWithInvalidCount(t *testing.T) { 49 | var n Net 50 | os.Setenv(envCountKey, "a") 51 | _, err := n.Listen("tcp", ":0") 52 | ensure.NotNil(t, err) 53 | } 54 | 55 | func TestInheritErrorOnListenUnixWithInvalidCount(t *testing.T) { 56 | var n Net 57 | os.Setenv(envCountKey, "a") 58 | tmpdir, err := ioutil.TempDir("", "TestInheritErrorOnListenUnixWithInvalidCount-") 59 | ensure.Nil(t, err) 60 | ensure.Nil(t, os.RemoveAll(tmpdir)) 61 | _, err = n.Listen("unix", filepath.Join(tmpdir, "socket")) 62 | ensure.NotNil(t, err) 63 | } 64 | 65 | func TestOneTcpInherit(t *testing.T) { 66 | var n Net 67 | l, err := net.Listen("tcp", ":0") 68 | ensure.Nil(t, err) 69 | file, err := l.(*net.TCPListener).File() 70 | ensure.Nil(t, err) 71 | ensure.Nil(t, l.Close()) 72 | n.fdStart = dup(t, int(file.Fd())) 73 | ensure.Nil(t, file.Close()) 74 | os.Setenv(envCountKey, "1") 75 | ensure.Nil(t, n.inherit()) 76 | ensure.DeepEqual(t, len(n.inherited), 1) 77 | l, err = n.Listen("tcp", l.Addr().String()) 78 | ensure.Nil(t, err) 79 | ensure.DeepEqual(t, len(n.active), 1) 80 | ensure.DeepEqual(t, n.inherited[0], nil) 81 | active, err := n.activeListeners() 82 | ensure.Nil(t, err) 83 | ensure.DeepEqual(t, len(active), 1) 84 | ensure.Nil(t, l.Close()) 85 | } 86 | 87 | func TestSecondTcpListen(t *testing.T) { 88 | var n Net 89 | os.Setenv(envCountKey, "") 90 | l, err := n.Listen("tcp", ":0") 91 | ensure.Nil(t, err) 92 | _, err = n.Listen("tcp", l.Addr().String()) 93 | ensure.Err(t, err, regexp.MustCompile("bind: address already in use$")) 94 | ensure.Nil(t, l.Close()) 95 | } 96 | 97 | func TestSecondTcpListenInherited(t *testing.T) { 98 | var n Net 99 | l, err := net.Listen("tcp", ":0") 100 | ensure.Nil(t, err) 101 | file, err := l.(*net.TCPListener).File() 102 | ensure.Nil(t, err) 103 | ensure.Nil(t, l.Close()) 104 | n.fdStart = dup(t, int(file.Fd())) 105 | ensure.Nil(t, file.Close()) 106 | os.Setenv(envCountKey, "1") 107 | ensure.Nil(t, n.inherit()) 108 | ensure.DeepEqual(t, len(n.inherited), 1) 109 | l, err = n.Listen("tcp", l.Addr().String()) 110 | ensure.Nil(t, err) 111 | ensure.DeepEqual(t, len(n.active), 1) 112 | ensure.DeepEqual(t, n.inherited[0], nil) 113 | _, err = n.Listen("tcp", l.Addr().String()) 114 | ensure.Err(t, err, regexp.MustCompile("bind: address already in use$")) 115 | ensure.Nil(t, l.Close()) 116 | } 117 | 118 | func TestInvalidNetwork(t *testing.T) { 119 | var n Net 120 | os.Setenv(envCountKey, "") 121 | _, err := n.Listen("foo", "") 122 | ensure.Err(t, err, regexp.MustCompile("^unknown network foo$")) 123 | } 124 | 125 | func TestInvalidNetworkUnix(t *testing.T) { 126 | var n Net 127 | os.Setenv(envCountKey, "") 128 | _, err := n.Listen("invalid_unix_net_for_test", "") 129 | ensure.Err(t, err, regexp.MustCompile("^unknown network invalid_unix_net_for_test$")) 130 | } 131 | 132 | func TestWithTcp0000(t *testing.T) { 133 | var n Net 134 | port, err := freeport.Get() 135 | ensure.Nil(t, err) 136 | addr := fmt.Sprintf("0.0.0.0:%d", port) 137 | l, err := net.Listen("tcp", addr) 138 | ensure.Nil(t, err) 139 | file, err := l.(*net.TCPListener).File() 140 | ensure.Nil(t, err) 141 | ensure.Nil(t, l.Close()) 142 | n.fdStart = dup(t, int(file.Fd())) 143 | ensure.Nil(t, file.Close()) 144 | os.Setenv(envCountKey, "1") 145 | ensure.Nil(t, n.inherit()) 146 | ensure.DeepEqual(t, len(n.inherited), 1) 147 | l, err = n.Listen("tcp", addr) 148 | ensure.Nil(t, err) 149 | ensure.DeepEqual(t, len(n.active), 1) 150 | ensure.DeepEqual(t, n.inherited[0], nil) 151 | ensure.Nil(t, l.Close()) 152 | } 153 | 154 | func TestWithTcpIPv6Loal(t *testing.T) { 155 | var n Net 156 | l, err := net.Listen("tcp", "[::]:0") 157 | ensure.Nil(t, err) 158 | file, err := l.(*net.TCPListener).File() 159 | ensure.Nil(t, err) 160 | ensure.Nil(t, l.Close()) 161 | n.fdStart = dup(t, int(file.Fd())) 162 | ensure.Nil(t, file.Close()) 163 | os.Setenv(envCountKey, "1") 164 | ensure.Nil(t, n.inherit()) 165 | ensure.DeepEqual(t, len(n.inherited), 1) 166 | l, err = n.Listen("tcp", l.Addr().String()) 167 | ensure.Nil(t, err) 168 | ensure.DeepEqual(t, len(n.active), 1) 169 | ensure.DeepEqual(t, n.inherited[0], nil) 170 | ensure.Nil(t, l.Close()) 171 | } 172 | 173 | func TestOneUnixInherit(t *testing.T) { 174 | var n Net 175 | tmpfile, err := ioutil.TempFile("", "TestOneUnixInherit-") 176 | ensure.Nil(t, err) 177 | ensure.Nil(t, tmpfile.Close()) 178 | ensure.Nil(t, os.Remove(tmpfile.Name())) 179 | defer os.Remove(tmpfile.Name()) 180 | l, err := net.Listen("unix", tmpfile.Name()) 181 | ensure.Nil(t, err) 182 | file, err := l.(*net.UnixListener).File() 183 | ensure.Nil(t, err) 184 | ensure.Nil(t, l.Close()) 185 | n.fdStart = dup(t, int(file.Fd())) 186 | ensure.Nil(t, file.Close()) 187 | os.Setenv(envCountKey, "1") 188 | ensure.Nil(t, n.inherit()) 189 | ensure.DeepEqual(t, len(n.inherited), 1) 190 | l, err = n.Listen("unix", tmpfile.Name()) 191 | ensure.Nil(t, err) 192 | ensure.DeepEqual(t, len(n.active), 1) 193 | ensure.DeepEqual(t, n.inherited[0], nil) 194 | ensure.Nil(t, l.Close()) 195 | } 196 | 197 | func TestInvalidTcpAddr(t *testing.T) { 198 | var n Net 199 | os.Setenv(envCountKey, "") 200 | _, err := n.Listen("tcp", "abc") 201 | ensure.Err(t, err, regexp.MustCompile("^missing port in address abc$")) 202 | } 203 | 204 | func TestTwoTCP(t *testing.T) { 205 | var n Net 206 | 207 | port1, err := freeport.Get() 208 | ensure.Nil(t, err) 209 | addr1 := fmt.Sprintf(":%d", port1) 210 | l1, err := net.Listen("tcp", addr1) 211 | ensure.Nil(t, err) 212 | 213 | port2, err := freeport.Get() 214 | ensure.Nil(t, err) 215 | addr2 := fmt.Sprintf(":%d", port2) 216 | l2, err := net.Listen("tcp", addr2) 217 | ensure.Nil(t, err) 218 | 219 | file1, err := l1.(*net.TCPListener).File() 220 | ensure.Nil(t, err) 221 | file2, err := l2.(*net.TCPListener).File() 222 | ensure.Nil(t, err) 223 | 224 | // assign both to prevent GC from kicking in the finalizer 225 | fds := []int{dup(t, int(file1.Fd())), dup(t, int(file2.Fd()))} 226 | n.fdStart = fds[0] 227 | os.Setenv(envCountKey, "2") 228 | 229 | // Close these after to ensure we get coalaced file descriptors. 230 | ensure.Nil(t, l1.Close()) 231 | ensure.Nil(t, l2.Close()) 232 | 233 | ensure.Nil(t, n.inherit()) 234 | ensure.DeepEqual(t, len(n.inherited), 2) 235 | 236 | l1, err = n.Listen("tcp", addr1) 237 | ensure.Nil(t, err) 238 | ensure.DeepEqual(t, len(n.active), 1) 239 | ensure.DeepEqual(t, n.inherited[0], nil) 240 | ensure.Nil(t, l1.Close()) 241 | ensure.Nil(t, file1.Close()) 242 | 243 | l2, err = n.Listen("tcp", addr2) 244 | ensure.Nil(t, err) 245 | ensure.DeepEqual(t, len(n.active), 2) 246 | ensure.DeepEqual(t, n.inherited[1], nil) 247 | ensure.Nil(t, l2.Close()) 248 | ensure.Nil(t, file2.Close()) 249 | } 250 | 251 | func TestOneUnixAndOneTcpInherit(t *testing.T) { 252 | var n Net 253 | 254 | tmpfile, err := ioutil.TempFile("", "TestOneUnixAndOneTcpInherit-") 255 | ensure.Nil(t, err) 256 | ensure.Nil(t, tmpfile.Close()) 257 | ensure.Nil(t, os.Remove(tmpfile.Name())) 258 | defer os.Remove(tmpfile.Name()) 259 | unixL, err := net.Listen("unix", tmpfile.Name()) 260 | ensure.Nil(t, err) 261 | 262 | port, err := freeport.Get() 263 | ensure.Nil(t, err) 264 | addr := fmt.Sprintf(":%d", port) 265 | tcpL, err := net.Listen("tcp", addr) 266 | ensure.Nil(t, err) 267 | 268 | tcpF, err := tcpL.(*net.TCPListener).File() 269 | ensure.Nil(t, err) 270 | unixF, err := unixL.(*net.UnixListener).File() 271 | ensure.Nil(t, err) 272 | 273 | // assign both to prevent GC from kicking in the finalizer 274 | fds := []int{dup(t, int(tcpF.Fd())), dup(t, int(unixF.Fd()))} 275 | n.fdStart = fds[0] 276 | os.Setenv(envCountKey, "2") 277 | 278 | // Close these after to ensure we get coalaced file descriptors. 279 | ensure.Nil(t, tcpL.Close()) 280 | ensure.Nil(t, unixL.Close()) 281 | 282 | ensure.Nil(t, n.inherit()) 283 | ensure.DeepEqual(t, len(n.inherited), 2) 284 | 285 | unixL, err = n.Listen("unix", tmpfile.Name()) 286 | ensure.Nil(t, err) 287 | ensure.DeepEqual(t, len(n.active), 1) 288 | ensure.DeepEqual(t, n.inherited[1], nil) 289 | ensure.Nil(t, unixL.Close()) 290 | ensure.Nil(t, unixF.Close()) 291 | 292 | tcpL, err = n.Listen("tcp", addr) 293 | ensure.Nil(t, err) 294 | ensure.DeepEqual(t, len(n.active), 2) 295 | ensure.DeepEqual(t, n.inherited[0], nil) 296 | ensure.Nil(t, tcpL.Close()) 297 | ensure.Nil(t, tcpF.Close()) 298 | } 299 | 300 | func TestSecondUnixListen(t *testing.T) { 301 | var n Net 302 | os.Setenv(envCountKey, "") 303 | tmpfile, err := ioutil.TempFile("", "TestSecondUnixListen-") 304 | ensure.Nil(t, err) 305 | ensure.Nil(t, tmpfile.Close()) 306 | ensure.Nil(t, os.Remove(tmpfile.Name())) 307 | defer os.Remove(tmpfile.Name()) 308 | l, err := n.Listen("unix", tmpfile.Name()) 309 | ensure.Nil(t, err) 310 | _, err = n.Listen("unix", tmpfile.Name()) 311 | ensure.Err(t, err, regexp.MustCompile("bind: address already in use$")) 312 | ensure.Nil(t, l.Close()) 313 | } 314 | 315 | func TestSecondUnixListenInherited(t *testing.T) { 316 | var n Net 317 | tmpfile, err := ioutil.TempFile("", "TestSecondUnixListenInherited-") 318 | ensure.Nil(t, err) 319 | ensure.Nil(t, tmpfile.Close()) 320 | ensure.Nil(t, os.Remove(tmpfile.Name())) 321 | defer os.Remove(tmpfile.Name()) 322 | l1, err := net.Listen("unix", tmpfile.Name()) 323 | ensure.Nil(t, err) 324 | file, err := l1.(*net.UnixListener).File() 325 | ensure.Nil(t, err) 326 | n.fdStart = dup(t, int(file.Fd())) 327 | ensure.Nil(t, file.Close()) 328 | os.Setenv(envCountKey, "1") 329 | ensure.Nil(t, n.inherit()) 330 | ensure.DeepEqual(t, len(n.inherited), 1) 331 | l2, err := n.Listen("unix", tmpfile.Name()) 332 | ensure.Nil(t, err) 333 | ensure.DeepEqual(t, len(n.active), 1) 334 | ensure.DeepEqual(t, n.inherited[0], nil) 335 | _, err = n.Listen("unix", tmpfile.Name()) 336 | ensure.Err(t, err, regexp.MustCompile("bind: address already in use$")) 337 | ensure.Nil(t, l1.Close()) 338 | ensure.Nil(t, l2.Close()) 339 | } 340 | 341 | func TestPortZeroTwice(t *testing.T) { 342 | var n Net 343 | os.Setenv(envCountKey, "") 344 | l1, err := n.Listen("tcp", ":0") 345 | ensure.Nil(t, err) 346 | l2, err := n.Listen("tcp", ":0") 347 | ensure.Nil(t, err) 348 | ensure.Nil(t, l1.Close()) 349 | ensure.Nil(t, l2.Close()) 350 | } 351 | 352 | // We dup file descriptors because the os.Files are closed by a finalizer when 353 | // they are GCed, which interacts badly with the fact that the OS reuses fds, 354 | // and that we emulating inheriting the fd by it's integer value in our tests. 355 | func dup(t *testing.T, fd int) int { 356 | nfd, err := syscall.Dup(fd) 357 | ensure.Nil(t, err) 358 | return nfd 359 | } 360 | -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2013-present, Facebook, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | grace [![Build Status](https://secure.travis-ci.org/facebookgo/grace.png)](https://travis-ci.org/facebookgo/grace) 2 | ===== 3 | 4 | Package grace provides a library that makes it easy to build socket 5 | based servers that can be gracefully terminated & restarted (that is, 6 | without dropping any connections). 7 | 8 | It provides a convenient API for HTTP servers including support for TLS, 9 | especially if you need to listen on multiple ports (for example a secondary 10 | internal only admin server). Additionally it is implemented using the same API 11 | as systemd providing [socket 12 | activation](http://0pointer.de/blog/projects/socket-activation.html) 13 | compatibility to also provide lazy activation of the server. 14 | 15 | 16 | Usage 17 | ----- 18 | 19 | Demo HTTP Server with graceful termination and restart: 20 | https://github.com/facebookgo/grace/blob/master/gracedemo/demo.go 21 | 22 | 1. Install the demo application 23 | 24 | go get github.com/facebookgo/grace/gracedemo 25 | 26 | 1. Start it in the first terminal 27 | 28 | gracedemo 29 | 30 | This will output something like: 31 | 32 | 2013/03/25 19:07:33 Serving [::]:48567, [::]:48568, [::]:48569 with pid 14642. 33 | 34 | 1. In a second terminal start a slow HTTP request 35 | 36 | curl 'http://localhost:48567/sleep/?duration=20s' 37 | 38 | 1. In a third terminal trigger a graceful server restart (using the pid from your output): 39 | 40 | kill -USR2 14642 41 | 42 | 1. Trigger another shorter request that finishes before the earlier request: 43 | 44 | curl 'http://localhost:48567/sleep/?duration=0s' 45 | 46 | 47 | If done quickly enough, this shows the second quick request will be served by 48 | the new process (as indicated by the PID) while the slow first request will be 49 | served by the first server. It shows how the active connection was gracefully 50 | served before the server was shutdown. It is also showing that at one point 51 | both the new as well as the old server was running at the same time. 52 | 53 | 54 | Documentation 55 | ------------- 56 | 57 | `http.Server` graceful termination and restart: 58 | https://godoc.org/github.com/facebookgo/grace/gracehttp 59 | 60 | `net.Listener` graceful termination and restart: 61 | https://godoc.org/github.com/facebookgo/grace/gracenet 62 | --------------------------------------------------------------------------------