├── examples ├── composite │ ├── .gitignore │ ├── main.go │ └── README.md ├── http │ ├── .gitignore │ ├── main_test.go │ └── main.go ├── httpcluster │ ├── .gitignore │ ├── README.md │ └── main_test.go ├── custom_middleware │ ├── .gitignore │ ├── README.md │ └── example │ │ └── jsonenforcer.go └── README.md ├── .github ├── renovate.json └── workflows │ ├── dependency-review.yml │ └── go.yml ├── .gitignore ├── .gitattributes ├── go.mod ├── runnables ├── composite │ ├── options.go │ ├── errors.go │ ├── options_test.go │ ├── state.go │ ├── config.go │ ├── reload.go │ ├── config_test.go │ └── README.md ├── httpserver │ ├── middleware │ │ ├── state │ │ │ ├── state.go │ │ │ └── state_test.go │ │ ├── metrics │ │ │ ├── metrics.go │ │ │ └── metrics_test.go │ │ ├── logger │ │ │ ├── logger.go │ │ │ └── logger_test.go │ │ ├── recovery │ │ │ ├── recovery.go │ │ │ └── recovery_test.go │ │ ├── wildcard │ │ │ └── wildcard.go │ │ └── headers │ │ │ ├── headers.go │ │ │ └── options.go │ ├── errors.go │ ├── helpers_test.go │ ├── state.go │ ├── runner_readiness_test.go │ ├── options.go │ ├── request_processor.go │ ├── reload.go │ ├── runner_boot_test.go │ ├── mocks_test.go │ ├── response_writer.go │ ├── runner_context_test.go │ ├── routes.go │ ├── runner_race_test.go │ ├── integration_race_test.go │ ├── options_test.go │ ├── state_mocked_test.go │ └── config.go ├── httpcluster │ ├── state.go │ ├── interfaces.go │ ├── options.go │ ├── options_test.go │ ├── state_test.go │ └── runner_context_test.go └── mocks │ └── mocks.go ├── Makefile ├── go.sum ├── supervisor ├── shutdown.go ├── interfaces.go ├── reload.go ├── state_test.go ├── state_deduplication_test.go ├── state_monitoring_test.go └── reload_test.go ├── internal ├── networking │ ├── portfinder.go │ ├── port.go │ └── port_test.go └── finitestate │ └── machine.go ├── .golangci.yml └── README.md /examples/composite/.gitignore: -------------------------------------------------------------------------------- 1 | composite -------------------------------------------------------------------------------- /examples/http/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore the compiled binary 2 | http -------------------------------------------------------------------------------- /examples/httpcluster/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore the compiled binary 2 | httpcluster 3 | -------------------------------------------------------------------------------- /examples/custom_middleware/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore the compiled binary 2 | custom_middleware 3 | -------------------------------------------------------------------------------- /.github/renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": [ 4 | "config:recommended" 5 | ], 6 | "postUpdateOptions": ["gomodTidy"] 7 | } 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | /bin 3 | config.toml 4 | *.log 5 | 6 | # Output of the go coverage tool 7 | *.out 8 | 9 | # Temporary files 10 | *swp 11 | .DS_Store 12 | 13 | # IDE/editor directories and files 14 | .vscode/ 15 | .idea/ 16 | *.iml 17 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Documentation 2 | *.md linguist-documentation 3 | 4 | # Go source files 5 | *.go text diff=go 6 | 7 | # Makefiles need to always use Unix-style line endings 8 | Makefile text eol=lf 9 | *.mk text eol=lf 10 | 11 | # Protobuf 12 | *.proto text diff=proto 13 | 14 | # Binary files 15 | *.pb.go binary 16 | 17 | # Scripts 18 | *.sh text eol=lf -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/robbyt/go-supervisor 2 | 3 | go 1.25.4 4 | 5 | require ( 6 | github.com/robbyt/go-fsm/v2 v2.3.0 7 | github.com/stretchr/testify v1.11.1 8 | ) 9 | 10 | require ( 11 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 12 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 13 | github.com/stretchr/objx v0.5.3 // indirect 14 | gopkg.in/yaml.v3 v3.0.1 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /runnables/composite/options.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import ( 4 | "log/slog" 5 | ) 6 | 7 | // Option represents a functional option for configuring CompositeRunner 8 | type Option[T runnable] func(*Runner[T]) 9 | 10 | // WithLogHandler sets a custom slog handler for the CompositeRunner instance. 11 | func WithLogHandler[T runnable](handler slog.Handler) Option[T] { 12 | return func(c *Runner[T]) { 13 | if handler != nil { 14 | c.logger = slog.New(handler.WithGroup("composite.Runner")) 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/state/state.go: -------------------------------------------------------------------------------- 1 | package state 2 | 3 | import "github.com/robbyt/go-supervisor/runnables/httpserver" 4 | 5 | // New creates a middleware that adds the current server state 6 | // to the response headers. 7 | func New(stateProvider func() string) httpserver.HandlerFunc { 8 | return func(rp *httpserver.RequestProcessor) { 9 | // Add state header if we have a state provider 10 | if stateProvider != nil { 11 | state := stateProvider() 12 | rp.Writer().Header().Set("X-Server-State", state) 13 | } 14 | 15 | rp.Next() 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /runnables/composite/errors.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import "errors" 4 | 5 | var ( 6 | // ErrCompositeRunnable is returned when there's a general error in the composite runnable 7 | ErrCompositeRunnable = errors.New("composite runnable error") 8 | 9 | // ErrRunnableFailed is returned when a child runnable fails 10 | ErrRunnableFailed = errors.New("child runnable failed") 11 | 12 | // ErrConfigMissing is returned when the config is missing 13 | ErrConfigMissing = errors.New("config is missing") 14 | 15 | // ErrOldConfig is returned when the config hasn't changed during a reload 16 | ErrOldConfig = errors.New("configuration unchanged") 17 | ) 18 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Several working examples of go-supervisor usage. 4 | 5 | ## [http](./http/) 6 | Basic HTTP server with graceful shutdown and configuration reloading. 7 | 8 | ## [custom_middleware](./custom_middleware/) 9 | HTTP server with middleware that transforms responses. 10 | 11 | ## [composite](./composite/) 12 | Multiple dynamic services managed as a single unit, using Generics. 13 | 14 | ## [httpcluster](./httpcluster/) 15 | Similar to composite, but designed specifically for running several `httpserver` instances, with a channel-based config "siphon" for dynamic updates. 16 | 17 | ## Running 18 | 19 | ```bash 20 | go run ./examples/ 21 | ``` -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all 2 | all: help 3 | 4 | ## help: Display this help message 5 | .PHONY: help 6 | help: Makefile 7 | @echo 8 | @echo " Choose a make command to run" 9 | @echo 10 | @sed -n 's/^##//p' $< | column -t -s ':' | sed -e 's/^/ /' 11 | @echo 12 | 13 | ## test: Run tests with race detection and coverage 14 | .PHONY: test 15 | test: 16 | go test -timeout 3m -race -cover ./... 17 | 18 | ## bench: Run performance benchmarks 19 | .PHONY: bench 20 | bench: 21 | go test -timeout 2m -run=^$$ -bench=. -benchmem ./... 22 | 23 | ## lint: Run golangci-lint code quality checks 24 | .PHONY: lint 25 | lint: 26 | golangci-lint run ./... 27 | 28 | ## lint-fix: Run golangci-lint with auto-fix for common issues 29 | .PHONY: lint-fix 30 | lint-fix: 31 | golangci-lint fmt 32 | golangci-lint run --fix ./... 33 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/metrics/metrics.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "log/slog" 5 | "time" 6 | 7 | "github.com/robbyt/go-supervisor/runnables/httpserver" 8 | ) 9 | 10 | // New creates a middleware that collects metrics about HTTP requests. 11 | // This is a placeholder implementation that can be extended to integrate with 12 | // your metrics collection system. 13 | func New() httpserver.HandlerFunc { 14 | return func(rp *httpserver.RequestProcessor) { 15 | start := time.Now() 16 | 17 | // Process request 18 | rp.Next() 19 | 20 | // Record metrics 21 | duration := time.Since(start) 22 | req := rp.Request() 23 | writer := rp.Writer() 24 | 25 | // Log the metrics 26 | slog.Debug("HTTP request metrics", 27 | "method", req.Method, 28 | "path", req.URL.Path, 29 | "status", writer.Status(), 30 | "duration_ms", duration.Milliseconds(), 31 | ) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "log/slog" 5 | "time" 6 | 7 | "github.com/robbyt/go-supervisor/runnables/httpserver" 8 | ) 9 | 10 | // New creates a middleware that logs information about HTTP requests. 11 | // It logs the request method, path, status code, and response time. 12 | func New(handler slog.Handler) httpserver.HandlerFunc { 13 | if handler == nil { 14 | handler = slog.Default().Handler() 15 | } 16 | 17 | logger := slog.New(handler).WithGroup("httpserver") 18 | 19 | return func(rp *httpserver.RequestProcessor) { 20 | start := time.Now() 21 | 22 | // Process request 23 | rp.Next() 24 | 25 | // Log after request is processed 26 | duration := time.Since(start) 27 | req := rp.Request() 28 | writer := rp.Writer() 29 | 30 | logger.Info("HTTP request", 31 | "method", req.Method, 32 | "path", req.URL.Path, 33 | "status", writer.Status(), 34 | "duration", duration, 35 | "size", writer.Size(), 36 | "user_agent", req.UserAgent(), 37 | "remote_addr", req.RemoteAddr, 38 | ) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /runnables/composite/options_test.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // Mock runnable implementation for testing 12 | type mockRunnable struct{} 13 | 14 | // Required method implementations to satisfy the runnable interface constraint 15 | func (m *mockRunnable) Run(ctx context.Context) error { return nil } 16 | func (m *mockRunnable) Stop() {} 17 | func (m *mockRunnable) String() string { return "mockRunnable" } 18 | 19 | func TestWithLogHandler(t *testing.T) { 20 | t.Parallel() 21 | 22 | testHandler := slog.NewTextHandler(nil, nil) 23 | runner := &Runner[*mockRunnable]{ 24 | logger: slog.Default(), 25 | } 26 | 27 | WithLogHandler[*mockRunnable](testHandler)(runner) 28 | assert.NotEqual(t, slog.Default(), runner.logger, "Logger should be changed") 29 | 30 | runner = &Runner[*mockRunnable]{ 31 | logger: slog.Default(), 32 | } 33 | WithLogHandler[*mockRunnable](nil)(runner) 34 | assert.Equal(t, slog.Default(), runner.logger, "Logger should not change with nil handler") 35 | } 36 | -------------------------------------------------------------------------------- /runnables/httpserver/errors.go: -------------------------------------------------------------------------------- 1 | // Package httpserver provides a configurable, reloadable HTTP server implementation 2 | // that can be managed by the supervisor package. 3 | package httpserver 4 | 5 | import "errors" 6 | 7 | var ( 8 | ErrNoConfig = errors.New("no config provided") 9 | ErrNoHandlers = errors.New("no handlers provided") 10 | ErrGracefulShutdown = errors.New("graceful shutdown failed") 11 | ErrGracefulShutdownTimeout = errors.New("graceful shutdown deadline reached") 12 | ErrHttpServer = errors.New("http server error") 13 | ErrOldConfig = errors.New("config hasn't changed since last update") 14 | ErrRetrieveConfig = errors.New("failed to retrieve server configuration") 15 | ErrCreateConfig = errors.New("failed to create server configuration") 16 | ErrServerNotRunning = errors.New("http server is not running") 17 | ErrServerReadinessTimeout = errors.New("server readiness check timed out") 18 | ErrServerBoot = errors.New("failed to start HTTP server") 19 | ErrConfigCallbackNil = errors.New("config callback returned nil") 20 | ErrConfigCallback = errors.New("failed to load configuration from callback") 21 | ErrStateTransition = errors.New("state transition failed") 22 | ) 23 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/recovery/recovery.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "log/slog" 5 | "net/http" 6 | 7 | "github.com/robbyt/go-supervisor/runnables/httpserver" 8 | ) 9 | 10 | // New creates a middleware that recovers from panics in HTTP handlers. 11 | // It returns a 500 Internal Server Error response when panics occur. 12 | // If handler is provided, panics are logged. If handler is nil, recovery is silent. 13 | func New(handler slog.Handler) httpserver.HandlerFunc { 14 | // Default logger that does nothing if no handler is provided 15 | // This allows the middleware to be used without logging if passed a nil handler. 16 | logger := func(err any, path string, method string) {} 17 | 18 | if handler != nil { 19 | logger = func(err any, path string, method string) { 20 | slogger := slog.New(handler) 21 | slogger.Error("HTTP handler panic recovered", 22 | "error", err, 23 | "path", path, 24 | "method", method, 25 | ) 26 | } 27 | } 28 | 29 | return func(rp *httpserver.RequestProcessor) { 30 | defer func() { 31 | if err := recover(); err != nil { 32 | req := rp.Request() 33 | writer := rp.Writer() 34 | logger(err, req.URL.Path, req.Method) 35 | http.Error(writer, "Internal Server Error", http.StatusInternalServerError) 36 | rp.Abort() 37 | } 38 | }() 39 | 40 | rp.Next() 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= 2 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= 4 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/robbyt/go-fsm/v2 v2.3.0 h1:XfgDozIDydgfsPulqwixrSpwhlnTxGAmNgyr+s2vQn8= 6 | github.com/robbyt/go-fsm/v2 v2.3.0/go.mod h1:rnc9GEyIJm7OjLQampMh47gXXhLuc2+hKjgwVqP2fx4= 7 | github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4= 8 | github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0= 9 | github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= 10 | github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= 11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 12 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 13 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 14 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 15 | -------------------------------------------------------------------------------- /supervisor/shutdown.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | import "sync" 4 | 5 | // startShutdownManager starts goroutines to listen for shutdown notifications 6 | // from any runnables that implement ShutdownSender. It blocks until the context is done. 7 | func (p *PIDZero) startShutdownManager() { 8 | p.logger.Debug("Starting shutdown manager...") 9 | 10 | var shutdownWg sync.WaitGroup 11 | 12 | // Find all runnables that can send shutdown signals and start listeners 13 | for _, r := range p.runnables { 14 | if sdSender, ok := r.(ShutdownSender); ok { 15 | shutdownWg.Add(1) 16 | // Pass both Runnable 'r' and ShutdownSender 's' for clarity 17 | go func(r Runnable, s ShutdownSender) { 18 | defer shutdownWg.Done() 19 | triggerChan := s.GetShutdownTrigger() 20 | for { 21 | select { 22 | case <-p.ctx.Done(): 23 | return 24 | case <-triggerChan: 25 | p.logger.Info("Shutdown requested by runnable", "runnable", r) 26 | p.Shutdown() // Trigger supervisor shutdown 27 | return // Exit this goroutine after triggering shutdown 28 | } 29 | } 30 | }(r, sdSender) // Pass both variables 31 | } 32 | } 33 | 34 | // Block until context is done, then wait for listener goroutines to finish 35 | <-p.ctx.Done() 36 | p.logger.Debug( 37 | "Shutdown manager received context done signal, waiting for listeners to exit...", 38 | ) 39 | shutdownWg.Wait() 40 | p.logger.Debug("Shutdown manager complete.") 41 | } 42 | -------------------------------------------------------------------------------- /runnables/httpcluster/state.go: -------------------------------------------------------------------------------- 1 | package httpcluster 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/robbyt/go-supervisor/internal/finitestate" 7 | ) 8 | 9 | // GetState returns the current state of the cluster. 10 | func (r *Runner) GetState() string { 11 | return r.fsm.GetState() 12 | } 13 | 14 | // GetStateChan returns a channel that receives state updates. 15 | func (r *Runner) GetStateChan(ctx context.Context) <-chan string { 16 | return r.fsm.GetStateChan(ctx) 17 | } 18 | 19 | // GetStateChanWithTimeout returns a channel that emits state changes from the Runner. 20 | // The channel is closed when the provided context is canceled. 21 | func (r *Runner) GetStateChanWithTimeout(ctx context.Context) <-chan string { 22 | return r.fsm.GetStateChan(ctx) 23 | } 24 | 25 | // IsRunning returns true if the cluster is in the Running state. 26 | func (r *Runner) IsRunning() bool { 27 | return r.fsm.GetState() == finitestate.StatusRunning 28 | } 29 | 30 | // setStateError transitions the state machine to the Error state. 31 | func (r *Runner) setStateError() { 32 | if r.fsm.TransitionBool(finitestate.StatusError) { 33 | return 34 | } 35 | 36 | r.logger.Debug("Using SetState to force Error state") 37 | if err := r.fsm.SetState(finitestate.StatusError); err != nil { 38 | r.logger.Error("Failed to set Error state", "error", err) 39 | 40 | if err := r.fsm.SetState(finitestate.StatusUnknown); err != nil { 41 | r.logger.Error("Failed to set Unknown state", "error", err) 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /examples/httpcluster/README.md: -------------------------------------------------------------------------------- 1 | # HTTP Cluster Example 2 | 3 | This example demonstrates how to use the `httpcluster` runnable with the go-supervisor to manage multiple HTTP servers dynamically. 4 | 5 | ## Overview 6 | 7 | The example shows an interactive HTTP server that can reconfigure itself to listen on different ports based on POST requests. The server starts on port 8080 and can be instructed to move to a different port by sending a JSON payload. 8 | 9 | ## Running the Example 10 | 11 | ```bash 12 | # From the examples/httpcluster directory 13 | go run main.go 14 | ``` 15 | 16 | The server will start on port 8080. Visit http://localhost:8080 to see instructions. 17 | 18 | ## Using the Example 19 | 20 | 1. **Check current status**: 21 | ```bash 22 | curl http://localhost:8080/status 23 | ``` 24 | 25 | 2. **Change the port**: 26 | ```bash 27 | curl -X POST http://localhost:8080/ \ 28 | -H 'Content-Type: application/json' \ 29 | -d '{"port":":8081"}' 30 | ``` 31 | 32 | 3. **Verify the server moved**: 33 | ```bash 34 | curl http://localhost:8081/status 35 | ``` 36 | 37 | ## Key Features Demonstrated 38 | 39 | 1. **Dynamic Configuration**: Server port can be changed at runtime 40 | 2. **Channel-based Updates**: Configuration updates are sent through a channel 41 | 3. **Supervisor Integration**: The cluster is managed by go-supervisor for signal handling 42 | 4. **State Monitoring**: The example logs all state changes 43 | 5. **Middleware**: Each server uses logging, recovery, and metrics middleware -------------------------------------------------------------------------------- /internal/networking/portfinder.go: -------------------------------------------------------------------------------- 1 | package networking 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "sync" 7 | "testing" 8 | ) 9 | 10 | // reduce the chance of port conflicts 11 | var ( 12 | portMutex = &sync.Mutex{} 13 | usedPorts = make(map[int]struct{}) 14 | ) 15 | 16 | // GetRandomPort finds an available port for a test by binding to port 0 17 | func GetRandomPort(tb testing.TB) int { 18 | tb.Helper() 19 | portMutex.Lock() 20 | listener, err := net.Listen("tcp", ":0") 21 | if err != nil { 22 | portMutex.Unlock() 23 | tb.Fatalf("Failed to get random port: %v", err) 24 | } 25 | 26 | err = listener.Close() 27 | if err != nil { 28 | portMutex.Unlock() 29 | tb.Fatalf("Failed to close listener: %v", err) 30 | } 31 | 32 | addr := listener.Addr().(*net.TCPAddr) 33 | p := addr.Port 34 | // Check if the port is already used 35 | if _, ok := usedPorts[p]; ok { 36 | portMutex.Unlock() 37 | return GetRandomPort(tb) 38 | } 39 | usedPorts[p] = struct{}{} 40 | portMutex.Unlock() 41 | return p 42 | } 43 | 44 | // GetRandomListeningPort finds an available port for a test by binding to port 0, and returns a string like localhost:PORT 45 | func GetRandomListeningPort(tb testing.TB) string { 46 | tb.Helper() 47 | p := GetRandomPort(tb) 48 | listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p)) 49 | if err != nil { 50 | return GetRandomListeningPort(tb) 51 | } 52 | err = listener.Close() 53 | if err != nil { 54 | tb.Fatalf("Failed to close listener: %v", err) 55 | } 56 | 57 | return fmt.Sprintf("localhost:%d", p) 58 | } 59 | -------------------------------------------------------------------------------- /runnables/httpserver/helpers_test.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "testing" 7 | "time" 8 | 9 | "github.com/robbyt/go-supervisor/internal/networking" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // waitForState waits for the server to reach the expected state within the timeout. 14 | func waitForState( 15 | t *testing.T, 16 | server interface{ GetState() string }, 17 | expectedState string, 18 | timeout time.Duration, 19 | message string, 20 | ) { 21 | t.Helper() 22 | require.Eventually(t, func() bool { 23 | return server.GetState() == expectedState 24 | }, timeout, 10*time.Millisecond, message) 25 | } 26 | 27 | // createTestServer creates a test server with the given handler, path, and drain timeout. 28 | // It returns a configured Runner instance and the listen address. 29 | func createTestServer( 30 | t *testing.T, 31 | handler http.HandlerFunc, 32 | path string, 33 | drainTimeout time.Duration, 34 | ) (*Runner, string) { 35 | t.Helper() 36 | 37 | // Get an available port 38 | port := fmt.Sprintf(":%d", networking.GetRandomPort(t)) 39 | 40 | // Create a route 41 | route, err := NewRouteFromHandlerFunc("test", path, handler) 42 | require.NoError(t, err) 43 | routes := Routes{*route} 44 | 45 | // Create config callback 46 | configCallback := func() (*Config, error) { 47 | return NewConfig(port, routes, WithDrainTimeout(drainTimeout)) 48 | } 49 | 50 | // Create the runner 51 | runner, err := NewRunner(WithConfigCallback(configCallback)) 52 | require.NoError(t, err) 53 | 54 | return runner, port 55 | } 56 | -------------------------------------------------------------------------------- /examples/custom_middleware/README.md: -------------------------------------------------------------------------------- 1 | # Custom Middleware Example 2 | 3 | This example demonstrates how to create custom middleware for the httpserver package. 4 | 5 | ## What It Shows 6 | 7 | - Creating a custom middleware that transforms HTTP responses 8 | - Using built-in middleware from the httpserver package 9 | - Correct middleware ordering and composition 10 | - Separation of concerns between middleware layers 11 | 12 | ## Key Components 13 | 14 | ### JSON Enforcer Middleware 15 | A custom middleware that ensures all responses are JSON formatted. Non-JSON responses are wrapped in `{"response": "content"}` while valid JSON passes through unchanged. 16 | 17 | ### Headers Middleware 18 | Uses the built-in headers middleware to set Content-Type, CORS, and security headers. 19 | 20 | ## Running the Example 21 | 22 | ```bash 23 | go run ./examples/custom_middleware 24 | ``` 25 | 26 | The server starts on `:8081` with several endpoints to demonstrate the middleware behavior. 27 | 28 | ## Endpoints 29 | 30 | - `GET /` - Returns plain text (wrapped in JSON) 31 | - `GET /api/data` - Returns JSON (preserved as-is) 32 | - `GET /html` - Returns HTML (wrapped in JSON) 33 | - `GET /error` - Returns 404 error (wrapped in JSON) 34 | - `GET /panic` - Triggers panic recovery middleware 35 | 36 | ## Middleware Ordering 37 | 38 | The example demonstrates why middleware order matters: 39 | 40 | 1. **Recovery** - Must be first to catch panics 41 | 2. **Security** - Set security headers early 42 | 3. **Logging** - Log all requests 43 | 4. **Metrics** - Collect request metrics 44 | 5. **Headers** - Set response headers before handler 45 | 46 | See the code comments in `main.go`. -------------------------------------------------------------------------------- /.github/workflows/dependency-review.yml: -------------------------------------------------------------------------------- 1 | # Dependency Review Action 2 | # 3 | # This Action will scan dependency manifest files that change as part of a Pull Request, 4 | # surfacing known-vulnerable versions of the packages declared or updated in the PR. 5 | # Once installed, if the workflow run is marked as required, PRs introducing known-vulnerable 6 | # packages will be blocked from merging. 7 | # 8 | # Source repository: https://github.com/actions/dependency-review-action 9 | # Public documentation: https://docs.github.com/en/code-security/supply-chain-security/understanding-your-software-supply-chain/about-dependency-review#dependency-review-enforcement 10 | name: 'Dependency review' 11 | on: 12 | pull_request: 13 | branches: [ "main" ] 14 | 15 | # If using a dependency submission action in this workflow this permission will need to be set to: 16 | # 17 | # permissions: 18 | # contents: write 19 | # 20 | # https://docs.github.com/en/enterprise-cloud@latest/code-security/supply-chain-security/understanding-your-software-supply-chain/using-the-dependency-submission-api 21 | permissions: 22 | contents: read 23 | # Write permissions for pull-requests are required for using the `comment-summary-in-pr` option, comment out if you aren't using this option 24 | pull-requests: write 25 | 26 | jobs: 27 | dependency-review: 28 | runs-on: ubuntu-latest 29 | steps: 30 | - name: 'Checkout repository' 31 | uses: actions/checkout@v6 32 | - name: 'Dependency Review' 33 | uses: actions/dependency-review-action@v4 34 | # Commonly enabled options, see https://github.com/actions/dependency-review-action#configuration-options for all available options. 35 | with: 36 | comment-summary-in-pr: always 37 | -------------------------------------------------------------------------------- /runnables/httpserver/state.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/robbyt/go-supervisor/internal/finitestate" 7 | ) 8 | 9 | // setStateError transitions the state machine to the Error state, 10 | // falling back to alternative approaches if the transition fails. 11 | func (r *Runner) setStateError() { 12 | // First try with normal transition 13 | if r.fsm.TransitionBool(finitestate.StatusError) { 14 | return 15 | } 16 | 17 | // If that fails, force the state using SetState 18 | r.logger.Debug("Using SetState to force Error state") 19 | if err := r.fsm.SetState(finitestate.StatusError); err != nil { 20 | r.logger.Error("Failed to set Error state", "error", err) 21 | 22 | // Last resort - try to set to Unknown 23 | if err := r.fsm.SetState(finitestate.StatusUnknown); err != nil { 24 | r.logger.Error("Failed to set Unknown state", "error", err) 25 | } 26 | } 27 | } 28 | 29 | // GetState returns the status of the HTTP server 30 | func (r *Runner) GetState() string { 31 | return r.fsm.GetState() 32 | } 33 | 34 | // GetStateChan returns a channel that emits the HTTP server's state whenever it changes. 35 | // The channel is closed when the provided context is canceled. 36 | func (r *Runner) GetStateChan(ctx context.Context) <-chan string { 37 | return r.fsm.GetStateChan(ctx) 38 | } 39 | 40 | // GetStateChanWithTimeout returns a channel that emits state changes. 41 | // The channel is closed when the provided context is canceled. 42 | func (r *Runner) GetStateChanWithTimeout(ctx context.Context) <-chan string { 43 | return r.fsm.GetStateChan(ctx) 44 | } 45 | 46 | // IsRunning returns true if the HTTP server is currently running. 47 | func (r *Runner) IsRunning() bool { 48 | return r.fsm.GetState() == finitestate.StatusRunning 49 | } 50 | -------------------------------------------------------------------------------- /runnables/httpserver/runner_readiness_test.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | "testing" 9 | "time" 10 | 11 | "github.com/robbyt/go-supervisor/internal/networking" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestServerReadinessProbe(t *testing.T) { 17 | t.Parallel() 18 | 19 | route, err := NewRouteFromHandlerFunc( 20 | "test", 21 | "/test", 22 | func(w http.ResponseWriter, r *http.Request) { 23 | w.WriteHeader(http.StatusOK) 24 | }, 25 | ) 26 | require.NoError(t, err) 27 | 28 | port := fmt.Sprintf(":%d", networking.GetRandomPort(t)) 29 | callback := func() (*Config, error) { 30 | return NewConfig(port, Routes{*route}) 31 | } 32 | 33 | runner, err := NewRunner(WithConfigCallback(callback)) 34 | require.NoError(t, err) 35 | 36 | t.Run("probe_timeout", func(t *testing.T) { 37 | ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond) 38 | defer cancel() 39 | 40 | err := runner.serverReadinessProbe(ctx, "test.invalid:80") 41 | require.Error(t, err) 42 | assert.ErrorIs(t, err, ErrServerReadinessTimeout) 43 | }) 44 | 45 | t.Run("successful_probe", func(t *testing.T) { 46 | listener, err := net.Listen("tcp", "127.0.0.1:0") 47 | require.NoError(t, err) 48 | defer func() { 49 | err := listener.Close() 50 | require.NoError(t, err) 51 | }() 52 | 53 | _, portStr, err := net.SplitHostPort(listener.Addr().String()) 54 | require.NoError(t, err) 55 | addr := "127.0.0.1:" + portStr 56 | 57 | go func() { 58 | conn, err := listener.Accept() 59 | if err == nil { 60 | assert.NoError(t, conn.Close()) 61 | } 62 | }() 63 | 64 | err = runner.serverReadinessProbe(context.Background(), addr) 65 | assert.NoError(t, err) 66 | }) 67 | } 68 | -------------------------------------------------------------------------------- /runnables/httpserver/options.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "log/slog" 5 | ) 6 | 7 | // Option represents a functional option for configuring Runner. 8 | type Option func(*Runner) 9 | 10 | // WithLogHandler sets a custom slog handler for the Runner instance. 11 | // For example, to use a custom JSON handler with debug level: 12 | // 13 | // handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}) 14 | // runner, err := httpserver.NewRunner(ctx, httpserver.WithConfigCallback(configCallback), httpserver.WithLogHandler(handler)) 15 | func WithLogHandler(handler slog.Handler) Option { 16 | return func(r *Runner) { 17 | if handler != nil { 18 | r.logger = slog.New(handler.WithGroup("httpserver.Runner")) 19 | } 20 | } 21 | } 22 | 23 | // WithConfigCallback sets the function that will be called to load or reload configuration. 24 | // Either this option or WithConfig initializes the Runner instance by providing the 25 | // configuration for the HTTP server managed by the Runner. 26 | func WithConfigCallback(callback ConfigCallback) Option { 27 | return func(r *Runner) { 28 | r.configCallback = callback 29 | } 30 | } 31 | 32 | // WithConfig sets the initial configuration for the Runner instance. 33 | // This option wraps the WithConfigCallback option, allowing you to pass a Config 34 | // instance directly instead of a callback function. This is useful when you have a static 35 | // configuration that doesn't require dynamic loading or reloading. 36 | func WithConfig(cfg *Config) Option { 37 | return func(r *Runner) { 38 | callback := func() (*Config, error) { 39 | return cfg, nil 40 | } 41 | r.configCallback = callback 42 | } 43 | } 44 | 45 | // WithName sets the name of the Runner instance. 46 | func WithName(name string) Option { 47 | return func(r *Runner) { 48 | r.name = name 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /internal/networking/port.go: -------------------------------------------------------------------------------- 1 | package networking 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | // Common port validation errors 12 | var ( 13 | ErrEmptyPort = errors.New("port cannot be empty") 14 | ErrInvalidFormat = errors.New("invalid port format") 15 | ErrPortOutOfRange = errors.New("port number must be between 1 and 65535") 16 | ) 17 | 18 | // ValidatePort checks if the provided port string is valid and returns a normalized version. 19 | // It accepts formats like ":8080", "localhost:8080", "127.0.0.1:8080", or "8080". 20 | // Returns the normalized port format (":PORT") and an error if the port is invalid. 21 | func ValidatePort(portStr string) (string, error) { 22 | if portStr == "" { 23 | return "", ErrEmptyPort 24 | } 25 | 26 | // Check for negative numbers in the input before processing 27 | if strings.Contains(portStr, ":-") || 28 | (strings.HasPrefix(portStr, "-") && !strings.Contains(portStr, ":")) { 29 | return "", fmt.Errorf("%w: negative port numbers are not allowed", ErrInvalidFormat) 30 | } 31 | 32 | // If it's a number without a colon, add a colon prefix 33 | if !strings.Contains(portStr, ":") { 34 | portStr = ":" + portStr 35 | } 36 | 37 | // Handle host:port format by extracting the port 38 | host, port, err := net.SplitHostPort(portStr) 39 | if err != nil { 40 | return "", fmt.Errorf("%w: %w", ErrInvalidFormat, err) 41 | } 42 | 43 | // Validate port is a number 44 | portNum, err := strconv.Atoi(port) 45 | if err != nil { 46 | return "", fmt.Errorf("%w: port must be a number", ErrInvalidFormat) 47 | } 48 | 49 | // Check port range 50 | if portNum < 1 || portNum > 65535 { 51 | return "", ErrPortOutOfRange 52 | } 53 | 54 | // Return normalized form 55 | if host == "" { 56 | return ":" + port, nil 57 | } 58 | return host + ":" + port, nil 59 | } 60 | -------------------------------------------------------------------------------- /runnables/composite/state.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/robbyt/go-supervisor/internal/finitestate" 7 | "github.com/robbyt/go-supervisor/supervisor" 8 | ) 9 | 10 | // GetState returns the current state of the CompositeRunner. 11 | func (r *Runner[T]) GetState() string { 12 | return r.fsm.GetState() 13 | } 14 | 15 | // GetStateChan returns a channel that will receive state updates. 16 | func (r *Runner[T]) GetStateChan(ctx context.Context) <-chan string { 17 | return r.fsm.GetStateChan(ctx) 18 | } 19 | 20 | // GetStateChanWithTimeout returns a channel that emits state changes. 21 | // It's a pass-through to the underlying finite state machine. 22 | func (r *Runner[T]) GetStateChanWithTimeout(ctx context.Context) <-chan string { 23 | return r.fsm.GetStateChan(ctx) 24 | } 25 | 26 | // IsRunning returns true if the runner is in the Running state. 27 | func (r *Runner[T]) IsRunning() bool { 28 | return r.fsm.GetState() == finitestate.StatusRunning 29 | } 30 | 31 | // setStateError marks the FSM as being in the error state. 32 | func (r *Runner[T]) setStateError() { 33 | err := r.fsm.SetState(finitestate.StatusError) 34 | if err != nil { 35 | r.logger.Error("Failed to transition to Error state", "error", err) 36 | } 37 | } 38 | 39 | // GetChildStates returns a map of child runnable names to their states. 40 | func (r *Runner[T]) GetChildStates() map[string]string { 41 | // Runnables lock not required, reading config and querying state 42 | // does not modify any internal state 43 | 44 | states := make(map[string]string) 45 | cfg := r.getConfig() 46 | if cfg == nil { 47 | return states 48 | } 49 | 50 | for _, entry := range cfg.Entries { 51 | if s, ok := any(entry.Runnable).(supervisor.Stateable); ok { 52 | states[entry.Runnable.String()] = s.GetState() 53 | } else { 54 | states[entry.Runnable.String()] = "unknown" 55 | } 56 | } 57 | 58 | return states 59 | } 60 | -------------------------------------------------------------------------------- /runnables/httpserver/request_processor.go: -------------------------------------------------------------------------------- 1 | // This file defines the core middleware execution types. 2 | // 3 | // RequestProcessor manages the middleware chain execution and provides access 4 | // to request/response data. It handles the control flow ("when" middleware runs) 5 | // while ResponseWriter (response_writer.go) handles data capture ("what" data 6 | // is available). 7 | // 8 | // When writing custom middleware, use RequestProcessor methods to: 9 | // - Continue processing: rp.Next() 10 | // - Stop processing: rp.Abort() 11 | // - Access request: rp.Request() 12 | // - Access response: rp.Writer() 13 | package httpserver 14 | 15 | import ( 16 | "net/http" 17 | ) 18 | 19 | // HandlerFunc is the middleware/handler signature 20 | type HandlerFunc func(*RequestProcessor) 21 | 22 | // RequestProcessor carries the request/response and middleware chain 23 | type RequestProcessor struct { 24 | // Public fields for direct access 25 | writer ResponseWriter 26 | request *http.Request 27 | 28 | // Private fields 29 | handlers []HandlerFunc 30 | index int 31 | } 32 | 33 | // Next executes the remaining handlers in the chain 34 | func (rp *RequestProcessor) Next() { 35 | rp.index++ 36 | for rp.index < len(rp.handlers) { 37 | rp.handlers[rp.index](rp) 38 | rp.index++ 39 | } 40 | } 41 | 42 | // Abort prevents remaining handlers from being called 43 | func (rp *RequestProcessor) Abort() { 44 | rp.index = len(rp.handlers) 45 | } 46 | 47 | // IsAborted returns true if the request processing was aborted 48 | func (rp *RequestProcessor) IsAborted() bool { 49 | return rp.index >= len(rp.handlers) 50 | } 51 | 52 | // Writer returns the ResponseWriter for the request 53 | func (rp *RequestProcessor) Writer() ResponseWriter { 54 | return rp.writer 55 | } 56 | 57 | // Request returns the HTTP request 58 | func (rp *RequestProcessor) Request() *http.Request { 59 | return rp.request 60 | } 61 | 62 | // SetWriter replaces the ResponseWriter for the request. 63 | // This allows middleware to intercept and transform responses. 64 | func (rp *RequestProcessor) SetWriter(w ResponseWriter) { 65 | rp.writer = w 66 | } 67 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | permissions: 10 | contents: read 11 | # Optional: allow read access to pull request. Use with `only-new-issues` option. 12 | pull-requests: read 13 | # Optional: allow write access to checks to allow the action to annotate code in the PR. 14 | checks: write 15 | 16 | jobs: 17 | build: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v6 21 | with: 22 | # Fetch all history for proper SonarQube analysis 23 | fetch-depth: 0 24 | 25 | - name: Set up Go 26 | uses: actions/setup-go@v6 27 | with: 28 | go-version-file: go.mod 29 | cache: true 30 | cache-dependency-path: go.sum 31 | 32 | - name: Display Go version 33 | run: go version 34 | 35 | - name: go mod tidy (fails if changes are needed) 36 | run: go mod tidy --diff 37 | 38 | - name: golangci-lint 39 | uses: golangci/golangci-lint-action@v9 40 | with: 41 | version: latest 42 | cache-invalidation-interval: 30 43 | 44 | - name: Go test 45 | run: | 46 | go test \ 47 | -cover -coverprofile=unit.coverage.out \ 48 | github.com/robbyt/go-supervisor/internal/... \ 49 | github.com/robbyt/go-supervisor/runnables/... \ 50 | github.com/robbyt/go-supervisor/supervisor/... 51 | 52 | - name: Build examples 53 | run: cd examples/http && go build . 54 | 55 | - name: SonarQube Scan 56 | uses: SonarSource/sonarqube-scan-action@v6 57 | env: 58 | SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} 59 | with: 60 | args: > 61 | -Dsonar.projectKey=robbyt_go-supervisor 62 | -Dsonar.organization=robbyt 63 | -Dsonar.go.coverage.reportPaths=unit.coverage.out 64 | -Dsonar.sources=. 65 | -Dsonar.coverage.exclusions=examples/** 66 | -Dsonar.exclusions=**/*_test.go 67 | -Dsonar.tests=. 68 | -Dsonar.test.inclusions=**/*_test.go 69 | -Dsonar.language=go 70 | -Dsonar.sourceEncoding=UTF-8 71 | -------------------------------------------------------------------------------- /runnables/httpcluster/interfaces.go: -------------------------------------------------------------------------------- 1 | package httpcluster 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/robbyt/go-supervisor/supervisor" 7 | ) 8 | 9 | // entriesManager defines the interface for managing server entries. 10 | type entriesManager interface { 11 | // get returns a server entry by ID, or nil if not found. 12 | get(id string) *serverEntry 13 | 14 | // count returns the total number of server entries. 15 | count() int 16 | 17 | // countByAction returns the number of entries with the specified action. 18 | countByAction(a action) int 19 | 20 | // getPendingActions returns lists of server IDs grouped by their pending action. 21 | getPendingActions() (toStart, toStop []string) 22 | 23 | // commit creates a new entries collection with all actions marked as complete. 24 | // Called after all pending actions have been executed. 25 | // It removes entries marked for stop and clears all action flags. 26 | commit() entriesManager 27 | 28 | // setRuntime creates a new entries collection with updated runtime state for a server. 29 | // This is used during the commit phase to record that a server has been started. 30 | // Returns nil if the server doesn't exist. 31 | setRuntime( 32 | id string, 33 | runner httpServerRunner, 34 | ctx context.Context, 35 | cancel context.CancelFunc, 36 | ) entriesManager 37 | 38 | // clearRuntime creates a new entries collection with cleared runtime state for a server. 39 | // This is used during the commit phase to record that a server has been stopped. 40 | // Returns nil if the server doesn't exist. 41 | clearRuntime(id string) entriesManager 42 | 43 | // removeEntry creates a new entries collection with the specified entry removed. 44 | // This is used when a server fails to start and is removed from the collection. 45 | removeEntry(id string) entriesManager 46 | 47 | // buildPendingEntries creates a new entries collection based on the desired state and the previous state. 48 | // It marks the entries with the action needed during the commit phase. 49 | buildPendingEntries(desired entriesManager) entriesManager 50 | } 51 | 52 | // httpServerRunner defines the interface for running an HTTP server. 53 | type httpServerRunner interface { 54 | supervisor.Runnable 55 | supervisor.Stateable 56 | } 57 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/metrics/metrics_test.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/robbyt/go-supervisor/runnables/httpserver" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // setupRequest creates a basic HTTP request for testing 14 | func setupRequest(t *testing.T, method, path string) (*httptest.ResponseRecorder, *http.Request) { 15 | t.Helper() 16 | req := httptest.NewRequest(method, path, nil) 17 | rec := httptest.NewRecorder() 18 | return rec, req 19 | } 20 | 21 | // createStatusHandler returns a handler that returns a specific status code and message 22 | func createStatusHandler(t *testing.T, status int, message string) http.HandlerFunc { 23 | t.Helper() 24 | return func(w http.ResponseWriter, r *http.Request) { 25 | w.WriteHeader(status) 26 | n, err := w.Write([]byte(message)) 27 | assert.NoError(t, err) 28 | assert.Equal(t, len(message), n) 29 | } 30 | } 31 | 32 | // executeHandlerWithMetrics runs the provided handler with the MetricCollector middleware 33 | func executeHandlerWithMetrics( 34 | t *testing.T, 35 | handler http.HandlerFunc, 36 | rec *httptest.ResponseRecorder, 37 | req *http.Request, 38 | ) { 39 | t.Helper() 40 | // Create a route with metrics middleware and the handler 41 | route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", handler, New()) 42 | require.NoError(t, err) 43 | route.ServeHTTP(rec, req) 44 | } 45 | 46 | func TestMetricCollector(t *testing.T) { 47 | t.Run("handles successful requests", func(t *testing.T) { 48 | // Setup 49 | rec, req := setupRequest(t, "GET", "/test") 50 | handler := createStatusHandler(t, http.StatusOK, "OK") 51 | 52 | // Execute 53 | executeHandlerWithMetrics(t, handler, rec, req) 54 | 55 | // Verify 56 | assert.Equal(t, http.StatusOK, rec.Code) 57 | assert.Equal(t, "OK", rec.Body.String()) 58 | }) 59 | 60 | t.Run("handles error requests", func(t *testing.T) { 61 | // Setup 62 | rec, req := setupRequest(t, "POST", "/error") 63 | handler := createStatusHandler(t, http.StatusInternalServerError, "Error") 64 | 65 | // Execute 66 | executeHandlerWithMetrics(t, handler, rec, req) 67 | 68 | // Verify 69 | assert.Equal(t, http.StatusInternalServerError, rec.Code) 70 | assert.Equal(t, "Error", rec.Body.String()) 71 | }) 72 | } 73 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | 3 | run: 4 | timeout: 5m 5 | go: '1.25' 6 | 7 | linters: 8 | default: standard 9 | disable: 10 | - godox # Detects usage of FIXME, TODO and other keywords inside comments 11 | - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value 12 | - noctx # Finds sending http request without context.Context 13 | - prealloc # Temporarily disable until slice allocation issues are fixed 14 | enable: 15 | - bodyclose # Ensure HTTP response bodies are closed 16 | - contextcheck # Ensure functions use a non-inherited context 17 | - dupl # Detect duplicate code 18 | - dupword # Detect duplicate words in comments/strings 19 | - errcheck # Check for unchecked errors 20 | - errorlint # Enforce idiomatic error handling 21 | - govet # Report suspicious constructs 22 | - ineffassign # Detect unused variable assignments 23 | - misspell # Detect misspelled English words 24 | - nilerr # Detect returning nil after error checks 25 | - nolintlint # Check for invalid/missing nolint directives 26 | - reassign # Prevent package variable reassignment 27 | - staticcheck # Advanced static analysis 28 | - tagalign # Check struct tag alignment 29 | - tagliatelle # Enforce struct tag formatting 30 | - testifylint # Avoid common testify mistakes 31 | - thelper # Ensure test helpers use t.Helper() 32 | - unconvert # Remove unnecessary type conversions 33 | - unused # Detect unused code 34 | - whitespace # Detect unnecessary whitespace 35 | settings: 36 | errcheck: 37 | check-blank: true 38 | exclude-functions: 39 | - fmt.Fprintf 40 | - (*github.com/stretchr/testify/mock.Mock).Get 41 | - (net.Listener).Addr 42 | - (sync/atomic.Value).Load 43 | errorlint: 44 | errorf: true 45 | asserts: true 46 | comparison: true 47 | tagalign: 48 | strict: true 49 | order: 50 | - json 51 | - toml 52 | - yaml 53 | - xml 54 | - env_interpolation 55 | 56 | formatters: 57 | enable: 58 | - gci 59 | - gofmt 60 | - goimports 61 | - gofumpt 62 | 63 | issues: 64 | max-issues-per-linter: 20 65 | max-same-issues: 5 66 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/state/state_test.go: -------------------------------------------------------------------------------- 1 | package state 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/robbyt/go-supervisor/runnables/httpserver" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // setupRequest creates a basic HTTP request for testing 14 | func setupRequest(t *testing.T, method, path string) (*httptest.ResponseRecorder, *http.Request) { 15 | t.Helper() 16 | req := httptest.NewRequest(method, path, nil) 17 | rec := httptest.NewRecorder() 18 | return rec, req 19 | } 20 | 21 | // createTestHandler returns a handler that writes a response 22 | func createTestHandler(t *testing.T, checkResponse bool) http.HandlerFunc { 23 | t.Helper() 24 | return func(w http.ResponseWriter, r *http.Request) { 25 | w.WriteHeader(http.StatusOK) 26 | n, err := w.Write([]byte("test response")) 27 | if checkResponse { 28 | assert.NoError(t, err) 29 | assert.Equal(t, 13, n) 30 | } else if err != nil { 31 | http.Error(w, "Failed to write response", http.StatusInternalServerError) 32 | return 33 | } 34 | } 35 | } 36 | 37 | // createStateProvider returns a function that provides a static state string 38 | func createStateProvider(t *testing.T, state string) func() string { 39 | t.Helper() 40 | return func() string { 41 | return state 42 | } 43 | } 44 | 45 | // executeHandlerWithState runs the provided handler with the StateDebugger middleware 46 | func executeHandlerWithState( 47 | t *testing.T, 48 | handler http.HandlerFunc, 49 | stateProvider func() string, 50 | rec *httptest.ResponseRecorder, 51 | req *http.Request, 52 | ) { 53 | t.Helper() 54 | // Create a route with state middleware and the handler 55 | route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", handler, New(stateProvider)) 56 | require.NoError(t, err) 57 | route.ServeHTTP(rec, req) 58 | } 59 | 60 | func TestStateMiddleware(t *testing.T) { 61 | // Setup 62 | rec, req := setupRequest(t, "GET", "/test") 63 | handler := createTestHandler(t, false) // Reusing helper from logger_test.go 64 | stateProvider := createStateProvider(t, "running") 65 | 66 | // Execute 67 | executeHandlerWithState(t, handler, stateProvider, rec, req) 68 | 69 | // Verify 70 | resp := rec.Result() 71 | assert.Equal(t, http.StatusOK, resp.StatusCode) 72 | assert.Equal(t, "running", resp.Header.Get("X-Server-State")) 73 | } 74 | -------------------------------------------------------------------------------- /runnables/httpserver/reload.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/robbyt/go-supervisor/internal/finitestate" 8 | ) 9 | 10 | // reloadConfig reloads the configuration using the config callback 11 | func (r *Runner) reloadConfig() error { 12 | newConfig, err := r.configCallback() 13 | if err != nil { 14 | return fmt.Errorf("%w: %w", ErrConfigCallback, err) 15 | } 16 | 17 | if newConfig == nil { 18 | return ErrConfigCallbackNil 19 | } 20 | 21 | oldConfig := r.getConfig() 22 | if oldConfig == nil { 23 | r.setConfig(newConfig) 24 | r.logger.Debug("Config loaded", "newConfig", newConfig) 25 | return nil 26 | } 27 | 28 | if newConfig.Equal(oldConfig) { 29 | // Config unchanged, skip reload and return early 30 | return ErrOldConfig 31 | } 32 | 33 | r.setConfig(newConfig) 34 | r.logger.Debug("Config reloaded", "newConfig", newConfig) 35 | return nil 36 | } 37 | 38 | // Reload refreshes the server configuration and restarts the HTTP server if necessary. 39 | // This method is safe to call while the server is running and will handle graceful shutdown and restart. 40 | func (r *Runner) Reload() { 41 | r.logger.Debug("Reloading...") 42 | r.mutex.Lock() 43 | defer r.mutex.Unlock() 44 | 45 | if err := r.fsm.Transition(finitestate.StatusReloading); err != nil { 46 | r.logger.Error("Failed to transition to Reloading", "error", err) 47 | return 48 | } 49 | 50 | err := r.reloadConfig() 51 | switch { 52 | case err == nil: 53 | r.logger.Debug("Config reloaded") 54 | case errors.Is(err, ErrOldConfig): 55 | r.logger.Debug("Config unchanged, skipping reload") 56 | if stateErr := r.fsm.Transition(finitestate.StatusRunning); stateErr != nil { 57 | r.logger.Error("Failed to transition to Running", "error", stateErr) 58 | r.setStateError() 59 | } 60 | return 61 | default: 62 | r.logger.Error("Failed to reload configuration", "error", err) 63 | r.setStateError() 64 | return 65 | } 66 | 67 | if err := r.stopServer(r.ctx); err != nil { 68 | r.logger.Error("Failed to stop server during reload", "error", err) 69 | r.setStateError() 70 | return 71 | } 72 | 73 | if err := r.boot(); err != nil { 74 | r.logger.Error("Failed to boot server during reload", "error", err) 75 | r.setStateError() 76 | return 77 | } 78 | 79 | if err := r.fsm.Transition(finitestate.StatusRunning); err != nil { 80 | r.logger.Error("Failed to transition to Running", "error", err) 81 | r.setStateError() 82 | return 83 | } 84 | 85 | r.logger.Debug("Completed.") 86 | } 87 | -------------------------------------------------------------------------------- /supervisor/interfaces.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2024 Robert Terhaar 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package supervisor 18 | 19 | import ( 20 | "context" 21 | "fmt" 22 | ) 23 | 24 | // Runnable represents a service that can be run and stopped. 25 | type Runnable interface { 26 | fmt.Stringer // Runnables implement a String() method to be identifiable in logs 27 | 28 | // Run starts the service with the given context and returns an error if it fails. 29 | // Run is a blocking call that runs the work unit until it is stopped. 30 | Run(ctx context.Context) error 31 | // Stop signals the service to stop. 32 | // Stop is a blocking call that stops the work unit. 33 | Stop() 34 | } 35 | 36 | // Reloadable represents a service that can be reloaded. 37 | type Reloadable interface { 38 | // Reload signals the service to reload its configuration. 39 | // Reload is a blocking call that reloads the configuration of the work unit. 40 | Reload() 41 | } 42 | 43 | // Stateable represents a service that can report its state. 44 | type Stateable interface { 45 | Readiness 46 | 47 | // GetState returns the current state of the service. 48 | GetState() string 49 | 50 | // GetStateChan returns a channel that will receive the current state of the service. 51 | GetStateChan(context.Context) <-chan string 52 | } 53 | 54 | // Readiness provides a way to check if a Runnable is currently running, 55 | // used to determine if a Runnable is done with it's startup phase. 56 | type Readiness interface { 57 | IsRunning() bool 58 | } 59 | 60 | // ReloadSender represents a service that can trigger reloads. 61 | type ReloadSender interface { 62 | // GetReloadTrigger returns a channel that emits signals when a reload is requested. 63 | GetReloadTrigger() <-chan struct{} 64 | } 65 | 66 | // ShutdownSender represents a service that can trigger system shutdown. 67 | type ShutdownSender interface { 68 | // GetShutdownTrigger returns a channel that emits signals when a shutdown is requested. 69 | GetShutdownTrigger() <-chan struct{} 70 | } 71 | -------------------------------------------------------------------------------- /runnables/httpserver/runner_boot_test.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "testing" 7 | "time" 8 | 9 | "github.com/robbyt/go-supervisor/internal/finitestate" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestBootConfigCreateFailure(t *testing.T) { 15 | t.Parallel() 16 | 17 | callback := func() (*Config, error) { 18 | return &Config{ 19 | ListenAddr: ":8000", 20 | Routes: Routes{}, 21 | DrainTimeout: 1 * time.Second, 22 | ServerCreator: DefaultServerCreator, 23 | }, nil 24 | } 25 | 26 | runner, err := NewRunner(WithConfigCallback(callback)) 27 | require.NoError(t, err) 28 | 29 | err = runner.boot() 30 | require.Error(t, err) 31 | require.ErrorIs(t, err, ErrCreateConfig) 32 | } 33 | 34 | // TestBootFailure tests various boot failure scenarios 35 | func TestBootFailure(t *testing.T) { 36 | t.Parallel() 37 | 38 | t.Run("Missing config callback", func(t *testing.T) { 39 | _, err := NewRunner() 40 | require.Error(t, err) 41 | assert.Contains(t, err.Error(), "config callback is required") 42 | }) 43 | 44 | t.Run("Config callback returns nil", func(t *testing.T) { 45 | callback := func() (*Config, error) { return nil, nil } 46 | runner, err := NewRunner( 47 | WithConfigCallback(callback), 48 | ) 49 | 50 | require.Error(t, err) 51 | assert.Nil(t, runner) 52 | require.ErrorIs(t, err, ErrConfigCallback) 53 | }) 54 | 55 | t.Run("Config callback returns error", func(t *testing.T) { 56 | callback := func() (*Config, error) { return nil, errors.New("failed to load config") } 57 | runner, err := NewRunner( 58 | WithConfigCallback(callback), 59 | ) 60 | 61 | require.Error(t, err) 62 | assert.Nil(t, runner) 63 | require.ErrorIs(t, err, ErrConfigCallback) 64 | }) 65 | 66 | t.Run("Server boot fails with invalid port", func(t *testing.T) { 67 | handler := func(w http.ResponseWriter, r *http.Request) {} 68 | route, err := NewRouteFromHandlerFunc("v1", "/", handler) 69 | require.NoError(t, err) 70 | 71 | callback := func() (*Config, error) { 72 | return &Config{ 73 | ListenAddr: "invalid-port", 74 | DrainTimeout: 1 * time.Second, 75 | Routes: Routes{*route}, 76 | }, nil 77 | } 78 | 79 | runner, err := NewRunner( 80 | WithConfigCallback(callback), 81 | ) 82 | 83 | require.NoError(t, err) 84 | assert.NotNil(t, runner) 85 | 86 | // Test actual run 87 | err = runner.Run(t.Context()) 88 | require.Error(t, err) 89 | // With our readiness probe, the error format is different but should be propagated properly 90 | require.ErrorIs(t, err, ErrServerBoot) 91 | assert.Equal(t, finitestate.StatusError, runner.GetState()) 92 | }) 93 | } 94 | -------------------------------------------------------------------------------- /runnables/httpcluster/options.go: -------------------------------------------------------------------------------- 1 | package httpcluster 2 | 3 | import ( 4 | "fmt" 5 | "log/slog" 6 | "time" 7 | 8 | "github.com/robbyt/go-supervisor/runnables/httpserver" 9 | ) 10 | 11 | // Option is a function that configures a Runner. 12 | type Option func(*Runner) error 13 | 14 | // WithLogger sets the logger for the cluster. 15 | func WithLogger(logger *slog.Logger) Option { 16 | return func(r *Runner) error { 17 | r.logger = logger 18 | return nil 19 | } 20 | } 21 | 22 | // WithLogHandler sets the log handler for the cluster. 23 | func WithLogHandler(handler slog.Handler) Option { 24 | return func(r *Runner) error { 25 | r.logger = slog.New(handler) 26 | return nil 27 | } 28 | } 29 | 30 | // WithSiphonBuffer sets the buffer size for the configuration siphon channel. 31 | // A buffer of 0 (default) makes the channel synchronous, providing natural backpressure 32 | // and preventing rapid config updates that could cause server restart race conditions. 33 | // Values > 1 may cause race conditions during heavy update pressure and are not recommended. 34 | func WithSiphonBuffer(size int) Option { 35 | return func(r *Runner) error { 36 | if size > 1 { 37 | r.logger.Warn( 38 | "SiphonBuffer size > 1 may cause race conditions during heavy update pressure, keeping default 0 is recommended", 39 | "size", 40 | size, 41 | ) 42 | } 43 | r.configSiphon = make(chan map[string]*httpserver.Config, size) 44 | return nil 45 | } 46 | } 47 | 48 | // WithCustomSiphonChannel sets the custom configuration siphon channel for the cluster. 49 | func WithCustomSiphonChannel(channel chan map[string]*httpserver.Config) Option { 50 | return func(r *Runner) error { 51 | r.configSiphon = channel 52 | return nil 53 | } 54 | } 55 | 56 | // WithRunnerFactory sets the factory function for creating Runnable instances. 57 | func WithRunnerFactory( 58 | factory runnerFactory, 59 | ) Option { 60 | return func(r *Runner) error { 61 | r.runnerFactory = factory 62 | return nil 63 | } 64 | } 65 | 66 | // WithStateChanBufferSize sets the buffer size for state channels. 67 | // This helps prevent dropped state transitions in tests or when state changes happen rapidly. 68 | // Default is 10. Size of 0 creates an unbuffered channel. 69 | func WithStateChanBufferSize(size int) Option { 70 | return func(r *Runner) error { 71 | if size < 0 { 72 | return fmt.Errorf("state channel buffer size cannot be negative: %d", size) 73 | } 74 | r.stateChanBufferSize = size 75 | return nil 76 | } 77 | } 78 | 79 | // WithRestartDelay sets the delay between server restarts when configs change. 80 | func WithRestartDelay(delay time.Duration) Option { 81 | return func(r *Runner) error { 82 | r.restartDelay = delay 83 | return nil 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /runnables/composite/config.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | 8 | "github.com/robbyt/go-supervisor/supervisor" 9 | ) 10 | 11 | // runnable is a local alias constraining sub-runnables to implement the 12 | // supervisor.Runnable interface. 13 | type runnable interface { 14 | supervisor.Runnable 15 | } 16 | 17 | // RunnableEntry associates a runnable with its configuration 18 | type RunnableEntry[T runnable] struct { 19 | // Runnable is the component to be managed 20 | Runnable T 21 | 22 | // Config holds the configuration data for this specific runnable 23 | Config any 24 | } 25 | 26 | // Config represents the configuration for a CompositeRunner 27 | type Config[T runnable] struct { 28 | // Name is a human-readable identifier for this composite runner 29 | Name string 30 | 31 | // Entries is the list of runnables with their associated configurations 32 | Entries []RunnableEntry[T] 33 | } 34 | 35 | // NewConfig creates a new Config instance for a CompositeRunner 36 | func NewConfig[T runnable]( 37 | name string, 38 | entries []RunnableEntry[T], 39 | ) (*Config[T], error) { 40 | if name == "" { 41 | return nil, errors.New("name cannot be empty") 42 | } 43 | 44 | return &Config[T]{ 45 | Name: name, 46 | Entries: entries, 47 | }, nil 48 | } 49 | 50 | // NewConfigFromRunnables creates a Config from a list of runnables, all using the same config 51 | func NewConfigFromRunnables[T runnable]( 52 | name string, 53 | runnables []T, 54 | sharedConfig any, 55 | ) (*Config[T], error) { 56 | if name == "" { 57 | return nil, errors.New("name cannot be empty") 58 | } 59 | 60 | entries := make([]RunnableEntry[T], len(runnables)) 61 | for i, runnable := range runnables { 62 | entries[i] = RunnableEntry[T]{ 63 | Runnable: runnable, 64 | Config: sharedConfig, 65 | } 66 | } 67 | 68 | return &Config[T]{ 69 | Name: name, 70 | Entries: entries, 71 | }, nil 72 | } 73 | 74 | // Equal compares two configs for equality 75 | func (c *Config[T]) Equal(other *Config[T]) bool { 76 | if c.Name != other.Name { 77 | return false 78 | } 79 | 80 | if len(c.Entries) != len(other.Entries) { 81 | return false 82 | } 83 | 84 | // Compare runnables and their configs 85 | for i, entry := range c.Entries { 86 | // Compare runnable by string representation 87 | if entry.Runnable.String() != other.Entries[i].Runnable.String() { 88 | return false 89 | } 90 | 91 | // For config, use reflection for comparison 92 | if !reflect.DeepEqual(entry.Config, other.Entries[i].Config) { 93 | return false 94 | } 95 | } 96 | 97 | return true 98 | } 99 | 100 | // String returns a string representation of the Config 101 | func (c *Config[T]) String() string { 102 | return fmt.Sprintf("Config{Name: %s, Entries: %d}", c.Name, len(c.Entries)) 103 | } 104 | -------------------------------------------------------------------------------- /supervisor/reload.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2024 Robert Terhaar 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package supervisor 18 | 19 | // ReloadAll triggers a reload of all runnables that implement the Reloadable interface. 20 | func (p *PIDZero) ReloadAll() { 21 | p.reloadListener <- struct{}{} 22 | } 23 | 24 | // startReloadManager starts a goroutine that listens for reload notifications 25 | // and calls the reload method on all reloadable services. This will also prevent 26 | // multiple reloads from happening concurrently. 27 | func (p *PIDZero) startReloadManager() { 28 | p.logger.Debug("Starting reload manager...") 29 | 30 | // iterate all the runnables, and find the ones that are can send reload notifications 31 | // and start a goroutine to listen for signals from them. 32 | for _, run := range p.runnables { 33 | if rldSender, ok := run.(ReloadSender); ok { 34 | go func(r Runnable, s ReloadSender) { 35 | for { 36 | select { 37 | case <-p.ctx.Done(): 38 | return 39 | case <-rldSender.GetReloadTrigger(): 40 | p.reloadListener <- struct{}{} 41 | p.logger.Debug("Reload notifier received from runnable", "runnable", r) 42 | } 43 | } 44 | }(run, rldSender) 45 | } 46 | } 47 | 48 | for { 49 | select { 50 | case <-p.ctx.Done(): 51 | return 52 | case <-p.reloadListener: 53 | reloads := p.reloadAllRunnables() 54 | p.logger.Info("Reload complete.", "runnablesReloaded", reloads) 55 | } 56 | } 57 | } 58 | 59 | // reloadAllRunnables calls the Reload method on all runnables that implement the Reloadable 60 | // interface. 61 | func (p *PIDZero) reloadAllRunnables() int { 62 | reloads := 0 63 | p.logger.Info("Starting Reload...") 64 | 65 | for _, r := range p.runnables { 66 | if reloader, ok := r.(Reloadable); ok { 67 | // Log pre-reload state if available 68 | if stateable, ok := r.(Stateable); ok { 69 | preState := stateable.GetState() 70 | p.logger.Debug("Pre-reload state", "runnable", r, "state", preState) 71 | } 72 | 73 | p.logger.Debug("Reloading", "runnable", r) 74 | reloader.Reload() 75 | reloads++ 76 | 77 | if stateable, ok := r.(Stateable); ok { 78 | postState := stateable.GetState() 79 | p.stateMap.Store(r, postState) 80 | p.logger.Debug("Post-reload state", "runnable", r, "state", postState) 81 | } 82 | 83 | continue 84 | } 85 | p.logger.Debug("Skipping Reload, not supported", "runnable", r) 86 | } 87 | return reloads 88 | } 89 | -------------------------------------------------------------------------------- /runnables/httpcluster/options_test.go: -------------------------------------------------------------------------------- 1 | package httpcluster 2 | 3 | import ( 4 | "log/slog" 5 | "testing" 6 | 7 | "github.com/robbyt/go-supervisor/runnables/httpserver" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestWithLogger(t *testing.T) { 13 | t.Parallel() 14 | 15 | logger := slog.Default().WithGroup("test") 16 | runner, err := NewRunner(WithLogger(logger)) 17 | require.NoError(t, err) 18 | 19 | assert.Equal(t, logger, runner.logger) 20 | } 21 | 22 | func TestWithLogHandler(t *testing.T) { 23 | t.Parallel() 24 | 25 | handler := slog.Default().Handler() 26 | runner, err := NewRunner(WithLogHandler(handler)) 27 | require.NoError(t, err) 28 | 29 | // Logger should be created with the provided handler 30 | assert.NotNil(t, runner.logger) 31 | } 32 | 33 | func TestWithCustomSiphonChannel(t *testing.T) { 34 | t.Parallel() 35 | 36 | customChannel := make(chan map[string]*httpserver.Config, 5) 37 | runner, err := NewRunner(WithCustomSiphonChannel(customChannel)) 38 | require.NoError(t, err) 39 | 40 | assert.Equal(t, customChannel, runner.configSiphon) 41 | 42 | // Verify the channel capacity 43 | assert.Equal(t, 5, cap(runner.configSiphon)) 44 | } 45 | 46 | func TestWithStateChanBufferSize(t *testing.T) { 47 | t.Parallel() 48 | 49 | t.Run("valid buffer size", func(t *testing.T) { 50 | runner, err := NewRunner(WithStateChanBufferSize(20)) 51 | require.NoError(t, err) 52 | 53 | assert.Equal(t, 20, runner.stateChanBufferSize) 54 | }) 55 | 56 | t.Run("zero buffer size", func(t *testing.T) { 57 | runner, err := NewRunner(WithStateChanBufferSize(0)) 58 | require.NoError(t, err) 59 | 60 | assert.Equal(t, 0, runner.stateChanBufferSize) 61 | }) 62 | 63 | t.Run("negative buffer size returns error", func(t *testing.T) { 64 | _, err := NewRunner(WithStateChanBufferSize(-1)) 65 | require.Error(t, err) 66 | assert.Contains(t, err.Error(), "state channel buffer size cannot be negative") 67 | assert.Contains(t, err.Error(), "-1") 68 | }) 69 | } 70 | 71 | func TestOptionApplicationOrder(t *testing.T) { 72 | t.Parallel() 73 | 74 | // Test that multiple options are applied correctly 75 | logger := slog.Default().WithGroup("test") 76 | 77 | runner, err := NewRunner( 78 | WithLogger(logger), 79 | WithStateChanBufferSize(15), 80 | WithSiphonBuffer(3), 81 | ) 82 | require.NoError(t, err) 83 | 84 | assert.Equal(t, logger, runner.logger) 85 | assert.Equal(t, 15, runner.stateChanBufferSize) 86 | assert.Equal(t, 3, cap(runner.configSiphon)) 87 | } 88 | 89 | func TestOptionError(t *testing.T) { 90 | t.Parallel() 91 | 92 | // Test that an option error is properly propagated 93 | errorOption := func(r *Runner) error { 94 | return assert.AnError // Use testify's standard error 95 | } 96 | 97 | _, err := NewRunner(errorOption) 98 | require.Error(t, err) 99 | assert.Contains(t, err.Error(), "failed to apply option") 100 | } 101 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/wildcard/wildcard.go: -------------------------------------------------------------------------------- 1 | package wildcard 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | 7 | "github.com/robbyt/go-supervisor/runnables/httpserver" 8 | ) 9 | 10 | // New creates a middleware that handles requests with a prefix pattern, 11 | // stripping the prefix before passing to the handler. 12 | // 13 | // This middleware enables a single handler to manage multiple routes by removing 14 | // a common prefix from the request path. The handler receives the path with the 15 | // prefix stripped, making it easier to implement sub-routers or delegate handling 16 | // to other routing systems. 17 | // 18 | // Why use this middleware: 19 | // - Route delegation: Pass "/api/*" requests to a separate API handler 20 | // - Legacy integration: Wrap existing handlers that expect different path structures 21 | // - Microservice routing: Forward requests to handlers that manage their own sub-paths 22 | // - File serving: Strip prefixes when serving static files from subdirectories 23 | // 24 | // Examples: 25 | // 26 | // // API delegation - all /api/* requests go to apiHandler with prefix stripped 27 | // apiRoute, _ := httpserver.NewRouteFromHandlerFunc( 28 | // "api", 29 | // "/api/*", 30 | // apiHandler, 31 | // wildcard.New("/api/"), 32 | // ) 33 | // // Request to "/api/users/123" becomes "/users/123" for apiHandler 34 | // 35 | // // File serving - serve files from ./static/ directory 36 | // fileHandler := http.FileServer(http.Dir("./static/")) 37 | // staticRoute, _ := httpserver.NewRouteFromHandlerFunc( 38 | // "static", 39 | // "/static/*", 40 | // fileHandler.ServeHTTP, 41 | // wildcard.New("/static/"), 42 | // ) 43 | // // Request to "/static/css/main.css" becomes "/css/main.css" for file server 44 | // 45 | // // Legacy system integration - forward to old handler expecting different paths 46 | // legacyRoute, _ := httpserver.NewRouteFromHandlerFunc( 47 | // "legacy", 48 | // "/v1/*", 49 | // legacySystemHandler, 50 | // wildcard.New("/v1/"), 51 | // ) 52 | // // Request to "/v1/old/endpoint" becomes "/old/endpoint" for legacy handler 53 | // 54 | // Behavior: 55 | // - Requests that don't match the prefix return 404 Not Found 56 | // - Exact prefix matches result in empty path being passed to handler 57 | // - Query parameters and request body are preserved unchanged 58 | // - All HTTP methods are supported 59 | func New(prefix string) httpserver.HandlerFunc { 60 | // Input validation and normalization 61 | if prefix == "" { 62 | prefix = "/" 63 | } 64 | 65 | // Ensure prefix starts with / 66 | if !strings.HasPrefix(prefix, "/") { 67 | prefix = "/" + prefix 68 | } 69 | 70 | // Ensure prefix ends with / (unless it's just "/") 71 | if prefix != "/" && !strings.HasSuffix(prefix, "/") { 72 | prefix = prefix + "/" 73 | } 74 | 75 | return func(rp *httpserver.RequestProcessor) { 76 | req := rp.Request() 77 | if !strings.HasPrefix(req.URL.Path, prefix) { 78 | http.NotFound(rp.Writer(), req) 79 | rp.Abort() 80 | return 81 | } 82 | req.URL.Path = strings.TrimPrefix(req.URL.Path, prefix) 83 | rp.Next() 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /runnables/httpserver/mocks_test.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/stretchr/testify/mock" 7 | ) 8 | 9 | // MockStateMachine is a mock implementation of the stateMachine interface 10 | // for testing purposes. It uses testify/mock to record and verify calls. 11 | type MockStateMachine struct { 12 | mock.Mock 13 | } 14 | 15 | // NewMockStateMachine creates a new instance of MockStateMachine 16 | func NewMockStateMachine() *MockStateMachine { 17 | return &MockStateMachine{} 18 | } 19 | 20 | // Transition mocks the Transition method of the stateMachine interface. 21 | // It attempts to transition the state machine to the specified state. 22 | func (m *MockStateMachine) Transition(state string) error { 23 | args := m.Called(state) 24 | return args.Error(0) 25 | } 26 | 27 | // TransitionBool mocks the TransitionBool method of the stateMachine interface. 28 | // It attempts to transition the state machine to the specified state and returns 29 | // a boolean indicating success or failure. 30 | func (m *MockStateMachine) TransitionBool(state string) bool { 31 | args := m.Called(state) 32 | return args.Bool(0) 33 | } 34 | 35 | // TransitionIfCurrentState mocks the TransitionIfCurrentState method of the stateMachine interface. 36 | // It attempts to transition the state machine to the specified state only if the current state 37 | // matches the expected current state. 38 | func (m *MockStateMachine) TransitionIfCurrentState(currentState, newState string) error { 39 | args := m.Called(currentState, newState) 40 | return args.Error(0) 41 | } 42 | 43 | // SetState mocks the SetState method of the stateMachine interface. 44 | // It sets the state of the state machine to the specified state. 45 | func (m *MockStateMachine) SetState(state string) error { 46 | args := m.Called(state) 47 | return args.Error(0) 48 | } 49 | 50 | // GetState mocks the GetState method of the stateMachine interface. 51 | // It returns the current state of the state machine. 52 | func (m *MockStateMachine) GetState() string { 53 | args := m.Called() 54 | return args.String(0) 55 | } 56 | 57 | // GetStateChan mocks the GetStateChan method of the stateMachine interface. 58 | // It returns a channel that emits the state machine's state whenever it changes. 59 | func (m *MockStateMachine) GetStateChan(ctx context.Context) <-chan string { 60 | args := m.Called(ctx) 61 | return args.Get(0).(<-chan string) 62 | } 63 | 64 | // GetStateChanBuffer mocks the GetStateChanBuffer method of the stateMachine interface. 65 | // It returns a channel with a configurable buffer size that emits the state machine's state whenever it changes. 66 | 67 | // MockHttpServer is a mock implementation of the HttpServer interface 68 | type MockHttpServer struct { 69 | mock.Mock 70 | } 71 | 72 | // ListenAndServe mocks the ListenAndServe method of the HttpServer interface 73 | func (m *MockHttpServer) ListenAndServe() error { 74 | args := m.Called() 75 | return args.Error(0) 76 | } 77 | 78 | // Shutdown mocks the Shutdown method of the HttpServer interface 79 | func (m *MockHttpServer) Shutdown(ctx context.Context) error { 80 | args := m.Called(ctx) 81 | return args.Error(0) 82 | } 83 | -------------------------------------------------------------------------------- /examples/composite/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "math/rand" 7 | "os" 8 | "strings" 9 | "time" 10 | 11 | "github.com/robbyt/go-supervisor/runnables/composite" 12 | "github.com/robbyt/go-supervisor/supervisor" 13 | ) 14 | 15 | func randInterval() time.Duration { 16 | return time.Duration(rand.Intn(20)+1) * time.Second 17 | } 18 | 19 | func main() { 20 | // Configure the logger 21 | handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ 22 | Level: slog.LevelDebug, 23 | }) 24 | logger := slog.New(handler) 25 | slog.SetDefault(logger) 26 | 27 | // Create base context with cancellation 28 | ctx, cancel := context.WithCancel(context.Background()) 29 | defer cancel() 30 | 31 | // Create workers with initial configuration 32 | worker1, err := NewWorker(WorkerConfig{ 33 | Interval: 5 * time.Second, 34 | JobName: "periodic-task-1", 35 | }, logger) 36 | if err != nil { 37 | logger.Error("Failed to create worker 1", "error", err) 38 | os.Exit(1) 39 | } 40 | 41 | worker2, err := NewWorker(WorkerConfig{ 42 | Interval: 10 * time.Second, 43 | JobName: "periodic-task-2", 44 | }, logger) 45 | if err != nil { 46 | logger.Error("Failed to create worker 2", "error", err) 47 | os.Exit(1) 48 | } 49 | 50 | // Config callback that randomizes intervals on reload 51 | configCallback := func() (*composite.Config[*Worker], error) { 52 | // Create entries array with our workers 53 | newEntries := make([]composite.RunnableEntry[*Worker], 2) 54 | 55 | // Update worker1 config with random interval 56 | newEntries[0] = composite.RunnableEntry[*Worker]{ 57 | Runnable: worker1, 58 | Config: WorkerConfig{ 59 | Interval: randInterval(), 60 | JobName: worker1.config.JobName, 61 | }, 62 | } 63 | 64 | // Update worker2 config with random interval 65 | newEntries[1] = composite.RunnableEntry[*Worker]{ 66 | Runnable: worker2, 67 | Config: WorkerConfig{ 68 | Interval: randInterval(), 69 | JobName: worker2.config.JobName, 70 | }, 71 | } 72 | 73 | logger.Debug("Generated new config for workers") 74 | return composite.NewConfig("worker-composite", newEntries) 75 | } 76 | 77 | // Create composite runner 78 | runner, err := composite.NewRunner( 79 | configCallback, 80 | ) 81 | if err != nil { 82 | logger.Error("Failed to create composite runner", "error", err) 83 | os.Exit(1) 84 | } 85 | 86 | // Create supervisor and add the composite runner 87 | sv, err := supervisor.New( 88 | supervisor.WithContext(ctx), 89 | supervisor.WithRunnables(runner), 90 | ) 91 | if err != nil { 92 | logger.Error("Failed to create supervisor", "error", err) 93 | os.Exit(1) 94 | } 95 | 96 | // Start the supervisor - this will block until shutdown 97 | logger.Info("Starting supervisor with composite runner") 98 | logger.Info("Send this process a HUP signal to reload the configuration") 99 | logger.Info("Press Ctrl+C to quit") 100 | logger.Info(strings.Repeat("-", 79)) 101 | 102 | if err := sv.Run(); err != nil { 103 | logger.Error("Supervisor failed", "error", err) 104 | os.Exit(1) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /runnables/httpserver/response_writer.go: -------------------------------------------------------------------------------- 1 | // Package httpserver provides HTTP server functionality with middleware support. 2 | // 3 | // This file defines ResponseWriter, which wraps the standard http.ResponseWriter 4 | // to capture response metadata that the standard interface doesn't expose. 5 | // 6 | // # Why This Wrapper Exists 7 | // 8 | // The standard http.ResponseWriter doesn't provide access to the status code 9 | // or byte count after they've been written. Middleware needs this information 10 | // for logging, metrics, and conditional processing. 11 | // 12 | // # Relationship to Middleware 13 | // 14 | // Middleware functions receive a RequestProcessor that contains this ResponseWriter. 15 | // The wrapper allows middleware to inspect response state after handlers execute. 16 | // 17 | // # Relationship to RequestProcessor (context.go) 18 | // 19 | // RequestProcessor manages middleware execution flow and provides access to this 20 | // ResponseWriter through its Writer() method. The RequestProcessor handles the 21 | // "when" of middleware execution, while ResponseWriter handles the "what" of 22 | // response data capture. 23 | package httpserver 24 | 25 | import "net/http" 26 | 27 | // ResponseWriter wraps http.ResponseWriter with additional functionality 28 | type ResponseWriter interface { 29 | http.ResponseWriter 30 | 31 | // Status returns the HTTP status code 32 | Status() int 33 | 34 | // Written returns true if the response has been written 35 | Written() bool 36 | 37 | // Size returns the number of bytes written 38 | Size() int 39 | } 40 | 41 | type responseWriter struct { 42 | http.ResponseWriter 43 | status int 44 | written bool 45 | size int 46 | } 47 | 48 | // newResponseWriter creates a new ResponseWriter wrapper 49 | func newResponseWriter(w http.ResponseWriter) ResponseWriter { 50 | return &responseWriter{ 51 | ResponseWriter: w, 52 | status: 0, 53 | written: false, 54 | size: 0, 55 | } 56 | } 57 | 58 | // WriteHeader captures the status code and calls the underlying WriteHeader 59 | func (rw *responseWriter) WriteHeader(statusCode int) { 60 | if !rw.written { 61 | rw.status = statusCode 62 | rw.written = true 63 | rw.ResponseWriter.WriteHeader(statusCode) 64 | } 65 | } 66 | 67 | // Write captures that a response has been written, counts the bytes, and calls the underlying Write 68 | func (rw *responseWriter) Write(b []byte) (int, error) { 69 | if !rw.written { 70 | // The status will be StatusOK if WriteHeader has not been called yet 71 | rw.WriteHeader(http.StatusOK) 72 | } 73 | n, err := rw.ResponseWriter.Write(b) 74 | rw.size += n 75 | return n, err 76 | } 77 | 78 | // Status returns the HTTP status code that was written to the response 79 | func (rw *responseWriter) Status() int { 80 | if rw.status == 0 && rw.written { 81 | // If no explicit status was set but data was written, it's 200 82 | return http.StatusOK 83 | } 84 | return rw.status 85 | } 86 | 87 | // Written returns true if the response has been written 88 | func (rw *responseWriter) Written() bool { 89 | return rw.written 90 | } 91 | 92 | // Size returns the number of bytes written to the response body 93 | func (rw *responseWriter) Size() int { 94 | return rw.size 95 | } 96 | -------------------------------------------------------------------------------- /examples/composite/README.md: -------------------------------------------------------------------------------- 1 | # Composite Runner Example 2 | 3 | This example demonstrates how to use the CompositeRunner to manage multiple Worker instances as a single logical unit. It provides a complete implementation you can use as a reference for your own applications. 4 | 5 | ## What This Example Shows 6 | 7 | 1. Creating a Worker type that implements both `Runnable` and `ReloadableWithConfig` interfaces 8 | 2. Managing multiple Worker instances with a Composite Runner 9 | 3. Handling dynamic configuration updates through the `ReloadWithConfig` method 10 | 4. Implementing a thread-safe periodic task with proper context cancellation 11 | 5. Coordinating multiple workers' lifecycles through a common supervisor 12 | 13 | ## Key Components 14 | 15 | ### Worker Implementation 16 | 17 | The `Worker` type (`worker.go`) demonstrates: 18 | 19 | ```go 20 | // Worker implements both Runnable and ReloadableWithConfig 21 | type Worker struct { 22 | name string 23 | mu sync.RWMutex 24 | config WorkerConfig 25 | nextConfig chan WorkerConfig 26 | // ... other fields 27 | } 28 | 29 | // Run starts the worker's main loop 30 | func (w *Worker) Run(ctx context.Context) error { 31 | // Implementation handles context cancellation and config updates 32 | } 33 | 34 | // Stop signals the worker to gracefully shut down 35 | func (w *Worker) Stop() { 36 | // Implementation ensures clean shutdown 37 | } 38 | 39 | // ReloadWithConfig receives configuration updates 40 | func (w *Worker) ReloadWithConfig(config any) { 41 | // Implementation handles type conversion and validation 42 | } 43 | ``` 44 | 45 | ### Composite Configuration 46 | 47 | The example (`main.go`) demonstrates creating a composite runner: 48 | 49 | ```go 50 | // Create workers with initial configuration 51 | worker1, err := NewWorker(WorkerConfig{...}) 52 | worker2, err := NewWorker(WorkerConfig{...}) 53 | 54 | // Define config callback that returns configuration for both workers 55 | configCallback := func() (*composite.Config[*Worker], error) { 56 | newEntries := []composite.RunnableEntry[*Worker]{ 57 | {Runnable: worker1, Config: WorkerConfig{...}}, 58 | {Runnable: worker2, Config: WorkerConfig{...}}, 59 | } 60 | return composite.NewConfig("worker-composite", newEntries) 61 | } 62 | 63 | // Create composite runner with our callback 64 | runner, err := composite.NewRunner( 65 | composite.WithContext[*Worker](ctx), 66 | composite.WithConfigCallback(configCallback), 67 | ) 68 | ``` 69 | 70 | ## Running the Example 71 | 72 | ```bash 73 | go run . 74 | ``` 75 | 76 | When you run the example: 77 | 78 | 1. Two Worker instances start with different interval configurations (5s and 10s) 79 | 2. Each worker performs periodic tasks at its configured interval 80 | 3. You can send a SIGHUP signal to trigger configuration reload (random intervals) 81 | 4. Workers smoothly transition to new intervals without stopping 82 | 5. Press Ctrl+C to trigger graceful shutdown 83 | 84 | ## Key Patterns Demonstrated 85 | 86 | - **Graceful Shutdown**: Proper context cancellation and resource cleanup 87 | - **Configuration Management**: Thread-safe handling of configuration updates 88 | - **Dynamic Reconfiguration**: Changing intervals without restarting workers 89 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/headers/headers.go: -------------------------------------------------------------------------------- 1 | package headers 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/robbyt/go-supervisor/runnables/httpserver" 7 | ) 8 | 9 | // New creates a middleware that sets HTTP headers on responses. 10 | // Headers are set before the request is processed, allowing other middleware 11 | // and handlers to override them if needed. 12 | // 13 | // Note: The Go standard library's http package will validate headers when 14 | // writing them to prevent protocol violations. This middleware does not 15 | // perform additional validation beyond what the standard library provides. 16 | func New(headers http.Header) httpserver.HandlerFunc { 17 | return func(rp *httpserver.RequestProcessor) { 18 | // Set headers before processing 19 | for key, values := range headers { 20 | for _, value := range values { 21 | rp.Writer().Header().Add(key, value) 22 | } 23 | } 24 | 25 | // Continue processing 26 | rp.Next() 27 | } 28 | } 29 | 30 | // JSON creates a middleware that sets JSON-specific headers. 31 | // Sets Content-Type to application/json and Cache-Control to no-cache. 32 | func JSON() httpserver.HandlerFunc { 33 | return New(http.Header{ 34 | "Content-Type": []string{"application/json"}, 35 | "Cache-Control": []string{"no-cache"}, 36 | }) 37 | } 38 | 39 | // CORS creates a middleware that sets CORS headers for cross-origin requests. 40 | // 41 | // Parameters: 42 | // - allowOrigin: Which origins can access the resource ("*" for any, or specific domain) 43 | // - allowMethods: Comma-separated list of allowed HTTP methods 44 | // - allowHeaders: Comma-separated list of allowed request headers 45 | // 46 | // Examples: 47 | // 48 | // // Allow any origin (useful for public APIs) 49 | // CORS("*", "GET,POST,PUT,DELETE", "Content-Type,Authorization") 50 | // 51 | // // Allow specific origin with credentials 52 | // CORS("https://app.example.com", "GET,POST", "Content-Type,Authorization") 53 | // 54 | // // Minimal read-only API 55 | // CORS("*", "GET,OPTIONS", "Content-Type") 56 | // 57 | // // Development setup with all methods 58 | // CORS("http://localhost:3000", "GET,POST,PUT,PATCH,DELETE,OPTIONS", "*") 59 | func CORS(allowOrigin, allowMethods, allowHeaders string) httpserver.HandlerFunc { 60 | corsHeaders := http.Header{ 61 | "Access-Control-Allow-Origin": []string{allowOrigin}, 62 | "Access-Control-Allow-Methods": []string{allowMethods}, 63 | "Access-Control-Allow-Headers": []string{allowHeaders}, 64 | } 65 | 66 | // Add credentials header if origin is not wildcard 67 | if allowOrigin != "*" { 68 | corsHeaders["Access-Control-Allow-Credentials"] = []string{"true"} 69 | } 70 | 71 | return NewWithOperations(WithSet(corsHeaders)) 72 | } 73 | 74 | // Security creates a middleware that sets common security headers. 75 | func Security() httpserver.HandlerFunc { 76 | return New(http.Header{ 77 | "X-Content-Type-Options": []string{"nosniff"}, 78 | "X-Frame-Options": []string{"DENY"}, 79 | "X-XSS-Protection": []string{"1; mode=block"}, 80 | "Referrer-Policy": []string{"strict-origin-when-cross-origin"}, 81 | }) 82 | } 83 | 84 | // Add creates a middleware that adds a single header. 85 | // This is useful for simple header additions. 86 | func Add(key, value string) httpserver.HandlerFunc { 87 | return New(http.Header{key: []string{value}}) 88 | } 89 | -------------------------------------------------------------------------------- /internal/finitestate/machine.go: -------------------------------------------------------------------------------- 1 | package finitestate 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "time" 7 | 8 | "github.com/robbyt/go-fsm/v2" 9 | "github.com/robbyt/go-fsm/v2/hooks" 10 | "github.com/robbyt/go-fsm/v2/hooks/broadcast" 11 | "github.com/robbyt/go-fsm/v2/transitions" 12 | ) 13 | 14 | const ( 15 | StatusNew = transitions.StatusNew 16 | StatusBooting = transitions.StatusBooting 17 | StatusRunning = transitions.StatusRunning 18 | StatusReloading = transitions.StatusReloading 19 | StatusStopping = transitions.StatusStopping 20 | StatusStopped = transitions.StatusStopped 21 | StatusError = transitions.StatusError 22 | StatusUnknown = transitions.StatusUnknown 23 | ) 24 | 25 | // TypicalTransitions is a set of standard transitions for a finite state machine. 26 | var TypicalTransitions = transitions.Typical 27 | 28 | // Machine is a wrapper around go-fsm v2 that provides the v1 API compatibility. 29 | // It manages both the FSM and broadcast functionality. 30 | type Machine struct { 31 | *fsm.Machine 32 | broadcastManager *broadcast.Manager 33 | } 34 | 35 | // GetStateChan returns a channel that emits the state whenever it changes. 36 | // The channel is closed when the provided context is canceled. 37 | // For v1 API compatibility, the current state is sent immediately to the channel. 38 | // A 5-second broadcast timeout is used to prevent slow consumers from blocking state updates. 39 | func (s *Machine) GetStateChan(ctx context.Context) <-chan string { 40 | return s.getStateChanInternal(ctx, broadcast.WithTimeout(5*time.Second)) 41 | } 42 | 43 | // getStateChanInternal is a helper that creates a channel and sends the current state to it. 44 | // This maintains v1 API compatibility where GetStateChan immediately sends the current state. 45 | func (s *Machine) getStateChanInternal(ctx context.Context, opts ...broadcast.Option) <-chan string { 46 | wrappedCh := make(chan string, 1) 47 | 48 | userCh, err := s.broadcastManager.GetStateChan(ctx, opts...) 49 | if err != nil { 50 | close(wrappedCh) 51 | return wrappedCh 52 | } 53 | 54 | currentState := s.GetState() 55 | wrappedCh <- currentState 56 | 57 | go func() { 58 | defer close(wrappedCh) 59 | for state := range userCh { 60 | wrappedCh <- state 61 | } 62 | }() 63 | 64 | return wrappedCh 65 | } 66 | 67 | // New creates a new finite state machine with the specified logger using "standard" state transitions. 68 | // This function provides compatibility with the v1 API while using v2 under the hood. 69 | func New(handler slog.Handler) (*Machine, error) { 70 | registry, err := hooks.NewRegistry( 71 | hooks.WithLogHandler(handler), 72 | hooks.WithTransitions(TypicalTransitions), 73 | ) 74 | if err != nil { 75 | return nil, err 76 | } 77 | 78 | broadcastManager := broadcast.NewManager(handler) 79 | 80 | err = registry.RegisterPostTransitionHook(hooks.PostTransitionHookConfig{ 81 | Name: "broadcast", 82 | From: []string{"*"}, 83 | To: []string{"*"}, 84 | Action: broadcastManager.BroadcastHook, 85 | }) 86 | if err != nil { 87 | return nil, err 88 | } 89 | 90 | f, err := fsm.New( 91 | StatusNew, 92 | TypicalTransitions, 93 | fsm.WithLogHandler(handler), 94 | fsm.WithCallbackRegistry(registry), 95 | ) 96 | if err != nil { 97 | return nil, err 98 | } 99 | 100 | return &Machine{ 101 | Machine: f, 102 | broadcastManager: broadcastManager, 103 | }, nil 104 | } 105 | -------------------------------------------------------------------------------- /supervisor/state_test.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/robbyt/go-supervisor/runnables/mocks" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | // TestPIDZero_GetState tests the GetState method with Stateable and non-Stateable runnables. 12 | func TestPIDZero_GetState(t *testing.T) { 13 | t.Parallel() 14 | 15 | tests := []struct { 16 | name string 17 | setupMock func() Runnable 18 | expectedState string 19 | }{ 20 | { 21 | name: "runnable implements Stateable", 22 | setupMock: func() Runnable { 23 | mockService := mocks.NewMockRunnableWithStateable() 24 | mockService.On("GetState").Return("running").Once() 25 | mockService.On("String").Return("StateableService").Maybe() 26 | return mockService 27 | }, 28 | expectedState: "running", 29 | }, 30 | { 31 | name: "runnable does not implement Stateable", 32 | setupMock: func() Runnable { 33 | // Create a mock that only implements Runnable but not Stateable 34 | mockRunnable := mocks.NewMockRunnable() 35 | // Since we're not actually calling Run or Stop in this test, we only need String 36 | mockRunnable.On("String").Return("SimpleRunnable").Maybe() 37 | return mockRunnable 38 | }, 39 | expectedState: "unknown", 40 | }, 41 | } 42 | 43 | for _, tt := range tests { 44 | tt := tt // Capture range variable 45 | t.Run(tt.name, func(t *testing.T) { 46 | t.Parallel() 47 | 48 | // Create a supervisor with the mock runnable 49 | runnable := tt.setupMock() 50 | pidZero, err := New(WithRunnables(runnable)) 51 | require.NoError(t, err) 52 | 53 | // Get the state of the runnable 54 | state := pidZero.GetCurrentState(runnable) 55 | 56 | // Verify the state 57 | assert.Equal(t, tt.expectedState, state) 58 | 59 | // If using a mock, verify expectations 60 | if m, ok := runnable.(*mocks.Runnable); ok { 61 | m.AssertExpectations(t) 62 | } 63 | }) 64 | } 65 | } 66 | 67 | // TestPIDZero_GetStates tests the GetStates method with multiple runnables. 68 | func TestPIDZero_GetStates(t *testing.T) { 69 | t.Parallel() 70 | 71 | // Create mock services 72 | mockService1 := mocks.NewMockRunnableWithStateable() 73 | mockService1.On("GetState").Return("running").Once() 74 | mockService1.On("String").Return("MockService1").Maybe() 75 | 76 | mockService2 := mocks.NewMockRunnableWithStateable() 77 | mockService2.On("GetState").Return("stopped").Once() 78 | mockService2.On("String").Return("MockService2").Maybe() 79 | 80 | // Create a non-Stateable runnable that doesn't implement GetState 81 | nonStateableRunnable := mocks.NewMockRunnable() 82 | nonStateableRunnable.On("String").Return("NonStateableRunnable").Maybe() 83 | 84 | // Create a supervisor with the mock runnables 85 | pidZero, err := New(WithRunnables(mockService1, mockService2, nonStateableRunnable)) 86 | require.NoError(t, err) 87 | 88 | // Get the states of all runnables 89 | states := pidZero.GetCurrentStates() 90 | 91 | // Verify the states map 92 | assert.Len(t, states, 2) // Only Stateable runnables should be in the map 93 | assert.Equal(t, "running", states[mockService1]) 94 | assert.Equal(t, "stopped", states[mockService2]) 95 | 96 | // Verify the non-Stateable runnable is not in the map 97 | _, exists := states[nonStateableRunnable] 98 | assert.False(t, exists) 99 | 100 | // Verify expectations 101 | mockService1.AssertExpectations(t) 102 | mockService2.AssertExpectations(t) 103 | } 104 | -------------------------------------------------------------------------------- /internal/networking/port_test.go: -------------------------------------------------------------------------------- 1 | package networking 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestValidatePort(t *testing.T) { 12 | t.Parallel() 13 | 14 | tests := []struct { 15 | name string 16 | input string 17 | expectedOutput string 18 | expectedError error 19 | }{ 20 | { 21 | name: "valid colon port", 22 | input: ":8080", 23 | expectedOutput: ":8080", 24 | expectedError: nil, 25 | }, 26 | { 27 | name: "valid port without colon", 28 | input: "8080", 29 | expectedOutput: ":8080", 30 | expectedError: nil, 31 | }, 32 | { 33 | name: "valid host port", 34 | input: "localhost:8080", 35 | expectedOutput: "localhost:8080", 36 | expectedError: nil, 37 | }, 38 | { 39 | name: "valid IP port", 40 | input: "127.0.0.1:8080", 41 | expectedOutput: "127.0.0.1:8080", 42 | expectedError: nil, 43 | }, 44 | { 45 | name: "maximum valid port", 46 | input: ":65535", 47 | expectedOutput: ":65535", 48 | expectedError: nil, 49 | }, 50 | { 51 | name: "minimum valid port", 52 | input: ":1", 53 | expectedOutput: ":1", 54 | expectedError: nil, 55 | }, 56 | { 57 | name: "empty string", 58 | input: "", 59 | expectedOutput: "", 60 | expectedError: ErrEmptyPort, 61 | }, 62 | { 63 | name: "just colon", 64 | input: ":", 65 | expectedOutput: "", 66 | expectedError: ErrInvalidFormat, 67 | }, 68 | { 69 | name: "port out of range (too high)", 70 | input: ":65536", 71 | expectedOutput: "", 72 | expectedError: ErrPortOutOfRange, 73 | }, 74 | { 75 | name: "port out of range (too low)", 76 | input: ":0", 77 | expectedOutput: "", 78 | expectedError: ErrPortOutOfRange, 79 | }, 80 | { 81 | name: "port out of range (negative)", 82 | input: ":-1", 83 | expectedOutput: "", 84 | expectedError: ErrInvalidFormat, 85 | }, 86 | { 87 | name: "non-numeric port", 88 | input: ":abc", 89 | expectedOutput: "", 90 | expectedError: ErrInvalidFormat, 91 | }, 92 | { 93 | name: "invalid format with multiple colons", 94 | input: "localhost:8080:8081", 95 | expectedOutput: "", 96 | expectedError: ErrInvalidFormat, 97 | }, 98 | { 99 | name: "mixed numeric and letters in port", 100 | input: ":8080a", 101 | expectedOutput: "", 102 | expectedError: ErrInvalidFormat, 103 | }, 104 | } 105 | 106 | for _, tt := range tests { 107 | tt := tt // Capture range variable 108 | t.Run(tt.name, func(t *testing.T) { 109 | t.Parallel() 110 | output, err := ValidatePort(tt.input) 111 | 112 | if tt.expectedError != nil { 113 | require.Error(t, err) 114 | if !errors.Is(tt.expectedError, ErrInvalidFormat) { 115 | // For specific errors, check exact match 116 | require.ErrorIs(t, err, tt.expectedError) 117 | } else { 118 | // For format errors, just check that it's a format error (details may vary) 119 | require.ErrorIs(t, err, ErrInvalidFormat, "Expected error to be a ErrInvalidFormat") 120 | } 121 | assert.Empty(t, output) 122 | } else { 123 | require.NoError(t, err) 124 | assert.Equal(t, tt.expectedOutput, output) 125 | } 126 | }) 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /runnables/httpserver/runner_context_test.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "testing" 9 | "time" 10 | 11 | "github.com/robbyt/go-supervisor/internal/finitestate" 12 | "github.com/robbyt/go-supervisor/internal/networking" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | // TestContextValuePropagation verifies that context values are properly propagated from the 18 | // Runner to the HTTP request handlers. 19 | func TestContextValuePropagation(t *testing.T) { 20 | t.Parallel() 21 | 22 | // Create a parent context with a test value 23 | type contextKey string 24 | const testKey contextKey = "test-key" 25 | const testValue = "test-value" 26 | parentCtx := context.WithValue(t.Context(), testKey, testValue) 27 | 28 | // Create a cancellable context to test cancellation propagation 29 | ctx, cancel := context.WithCancel(parentCtx) 30 | defer cancel() 31 | 32 | // Create a handler that checks the context value 33 | contextValueReceived := make(chan string, 1) 34 | contextCancelReceived := make(chan struct{}, 1) 35 | 36 | handler := func(w http.ResponseWriter, r *http.Request) { 37 | // Check if context value is properly propagated 38 | if value, ok := r.Context().Value(testKey).(string); ok { 39 | contextValueReceived <- value 40 | } else { 41 | contextValueReceived <- "value-not-found" 42 | } 43 | 44 | // Also check if cancellation propagates 45 | go func() { 46 | <-r.Context().Done() 47 | contextCancelReceived <- struct{}{} 48 | }() 49 | 50 | // Send a response 51 | w.WriteHeader(http.StatusOK) 52 | _, err := w.Write([]byte("OK")) 53 | assert.NoError(t, err) 54 | } 55 | 56 | // Create a route with the test handler 57 | route, err := NewRouteFromHandlerFunc("test", "/test", handler) 58 | require.NoError(t, err) 59 | 60 | // Get a unique port for this test 61 | port := fmt.Sprintf(":%d", networking.GetRandomPort(t)) 62 | 63 | // Create config with our test context 64 | cfgCallback := func() (*Config, error) { 65 | return NewConfig(port, Routes{*route}, WithRequestContext(ctx)) 66 | } 67 | 68 | // Create the runner 69 | runner, err := NewRunner( 70 | WithConfigCallback(cfgCallback), 71 | ) 72 | require.NoError(t, err) 73 | 74 | // Run the server in a goroutine 75 | errChan := make(chan error, 1) 76 | go func() { 77 | err := runner.Run(ctx) 78 | errChan <- err 79 | }() 80 | 81 | // Wait for the server to start 82 | require.Eventually(t, func() bool { 83 | return runner.GetState() == finitestate.StatusRunning 84 | }, 2*time.Second, 10*time.Millisecond, "Server should reach Running state") 85 | 86 | // Make a request to the server 87 | client := &http.Client{Timeout: 1 * time.Second} 88 | resp, err := client.Get(fmt.Sprintf("http://localhost%s/test", port)) 89 | require.NoError(t, err, "Request to server should succeed") 90 | 91 | // Read and close response body 92 | body, err := io.ReadAll(resp.Body) 93 | require.NoError(t, err) 94 | require.Equal(t, "OK", string(body)) 95 | require.NoError(t, resp.Body.Close()) 96 | 97 | // Verify the context value was properly received 98 | select { 99 | case value := <-contextValueReceived: 100 | assert.Equal(t, testValue, value, "Context value should be propagated to handlers") 101 | case <-time.After(2 * time.Second): 102 | t.Fatal("Timed out waiting for context value") 103 | } 104 | 105 | // Cancel the context and check if the cancellation is propagated 106 | cancel() 107 | 108 | // Wait for cancellation signal 109 | select { 110 | case <-contextCancelReceived: 111 | // Success, cancellation was propagated 112 | case <-time.After(2 * time.Second): 113 | t.Fatal("Timed out waiting for context cancellation") 114 | } 115 | 116 | // Server should shut down 117 | select { 118 | case err := <-errChan: 119 | require.NoError(t, err, "Server should shut down cleanly") 120 | case <-time.After(2 * time.Second): 121 | t.Fatal("Timed out waiting for server to shut down") 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /runnables/httpserver/routes.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "slices" 8 | "strings" 9 | ) 10 | 11 | // Route represents a single HTTP route with a name, path, and handler chain. 12 | type Route struct { 13 | name string // internal identifier for the route, used for equality checks 14 | Path string 15 | Handlers []HandlerFunc 16 | } 17 | 18 | // newRoute is a private/internal constructor used by the other route creation functions. 19 | func newRoute(name, path string, handlers ...HandlerFunc) (*Route, error) { 20 | if name == "" { 21 | return nil, errors.New("name cannot be empty") 22 | } 23 | if path == "" { 24 | return nil, errors.New("path cannot be empty") 25 | } 26 | if len(handlers) == 0 { 27 | return nil, errors.New("at least one handler required") 28 | } 29 | 30 | return &Route{ 31 | name: name, 32 | Path: path, 33 | Handlers: handlers, 34 | }, nil 35 | } 36 | 37 | // NewRouteFromHandlerFunc creates a new Route with the given name, path, and handler. Optionally, it can include middleware functions. 38 | // This is the preferred way to create routes in the httpserver package. 39 | func NewRouteFromHandlerFunc( 40 | name string, 41 | path string, 42 | handler http.HandlerFunc, 43 | middlewares ...HandlerFunc, 44 | ) (*Route, error) { 45 | if handler == nil { 46 | return nil, errors.New("handler cannot be nil") 47 | } 48 | h := func(rp *RequestProcessor) { 49 | handler.ServeHTTP(rp.Writer(), rp.Request()) 50 | } 51 | middlewares = append(middlewares, h) 52 | return newRoute(name, path, middlewares...) 53 | } 54 | 55 | // ServeHTTP adapts the route to work with standard http.Handler 56 | func (r *Route) ServeHTTP(w http.ResponseWriter, req *http.Request) { 57 | rp := &RequestProcessor{ 58 | writer: newResponseWriter(w), 59 | request: req, 60 | handlers: r.Handlers, 61 | index: -1, 62 | } 63 | 64 | rp.Next() 65 | } 66 | 67 | func (r Route) Equal(other Route) bool { 68 | if r.Path != other.Path { 69 | return false 70 | } 71 | 72 | if r.name != other.name { 73 | return false 74 | } 75 | 76 | return true 77 | } 78 | 79 | // Routes is a map of paths as strings, that route to http.HandlerFuncs 80 | type Routes []Route 81 | 82 | // Equal compares two routes and returns true if they are equal, false otherwise. 83 | // This works because we assume the route names uniquely identify the route. 84 | // For example, the route name could be based on a content hash or other unique identifier. 85 | func (r Routes) Equal(other Routes) bool { 86 | // First compare the lengths of both routes, if they are different they can't be equal 87 | if len(r) != len(other) { 88 | return false 89 | } 90 | 91 | // now compare the names of both routes, if they are different, they are not equal 92 | oldNames := make([]string, 0, len(r)) 93 | for _, route := range r { 94 | oldNames = append(oldNames, route.name) 95 | } 96 | slices.Sort(oldNames) 97 | 98 | newNames := make([]string, 0, len(other)) 99 | for _, route := range other { 100 | newNames = append(newNames, route.name) 101 | } 102 | slices.Sort(newNames) 103 | 104 | if fmt.Sprintf("%v", oldNames) != fmt.Sprintf("%v", newNames) { 105 | return false 106 | } 107 | 108 | // now compare the paths of both routes, if they are different they are not equal 109 | routeMap := make(map[string]Route) 110 | for _, route := range r { 111 | routeMap[route.Path] = route 112 | } 113 | 114 | for _, otherRoute := range other { 115 | route, exists := routeMap[otherRoute.Path] 116 | if !exists || !route.Equal(otherRoute) { 117 | return false 118 | } 119 | } 120 | 121 | return true 122 | } 123 | 124 | // String returns a string representation of all routes, including their versions and paths. 125 | func (r Routes) String() string { 126 | if len(r) == 0 { 127 | return "Routes<>" 128 | } 129 | 130 | var routes []string 131 | for _, route := range r { 132 | routes = append(routes, fmt.Sprintf("Name: %s, Path: %s", route.name, route.Path)) 133 | } 134 | slices.Sort(routes) 135 | 136 | return fmt.Sprintf("Routes<%s>", strings.Join(routes, ", ")) 137 | } 138 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/logger/logger_test.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "log/slog" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | 11 | "github.com/robbyt/go-supervisor/runnables/httpserver" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | // setupRequest creates a basic HTTP request for testing 17 | func setupRequest(t *testing.T, method, path string) (*httptest.ResponseRecorder, *http.Request) { 18 | t.Helper() 19 | req := httptest.NewRequest(method, path, nil) 20 | rec := httptest.NewRecorder() 21 | return rec, req 22 | } 23 | 24 | // setupLogBuffer creates a handler that writes to a buffer for testing 25 | func setupLogBuffer(t *testing.T, level slog.Level) (*bytes.Buffer, slog.Handler) { 26 | t.Helper() 27 | buffer := &bytes.Buffer{} 28 | handler := slog.NewTextHandler(buffer, &slog.HandlerOptions{Level: level}) 29 | return buffer, handler 30 | } 31 | 32 | // createTestHandler returns a handler that writes a response 33 | func createTestHandler(t *testing.T, checkResponse bool) http.HandlerFunc { 34 | t.Helper() 35 | return func(w http.ResponseWriter, r *http.Request) { 36 | w.WriteHeader(http.StatusOK) 37 | n, err := w.Write([]byte("test response")) 38 | if checkResponse { 39 | assert.NoError(t, err) 40 | assert.Equal(t, 13, n) 41 | } else if err != nil { 42 | http.Error(w, "Failed to write response", http.StatusInternalServerError) 43 | return 44 | } 45 | } 46 | } 47 | 48 | // executeHandlerWithLogger runs the provided handler with the Logger middleware 49 | func executeHandlerWithLogger( 50 | t *testing.T, 51 | handler http.HandlerFunc, 52 | logHandler slog.Handler, 53 | rec *httptest.ResponseRecorder, 54 | req *http.Request, 55 | ) { 56 | t.Helper() 57 | // Create a route with logger middleware and the handler 58 | route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", handler, New(logHandler)) 59 | require.NoError(t, err) 60 | route.ServeHTTP(rec, req) 61 | } 62 | 63 | // setupDetailedRequest creates a test HTTP request with user agent and remote addr 64 | func setupDetailedRequest( 65 | t *testing.T, 66 | method, path, userAgent, remoteAddr string, 67 | ) (*httptest.ResponseRecorder, *http.Request) { 68 | t.Helper() 69 | rec, req := setupRequest(t, method, path) 70 | req.Header.Set("User-Agent", userAgent) 71 | req.RemoteAddr = remoteAddr 72 | return rec, req 73 | } 74 | 75 | func TestLogger(t *testing.T) { 76 | t.Run("with custom handler", func(t *testing.T) { 77 | // Setup 78 | logBuffer, logHandler := setupLogBuffer(t, slog.LevelInfo) 79 | rec, req := setupDetailedRequest(t, "GET", "/test", "test-agent", "127.0.0.1:12345") 80 | handler := createTestHandler(t, false) 81 | 82 | // Execute 83 | executeHandlerWithLogger(t, handler, logHandler, rec, req) 84 | 85 | // Check response 86 | resp := rec.Result() 87 | body, err := io.ReadAll(resp.Body) 88 | require.NoError(t, err) 89 | assert.Equal(t, http.StatusOK, resp.StatusCode) 90 | assert.Equal(t, "test response", string(body)) 91 | 92 | // Verify log output contains expected info 93 | logOutput := logBuffer.String() 94 | assert.Contains(t, logOutput, "HTTP request") 95 | assert.Contains(t, logOutput, "method=GET") 96 | assert.Contains(t, logOutput, "path=/test") 97 | assert.Contains(t, logOutput, "status=200") 98 | assert.Contains(t, logOutput, "user_agent=test-agent") 99 | assert.Contains(t, logOutput, "remote_addr=127.0.0.1:12345") 100 | }) 101 | 102 | t.Run("with nil handler (uses default)", func(t *testing.T) { 103 | // Save and restore default logger 104 | defaultLogger := slog.Default() 105 | defer slog.SetDefault(defaultLogger) 106 | 107 | // Setup with default logger 108 | logBuffer, testHandler := setupLogBuffer(t, slog.LevelInfo) 109 | slog.SetDefault(slog.New(testHandler)) 110 | 111 | rec, req := setupRequest(t, "GET", "/test") 112 | handler := createTestHandler(t, true) 113 | 114 | // Execute with nil handler (will use default) 115 | executeHandlerWithLogger(t, handler, nil, rec, req) 116 | 117 | // Verify log output 118 | logOutput := logBuffer.String() 119 | assert.Contains(t, logOutput, "httpserver") 120 | assert.Contains(t, logOutput, "HTTP request") 121 | }) 122 | } 123 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/recovery/recovery_test.go: -------------------------------------------------------------------------------- 1 | package recovery 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "log/slog" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | 11 | "github.com/robbyt/go-supervisor/runnables/httpserver" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | // setupRequest creates a basic HTTP request for testing 17 | func setupRequest(t *testing.T, method, path string) (*httptest.ResponseRecorder, *http.Request) { 18 | t.Helper() 19 | req := httptest.NewRequest(method, path, nil) 20 | rec := httptest.NewRecorder() 21 | return rec, req 22 | } 23 | 24 | // setupLogBuffer creates a handler that writes to a buffer for testing 25 | func setupLogBuffer(t *testing.T, level slog.Level) (*bytes.Buffer, slog.Handler) { 26 | t.Helper() 27 | buffer := &bytes.Buffer{} 28 | handler := slog.NewTextHandler(buffer, &slog.HandlerOptions{Level: level}) 29 | return buffer, handler 30 | } 31 | 32 | // executeHandlerWithRecovery runs the provided handler with the PanicRecovery middleware 33 | func executeHandlerWithRecovery( 34 | t *testing.T, 35 | handler http.HandlerFunc, 36 | logHandler slog.Handler, 37 | rec *httptest.ResponseRecorder, 38 | req *http.Request, 39 | ) { 40 | t.Helper() 41 | // Create a route with recovery middleware and the handler 42 | route, err := httpserver.NewRouteFromHandlerFunc("test", "/test", handler, New(logHandler)) 43 | require.NoError(t, err) 44 | route.ServeHTTP(rec, req) 45 | } 46 | 47 | // createPanicHandler returns a handler that panics with the given message 48 | func createPanicHandler(t *testing.T, panicMsg string) http.HandlerFunc { 49 | t.Helper() 50 | return func(w http.ResponseWriter, r *http.Request) { 51 | panic(panicMsg) 52 | } 53 | } 54 | 55 | // createSuccessHandler returns a handler that returns a 200 OK with "Success" body 56 | func createSuccessHandler(t *testing.T) http.HandlerFunc { 57 | t.Helper() 58 | return func(w http.ResponseWriter, r *http.Request) { 59 | w.WriteHeader(http.StatusOK) 60 | n, err := w.Write([]byte("Success")) 61 | assert.NoError(t, err) 62 | assert.Equal(t, 7, n) 63 | } 64 | } 65 | 66 | func TestRecoveryMiddleware(t *testing.T) { 67 | t.Run("recovers from panic with custom handler", func(t *testing.T) { 68 | // Setup 69 | logBuffer, logHandler := setupLogBuffer(t, slog.LevelError) 70 | rec, req := setupRequest(t, "GET", "/test") 71 | handler := createPanicHandler(t, "test panic") 72 | 73 | // Execute 74 | executeHandlerWithRecovery(t, handler, logHandler, rec, req) 75 | 76 | // Verify response 77 | resp := rec.Result() 78 | body, err := io.ReadAll(resp.Body) 79 | require.NoError(t, err) 80 | assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 81 | assert.Equal(t, "Internal Server Error\n", string(body)) 82 | 83 | // Verify log output contains panic info 84 | logOutput := logBuffer.String() 85 | assert.Contains(t, logOutput, "HTTP handler panic recovered") 86 | assert.Contains(t, logOutput, "error=\"test panic\"") 87 | assert.Contains(t, logOutput, "path=/test") 88 | assert.Contains(t, logOutput, "method=GET") 89 | }) 90 | 91 | t.Run("recovers from panic silently with nil handler", func(t *testing.T) { 92 | // Setup 93 | rec, req := setupRequest(t, "POST", "/api/test") 94 | handler := createPanicHandler(t, "test panic with nil handler") 95 | 96 | // Execute with nil handler - should recover silently 97 | executeHandlerWithRecovery(t, handler, nil, rec, req) 98 | 99 | // Verify response - should still return 500 error 100 | resp := rec.Result() 101 | body, err := io.ReadAll(resp.Body) 102 | require.NoError(t, err) 103 | assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) 104 | assert.Equal(t, "Internal Server Error\n", string(body)) 105 | 106 | // No log verification since recovery should be silent 107 | }) 108 | 109 | t.Run("passes through normal requests", func(t *testing.T) { 110 | // Setup 111 | rec, req := setupRequest(t, "GET", "/test") 112 | handler := createSuccessHandler(t) 113 | 114 | // Execute 115 | executeHandlerWithRecovery(t, handler, nil, rec, req) 116 | 117 | // Verify response 118 | assert.Equal(t, http.StatusOK, rec.Code) 119 | assert.Equal(t, "Success", rec.Body.String()) 120 | }) 121 | } 122 | -------------------------------------------------------------------------------- /examples/http/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "log/slog" 8 | "net/http" 9 | "testing" 10 | "time" 11 | 12 | "github.com/robbyt/go-supervisor/runnables/httpserver" 13 | "github.com/robbyt/go-supervisor/supervisor" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | // TestRunServer tests that the HTTP server starts successfully 19 | func TestRunServer(t *testing.T) { 20 | t.Parallel() 21 | 22 | // Create a test logger that discards output 23 | logHandler := slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}) 24 | 25 | // Create a context with timeout for the test 26 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 27 | defer cancel() 28 | 29 | // Run the server 30 | routes, err := buildRoutes(logHandler) 31 | require.NoError(t, err, "Failed to build routes") 32 | require.NotEmpty(t, routes, "Routes should not be empty") 33 | 34 | sv, err := RunServer(ctx, logHandler, routes) 35 | require.NoError(t, err, "RunServer should not return an error") 36 | require.NotNil(t, sv, "Supervisor should not be nil") 37 | 38 | // Start the server in a goroutine to avoid blocking the test 39 | errCh := make(chan error, 1) 40 | go func() { 41 | errCh <- sv.Run() 42 | }() 43 | 44 | // Wait for the server to be ready by checking if it responds to requests 45 | assert.Eventually(t, func() bool { 46 | resp, err := http.Get("http://localhost:8080/status") 47 | if err != nil { 48 | return false 49 | } 50 | defer func() { assert.NoError(t, resp.Body.Close()) }() 51 | return resp.StatusCode == http.StatusOK 52 | }, 2*time.Second, 50*time.Millisecond, "Server should become ready") 53 | 54 | // Make a request to the server 55 | resp, err := http.Get("http://localhost:8080/status") 56 | require.NoError(t, err, "Failed to make GET request") 57 | 58 | body, err := io.ReadAll(resp.Body) 59 | require.NoError(t, err, "Failed to read response body") 60 | assert.NoError(t, resp.Body.Close()) 61 | assert.Equal(t, http.StatusOK, resp.StatusCode) 62 | assert.Equal(t, "Status: OK\n", string(body)) 63 | 64 | // Stop the supervisor 65 | sv.Shutdown() 66 | 67 | // Wait for Run() to complete and check the result 68 | select { 69 | case err := <-errCh: 70 | if err != nil && !errors.Is(err, context.Canceled) { 71 | require.NoError(t, err, "Run() should not return an error") 72 | } 73 | case <-time.After(2 * time.Second): 74 | t.Fatal("Run() should have completed within timeout") 75 | } 76 | } 77 | 78 | // TestRunServerInvalidPort tests error handling when an invalid port is specified 79 | func TestRunServerInvalidPort(t *testing.T) { 80 | t.Parallel() 81 | 82 | // Create a test logger that discards output 83 | logHandler := slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}) 84 | 85 | // Create a context with timeout for the test 86 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 87 | defer cancel() 88 | 89 | // Build routes - reuse the same routes from the main app 90 | routes, err := buildRoutes(logHandler) 91 | require.NoError(t, err, "Failed to build routes") 92 | 93 | // Create a config callback that uses an invalid port 94 | configCallback := func() (*httpserver.Config, error) { 95 | return httpserver.NewConfig(":-1", routes, httpserver.WithDrainTimeout(DrainTimeout)) 96 | } 97 | 98 | // Create HTTP server runner with invalid port 99 | runner, err := httpserver.NewRunner( 100 | httpserver.WithConfigCallback(configCallback), 101 | httpserver.WithLogHandler(logHandler.WithGroup("httpserver")), 102 | ) 103 | require.NoError(t, err, "Should be able to create runner even with invalid config") 104 | 105 | // Create supervisor with a reasonable timeout 106 | sv, err := supervisor.New( 107 | supervisor.WithContext(ctx), 108 | supervisor.WithRunnables(runner), 109 | supervisor.WithLogHandler(logHandler), 110 | supervisor.WithStartupTimeout(100*time.Millisecond), // Short timeout for tests 111 | ) 112 | require.NoError(t, err, "Failed to create supervisor") 113 | 114 | err = sv.Run() 115 | require.Error(t, err, "Run() should fail with invalid port") 116 | assert.ErrorIs(t, 117 | err, httpserver.ErrServerBoot, 118 | "httpserver.Runner.Run() should return ErrServerBoot", 119 | ) 120 | } 121 | -------------------------------------------------------------------------------- /examples/custom_middleware/example/jsonenforcer.go: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "net/http" 7 | 8 | "github.com/robbyt/go-supervisor/runnables/httpserver" 9 | ) 10 | 11 | // ResponseBuffer captures response data for transformation 12 | // 13 | // This is a simple example for demonstration purposes and is not intended for 14 | // production use. Limitations: 15 | // - Does not preserve optional HTTP interfaces (http.Hijacker, http.Flusher, http.Pusher) 16 | // - Not safe for concurrent writes from multiple goroutines within the same request 17 | // - No memory limits on buffered content 18 | // 19 | // Each request gets its own ResponseBuffer instance, so different requests won't 20 | // interfere with each other. 21 | type ResponseBuffer struct { 22 | buffer *bytes.Buffer 23 | headers http.Header 24 | status int 25 | } 26 | 27 | // NewResponseBuffer creates a new response buffer 28 | func NewResponseBuffer() *ResponseBuffer { 29 | return &ResponseBuffer{ 30 | buffer: new(bytes.Buffer), 31 | headers: make(http.Header), 32 | status: 0, // 0 means not set yet 33 | } 34 | } 35 | 36 | // Header implements http.ResponseWriter 37 | func (rb *ResponseBuffer) Header() http.Header { 38 | return rb.headers 39 | } 40 | 41 | // Write implements http.ResponseWriter 42 | func (rb *ResponseBuffer) Write(data []byte) (int, error) { 43 | return rb.buffer.Write(data) 44 | } 45 | 46 | // WriteHeader implements http.ResponseWriter 47 | func (rb *ResponseBuffer) WriteHeader(statusCode int) { 48 | if rb.status == 0 { 49 | rb.status = statusCode 50 | } 51 | } 52 | 53 | // Status implements httpserver.ResponseWriter 54 | func (rb *ResponseBuffer) Status() int { 55 | if rb.status == 0 && rb.buffer.Len() > 0 { 56 | return http.StatusOK 57 | } 58 | return rb.status 59 | } 60 | 61 | // Written implements httpserver.ResponseWriter 62 | func (rb *ResponseBuffer) Written() bool { 63 | return rb.buffer.Len() > 0 || rb.status != 0 64 | } 65 | 66 | // Size implements httpserver.ResponseWriter 67 | func (rb *ResponseBuffer) Size() int { 68 | return rb.buffer.Len() 69 | } 70 | 71 | // transformToJSON wraps non-JSON content in a JSON response 72 | func transformToJSON(data []byte) ([]byte, error) { 73 | // Use json.Valid for efficient validation without unmarshaling 74 | if json.Valid(data) { 75 | return data, nil // Valid JSON, return as-is 76 | } 77 | 78 | // If not valid JSON, wrap it 79 | response := map[string]string{ 80 | "response": string(data), 81 | } 82 | 83 | return json.Marshal(response) 84 | } 85 | 86 | // New creates a middleware that transforms all responses to JSON format. 87 | // Non-JSON responses are wrapped in {"response": "content"}. 88 | // Valid JSON responses are preserved as-is. 89 | func New() httpserver.HandlerFunc { 90 | return func(rp *httpserver.RequestProcessor) { 91 | // Store original writer before buffering 92 | originalWriter := rp.Writer() 93 | 94 | // Buffer the response to capture output 95 | buffer := NewResponseBuffer() 96 | rp.SetWriter(buffer) 97 | 98 | // Continue to next middleware/handler 99 | rp.Next() 100 | 101 | // RESPONSE PHASE: Transform response to JSON 102 | originalData := buffer.buffer.Bytes() 103 | statusCode := buffer.Status() 104 | if statusCode == 0 { 105 | statusCode = http.StatusOK 106 | } 107 | 108 | // Copy headers to original writer 109 | for key, values := range buffer.Header() { 110 | for _, value := range values { 111 | originalWriter.Header().Add(key, value) 112 | } 113 | } 114 | 115 | // Check if this status code should have no body per HTTP spec 116 | // 204 No Content and 304 Not Modified MUST NOT have a message body 117 | if statusCode == http.StatusNoContent || statusCode == http.StatusNotModified { 118 | originalWriter.WriteHeader(statusCode) 119 | return 120 | } 121 | 122 | // Transform captured data to JSON 123 | if len(originalData) == 0 && buffer.status == 0 { 124 | return 125 | } 126 | 127 | // Transform to JSON if needed 128 | jsonData, err := transformToJSON(originalData) 129 | if err != nil { 130 | // Fallback: wrap error in JSON 131 | jsonData = []byte(`{"error":"Unable to encode response"}`) 132 | } 133 | 134 | // Ensure JSON content type 135 | originalWriter.Header().Set("Content-Type", "application/json") 136 | 137 | // Write status and transformed data 138 | originalWriter.WriteHeader(statusCode) 139 | if _, err := originalWriter.Write(jsonData); err != nil { 140 | // Response is already committed, cannot recover from write error 141 | return 142 | } 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /examples/httpcluster/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "log/slog" 7 | "net/http" 8 | "net/http/httptest" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/robbyt/go-supervisor/internal/networking" 13 | "github.com/robbyt/go-supervisor/runnables/httpcluster" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | // TestRunCluster tests that the HTTP cluster can be created and configured 19 | func TestRunCluster(t *testing.T) { 20 | t.Parallel() 21 | 22 | logHandler := slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}) 23 | ctx := t.Context() 24 | 25 | sv, configMgr, err := createHTTPCluster(ctx, logHandler) 26 | require.NoError(t, err) 27 | require.NotNil(t, sv) 28 | require.NotNil(t, configMgr) 29 | 30 | // Test basic configuration 31 | assert.Equal(t, InitialPort, configMgr.getCurrentPort()) 32 | } 33 | 34 | // TestConfigManager tests the configuration manager 35 | func TestConfigManager(t *testing.T) { 36 | t.Parallel() 37 | 38 | logHandler := slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}) 39 | logger := slog.New(logHandler) 40 | 41 | cluster, err := httpcluster.NewRunner() 42 | require.NoError(t, err) 43 | 44 | configMgr := NewConfigManager(cluster, logger) 45 | 46 | // Test port tracking 47 | assert.Equal(t, InitialPort, configMgr.getCurrentPort()) 48 | } 49 | 50 | // TestPortValidation tests port validation logic 51 | func TestPortValidation(t *testing.T) { 52 | t.Parallel() 53 | 54 | tests := []struct { 55 | name string 56 | port string 57 | wantError bool 58 | }{ 59 | {"valid port", ":8080", false}, 60 | {"valid high port", ":65535", false}, 61 | {"port without colon", "8080", false}, // networking package adds colon automatically 62 | {"empty port", "", true}, 63 | {"just colon", ":", true}, 64 | {"negative port", ":-1", true}, 65 | {"port too high", ":99999", true}, 66 | {"non-numeric port", ":abc", true}, 67 | } 68 | 69 | for _, tt := range tests { 70 | tt := tt // Capture range variable 71 | t.Run(tt.name, func(t *testing.T) { 72 | t.Parallel() 73 | 74 | // Test just the validation logic, not actual server startup 75 | _, err := networking.ValidatePort(tt.port) 76 | if tt.wantError { 77 | assert.Error(t, err) 78 | } else { 79 | assert.NoError(t, err) 80 | } 81 | }) 82 | } 83 | } 84 | 85 | // TestHandlers tests HTTP handlers in isolation 86 | func TestHandlers(t *testing.T) { 87 | t.Parallel() 88 | 89 | logHandler := slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}) 90 | logger := slog.New(logHandler) 91 | 92 | cluster, err := httpcluster.NewRunner() 93 | require.NoError(t, err) 94 | configMgr := NewConfigManager(cluster, logger) 95 | 96 | t.Run("status handler", func(t *testing.T) { 97 | req := httptest.NewRequest(http.MethodGet, "/status", nil) 98 | w := httptest.NewRecorder() 99 | 100 | handler := configMgr.createStatusHandler() 101 | handler(w, req) 102 | 103 | assert.Equal(t, http.StatusOK, w.Code) 104 | assert.Equal(t, "application/json", w.Header().Get("Content-Type")) 105 | 106 | var status map[string]string 107 | err := json.Unmarshal(w.Body.Bytes(), &status) 108 | require.NoError(t, err) 109 | assert.Equal(t, "running", status["status"]) 110 | assert.Equal(t, InitialPort, status["port"]) 111 | }) 112 | 113 | t.Run("config handler GET", func(t *testing.T) { 114 | req := httptest.NewRequest(http.MethodGet, "/", nil) 115 | w := httptest.NewRecorder() 116 | 117 | handler := configMgr.createConfigHandler() 118 | handler(w, req) 119 | 120 | assert.Equal(t, http.StatusOK, w.Code) 121 | body := w.Body.String() 122 | assert.Contains(t, body, "HTTP Cluster Example") 123 | assert.Contains(t, body, "Current port: :8080") 124 | }) 125 | 126 | t.Run("config handler POST invalid JSON", func(t *testing.T) { 127 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("invalid json")) 128 | req.Header.Set("Content-Type", "application/json") 129 | w := httptest.NewRecorder() 130 | 131 | handler := configMgr.createConfigHandler() 132 | handler(w, req) 133 | 134 | assert.Equal(t, http.StatusBadRequest, w.Code) 135 | }) 136 | 137 | t.Run("config handler POST missing port", func(t *testing.T) { 138 | req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) 139 | req.Header.Set("Content-Type", "application/json") 140 | w := httptest.NewRecorder() 141 | 142 | handler := configMgr.createConfigHandler() 143 | handler(w, req) 144 | 145 | assert.Equal(t, http.StatusBadRequest, w.Code) 146 | }) 147 | } 148 | -------------------------------------------------------------------------------- /supervisor/state_deduplication_test.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | import ( 4 | "context" 5 | "maps" 6 | "testing" 7 | "time" 8 | 9 | "github.com/robbyt/go-supervisor/runnables/mocks" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/mock" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | // TestStateDeduplication tests that the startStateMonitor implementation 16 | // properly filters out duplicate state changes when it receives the same state 17 | // multiple times in a row through the state channel. 18 | func TestStateDeduplication(t *testing.T) { 19 | t.Parallel() 20 | 21 | // Create a context with a suitable timeout 22 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 23 | defer cancel() 24 | 25 | // Create a channel for sending state updates 26 | stateChan := make(chan string, 10) 27 | runnable := mocks.NewMockRunnableWithStateable() 28 | runnable.On("String").Return("test-runnable") 29 | runnable.On("GetStateChan", mock.Anything).Return(stateChan) 30 | runnable.On("GetState").Return("initial") 31 | 32 | // Create a new supervisor with our test runnable 33 | pidZero, err := New(WithContext(ctx), WithRunnables(runnable)) 34 | require.NoError(t, err) 35 | 36 | // Track the broadcasts that occur 37 | broadcasts := []StateMap{} 38 | broadcastChan := make(chan StateMap, 10) 39 | unsubscribe := pidZero.AddStateSubscriber(broadcastChan) 40 | defer unsubscribe() 41 | 42 | // Collect broadcasts in a background goroutine 43 | collectDone := make(chan struct{}) 44 | go func() { 45 | defer close(collectDone) 46 | for { 47 | select { 48 | case stateMap, ok := <-broadcastChan: 49 | if !ok { 50 | return 51 | } 52 | // Copy the map to avoid issues with concurrent modification 53 | copy := make(StateMap) 54 | maps.Copy(copy, stateMap) 55 | broadcasts = append(broadcasts, copy) 56 | t.Logf("Received broadcast: %+v", copy) 57 | case <-ctx.Done(): 58 | return 59 | } 60 | } 61 | }() 62 | 63 | // Store the initial state to match production behavior 64 | pidZero.stateMap.Store(runnable, "initial") 65 | 66 | // Start the state monitor 67 | pidZero.wg.Add(1) 68 | go pidZero.startStateMonitor() 69 | 70 | // Send the initial state to be discarded as per implementation 71 | t.Log("Sending 'initial' to be discarded") 72 | stateChan <- "initial" 73 | time.Sleep(50 * time.Millisecond) 74 | 75 | // Test sequence: 76 | // 1. Send "running" once - should trigger broadcast 77 | // 2. Send "running" twice more - should be ignored as duplicates 78 | // 3. Send "stopped" - should trigger broadcast 79 | // 4. Send "stopped" again - should be ignored as duplicate 80 | // 5. Send "error" - should trigger broadcast 81 | 82 | // First state change 83 | t.Log("Sending 'running' state") 84 | runnable.On("GetState").Return("running") 85 | stateChan <- "running" 86 | 87 | // Send duplicate states - should be ignored 88 | t.Log("Sending 'running' state again (should be ignored)") 89 | stateChan <- "running" 90 | 91 | t.Log("Sending 'running' state a third time (should be ignored)") 92 | stateChan <- "running" 93 | 94 | // Second state change 95 | t.Log("Sending 'stopped' state") 96 | runnable.On("GetState").Return("stopped") 97 | stateChan <- "stopped" 98 | 99 | // Another duplicate - should be ignored 100 | t.Log("Sending 'stopped' state again (should be ignored)") 101 | stateChan <- "stopped" 102 | 103 | // Third state change 104 | t.Log("Sending 'error' state") 105 | runnable.On("GetState").Return("error") 106 | stateChan <- "error" 107 | time.Sleep(100 * time.Millisecond) 108 | 109 | // Clean up and wait for collection to complete 110 | cancel() 111 | unsubscribe() 112 | close(broadcastChan) 113 | <-collectDone 114 | 115 | // Log final state for debugging 116 | t.Log("All broadcasts received:") 117 | for i, b := range broadcasts { 118 | t.Logf(" %d: %+v", i, b) 119 | } 120 | 121 | // Count number of each state broadcast received 122 | statesReceived := make(map[string]int) 123 | for _, broadcast := range broadcasts { 124 | // Look for the state of our test runnable 125 | if state, ok := broadcast[runnable.String()]; ok { 126 | statesReceived[state]++ 127 | } 128 | } 129 | 130 | // Log state counts 131 | t.Logf("State broadcast counts: %+v", statesReceived) 132 | 133 | // We should have unique state broadcasts (one each) 134 | // for running, stopped, and error states 135 | assert.Equal( 136 | t, 1, statesReceived["running"], 137 | "Should receive exactly one 'running' state broadcast", 138 | ) 139 | assert.Equal( 140 | t, 1, statesReceived["stopped"], 141 | "Should receive exactly one 'stopped' state broadcast", 142 | ) 143 | assert.Equal( 144 | t, 1, statesReceived["error"], 145 | "Should receive exactly one 'error' state broadcast", 146 | ) 147 | } 148 | -------------------------------------------------------------------------------- /examples/http/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log/slog" 7 | "net/http" 8 | "os" 9 | "time" 10 | 11 | "github.com/robbyt/go-supervisor/runnables/httpserver" 12 | "github.com/robbyt/go-supervisor/runnables/httpserver/middleware/logger" 13 | "github.com/robbyt/go-supervisor/runnables/httpserver/middleware/metrics" 14 | "github.com/robbyt/go-supervisor/runnables/httpserver/middleware/recovery" 15 | "github.com/robbyt/go-supervisor/runnables/httpserver/middleware/wildcard" 16 | "github.com/robbyt/go-supervisor/supervisor" 17 | ) 18 | 19 | const ( 20 | // Port the http server binds to 21 | ListenOn = ":8080" 22 | 23 | // How long the supervisor waits for the HTTP server to drain before forcefully shutting down 24 | DrainTimeout = 5 * time.Second 25 | ) 26 | 27 | // buildRoutes will setup various HTTP routes for this example server 28 | func buildRoutes(logHandler slog.Handler) ([]httpserver.Route, error) { 29 | // Create HTTP handlers functions 30 | indexHandler := func(w http.ResponseWriter, r *http.Request) { 31 | fmt.Fprintf(w, "Welcome to the go-supervisor example HTTP server!\n") 32 | } 33 | 34 | statusHandler := func(w http.ResponseWriter, r *http.Request) { 35 | fmt.Fprintf(w, "Status: OK\n") 36 | } 37 | 38 | wildcardHandler := func(w http.ResponseWriter, r *http.Request) { 39 | fmt.Fprintf(w, "You requested: %s\n", r.URL.Path) 40 | } 41 | 42 | // Create middleware for the routes using the new middleware system 43 | loggingMw := logger.New(logHandler.WithGroup("http")) 44 | recoveryMw := recovery.New(logHandler.WithGroup("recovery")) 45 | metricsMw := metrics.New() 46 | 47 | // Common middleware stack for all routes (using new HandlerFunc pattern) 48 | commonMw := []httpserver.HandlerFunc{ 49 | recoveryMw, 50 | loggingMw, 51 | metricsMw, 52 | } 53 | 54 | // Create routes with common middleware attached to each 55 | indexRoute, err := httpserver.NewRouteFromHandlerFunc( 56 | "index", 57 | "/", 58 | indexHandler, 59 | commonMw..., 60 | ) 61 | if err != nil { 62 | return nil, fmt.Errorf("failed to create index route: %w", err) 63 | } 64 | 65 | // Status route to provide a health check 66 | statusRoute, err := httpserver.NewRouteFromHandlerFunc( 67 | "status", 68 | "/status", 69 | statusHandler, 70 | commonMw..., 71 | ) 72 | if err != nil { 73 | return nil, fmt.Errorf("failed to create status route: %w", err) 74 | } 75 | 76 | // API wildcard route using the new middleware system 77 | apiRoute, err := httpserver.NewRouteFromHandlerFunc( 78 | "api", 79 | "/api/*", 80 | wildcardHandler, 81 | wildcard.New("/api/"), 82 | ) 83 | if err != nil { 84 | return nil, fmt.Errorf("failed to create wildcard route: %w", err) 85 | } 86 | 87 | return httpserver.Routes{*indexRoute, *statusRoute, *apiRoute}, nil 88 | } 89 | 90 | // RunServer initializes and runs the HTTP server with supervisor 91 | func RunServer( 92 | ctx context.Context, 93 | logHandler slog.Handler, 94 | routes []httpserver.Route, 95 | ) (*supervisor.PIDZero, error) { 96 | // Create a config callback function that will be used by the runner 97 | configCallback := func() (*httpserver.Config, error) { 98 | return httpserver.NewConfig(ListenOn, routes, httpserver.WithDrainTimeout(DrainTimeout)) 99 | } 100 | 101 | // Create the HTTP server runner 102 | runner, err := httpserver.NewRunner( 103 | httpserver.WithConfigCallback(configCallback), 104 | httpserver.WithLogHandler(logHandler)) 105 | if err != nil { 106 | return nil, fmt.Errorf("failed to create HTTP server runner: %w", err) 107 | } 108 | 109 | // Create a PIDZero supervisor and add the runner 110 | sv, err := supervisor.New( 111 | supervisor.WithContext(ctx), 112 | supervisor.WithLogHandler(logHandler), 113 | supervisor.WithRunnables(runner)) 114 | if err != nil { 115 | return nil, fmt.Errorf("failed to create supervisor: %w", err) 116 | } 117 | 118 | return sv, nil 119 | } 120 | 121 | func main() { 122 | // Configure the custom logger 123 | handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ 124 | Level: slog.LevelDebug, 125 | // AddSource: true, 126 | }) 127 | slog.SetDefault(slog.New(handler)) 128 | 129 | // Create base context 130 | ctx := context.Background() 131 | 132 | // Run the server 133 | routes, err := buildRoutes(handler) 134 | if err != nil { 135 | slog.Error("Failed to build routes", "error", err) 136 | os.Exit(1) 137 | } 138 | 139 | sv, err := RunServer(ctx, handler, routes) 140 | if err != nil { 141 | slog.Error("Failed to setup server", "error", err) 142 | os.Exit(1) 143 | } 144 | 145 | // Start the supervisor - this will block until shutdown 146 | slog.Info("Starting supervisor with HTTP server on " + ListenOn) 147 | if err := sv.Run(); err != nil { 148 | slog.Error("Supervisor failed", "error", err) 149 | os.Exit(1) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /runnables/httpserver/runner_race_test.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/robbyt/go-supervisor/internal/networking" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | // TestConcurrentReloadsRaceCondition verifies that concurrent reloads don't cause race conditions 17 | func TestConcurrentReloadsRaceCondition(t *testing.T) { 18 | if testing.Short() { 19 | t.Skip("Skipping test in short mode") 20 | } 21 | handler := func(w http.ResponseWriter, r *http.Request) { 22 | w.WriteHeader(http.StatusOK) 23 | } 24 | 25 | route, err := NewRouteFromHandlerFunc("test", "/test", handler) 26 | require.NoError(t, err) 27 | 28 | port := fmt.Sprintf(":%d", networking.GetRandomPort(t)) 29 | 30 | configVersion := 0 31 | cfgCallback := func() (*Config, error) { 32 | configVersion++ 33 | updatedCfg, err := NewConfig( 34 | port, 35 | Routes{*route}, 36 | WithDrainTimeout(1*time.Second), 37 | WithIdleTimeout(time.Duration(configVersion)*time.Millisecond+1*time.Minute), 38 | ) 39 | return updatedCfg, err 40 | } 41 | 42 | runner, err := NewRunner(WithConfigCallback(cfgCallback)) 43 | require.NoError(t, err) 44 | 45 | errChan := make(chan error, 1) 46 | ctx, cancel := context.WithCancel(t.Context()) 47 | defer cancel() 48 | 49 | go func() { 50 | err := runner.Run(ctx) 51 | errChan <- err 52 | }() 53 | 54 | // Wait for server to reach Running state 55 | waitForState( 56 | t, 57 | runner, 58 | "Running", 59 | 10*time.Second, 60 | "Server should reach Running state before reloads", 61 | ) 62 | 63 | // Verify server is accepting connections before starting reloads 64 | require.Eventually(t, func() bool { 65 | // Check if the server is running and accepting connections 66 | resp, err := http.Head("http://localhost" + port + "/test") 67 | if err != nil { 68 | t.Logf("Initial connection attempt failed: %v", err) 69 | return false 70 | } 71 | defer func() { assert.NoError(t, resp.Body.Close()) }() 72 | return resp.StatusCode == http.StatusOK 73 | }, 5*time.Second, 100*time.Millisecond, "Server should be accepting connections before reloads") 74 | 75 | // Launch concurrent reloads 76 | var wg sync.WaitGroup 77 | 78 | for i := 0; i < 5; i++ { 79 | wg.Add(1) 80 | go func() { 81 | defer wg.Done() 82 | runner.Reload() 83 | }() 84 | } 85 | 86 | // Wait for all reloads to complete 87 | wg.Wait() 88 | 89 | // Verify server is still accepting connections after reloads 90 | require.Eventually(t, func() bool { 91 | resp, err := http.Get("http://localhost" + port + "/test") 92 | if err != nil { 93 | t.Logf("Final connection attempt failed: %v", err) 94 | return false 95 | } 96 | defer func() { assert.NoError(t, resp.Body.Close()) }() 97 | return resp.StatusCode == http.StatusOK 98 | }, 10*time.Second, 100*time.Millisecond, "Server should still be accepting connections after concurrent reloads") 99 | 100 | cancel() 101 | 102 | <-errChan 103 | } 104 | 105 | // TestRunnerRaceConditions verifies that there are no race conditions in the boot and stopServer methods 106 | func TestRunnerRaceConditions(t *testing.T) { 107 | if testing.Short() { 108 | t.Skip("Skipping test in short mode") 109 | } 110 | handler := func(w http.ResponseWriter, r *http.Request) { 111 | w.WriteHeader(http.StatusOK) 112 | } 113 | 114 | route, err := NewRouteFromHandlerFunc("test", "/test", handler) 115 | require.NoError(t, err) 116 | 117 | port := fmt.Sprintf(":%d", networking.GetRandomPort(t)) 118 | 119 | cfg, err := NewConfig(port, Routes{*route}, WithDrainTimeout(1*time.Second)) 120 | require.NoError(t, err) 121 | 122 | cfgCallback := func() (*Config, error) { 123 | return cfg, nil 124 | } 125 | 126 | runner, err := NewRunner(WithConfigCallback(cfgCallback)) 127 | require.NoError(t, err) 128 | 129 | errChan := make(chan error, 1) 130 | ctx, cancel := context.WithCancel(t.Context()) 131 | defer cancel() 132 | 133 | go func() { 134 | err := runner.Run(ctx) 135 | errChan <- err 136 | }() 137 | 138 | // Wait for server to reach Running state 139 | waitForState( 140 | t, 141 | runner, 142 | "Running", 143 | 5*time.Second, 144 | "Server should reach Running state before connection attempt", 145 | ) 146 | 147 | // Try connecting with retries to handle potential timing issues 148 | require.Eventually(t, func() bool { 149 | resp, err := http.Get("http://localhost" + port + "/test") 150 | if err != nil { 151 | t.Logf("Connection attempt failed: %v", err) 152 | return false 153 | } 154 | defer func() { assert.NoError(t, resp.Body.Close()) }() 155 | return resp.StatusCode == http.StatusOK 156 | }, 5*time.Second, 100*time.Millisecond, "Server should be accepting connections") 157 | 158 | runner.Stop() 159 | 160 | <-errChan 161 | } 162 | -------------------------------------------------------------------------------- /runnables/httpserver/integration_race_test.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | "testing" 9 | "time" 10 | 11 | "github.com/robbyt/go-supervisor/internal/networking" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | // TestIntegration_NoRaceCondition verifies that when IsRunning() returns true, 17 | // the server is actually accepting TCP connections. 18 | func TestIntegration_NoRaceCondition(t *testing.T) { 19 | if testing.Short() { 20 | t.Skip("Skipping integration test in short mode") 21 | } 22 | 23 | const iterations = 10 24 | 25 | for i := 0; i < iterations; i++ { 26 | t.Run(fmt.Sprintf("iteration_%d", i), func(t *testing.T) { 27 | testSingleRunnerRaceCondition(t) 28 | }) 29 | } 30 | } 31 | 32 | func testSingleRunnerRaceCondition(t *testing.T) { 33 | t.Helper() 34 | 35 | route, err := NewRouteFromHandlerFunc( 36 | "test", 37 | "/health", 38 | func(w http.ResponseWriter, r *http.Request) { 39 | w.WriteHeader(http.StatusOK) 40 | _, err := w.Write([]byte("OK")) 41 | assert.NoError(t, err) 42 | }, 43 | ) 44 | require.NoError(t, err) 45 | 46 | port := fmt.Sprintf(":%d", networking.GetRandomPort(t)) 47 | 48 | callback := func() (*Config, error) { 49 | return NewConfig(port, Routes{*route}) 50 | } 51 | 52 | runner, err := NewRunner(WithConfigCallback(callback)) 53 | require.NoError(t, err) 54 | 55 | ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) 56 | defer cancel() 57 | runErr := make(chan error, 1) 58 | go func() { 59 | runErr <- runner.Run(ctx) 60 | }() 61 | 62 | require.Eventually(t, func() bool { 63 | return runner.IsRunning() 64 | }, 5*time.Second, 10*time.Millisecond, "Server should report as running") 65 | 66 | conn, err := net.DialTimeout("tcp", port, 100*time.Millisecond) 67 | require.NoError(t, err, "TCP connection should succeed when IsRunning() returns true") 68 | 69 | if conn != nil { 70 | require.NoError(t, conn.Close()) 71 | } 72 | 73 | client := &http.Client{Timeout: 1 * time.Second} 74 | resp, err := client.Get("http://" + port + "/health") 75 | require.NoError(t, err, "HTTP request should succeed when IsRunning()=true") 76 | 77 | if resp != nil { 78 | require.NoError(t, resp.Body.Close()) 79 | assert.Equal(t, http.StatusOK, resp.StatusCode) 80 | } 81 | 82 | cancel() 83 | 84 | timeoutCtx, timeoutCancel := context.WithTimeout(t.Context(), 5*time.Second) 85 | defer timeoutCancel() 86 | select { 87 | case err := <-runErr: 88 | require.NoError(t, err) 89 | case <-timeoutCtx.Done(): 90 | t.Fatal("Server did not shutdown within timeout") 91 | } 92 | } 93 | 94 | // TestIntegration_FullLifecycle tests the complete lifecycle. 95 | func TestIntegration_FullLifecycle(t *testing.T) { 96 | if testing.Short() { 97 | t.Skip("Skipping integration test in short mode") 98 | } 99 | 100 | route, err := NewRouteFromHandlerFunc( 101 | "lifecycle", 102 | "/status", 103 | func(w http.ResponseWriter, r *http.Request) { 104 | w.WriteHeader(http.StatusOK) 105 | _, err := w.Write([]byte("alive")) 106 | assert.NoError(t, err) 107 | }, 108 | ) 109 | require.NoError(t, err) 110 | 111 | port := fmt.Sprintf(":%d", networking.GetRandomPort(t)) 112 | 113 | callback := func() (*Config, error) { 114 | return NewConfig(port, Routes{*route}) 115 | } 116 | 117 | runner, err := NewRunner(WithConfigCallback(callback)) 118 | require.NoError(t, err) 119 | 120 | assert.Equal(t, "New", runner.GetState()) 121 | assert.False(t, runner.IsRunning()) 122 | 123 | ctx, cancel := context.WithCancel(t.Context()) 124 | runErr := make(chan error, 1) 125 | go func() { 126 | runErr <- runner.Run(ctx) 127 | }() 128 | 129 | assert.Eventually(t, func() bool { 130 | state := runner.GetState() 131 | return state == "Booting" || state == "Running" 132 | }, 2*time.Second, 50*time.Millisecond, "Should transition to Booting") 133 | 134 | assert.Eventually(t, func() bool { 135 | return runner.IsRunning() && runner.GetState() == "Running" 136 | }, 5*time.Second, 50*time.Millisecond, "Should transition to Running") 137 | 138 | conn, err := net.DialTimeout("tcp", port, 100*time.Millisecond) 139 | require.NoError(t, err, "TCP should be available when Running") 140 | require.NoError(t, conn.Close()) 141 | client := &http.Client{Timeout: 1 * time.Second} 142 | resp, err := client.Get("http://" + port + "/status") 143 | require.NoError(t, err) 144 | require.Equal(t, http.StatusOK, resp.StatusCode) 145 | require.NoError(t, resp.Body.Close()) 146 | 147 | cancel() 148 | 149 | assert.Eventually(t, func() bool { 150 | state := runner.GetState() 151 | return state == "Stopping" || state == "Stopped" 152 | }, 2*time.Second, 50*time.Millisecond, "Should transition to Stopping") 153 | 154 | timeoutCtx, timeoutCancel := context.WithTimeout(t.Context(), 5*time.Second) 155 | defer timeoutCancel() 156 | select { 157 | case err := <-runErr: 158 | require.NoError(t, err) 159 | case <-timeoutCtx.Done(): 160 | t.Fatal("Server did not shutdown within timeout") 161 | } 162 | 163 | assert.Eventually(t, func() bool { 164 | return runner.GetState() == "Stopped" 165 | }, 1*time.Second, 10*time.Millisecond, "Should be Stopped") 166 | 167 | assert.False(t, runner.IsRunning()) 168 | } 169 | -------------------------------------------------------------------------------- /runnables/httpserver/options_test.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "log/slog" 5 | "net/http" 6 | "strings" 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/mock" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestWithLogHandler(t *testing.T) { 16 | t.Parallel() 17 | // Create a custom logger with a buffer for testing output 18 | var logBuffer strings.Builder 19 | customHandler := slog.NewTextHandler(&logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}) 20 | // Create required route and config callback for Runner 21 | handler := func(w http.ResponseWriter, r *http.Request) {} 22 | route, err := NewRouteFromHandlerFunc("v1", "/", handler) 23 | require.NoError(t, err) 24 | hConfig := Routes{*route} 25 | cfgCallback := func() (*Config, error) { 26 | return NewConfig(":0", hConfig, WithDrainTimeout(1*time.Second)) 27 | } 28 | // Create a server with the custom logger 29 | server, err := NewRunner( 30 | WithConfigCallback(cfgCallback), 31 | WithLogHandler(customHandler), 32 | ) 33 | require.NoError(t, err) 34 | // Verify the custom logger was applied by checking that it's not the default logger 35 | assert.NotSame(t, slog.Default(), server.logger, "Server should use custom logger") 36 | // Log something and check if it appears in our buffer 37 | server.logger.Info("test message") 38 | logOutput := logBuffer.String() 39 | assert.Contains(t, logOutput, "test message", "Logger should write to our buffer") 40 | assert.Contains(t, logOutput, "httpserver.Runner", "Log should contain the correct group name") 41 | } 42 | 43 | func TestWithConfig(t *testing.T) { 44 | t.Parallel() 45 | // Create a test server config 46 | handler := func(w http.ResponseWriter, r *http.Request) {} 47 | route, err := NewRouteFromHandlerFunc("v1", "/test-route", handler) 48 | require.NoError(t, err) 49 | hConfig := Routes{*route} 50 | testAddr := ":8765" // Use a specific port for identification 51 | staticConfig, err := NewConfig(testAddr, hConfig, WithDrainTimeout(2*time.Second)) 52 | require.NoError(t, err) 53 | // Create a server with the static config 54 | server, err := NewRunner( 55 | WithConfig(staticConfig), 56 | ) 57 | require.NoError(t, err) 58 | // Verify the config callback was created and returns the correct config 59 | config, err := server.configCallback() 60 | require.NoError(t, err) 61 | assert.NotNil(t, config) 62 | // Verify config values match what we provided 63 | assert.Equal(t, testAddr, config.ListenAddr, "Config address should match") 64 | assert.Equal(t, 2*time.Second, config.DrainTimeout, "Config drain timeout should match") 65 | assert.Equal(t, "/test-route", config.Routes[0].Path, "Config route path should match") 66 | assert.Equal(t, "v1", config.Routes[0].name, "Config route name should match") 67 | // Verify we get the same config instance (not a copy) 68 | assert.Same(t, staticConfig, config, "Should return the exact same config instance") 69 | } 70 | 71 | // TestWithServerCreator verifies the WithServerCreator option works correctly 72 | func TestWithServerCreator(t *testing.T) { 73 | t.Parallel() 74 | // Create a mock server and track creation parameters 75 | mockServer := new(MockHttpServer) 76 | mockServer.On("ListenAndServe").Return(nil) 77 | mockServer.On("Shutdown", mock.Anything).Return(nil) 78 | var capturedAddr string 79 | var capturedHandler http.Handler 80 | // Custom server creator function that captures parameters 81 | customCreator := func(addr string, handler http.Handler, cfg *Config) HttpServer { 82 | capturedAddr = addr 83 | capturedHandler = handler 84 | return mockServer 85 | } 86 | // Create required route and config callback for Runner 87 | handler := func(w http.ResponseWriter, r *http.Request) {} 88 | route, err := NewRouteFromHandlerFunc("v1", "/", handler) 89 | require.NoError(t, err) 90 | hConfig := Routes{*route} 91 | testAddr := ":9876" // Use a specific port for identification 92 | cfgCallback := func() (*Config, error) { 93 | return NewConfig( 94 | testAddr, 95 | hConfig, 96 | WithDrainTimeout(1*time.Second), 97 | WithServerCreator(customCreator), 98 | ) 99 | } 100 | // Create a server with the config that has a custom server creator 101 | server, err := NewRunner( 102 | WithConfigCallback(cfgCallback), 103 | ) 104 | require.NoError(t, err) 105 | // Get the config to verify the server creator was set 106 | cfg := server.getConfig() 107 | assert.NotNil(t, cfg.ServerCreator, "Server creator should be set in config") 108 | // Call the ServerCreator directly to test it properly 109 | mux := http.NewServeMux() 110 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {}) 111 | server.server = cfg.ServerCreator(testAddr, mux, cfg) 112 | // Verify the custom creator was called with correct parameters 113 | assert.Equal(t, testAddr, capturedAddr, "Server creator should receive correct address") 114 | assert.NotNil(t, capturedHandler, "Server creator should receive a handler") 115 | // Verify the created server is our mock 116 | assert.Same(t, mockServer, server.server, "Server should be our mock instance") 117 | // The server is created but not started, so we don't need to stop it 118 | // The FSM would be in the 'New' state, not 'Running', so stopServer would fail 119 | } 120 | -------------------------------------------------------------------------------- /runnables/httpserver/middleware/headers/options.go: -------------------------------------------------------------------------------- 1 | package headers 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/robbyt/go-supervisor/runnables/httpserver" 7 | ) 8 | 9 | // HeaderOperation represents a single header manipulation operation 10 | type HeaderOperation func(*headerOperations) 11 | 12 | type headerOperations struct { 13 | setHeaders http.Header 14 | addHeaders http.Header 15 | removeHeaders []string 16 | setRequestHeaders http.Header 17 | addRequestHeaders http.Header 18 | removeRequestHeaders []string 19 | } 20 | 21 | // WithSet creates an operation to set (replace) headers 22 | func WithSet(headers http.Header) HeaderOperation { 23 | return func(ops *headerOperations) { 24 | if ops.setHeaders == nil { 25 | ops.setHeaders = make(http.Header) 26 | } 27 | for key, values := range headers { 28 | ops.setHeaders[key] = values 29 | } 30 | } 31 | } 32 | 33 | // WithSetHeader creates an operation to set a single header 34 | func WithSetHeader(key, value string) HeaderOperation { 35 | return func(ops *headerOperations) { 36 | if ops.setHeaders == nil { 37 | ops.setHeaders = make(http.Header) 38 | } 39 | ops.setHeaders.Set(key, value) 40 | } 41 | } 42 | 43 | // WithAdd creates an operation to add (append) headers 44 | func WithAdd(headers http.Header) HeaderOperation { 45 | return func(ops *headerOperations) { 46 | if ops.addHeaders == nil { 47 | ops.addHeaders = make(http.Header) 48 | } 49 | for key, values := range headers { 50 | for _, value := range values { 51 | ops.addHeaders.Add(key, value) 52 | } 53 | } 54 | } 55 | } 56 | 57 | // WithAddHeader creates an operation to add a single header 58 | func WithAddHeader(key, value string) HeaderOperation { 59 | return func(ops *headerOperations) { 60 | if ops.addHeaders == nil { 61 | ops.addHeaders = make(http.Header) 62 | } 63 | ops.addHeaders.Add(key, value) 64 | } 65 | } 66 | 67 | // WithRemove creates an operation to remove headers 68 | func WithRemove(headerNames ...string) HeaderOperation { 69 | return func(ops *headerOperations) { 70 | ops.removeHeaders = append(ops.removeHeaders, headerNames...) 71 | } 72 | } 73 | 74 | // WithSetRequest creates an operation to set (replace) request headers 75 | func WithSetRequest(headers http.Header) HeaderOperation { 76 | return func(ops *headerOperations) { 77 | if ops.setRequestHeaders == nil { 78 | ops.setRequestHeaders = make(http.Header) 79 | } 80 | for key, values := range headers { 81 | ops.setRequestHeaders[key] = values 82 | } 83 | } 84 | } 85 | 86 | // WithSetRequestHeader creates an operation to set a single request header 87 | func WithSetRequestHeader(key, value string) HeaderOperation { 88 | return func(ops *headerOperations) { 89 | if ops.setRequestHeaders == nil { 90 | ops.setRequestHeaders = make(http.Header) 91 | } 92 | ops.setRequestHeaders.Set(key, value) 93 | } 94 | } 95 | 96 | // WithAddRequest creates an operation to add (append) request headers 97 | func WithAddRequest(headers http.Header) HeaderOperation { 98 | return func(ops *headerOperations) { 99 | if ops.addRequestHeaders == nil { 100 | ops.addRequestHeaders = make(http.Header) 101 | } 102 | for key, values := range headers { 103 | for _, value := range values { 104 | ops.addRequestHeaders.Add(key, value) 105 | } 106 | } 107 | } 108 | } 109 | 110 | // WithAddRequestHeader creates an operation to add a single request header 111 | func WithAddRequestHeader(key, value string) HeaderOperation { 112 | return func(ops *headerOperations) { 113 | if ops.addRequestHeaders == nil { 114 | ops.addRequestHeaders = make(http.Header) 115 | } 116 | ops.addRequestHeaders.Add(key, value) 117 | } 118 | } 119 | 120 | // WithRemoveRequest creates an operation to remove request headers 121 | func WithRemoveRequest(headerNames ...string) HeaderOperation { 122 | return func(ops *headerOperations) { 123 | ops.removeRequestHeaders = append(ops.removeRequestHeaders, headerNames...) 124 | } 125 | } 126 | 127 | // NewWithOperations creates a middleware with full header control using functional options. 128 | // Operations are executed in order: remove → set → add (for both request and response headers) 129 | func NewWithOperations(operations ...HeaderOperation) httpserver.HandlerFunc { 130 | ops := &headerOperations{} 131 | for _, operation := range operations { 132 | operation(ops) 133 | } 134 | 135 | return func(rp *httpserver.RequestProcessor) { 136 | request := rp.Request() 137 | writer := rp.Writer() 138 | 139 | // Request header manipulation (before calling Next) 140 | // 1. Remove request headers first 141 | for _, key := range ops.removeRequestHeaders { 142 | request.Header.Del(key) 143 | } 144 | 145 | // 2. Set request headers (replace) 146 | for key, values := range ops.setRequestHeaders { 147 | request.Header.Del(key) 148 | for _, value := range values { 149 | request.Header.Add(key, value) 150 | } 151 | } 152 | 153 | // 3. Add request headers (append) 154 | for key, values := range ops.addRequestHeaders { 155 | for _, value := range values { 156 | request.Header.Add(key, value) 157 | } 158 | } 159 | 160 | // Response header manipulation (existing functionality) 161 | // 1. Remove response headers first 162 | for _, key := range ops.removeHeaders { 163 | writer.Header().Del(key) 164 | } 165 | 166 | // 2. Set response headers (replace) 167 | for key, values := range ops.setHeaders { 168 | writer.Header().Del(key) 169 | for _, value := range values { 170 | writer.Header().Add(key, value) 171 | } 172 | } 173 | 174 | // 3. Add response headers (append) 175 | for key, values := range ops.addHeaders { 176 | for _, value := range values { 177 | writer.Header().Add(key, value) 178 | } 179 | } 180 | 181 | rp.Next() 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /supervisor/state_monitoring_test.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/robbyt/go-supervisor/runnables/mocks" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/mock" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // TestPIDZero_StartStateMonitor tests that the state monitor is started for stateable runnables. 15 | func TestPIDZero_StartStateMonitor(t *testing.T) { 16 | t.Parallel() 17 | 18 | // Create a mock stateable runnable 19 | mockStateable := mocks.NewMockRunnableWithStateable() 20 | mockStateable.On("String").Return("stateable-runnable").Maybe() 21 | mockStateable.On("Run", mock.Anything).Return(nil) 22 | mockStateable.On("Stop").Once() 23 | mockStateable.On("GetState").Return("initial").Once() // Initial state 24 | mockStateable.On("GetState").Return("running").Maybe() // Called during shutdown 25 | 26 | stateChan := make(chan string, 5) // Buffered to prevent blocking 27 | mockStateable.On("GetStateChan", mock.Anything).Return(stateChan).Once() 28 | 29 | // Will be called during startup verification 30 | mockStateable.On("IsRunning").Return(true).Once() 31 | 32 | // Create context with timeout to ensure test completion 33 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 34 | defer cancel() 35 | 36 | // Create supervisor with the mock runnable 37 | pid0, err := New( 38 | WithContext(ctx), 39 | WithRunnables(mockStateable), 40 | WithStartupTimeout(100*time.Millisecond), 41 | ) 42 | require.NoError(t, err) 43 | 44 | // Create a state subscriber to verify state broadcasts 45 | stateUpdates := make(chan StateMap, 5) 46 | unsubscribe := pid0.AddStateSubscriber(stateUpdates) 47 | defer unsubscribe() 48 | 49 | // Start the supervisor in a goroutine 50 | execDone := make(chan error, 1) 51 | go func() { 52 | execDone <- pid0.Run() 53 | }() 54 | 55 | // Allow time for initialization 56 | time.Sleep(50 * time.Millisecond) 57 | 58 | // Send state updates through the channel 59 | stateChan <- "initial" // This should be discarded as it's the initial state 60 | stateChan <- "running" // This will be processed 61 | stateChan <- "stopping" // Additional state change 62 | 63 | // Use require.Eventually to verify the state monitor receives and broadcasts states 64 | require.Eventually(t, func() bool { 65 | // Check if we have received at least one state update 66 | select { 67 | case stateMap := <-stateUpdates: 68 | // We don't check for specific values, just that broadcasts are happening 69 | return stateMap["stateable-runnable"] != "" 70 | default: 71 | return false 72 | } 73 | }, 500*time.Millisecond, 50*time.Millisecond, "No state updates received") 74 | 75 | // Cancel the context to shut down the supervisor 76 | cancel() 77 | 78 | // Verify the supervisor shuts down cleanly 79 | select { 80 | case err := <-execDone: 81 | require.NoError(t, err) 82 | case <-time.After(1 * time.Second): 83 | t.Fatal("Supervisor did not shut down in time") 84 | } 85 | 86 | // Verify expectations 87 | mockStateable.AssertExpectations(t) 88 | } 89 | 90 | // TestPIDZero_SubscribeStateChanges tests the SubscribeStateChanges functionality. 91 | func TestPIDZero_SubscribeStateChanges(t *testing.T) { 92 | t.Parallel() 93 | 94 | // Create a context with cancel for cleanup 95 | ctx, cancel := context.WithCancel(context.Background()) 96 | defer cancel() 97 | 98 | // Create mock services that implement Stateable 99 | mockService := mocks.NewMockRunnableWithStateable() 100 | stateChan := make(chan string, 2) 101 | mockService.On("GetStateChan", mock.Anything).Return(stateChan).Once() 102 | mockService.On("String").Return("mock-service").Maybe() 103 | mockService.On("Run", mock.Anything).Return(nil).Maybe() 104 | mockService.On("Stop").Maybe() 105 | mockService.On("GetState").Return("initial").Maybe() 106 | mockService.On("IsRunning").Return(true).Maybe() 107 | 108 | // Create a supervisor with the mock runnable 109 | pid0, err := New(WithContext(ctx), WithRunnables(mockService)) 110 | require.NoError(t, err) 111 | 112 | // Store initial states manually in stateMap 113 | pid0.stateMap.Store(mockService, "initial") 114 | 115 | // Subscribe to state changes 116 | subCtx, subCancel := context.WithCancel(context.Background()) 117 | defer subCancel() 118 | stateMapChan := pid0.SubscribeStateChanges(subCtx) 119 | 120 | // Manually call startStateMonitor to avoid the full Run sequence 121 | pid0.wg.Add(1) 122 | go pid0.startStateMonitor() 123 | 124 | // Give state monitor a moment to start 125 | time.Sleep(100 * time.Millisecond) 126 | 127 | // Send an initial state update 128 | stateChan <- "initial" // Should be discarded 129 | 130 | // Send another state update 131 | stateChan <- "running" // Should be broadcast 132 | time.Sleep(100 * time.Millisecond) 133 | 134 | // Manually update state and trigger broadcast to ensure it happens 135 | pid0.stateMap.Store(mockService, "running") 136 | pid0.broadcastState() 137 | time.Sleep(100 * time.Millisecond) 138 | 139 | // Verify we receive state updates 140 | var stateMap StateMap 141 | var foundRunning bool 142 | timeout := time.After(500 * time.Millisecond) 143 | 144 | // Loop until we find the update we want or time out 145 | for !foundRunning { 146 | select { 147 | case stateMap = <-stateMapChan: 148 | if val, ok := stateMap["mock-service"]; ok && val == "running" { 149 | foundRunning = true 150 | } 151 | case <-timeout: 152 | t.Fatal("Did not receive running state update in time") 153 | } 154 | } 155 | 156 | assert.True(t, foundRunning, "Should have received a state map with running state") 157 | 158 | // Cancel the context to clean up goroutines 159 | cancel() 160 | time.Sleep(50 * time.Millisecond) 161 | 162 | // Verify expectations 163 | mockService.AssertExpectations(t) 164 | } 165 | -------------------------------------------------------------------------------- /runnables/httpcluster/state_test.go: -------------------------------------------------------------------------------- 1 | package httpcluster 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/robbyt/go-supervisor/internal/finitestate" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/mock" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestGetState(t *testing.T) { 14 | t.Parallel() 15 | 16 | runner, err := NewRunner() 17 | require.NoError(t, err) 18 | 19 | // Initial state should be New 20 | assert.Equal(t, finitestate.StatusNew, runner.GetState()) 21 | } 22 | 23 | func TestGetStateChan(t *testing.T) { 24 | t.Parallel() 25 | 26 | runner, err := NewRunner(WithStateChanBufferSize(5)) 27 | require.NoError(t, err) 28 | 29 | ctx, cancel := context.WithCancel(t.Context()) 30 | defer cancel() 31 | 32 | stateChan := runner.GetStateChan(ctx) 33 | require.NotNil(t, stateChan) 34 | 35 | // Should receive the initial state 36 | select { 37 | case state := <-stateChan: 38 | assert.Equal(t, finitestate.StatusNew, state) 39 | default: 40 | t.Fatal("Expected to receive initial state") 41 | } 42 | } 43 | 44 | func TestGetStateChanWithTimeout(t *testing.T) { 45 | t.Parallel() 46 | 47 | runner, err := NewRunner(WithStateChanBufferSize(5)) 48 | require.NoError(t, err) 49 | 50 | ctx, cancel := context.WithCancel(t.Context()) 51 | defer cancel() 52 | 53 | stateChan := runner.GetStateChanWithTimeout(ctx) 54 | require.NotNil(t, stateChan) 55 | 56 | // Should receive the initial state 57 | select { 58 | case state := <-stateChan: 59 | assert.Equal(t, finitestate.StatusNew, state) 60 | default: 61 | t.Fatal("Expected to receive initial state") 62 | } 63 | } 64 | 65 | func TestIsRunning(t *testing.T) { 66 | t.Parallel() 67 | 68 | runner, err := NewRunner() 69 | require.NoError(t, err) 70 | 71 | // Initially not running 72 | assert.False(t, runner.IsRunning()) 73 | 74 | // Transition to running state 75 | err = runner.fsm.Transition(finitestate.StatusBooting) 76 | require.NoError(t, err) 77 | err = runner.fsm.Transition(finitestate.StatusRunning) 78 | require.NoError(t, err) 79 | 80 | // Now should be running 81 | assert.True(t, runner.IsRunning()) 82 | } 83 | 84 | // MockFSMForStateError is a mock FSM that can simulate error conditions 85 | type MockFSMForStateError struct { 86 | mock.Mock 87 | } 88 | 89 | func (m *MockFSMForStateError) Transition(state string) error { 90 | args := m.Called(state) 91 | return args.Error(0) 92 | } 93 | 94 | func (m *MockFSMForStateError) TransitionBool(state string) bool { 95 | args := m.Called(state) 96 | return args.Bool(0) 97 | } 98 | 99 | func (m *MockFSMForStateError) TransitionIfCurrentState(currentState, newState string) error { 100 | args := m.Called(currentState, newState) 101 | return args.Error(0) 102 | } 103 | 104 | func (m *MockFSMForStateError) SetState(state string) error { 105 | args := m.Called(state) 106 | return args.Error(0) 107 | } 108 | 109 | func (m *MockFSMForStateError) GetState() string { 110 | args := m.Called() 111 | return args.String(0) 112 | } 113 | 114 | func (m *MockFSMForStateError) GetStateChan(ctx context.Context) <-chan string { 115 | args := m.Called(ctx) 116 | return args.Get(0).(<-chan string) 117 | } 118 | 119 | func TestSetStateError(t *testing.T) { 120 | t.Parallel() 121 | 122 | t.Run("successful transition to error state", func(t *testing.T) { 123 | mockFSM := &MockFSMForStateError{} 124 | runner, err := NewRunner() 125 | require.NoError(t, err) 126 | 127 | // Replace FSM with mock 128 | runner.fsm = mockFSM 129 | 130 | // Mock successful transition to error state 131 | mockFSM.On("TransitionBool", finitestate.StatusError).Return(true).Once() 132 | 133 | // Call setStateError - should succeed on first attempt 134 | runner.setStateError() 135 | 136 | mockFSM.AssertExpectations(t) 137 | }) 138 | 139 | t.Run("fallback to SetState when TransitionBool fails", func(t *testing.T) { 140 | mockFSM := &MockFSMForStateError{} 141 | runner, err := NewRunner() 142 | require.NoError(t, err) 143 | 144 | // Replace FSM with mock 145 | runner.fsm = mockFSM 146 | 147 | // Mock failed transition but successful SetState 148 | mockFSM.On("TransitionBool", finitestate.StatusError).Return(false).Once() 149 | mockFSM.On("SetState", finitestate.StatusError).Return(nil).Once() 150 | 151 | // Call setStateError - should use fallback 152 | runner.setStateError() 153 | 154 | mockFSM.AssertExpectations(t) 155 | }) 156 | 157 | t.Run("fallback to unknown state when both error state attempts fail", func(t *testing.T) { 158 | mockFSM := &MockFSMForStateError{} 159 | runner, err := NewRunner() 160 | require.NoError(t, err) 161 | 162 | // Replace FSM with mock 163 | runner.fsm = mockFSM 164 | 165 | // Mock all attempts failing except unknown 166 | mockFSM.On("TransitionBool", finitestate.StatusError).Return(false).Once() 167 | mockFSM.On("SetState", finitestate.StatusError).Return(assert.AnError).Once() 168 | mockFSM.On("SetState", finitestate.StatusUnknown).Return(nil).Once() 169 | 170 | // Call setStateError - should fallback to unknown state 171 | runner.setStateError() 172 | 173 | mockFSM.AssertExpectations(t) 174 | }) 175 | 176 | t.Run("all state setting attempts fail", func(t *testing.T) { 177 | mockFSM := &MockFSMForStateError{} 178 | runner, err := NewRunner() 179 | require.NoError(t, err) 180 | 181 | // Replace FSM with mock 182 | runner.fsm = mockFSM 183 | 184 | // Mock all attempts failing 185 | mockFSM.On("TransitionBool", finitestate.StatusError).Return(false).Once() 186 | mockFSM.On("SetState", finitestate.StatusError).Return(assert.AnError).Once() 187 | mockFSM.On("SetState", finitestate.StatusUnknown).Return(assert.AnError).Once() 188 | 189 | // Call setStateError - should handle all failures gracefully 190 | runner.setStateError() 191 | 192 | mockFSM.AssertExpectations(t) 193 | }) 194 | } 195 | -------------------------------------------------------------------------------- /runnables/composite/reload.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/robbyt/go-supervisor/internal/finitestate" 7 | "github.com/robbyt/go-supervisor/supervisor" 8 | ) 9 | 10 | // ReloadableWithConfig is an interface for sub-runnables that can reload with specific config 11 | type ReloadableWithConfig interface { 12 | ReloadWithConfig(config any) 13 | } 14 | 15 | // Reload updates the configuration and handles runnables appropriately. 16 | // If membership changes (different set of runnables), all existing runnables are stopped 17 | // and the new set of runnables is started. 18 | func (r *Runner[T]) Reload() { 19 | logger := r.logger.WithGroup("Reload") 20 | logger.Debug("Reloading...") 21 | defer func() { 22 | logger.Debug("Completed.") 23 | }() 24 | 25 | // Transition to Reloading state 26 | if err := r.fsm.Transition(finitestate.StatusReloading); err != nil { 27 | logger.Error("Failed to transition to Reloading", "error", err) 28 | r.setStateError() 29 | return 30 | } 31 | 32 | // Get updated config from the callback function 33 | newConfig, err := r.configCallback() 34 | if err != nil { 35 | logger.Error("Failed to get updated config", "error", err) 36 | // TODO: consider removing the setStateError() call here 37 | r.setStateError() 38 | return 39 | } 40 | if newConfig == nil { 41 | logger.Error("Config callback returned nil during reload") 42 | // TODO: consider removing the setStateError() call here 43 | r.setStateError() 44 | return 45 | } 46 | 47 | // Get the old config to compare 48 | oldConfig := r.getConfig() 49 | if oldConfig == nil { 50 | logger.Warn("Failed to get current config during reload, using empty config") 51 | oldConfig = &Config[T]{} 52 | } 53 | 54 | // Check if membership has changed by comparing runnable identities 55 | if hasMembershipChanged(oldConfig, newConfig) { 56 | logger.Debug( 57 | "Membership change detected, stopping all existing runnables before updating membership and config", 58 | ) 59 | if err := r.reloadWithRestart(newConfig); err != nil { 60 | logger.Error("Failed to reload runnables due to membership change", "error", err) 61 | r.setStateError() 62 | return 63 | } 64 | logger.Debug("Reloaded runnables due to membership change") 65 | } else { 66 | r.reloadSkipRestart(newConfig) 67 | logger.Debug("Reloaded runnables without membership change") 68 | } 69 | 70 | // Transition back to Running 71 | if err := r.fsm.Transition(finitestate.StatusRunning); err != nil { 72 | logger.Error("Failed to transition to Running", "error", err) 73 | r.setStateError() 74 | return 75 | } 76 | } 77 | 78 | // reloadWithRestart handles the case where the membership of runnables has changed. 79 | func (r *Runner[T]) reloadWithRestart(newConfig *Config[T]) error { 80 | logger := r.logger.WithGroup("reloadWithRestart") 81 | logger.Debug("Reloading runnables due to membership change") 82 | defer logger.Debug("Completed.") 83 | 84 | // Stop all existing runnables while we still have the old config 85 | // This acquires the runnables mutex 86 | if err := r.stopAllRunnables(); err != nil { 87 | return fmt.Errorf("%w: failed to stop existing runnables during membership change", err) 88 | } 89 | // Now update the stored config after stopping old runnables 90 | // Lock the config mutex for writing 91 | logger.Debug("Updating config after stopping existing runnables") 92 | r.configMu.Lock() 93 | r.setConfig(newConfig) 94 | r.configMu.Unlock() 95 | 96 | // Start all runnables from the new config 97 | // This acquires the runnables mutex 98 | if err := r.boot(r.ctx); err != nil { 99 | return fmt.Errorf("%w: failed to start new runnables during membership change", err) 100 | } 101 | return nil 102 | } 103 | 104 | // reloadSkipRestart handles the case where the membership of runnables has not changed. 105 | func (r *Runner[T]) reloadSkipRestart(newConfig *Config[T]) { 106 | logger := r.logger.WithGroup("reloadSkipRestart") 107 | logger.Debug("Reloading runnables without membership change") 108 | defer logger.Debug("Completed.") 109 | 110 | logger.Debug("Updating config") 111 | r.configMu.Lock() 112 | r.setConfig(newConfig) 113 | r.configMu.Unlock() 114 | 115 | logger.Debug("Reloading configs of existing runnables") 116 | // Reload configs of existing runnables 117 | // Runnables mutex not locked as membership is not changing 118 | for _, entry := range newConfig.Entries { 119 | logger := logger.With("runnable", entry.Runnable.String()) 120 | 121 | if reloadableWithConfig, ok := any(entry.Runnable).(ReloadableWithConfig); ok { 122 | // If the runnable implements our ReloadableWithConfig interface, use that to pass the new config 123 | logger.Debug("Reloading child runnable with config") 124 | reloadableWithConfig.ReloadWithConfig(entry.Config) 125 | } else if reloadable, ok := any(entry.Runnable).(supervisor.Reloadable); ok { 126 | // Fall back to standard Reloadable interface, assume the configCallback 127 | // has somehow updated the runnable's internal state 128 | logger.Debug("Reloading child runnable") 129 | reloadable.Reload() 130 | } else { 131 | logger.Warn("Child runnable does not implement Reloadable or ReloadableWithConfig") 132 | } 133 | } 134 | } 135 | 136 | // hasMembershipChanged checks if the set of runnables has changed between configurations 137 | func hasMembershipChanged[T runnable](oldConfig, newConfig *Config[T]) bool { 138 | if len(oldConfig.Entries) != len(newConfig.Entries) { 139 | // Different number of entries means membership changed 140 | return true 141 | } 142 | 143 | // Create a map of old runnables by their string representation 144 | oldMap := make(map[string]bool) 145 | for _, entry := range oldConfig.Entries { 146 | oldMap[entry.Runnable.String()] = true 147 | } 148 | 149 | // Check if any new runnable is not in the old set 150 | for _, entry := range newConfig.Entries { 151 | if !oldMap[entry.Runnable.String()] { 152 | return true 153 | } 154 | } 155 | 156 | return false 157 | } 158 | -------------------------------------------------------------------------------- /runnables/httpserver/state_mocked_test.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "errors" 5 | "log/slog" 6 | "net/http" 7 | "testing" 8 | "time" 9 | 10 | "github.com/robbyt/go-supervisor/internal/finitestate" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/mock" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | // TestSetStateError_FullIntegration does an end-to-end verification of setStateError 17 | // using a real FSM to complement the mocked tests 18 | func TestSetStateError_FullIntegration(t *testing.T) { 19 | t.Parallel() 20 | 21 | server, listenPort := createTestServer(t, 22 | func(w http.ResponseWriter, r *http.Request) {}, "/test", 1*time.Second) 23 | t.Logf("Server listening on port %s", listenPort) 24 | t.Cleanup(func() { 25 | server.Stop() 26 | }) 27 | 28 | // Set up initial conditions with FSM in Stopped state 29 | err := server.fsm.SetState(finitestate.StatusStopped) 30 | require.NoError(t, err) 31 | 32 | // Call the function under test which should fall back to SetState 33 | server.setStateError() 34 | 35 | // Verify we ended up in Error state 36 | assert.Equal(t, finitestate.StatusError, server.fsm.GetState()) 37 | } 38 | 39 | // TestSetStateError_Mocked tests the error state setting functionality using mocks 40 | func TestSetStateError_Mocked(t *testing.T) { 41 | t.Parallel() 42 | 43 | // Test successful TransitionBool path 44 | t.Run("Success with TransitionBool", func(t *testing.T) { 45 | // Create mock state machine 46 | mockFSM := NewMockStateMachine() 47 | 48 | // Setup the TransitionBool to return success 49 | mockFSM.On("TransitionBool", finitestate.StatusError).Return(true) 50 | 51 | // Create runner with mocked FSM 52 | r := &Runner{ 53 | fsm: mockFSM, 54 | logger: slog.Default().WithGroup("httpserver.Runner"), 55 | } 56 | 57 | // Call the function under test 58 | r.setStateError() 59 | 60 | // Verify our expectations 61 | mockFSM.AssertExpectations(t) 62 | 63 | // TransitionBool should have been called once, but SetState should not be called 64 | mockFSM.AssertCalled(t, "TransitionBool", finitestate.StatusError) 65 | mockFSM.AssertNotCalled(t, "SetState", mock.Anything) 66 | }) 67 | 68 | // Test fallback to SetState when TransitionBool fails 69 | t.Run("Fallback to SetState when TransitionBool fails", func(t *testing.T) { 70 | // Create mock state machine 71 | mockFSM := NewMockStateMachine() 72 | 73 | // Setup the TransitionBool to return failure 74 | mockFSM.On("TransitionBool", finitestate.StatusError).Return(false) 75 | 76 | // Setup the SetState to succeed 77 | mockFSM.On("SetState", finitestate.StatusError).Return(nil) 78 | 79 | // Create runner with mocked FSM 80 | r := &Runner{ 81 | fsm: mockFSM, 82 | logger: slog.Default().WithGroup("httpserver.Runner"), 83 | } 84 | 85 | // Call the function under test 86 | r.setStateError() 87 | 88 | // Verify our expectations 89 | mockFSM.AssertExpectations(t) 90 | 91 | // Both TransitionBool and SetState should have been called 92 | mockFSM.AssertCalled(t, "TransitionBool", finitestate.StatusError) 93 | mockFSM.AssertCalled(t, "SetState", finitestate.StatusError) 94 | 95 | // The Unknown state should not have been used 96 | mockFSM.AssertNotCalled(t, "SetState", finitestate.StatusUnknown) 97 | }) 98 | 99 | // Test fallback to StatusUnknown when both TransitionBool and the first SetState fail 100 | t.Run("Fallback to StatusUnknown when both previous methods fail", func(t *testing.T) { 101 | // Create mock state machine 102 | mockFSM := NewMockStateMachine() 103 | 104 | // Create a test error 105 | testErr := errors.New("cannot set error state") 106 | 107 | // Setup the TransitionBool to return failure 108 | mockFSM.On("TransitionBool", finitestate.StatusError).Return(false) 109 | 110 | // Setup the first SetState to fail 111 | mockFSM.On("SetState", finitestate.StatusError).Return(testErr) 112 | 113 | // Setup the fallback SetState to succeed 114 | mockFSM.On("SetState", finitestate.StatusUnknown).Return(nil) 115 | 116 | // Create runner with mocked FSM 117 | r := &Runner{ 118 | fsm: mockFSM, 119 | logger: slog.Default().WithGroup("httpserver.Runner"), 120 | } 121 | 122 | // Call the function under test 123 | r.setStateError() 124 | 125 | // Verify our expectations 126 | mockFSM.AssertExpectations(t) 127 | 128 | // All three method calls should have been made 129 | mockFSM.AssertCalled(t, "TransitionBool", finitestate.StatusError) 130 | mockFSM.AssertCalled(t, "SetState", finitestate.StatusError) 131 | mockFSM.AssertCalled(t, "SetState", finitestate.StatusUnknown) 132 | }) 133 | 134 | // Test complete failure case where all attempts fail 135 | t.Run("Complete failure when all state transitions fail", func(t *testing.T) { 136 | // Create mock state machine 137 | mockFSM := NewMockStateMachine() 138 | 139 | // Create test errors 140 | errorStateErr := errors.New("cannot set error state") 141 | unknownStateErr := errors.New("cannot set unknown state either") 142 | 143 | // Setup the TransitionBool to return failure 144 | mockFSM.On("TransitionBool", finitestate.StatusError).Return(false) 145 | 146 | // Setup the first SetState to fail 147 | mockFSM.On("SetState", finitestate.StatusError).Return(errorStateErr) 148 | 149 | // Setup the fallback SetState to also fail 150 | mockFSM.On("SetState", finitestate.StatusUnknown).Return(unknownStateErr) 151 | 152 | // Create runner with mocked FSM 153 | r := &Runner{ 154 | fsm: mockFSM, 155 | logger: slog.Default().WithGroup("httpserver.Runner"), 156 | } 157 | 158 | // Call the function under test 159 | r.setStateError() 160 | 161 | // Verify our expectations 162 | mockFSM.AssertExpectations(t) 163 | 164 | // All three method calls should have been made 165 | mockFSM.AssertCalled(t, "TransitionBool", finitestate.StatusError) 166 | mockFSM.AssertCalled(t, "SetState", finitestate.StatusError) 167 | mockFSM.AssertCalled(t, "SetState", finitestate.StatusUnknown) 168 | }) 169 | } 170 | -------------------------------------------------------------------------------- /runnables/mocks/mocks.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package mocks provides mock implementations of all supervisor interfaces for testing. 3 | These mocks implement the Runnable, Reloadable, Stateable, and ReloadSender interfaces 4 | with configurable delays to simulate real service behavior in tests. 5 | 6 | Example: 7 | ```go 8 | import ( 9 | 10 | "context" 11 | "testing" 12 | "time" 13 | 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/mock" 16 | 17 | "github.com/robbyt/go-supervisor" 18 | "github.com/robbyt/go-supervisor/runnables/mocks" 19 | 20 | ) 21 | 22 | func TestMyComponent(t *testing.T) { 23 | // Create a mock service 24 | mockRunnable := mocks.NewMockRunnable() 25 | 26 | // Set expectations 27 | mockRunnable.On("Run", mock.Anything).Return(nil) 28 | mockRunnable.On("Stop").Once() 29 | 30 | // For state-based tests 31 | stateCh := make(chan string) 32 | mockRunnable.On("GetStateChan", mock.Anything).Return(stateCh) 33 | 34 | // Create supervisor with mock 35 | super := supervisor.New([]supervisor.Runnable{mockRunnable}) 36 | 37 | // Run test... 38 | 39 | // Verify expectations 40 | mockRunnable.AssertExpectations(t) 41 | } 42 | 43 | ``` 44 | */ 45 | package mocks 46 | 47 | import ( 48 | "context" 49 | "time" 50 | 51 | "github.com/stretchr/testify/mock" 52 | ) 53 | 54 | const defaultDelay = 1 * time.Millisecond 55 | 56 | // Runnable is a mock implementation of the Runnable, Reloadable, and Stateable interfaces 57 | // using testify/mock. It allows for configurable delays in method responses to simulate 58 | // service behavior. 59 | type Runnable struct { 60 | mock.Mock 61 | DelayRun time.Duration // Delay before Run returns 62 | DelayStop time.Duration // Delay before Stop returns 63 | DelayReload time.Duration // Delay before Reload returns 64 | } 65 | 66 | // NewMockRunnable creates a new Runnable mock with default delays. 67 | func NewMockRunnable() *Runnable { 68 | return &Runnable{ 69 | DelayRun: defaultDelay, 70 | DelayStop: defaultDelay, 71 | DelayReload: defaultDelay, 72 | } 73 | } 74 | 75 | // Run mocks the Run method of the Runnable interface. 76 | // It sleeps for DelayRun duration before returning the mocked error result. 77 | func (m *Runnable) Run(ctx context.Context) error { 78 | time.Sleep(m.DelayRun) 79 | args := m.Called(ctx) 80 | return args.Error(0) 81 | } 82 | 83 | // Stop mocks the Stop method of the Runnable interface. 84 | // It sleeps for DelayStop duration before recording the call. 85 | func (m *Runnable) Stop() { 86 | time.Sleep(m.DelayStop) 87 | m.Called() 88 | } 89 | 90 | // Reload mocks the Reload method of the Reloadable interface. 91 | // It sleeps for DelayReload duration before recording the call. 92 | func (m *Runnable) Reload() { 93 | time.Sleep(m.DelayReload) 94 | m.Called() 95 | } 96 | 97 | // String returns a string representation of the mock service. 98 | // It can be mocked by doing mock.On("String").Return("customValue") in tests. 99 | func (m *Runnable) String() string { 100 | if mock := m.Called(); mock.Get(0) != nil { 101 | return mock.String(0) 102 | } 103 | return "Runnable" 104 | } 105 | 106 | // MockRunnableWithStateable extends Runnable to also implement the Stateable interface. 107 | type MockRunnableWithStateable struct { 108 | *Runnable 109 | DelayGetState time.Duration // Delay before GetState and GetStateChan return 110 | } 111 | 112 | // GetState mocks the GetState method of the Stateable interface. 113 | // It returns the current state of the service as configured in test expectations. 114 | func (m *MockRunnableWithStateable) GetState() string { 115 | time.Sleep(m.DelayGetState) 116 | args := m.Called() 117 | return args.String(0) 118 | } 119 | 120 | // GetStateChan mocks the GetStateChan method of the Stateable interface. 121 | // It returns a receive-only channel that will emit state updates as configured in test expectations. 122 | func (m *MockRunnableWithStateable) GetStateChan(ctx context.Context) <-chan string { 123 | args := m.Called(ctx) 124 | return args.Get(0).(chan string) 125 | } 126 | 127 | // IsRunning mocks the IsRunning method of the Stateable interface. 128 | // It returns true if the service is currently running, as configured in test expectations. 129 | func (m *MockRunnableWithStateable) IsRunning() bool { 130 | args := m.Called() 131 | return args.Bool(0) 132 | } 133 | 134 | // NewMockRunnableWithStateable creates a new MockRunnableWithStateable with default delays. 135 | func NewMockRunnableWithStateable() *MockRunnableWithStateable { 136 | return &MockRunnableWithStateable{ 137 | Runnable: NewMockRunnable(), 138 | DelayGetState: defaultDelay, 139 | } 140 | } 141 | 142 | // MockRunnableWithReloadSender extends Runnable to also implement the ReloadSender interface. 143 | type MockRunnableWithReloadSender struct { 144 | *Runnable 145 | } 146 | 147 | // GetReloadTrigger implements the ReloadSender interface. 148 | // It returns a receive-only channel that emits signals when a reload is requested. 149 | func (m *MockRunnableWithReloadSender) GetReloadTrigger() <-chan struct{} { 150 | args := m.Called() 151 | return args.Get(0).(chan struct{}) 152 | } 153 | 154 | // NewMockRunnableWithReloadSender creates a new MockRunnableWithReload with default delays. 155 | func NewMockRunnableWithReloadSender() *MockRunnableWithReloadSender { 156 | return &MockRunnableWithReloadSender{ 157 | Runnable: NewMockRunnable(), 158 | } 159 | } 160 | 161 | // MockRunnableWithShutdownSender extends Runnable to also implement the ShutdownSender interface. 162 | type MockRunnableWithShutdownSender struct { 163 | *Runnable 164 | } 165 | 166 | // ShutdownSender mocks implementation of the ShutdownSender interface. 167 | // It returns a receive-only channel that emits signals when a shutdown is requested. 168 | func (m *MockRunnableWithShutdownSender) GetShutdownTrigger() <-chan struct{} { 169 | args := m.Called() 170 | return args.Get(0).(chan struct{}) 171 | } 172 | 173 | // NewMockRunnableWithShutdown creates a new MockRunnableWithReload with default delays. 174 | func NewMockRunnableWithShutdownSender() *MockRunnableWithShutdownSender { 175 | return &MockRunnableWithShutdownSender{ 176 | Runnable: NewMockRunnable(), 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /runnables/composite/config_test.go: -------------------------------------------------------------------------------- 1 | package composite 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/robbyt/go-supervisor/runnables/mocks" 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | func TestNewConfig(t *testing.T) { 12 | t.Parallel() 13 | 14 | tests := []struct { 15 | name string 16 | configName string 17 | entries []RunnableEntry[*mocks.Runnable] 18 | expectError bool 19 | }{ 20 | { 21 | name: "valid config", 22 | configName: "test-config", 23 | entries: []RunnableEntry[*mocks.Runnable]{ 24 | { 25 | Runnable: &mocks.Runnable{}, 26 | Config: map[string]string{"key": "value"}, 27 | }, 28 | }, 29 | expectError: false, 30 | }, 31 | { 32 | name: "empty name", 33 | configName: "", 34 | entries: []RunnableEntry[*mocks.Runnable]{}, 35 | expectError: true, 36 | }, 37 | } 38 | 39 | for _, tt := range tests { 40 | tt := tt // Capture range variable for parallel execution 41 | t.Run(tt.name, func(t *testing.T) { 42 | t.Parallel() 43 | 44 | cfg, err := NewConfig(tt.configName, tt.entries) 45 | 46 | if tt.expectError { 47 | require.Error(t, err) 48 | assert.Nil(t, cfg) 49 | } else { 50 | require.NoError(t, err) 51 | assert.NotNil(t, cfg) 52 | assert.Equal(t, tt.configName, cfg.Name) 53 | assert.Equal(t, tt.entries, cfg.Entries) 54 | } 55 | }) 56 | } 57 | } 58 | 59 | func TestNewConfigFromRunnables(t *testing.T) { 60 | t.Parallel() 61 | 62 | mockRunnable1 := &mocks.Runnable{} 63 | mockRunnable2 := &mocks.Runnable{} 64 | sharedConfig := map[string]string{"key": "value"} 65 | 66 | cfg, err := NewConfigFromRunnables( 67 | "test", 68 | []*mocks.Runnable{mockRunnable1, mockRunnable2}, 69 | sharedConfig, 70 | ) 71 | require.NoError(t, err) 72 | 73 | assert.Equal(t, "test", cfg.Name) 74 | assert.Len(t, cfg.Entries, 2) 75 | assert.Equal(t, mockRunnable1, cfg.Entries[0].Runnable) 76 | assert.Equal(t, mockRunnable2, cfg.Entries[1].Runnable) 77 | assert.Equal(t, sharedConfig, cfg.Entries[0].Config) 78 | assert.Equal(t, sharedConfig, cfg.Entries[1].Config) 79 | } 80 | 81 | func TestConfig_Equal(t *testing.T) { 82 | t.Parallel() 83 | 84 | mockRunnable1 := &mocks.Runnable{} 85 | mockRunnable1.On("String").Return("runnable1") 86 | 87 | mockRunnable2 := &mocks.Runnable{} 88 | mockRunnable2.On("String").Return("runnable2") 89 | 90 | mockRunnable3 := &mocks.Runnable{} 91 | mockRunnable3.On("String").Return("runnable1") // Intentionally same name as mockRunnable1 92 | 93 | // Use the same reference for configs that should be equal 94 | runtimeConfig1 := map[string]string{"key": "value"} 95 | 96 | // Use a different reference for configs that should be different 97 | runtimeConfig2 := map[string]string{"key": "different"} 98 | 99 | // Create entries 100 | entries1 := []RunnableEntry[*mocks.Runnable]{ 101 | {Runnable: mockRunnable1, Config: runtimeConfig1}, 102 | {Runnable: mockRunnable2, Config: runtimeConfig1}, 103 | } 104 | 105 | entries2 := []RunnableEntry[*mocks.Runnable]{ 106 | {Runnable: mockRunnable1, Config: runtimeConfig1}, 107 | {Runnable: mockRunnable2, Config: runtimeConfig1}, 108 | } 109 | 110 | entries3 := []RunnableEntry[*mocks.Runnable]{ 111 | {Runnable: mockRunnable2, Config: runtimeConfig1}, 112 | {Runnable: mockRunnable1, Config: runtimeConfig1}, 113 | } 114 | 115 | entries4 := []RunnableEntry[*mocks.Runnable]{ 116 | {Runnable: mockRunnable3, Config: runtimeConfig1}, 117 | {Runnable: mockRunnable2, Config: runtimeConfig1}, 118 | } 119 | 120 | entries5 := []RunnableEntry[*mocks.Runnable]{ 121 | {Runnable: mockRunnable1, Config: runtimeConfig2}, 122 | {Runnable: mockRunnable2, Config: runtimeConfig1}, 123 | } 124 | 125 | entries6 := []RunnableEntry[*mocks.Runnable]{ 126 | {Runnable: mockRunnable1, Config: nil}, 127 | {Runnable: mockRunnable2, Config: nil}, 128 | } 129 | 130 | entries7 := []RunnableEntry[*mocks.Runnable]{ 131 | {Runnable: mockRunnable1, Config: nil}, 132 | {Runnable: mockRunnable2, Config: nil}, 133 | } 134 | 135 | cfg1, err := NewConfig("test", entries1) 136 | require.NoError(t, err) 137 | 138 | // Same config with same entries 139 | cfg2, err := NewConfig("test", entries2) 140 | require.NoError(t, err) 141 | 142 | // Different name 143 | cfg3, err := NewConfig("different", entries1) 144 | require.NoError(t, err) 145 | 146 | // Different runnables (order) 147 | cfg4, err := NewConfig("test", entries3) 148 | require.NoError(t, err) 149 | 150 | // Different runnables (same string rep) 151 | cfg5, err := NewConfig("test", entries4) 152 | require.NoError(t, err) 153 | 154 | // Different config for one runnable 155 | cfg6, err := NewConfig("test", entries5) 156 | require.NoError(t, err) 157 | 158 | // Nil configs 159 | cfg7, err := NewConfig("test", entries6) 160 | require.NoError(t, err) 161 | 162 | // Another nil configs 163 | cfg8, err := NewConfig("test", entries7) 164 | require.NoError(t, err) 165 | 166 | assert.True(t, cfg1.Equal(cfg2), "identical configs should be equal") 167 | assert.False(t, cfg1.Equal(cfg3), "configs with different names should not be equal") 168 | assert.False( 169 | t, 170 | cfg1.Equal(cfg4), 171 | "configs with runnables in different order should not be equal", 172 | ) 173 | assert.True(t, cfg1.Equal(cfg5), "configs with runnables with same String() should be equal") 174 | assert.False(t, cfg1.Equal(cfg6), "configs with different runnable configs should not be equal") 175 | assert.False( 176 | t, 177 | cfg1.Equal(cfg7), 178 | "config with runnable configs should not equal config with nil configs", 179 | ) 180 | assert.True(t, cfg7.Equal(cfg8), "configs with nil runnable configs should be equal") 181 | } 182 | 183 | func TestConfig_String(t *testing.T) { 184 | t.Parallel() 185 | 186 | mockRunnable1 := &mocks.Runnable{} 187 | mockRunnable2 := &mocks.Runnable{} 188 | 189 | entries := []RunnableEntry[*mocks.Runnable]{ 190 | {Runnable: mockRunnable1, Config: nil}, 191 | {Runnable: mockRunnable2, Config: nil}, 192 | } 193 | 194 | cfg, err := NewConfig("test-config", entries) 195 | require.NoError(t, err) 196 | 197 | str := cfg.String() 198 | assert.Contains(t, str, "test-config") 199 | assert.Contains(t, str, "2") 200 | } 201 | -------------------------------------------------------------------------------- /runnables/composite/README.md: -------------------------------------------------------------------------------- 1 | # Composite Runner 2 | 3 | The Composite Runner manages multiple runnables as a single logical unit. This "runnable" implements several `go-supervisor` core interfaces: `Runnable`, `Reloadable`, and `Stateable`. 4 | 5 | ## Features 6 | 7 | - Group and manage multiple runnables as a single service (all the same type) 8 | - Provide individual configuration for each runnable or shared configuration (with hot-reload) 9 | - Support dynamic membership changes during reloads (sub-runnables can be added or removed) 10 | - Propagate errors from child runnables to the supervisor 11 | - Monitor state of individual child runnables 12 | - Manage configuration updates to the sub-runnables with a callback function 13 | 14 | ## Quick Start Example 15 | 16 | ```go 17 | // Create runnable entries with their configs 18 | entries := []composite.RunnableEntry[*myapp.SomeRunnable]{ 19 | {Runnable: runnable1, Config: map[string]any{"timeout": 10 * time.Second}}, 20 | {Runnable: runnable2, Config: map[string]any{"maxConnections": 100}}, 21 | } 22 | 23 | // Define a config callback, used for dynamic membership changes and reloads 24 | configCallback := func() (*composite.Config[*myapp.SomeRunnable], error) { 25 | return composite.NewConfig("MyRunnableGroup", entries) 26 | } 27 | 28 | // Create a composite runner 29 | runner, err := composite.NewRunner( 30 | composite.WithConfigCallback(configCallback), 31 | ) 32 | if err != nil { 33 | log.Fatalf("Failed to create runner: %v", err) 34 | } 35 | 36 | // Load the composite runner into a supervisor 37 | super, err := supervisor.New(supervisor.WithRunnables(runner)) 38 | if err != nil { 39 | log.Fatalf("Failed to create supervisor: %v", err) 40 | } 41 | 42 | if err := super.Run(); err != nil { 43 | log.Fatalf("Supervisor failed: %v", err) 44 | } 45 | ``` 46 | 47 | ### Shared Configuration Example 48 | 49 | When all runnables share the same configuration: 50 | 51 | ```go 52 | runnables := []*myapp.SomeRunnable{runnable1, runnable2, runnable3} 53 | configCallback := func() (*composite.Config[*myapp.SomeRunnable], error) { 54 | return composite.NewConfigFromRunnables( 55 | "MyRunnableGroup", 56 | runnables, 57 | map[string]any{"timeout": 30 * time.Second}, 58 | ) 59 | } 60 | runner, err := composite.NewRunner( 61 | composite.WithConfigCallback(configCallback), 62 | ) 63 | ``` 64 | 65 | ## Dynamic Configuration 66 | 67 | The config callback function provides the configuration for the Composite Runner: 68 | 69 | - Returns the current configuration when requested 70 | - Called during initialization and reloads 71 | - Used to determine if runnable membership has changed 72 | - Should return quickly as it may be called frequently 73 | 74 | ```go 75 | // Example config callback that loads from file 76 | configCallback := func() (*composite.Config[*myapp.SomeRunnable], error) { 77 | // Read config from file or other source 78 | config, err := loadConfigFromFile("config.json") 79 | if err != nil { 80 | return nil, err 81 | } 82 | 83 | // Create entries based on loaded config 84 | var entries []composite.RunnableEntry[*myapp.SomeRunnable] 85 | for name, cfg := range config.Services { 86 | runnable := getOrCreateRunnable(name) 87 | entries = append(entries, composite.RunnableEntry[*myapp.SomeRunnable]{ 88 | Runnable: runnable, 89 | Config: cfg, 90 | }) 91 | } 92 | 93 | return composite.NewConfig("MyServices", entries) 94 | } 95 | ``` 96 | 97 | ## ReloadableWithConfig Interface 98 | 99 | Implement the `ReloadableWithConfig` interface to receive type-specific configuration updates: 100 | 101 | ```go 102 | type ConfigurableRunnable struct { 103 | timeout time.Duration 104 | // other fields 105 | } 106 | 107 | // Run implements the Runnable interface 108 | func (r *ConfigurableRunnable) Run(ctx context.Context) error { 109 | // implementation 110 | } 111 | 112 | // Stop implements the Runnable interface 113 | func (r *ConfigurableRunnable) Stop() { 114 | // implementation 115 | } 116 | 117 | // ReloadWithConfig receives configuration updates during reloads 118 | func (r *ConfigurableRunnable) ReloadWithConfig(config any) { 119 | if cfg, ok := config.(map[string]any); ok { 120 | if timeout, ok := cfg["timeout"].(time.Duration); ok { 121 | r.timeout = timeout 122 | } 123 | // Handle other config parameters 124 | } 125 | } 126 | ``` 127 | 128 | The Composite Runner will prioritize calling `ReloadWithConfig` over the standard `Reload()` method when a runnable implements both. 129 | 130 | ## Monitoring Child States 131 | 132 | Monitor the states of individual child runnables: 133 | 134 | ```go 135 | // Get a map of all child runnable states 136 | states := compositeRunner.GetChildStates() 137 | 138 | // Log the current state of each runnable 139 | for name, state := range states { 140 | logger.Infof("Service %s is in state %s", name, state) 141 | } 142 | 143 | // Check if a specific service is ready 144 | if states["database"] == "running" { 145 | // Database service is ready 146 | } 147 | ``` 148 | 149 | ## Managing Lifecycle 150 | 151 | The Composite Runner coordinates the lifecycle of all contained runnables: 152 | 153 | - Starts runnables in the order they are defined (async) 154 | - Stops runnables in reverse order 155 | - Propagates errors from any child runnable 156 | - Handles clean shutdown when context is canceled 157 | - Manages state transitions (New → Booting → Running → Stopping → Stopped) 158 | 159 | ## Best Practices 160 | 161 | - **Unique Identifiers**: Ensure each runnable's `String()` method returns a consistent, unique identifier 162 | - **Stateful Configuration**: Store your latest configuration for reuse if the config source becomes temporarily unavailable 163 | - **Error Handling**: Check errors returned from `Run()` to detect failures in any child runnable 164 | - **Context Management**: Pass a cancellable context to `Run()` for controlled shutdown 165 | - **Membership Changes**: Be aware that changes in membership will cause all runnables to restart 166 | - **Type Safety**: Use the same concrete type for all runnables in a composite to leverage Go's type system 167 | 168 | --- 169 | 170 | See the [examples directory](../examples/composite/) for complete working examples. -------------------------------------------------------------------------------- /supervisor/reload_test.go: -------------------------------------------------------------------------------- 1 | package supervisor 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/robbyt/go-supervisor/runnables/mocks" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/mock" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | // TestPIDZero_ReloadManager tests the reload manager functionality. 15 | func TestPIDZero_ReloadManager(t *testing.T) { 16 | t.Run("handles reload notifications", func(t *testing.T) { 17 | // Setup mock 18 | sender := mocks.NewMockRunnableWithReloadSender() 19 | reloadTrigger := make(chan struct{}) 20 | stateChan := make(chan string) 21 | 22 | sender.On("GetReloadTrigger").Return(reloadTrigger) 23 | sender.On("Run", mock.Anything).Return(nil) 24 | sender.On("Reload").Return() 25 | sender.On("Stop").Return() 26 | sender.On("GetState").Return("running").Maybe() 27 | sender.On("GetStateChan", mock.Anything).Return(stateChan).Maybe() 28 | 29 | p, err := New(WithContext(context.Background()), WithRunnables(sender)) 30 | require.NoError(t, err) 31 | 32 | done := make(chan struct{}) 33 | go func() { 34 | err := p.Run() 35 | assert.NoError(t, err) 36 | close(done) 37 | }() 38 | 39 | // Trigger reload 40 | reloadTrigger <- struct{}{} 41 | 42 | // Allow reload to process 43 | time.Sleep(100 * time.Millisecond) 44 | 45 | // Verify reload was called once 46 | sender.AssertCalled(t, "Reload") 47 | sender.AssertNumberOfCalls(t, "Reload", 1) 48 | 49 | p.Shutdown() 50 | <-done 51 | 52 | sender.AssertExpectations(t) 53 | }) 54 | 55 | t.Run("handles multiple reloads", func(t *testing.T) { 56 | sender1 := mocks.NewMockRunnableWithReloadSender() 57 | sender2 := mocks.NewMockRunnableWithReloadSender() 58 | 59 | reloadTrigger1 := make(chan struct{}) 60 | reloadTrigger2 := make(chan struct{}) 61 | stateChan1 := make(chan string) 62 | stateChan2 := make(chan string) 63 | 64 | sender1.On("GetReloadTrigger").Return(reloadTrigger1) 65 | sender2.On("GetReloadTrigger").Return(reloadTrigger2) 66 | 67 | sender1.On("Run", mock.Anything).Return(nil) 68 | sender2.On("Run", mock.Anything).Return(nil) 69 | 70 | sender1.On("Reload").Return() 71 | sender2.On("Reload").Return() 72 | 73 | sender1.On("Stop").Return() 74 | sender2.On("Stop").Return() 75 | 76 | sender1.On("GetState").Return("running").Maybe() 77 | sender2.On("GetState").Return("running").Maybe() 78 | sender1.On("GetStateChan", mock.Anything).Return(stateChan1).Maybe() 79 | sender2.On("GetStateChan", mock.Anything).Return(stateChan2).Maybe() 80 | 81 | p, err := New(WithContext(context.Background()), WithRunnables(sender1, sender2)) 82 | require.NoError(t, err) 83 | 84 | done := make(chan struct{}) 85 | go func() { 86 | err := p.Run() 87 | assert.NoError(t, err) 88 | close(done) 89 | }() 90 | 91 | // Two triggers => each service should reload twice 92 | reloadTrigger1 <- struct{}{} 93 | reloadTrigger2 <- struct{}{} 94 | 95 | time.Sleep(100 * time.Millisecond) 96 | 97 | // Expect each service to have Reload() called twice 98 | sender1.AssertNumberOfCalls(t, "Reload", 2) 99 | sender2.AssertNumberOfCalls(t, "Reload", 2) 100 | 101 | p.Shutdown() 102 | <-done 103 | 104 | sender1.AssertExpectations(t) 105 | sender2.AssertExpectations(t) 106 | }) 107 | 108 | t.Run("graceful shutdown", func(t *testing.T) { 109 | ctx, cancel := context.WithCancel(context.Background()) 110 | 111 | sender := mocks.NewMockRunnableWithReloadSender() 112 | reloadTrigger := make(chan struct{}) 113 | stateChan := make(chan string) 114 | 115 | sender.On("GetReloadTrigger").Return(reloadTrigger) 116 | sender.On("Run", mock.Anything).Return(nil) 117 | sender.On("Stop").Return() 118 | sender.On("GetState").Return("running").Maybe() 119 | sender.On("GetStateChan", mock.Anything).Return(stateChan).Maybe() 120 | 121 | p, err := New(WithContext(ctx), WithRunnables(sender)) 122 | require.NoError(t, err) 123 | 124 | done := make(chan struct{}) 125 | go func() { 126 | err := p.Run() 127 | assert.NoError(t, err) 128 | close(done) 129 | }() 130 | 131 | // Cancel context to trigger shutdown 132 | cancel() 133 | 134 | select { 135 | case <-done: 136 | sender.AssertCalled(t, "Stop") 137 | case <-time.After(time.Second): 138 | t.Fatal("shutdown timed out") 139 | } 140 | 141 | sender.AssertExpectations(t) 142 | }) 143 | 144 | // The test for state monitoring is covered in getState_test.go 145 | 146 | t.Run("manual ReloadAll call", func(t *testing.T) { 147 | // Setup multiple reloadable services that aren't reload senders 148 | mockService1 := mocks.NewMockRunnable() 149 | mockService2 := mocks.NewMockRunnable() 150 | 151 | mockService1.On("String").Return("ReloadableService1").Maybe() 152 | mockService2.On("String").Return("ReloadableService2").Maybe() 153 | 154 | mockService1.On("Reload").Once() 155 | mockService2.On("Reload").Once() 156 | 157 | mockService1.On("Run", mock.Anything).Return(nil).Once() 158 | mockService2.On("Run", mock.Anything).Return(nil).Once() 159 | 160 | mockService1.On("Stop").Once() 161 | mockService2.On("Stop").Once() 162 | 163 | // Create supervisor with both services 164 | ctx, cancel := context.WithCancel(context.Background()) 165 | defer cancel() 166 | 167 | pid0, err := New(WithContext(ctx), WithRunnables(mockService1, mockService2)) 168 | require.NoError(t, err) 169 | 170 | // Start supervisor 171 | execDone := make(chan error, 1) 172 | go func() { 173 | execDone <- pid0.Run() 174 | }() 175 | 176 | // Allow time for services to start 177 | time.Sleep(10 * time.Millisecond) 178 | 179 | // Manually trigger reload via API call 180 | pid0.ReloadAll() 181 | 182 | // Allow time for reload to complete 183 | time.Sleep(50 * time.Millisecond) 184 | 185 | // Verify both services were reloaded 186 | mockService1.AssertNumberOfCalls(t, "Reload", 1) 187 | mockService2.AssertNumberOfCalls(t, "Reload", 1) 188 | 189 | // Shutdown and wait for completion 190 | pid0.Shutdown() 191 | select { 192 | case err := <-execDone: 193 | require.NoError(t, err) 194 | case <-time.After(time.Second): 195 | t.Fatal("shutdown timed out") 196 | } 197 | 198 | // Verify expectations 199 | mockService1.AssertExpectations(t) 200 | mockService2.AssertExpectations(t) 201 | }) 202 | } 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-supervisor 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/robbyt/go-supervisor.svg)](https://pkg.go.dev/github.com/robbyt/go-supervisor) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/robbyt/go-supervisor)](https://goreportcard.com/report/github.com/robbyt/go-supervisor) 5 | [![Coverage](https://sonarcloud.io/api/project_badges/measure?project=robbyt_go-supervisor&metric=coverage)](https://sonarcloud.io/summary/new_code?id=robbyt_go-supervisor) 6 | [![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE) 7 | 8 | A service supervisor for Go applications that manages lifecycle for multiple services. Handles graceful shutdown with OS signal support (SIGINT, SIGTERM), configuration hot reloading (SIGHUP), and state monitoring. Service capabilities are added by implementing optional interfaces. 9 | 10 | ## Features 11 | 12 | - **Service Lifecycle Management**: Start, stop, and monitor multiple services 13 | - **Graceful Shutdown**: Handle OS signals (SIGINT, SIGTERM) for clean termination 14 | - **Hot Reloading**: Reload service configurations with SIGHUP or programmatically 15 | - **State Monitoring**: Track and query the state of running services 16 | - **Context Propagation**: Pass context through service lifecycle for proper cancellation 17 | - **Structured Logging**: Integrated with Go's `slog` package 18 | - **Flexible Configuration**: Functional options pattern for easy customization 19 | 20 | ## Installation 21 | 22 | ```bash 23 | go get github.com/robbyt/go-supervisor 24 | ``` 25 | 26 | ## Quick Start 27 | 28 | Define a runnable service by implementing the Runnable interface with `Run(ctx context.Context) error` and `Stop()` methods. Additional capabilities (Reloadable, Stateable) can be implemented as needed. See `supervisor/interfaces.go` for interface details. 29 | 30 | ```go 31 | package main 32 | 33 | import ( 34 | "context" 35 | "fmt" 36 | "log/slog" 37 | "os" 38 | "time" 39 | 40 | "github.com/robbyt/go-supervisor/supervisor" 41 | ) 42 | 43 | // Example service that implements Runnable interface 44 | type MyService struct { 45 | name string 46 | } 47 | 48 | // Interface guard, ensuring that MyService implements Runnable 49 | var _ supervisor.Runnable = (*MyService)(nil) 50 | 51 | func (s *MyService) Run(ctx context.Context) error { 52 | fmt.Printf("%s: Starting\n", s.name) 53 | 54 | ticker := time.NewTicker(1 * time.Second) 55 | defer ticker.Stop() 56 | 57 | for { 58 | select { 59 | case <-ctx.Done(): 60 | fmt.Printf("%s: Context canceled\n", s.name) 61 | return nil 62 | case <-ticker.C: 63 | fmt.Printf("%s: Tick\n", s.name) 64 | } 65 | } 66 | } 67 | 68 | func (s *MyService) Stop() { 69 | fmt.Printf("%s: Stopping\n", s.name) 70 | // Perform cleanup if needed 71 | } 72 | 73 | func (s *MyService) String() string { 74 | return s.name 75 | } 76 | 77 | func main() { 78 | // Create some services 79 | service1 := &MyService{name: "Service1"} 80 | service2 := &MyService{name: "Service2"} 81 | 82 | // Create a custom logger 83 | handler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ 84 | Level: slog.LevelDebug, 85 | }) 86 | 87 | // Create a supervisor with our services and custom logger 88 | super, err := supervisor.New( 89 | supervisor.WithRunnables(service1, service2), 90 | supervisor.WithLogHandler(handler), 91 | ) 92 | if err != nil { 93 | fmt.Printf("Error creating supervisor: %v\n", err) 94 | os.Exit(1) 95 | } 96 | 97 | // Blocking call to Run(), starts listening to signals and starts all Runnables 98 | if err := super.Run(); err != nil { 99 | fmt.Printf("Error: %v\n", err) 100 | os.Exit(1) 101 | } 102 | } 103 | ``` 104 | 105 | ## Core Interfaces 106 | 107 | The package is built around the following interfaces. A "Runnable" is any service that can be 108 | started and stopped, while "Reloadable" and "Stateable" services can be reloaded or report 109 | their state, respectively. The supervisor will discover the capabilities of each service 110 | and manage them accordingly. 111 | 112 | ```go 113 | // Runnable represents a service that can be run and stopped 114 | type Runnable interface { 115 | fmt.Stringer 116 | Run(ctx context.Context) error 117 | Stop() 118 | } 119 | 120 | // Reloadable represents a service that can be reloaded 121 | type Reloadable interface { 122 | Reload() 123 | } 124 | 125 | // Stateable represents a service that can report its state 126 | type Stateable interface { 127 | GetState() string 128 | GetStateChan(context.Context) <-chan string 129 | } 130 | 131 | // ReloadSender represents a service that can trigger reloads 132 | type ReloadSender interface { 133 | GetReloadTrigger() <-chan struct{} 134 | } 135 | ``` 136 | 137 | ## Advanced Usage 138 | 139 | ### Implementing a Reloadable Service 140 | 141 | ```go 142 | type ConfigurableService struct { 143 | MyService 144 | config *Config 145 | mu sync.Mutex 146 | } 147 | 148 | // Interface guards, ensuring that ConfigurableService implements Runnable and Reloadable 149 | var _ supervisor.Runnable = (*ConfigurableService)(nil) 150 | var _ supervisor.Reloadable = (*ConfigurableService)(nil) 151 | 152 | type Config struct { 153 | Interval time.Duration 154 | } 155 | 156 | func (s *ConfigurableService) Reload() { 157 | s.mu.Lock() 158 | defer s.mu.Unlock() 159 | 160 | // Load new config from file or environment 161 | newConfig := loadConfig() 162 | s.config = newConfig 163 | 164 | fmt.Printf("%s: Configuration reloaded\n", s.name) 165 | } 166 | ``` 167 | 168 | ## Example Runnables 169 | 170 | The package includes the following runnable implementations: 171 | 172 | - **HTTP Server**: A configurable HTTP server with routing and middleware support (`runnables/httpserver`) 173 | - **Composite**: A container for managing multiple Runnables using generics (`runnables/composite`) 174 | - **HTTP Cluster**: Dynamic management of multiple HTTP servers with hot-reload support using channel-based configuration (`runnables/httpcluster`) 175 | 176 | Each runnable has its own documentation in its directory (e.g., `runnables/httpserver/README.md`). 177 | 178 | ## License 179 | 180 | Apache License 2.0 - See [LICENSE](LICENSE) for details. 181 | -------------------------------------------------------------------------------- /runnables/httpcluster/runner_context_test.go: -------------------------------------------------------------------------------- 1 | package httpcluster 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/robbyt/go-supervisor/internal/finitestate" 11 | "github.com/robbyt/go-supervisor/runnables/httpserver" 12 | "github.com/robbyt/go-supervisor/runnables/mocks" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/mock" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | // TestRunnerContextPersistence ensures that server contexts are not prematurely canceled 19 | func TestRunnerContextPersistence(t *testing.T) { 20 | t.Parallel() 21 | 22 | t.Run("servers should persist after config update", func(t *testing.T) { 23 | // This test would have caught the bug where executeActions was canceling 24 | // the context immediately after starting servers 25 | var serverContexts []context.Context 26 | var mu sync.Mutex 27 | 28 | mockFactory := func(ctx context.Context, id string, cfg *httpserver.Config, handler slog.Handler) (httpServerRunner, error) { 29 | mu.Lock() 30 | serverContexts = append(serverContexts, ctx) 31 | mu.Unlock() 32 | 33 | mockServer := mocks.NewMockRunnableWithStateable() 34 | 35 | // Server should run until its context is canceled 36 | serverRunning := make(chan struct{}) 37 | mockServer.On("Run", mock.Anything).Run(func(args mock.Arguments) { 38 | close(serverRunning) 39 | <-ctx.Done() // Wait for context cancellation 40 | }).Return(nil) 41 | 42 | mockServer.On("Stop").Return().Maybe() 43 | mockServer.On("GetState").Return(finitestate.StatusRunning) 44 | mockServer.On("IsRunning").Return(true) 45 | 46 | stateChan := make(chan string, 1) 47 | stateChan <- finitestate.StatusRunning 48 | mockServer.On("GetStateChan", mock.Anything).Return(stateChan) 49 | 50 | return mockServer, nil 51 | } 52 | 53 | runner, err := NewRunner(WithRunnerFactory(mockFactory)) 54 | require.NoError(t, err) 55 | 56 | ctx, cancel := context.WithCancel(t.Context()) 57 | defer cancel() 58 | 59 | // Start runner 60 | go func() { 61 | if err := runner.Run(ctx); err != nil { 62 | t.Logf("Runner error: %v", err) 63 | } 64 | }() 65 | 66 | // Wait for running 67 | require.Eventually(t, func() bool { 68 | return runner.IsRunning() 69 | }, time.Second, 10*time.Millisecond) 70 | 71 | // Send config to create a server 72 | configs := map[string]*httpserver.Config{ 73 | "server1": createTestHTTPConfig(t, ":8001"), 74 | } 75 | runner.configSiphon <- configs 76 | 77 | // Wait for server to be created 78 | require.Eventually(t, func() bool { 79 | return runner.GetServerCount() == 1 80 | }, time.Second, 10*time.Millisecond) 81 | 82 | // Verify the server context is still valid (not canceled) 83 | mu.Lock() 84 | require.Len(t, serverContexts, 1, "Should have created one server") 85 | serverCtx := serverContexts[0] 86 | mu.Unlock() 87 | 88 | assert.Never(t, func() bool { 89 | select { 90 | case <-serverCtx.Done(): 91 | return true 92 | default: 93 | return false 94 | } 95 | }, 100*time.Millisecond, 10*time.Millisecond) 96 | 97 | // Verify server is still running 98 | assert.Equal(t, 1, runner.GetServerCount(), "Server should still be running") 99 | 100 | // Send another config update 101 | configs["server2"] = createTestHTTPConfig(t, ":8002") 102 | runner.configSiphon <- configs 103 | 104 | // Wait for second server 105 | require.Eventually(t, func() bool { 106 | return runner.GetServerCount() == 2 107 | }, time.Second, 10*time.Millisecond) 108 | 109 | // Original server context should still be valid 110 | select { 111 | case <-serverCtx.Done(): 112 | t.Fatal("First server context was canceled during second config update") 113 | default: 114 | // Good - context is still active 115 | } 116 | }) 117 | 118 | t.Run("context hierarchy is maintained correctly", func(t *testing.T) { 119 | // Track context relationships 120 | var serverCtx context.Context 121 | var runnerCtx context.Context 122 | 123 | mockFactory := func(ctx context.Context, id string, cfg *httpserver.Config, handler slog.Handler) (httpServerRunner, error) { 124 | serverCtx = ctx 125 | 126 | mockServer := mocks.NewMockRunnableWithStateable() 127 | mockServer.On("Run", mock.Anything).Run(func(args mock.Arguments) { 128 | <-ctx.Done() 129 | }).Return(nil) 130 | mockServer.On("Stop").Return().Maybe() 131 | mockServer.On("GetState").Return(finitestate.StatusRunning) 132 | mockServer.On("IsRunning").Return(true) 133 | 134 | stateChan := make(chan string, 1) 135 | stateChan <- finitestate.StatusRunning 136 | mockServer.On("GetStateChan", mock.Anything).Return(stateChan) 137 | 138 | return mockServer, nil 139 | } 140 | 141 | runner, err := NewRunner(WithRunnerFactory(mockFactory)) 142 | require.NoError(t, err) 143 | 144 | ctx, cancel := context.WithCancel(t.Context()) 145 | runnerCtx = ctx 146 | 147 | // Start runner 148 | go func() { 149 | if err := runner.Run(ctx); err != nil { 150 | t.Logf("Runner error: %v", err) 151 | } 152 | }() 153 | 154 | // Wait for running 155 | require.Eventually(t, func() bool { 156 | return runner.IsRunning() 157 | }, time.Second, 10*time.Millisecond) 158 | 159 | // Create a server 160 | configs := map[string]*httpserver.Config{ 161 | "server1": createTestHTTPConfig(t, ":8001"), 162 | } 163 | runner.configSiphon <- configs 164 | 165 | // Wait for server 166 | require.Eventually(t, func() bool { 167 | return runner.GetServerCount() == 1 168 | }, time.Second, 10*time.Millisecond) 169 | 170 | // Verify server context is derived from runner's run context 171 | // (not the runner context passed to Run) 172 | assert.NotNil(t, serverCtx, "Server context should be set") 173 | 174 | // Server context should not be the same as the runner context 175 | // (it should be a child context for individual server lifecycle) 176 | assert.NotEqual(t, runnerCtx, serverCtx, "Server should have its own context") 177 | 178 | // When we cancel the runner context, server should eventually stop 179 | cancel() 180 | 181 | // Server context should be canceled when runner stops 182 | assert.Eventually(t, func() bool { 183 | select { 184 | case <-serverCtx.Done(): 185 | return true 186 | default: 187 | return false 188 | } 189 | }, time.Second, 10*time.Millisecond) 190 | }) 191 | } 192 | -------------------------------------------------------------------------------- /runnables/httpserver/config.go: -------------------------------------------------------------------------------- 1 | package httpserver 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "time" 10 | ) 11 | 12 | const ( 13 | defaultDrainTimeout = 30 * time.Second 14 | defaultReadTimeout = 15 * time.Second 15 | defaultWriteTimeout = 15 * time.Second 16 | defaultIdleTimeout = 1 * time.Minute 17 | ) 18 | 19 | // ServerCreator is a function type that creates an HttpServer instance 20 | type ServerCreator func(addr string, handler http.Handler, cfg *Config) HttpServer 21 | 22 | // DefaultServerCreator creates a standard http.Server instance with the settings from Config 23 | func DefaultServerCreator(addr string, handler http.Handler, cfg *Config) HttpServer { 24 | // Determine which context to use 25 | ctx := cfg.context 26 | if ctx == nil { 27 | ctx = context.Background() 28 | } 29 | 30 | return &http.Server{ 31 | Addr: addr, 32 | Handler: handler, 33 | ReadTimeout: cfg.ReadTimeout, 34 | WriteTimeout: cfg.WriteTimeout, 35 | IdleTimeout: cfg.IdleTimeout, 36 | BaseContext: func(_ net.Listener) context.Context { return ctx }, 37 | } 38 | } 39 | 40 | // Config is the main configuration struct for the HTTP server 41 | type Config struct { 42 | // Core configuration 43 | ListenAddr string 44 | DrainTimeout time.Duration 45 | Routes Routes 46 | 47 | // Server settings 48 | ReadTimeout time.Duration 49 | WriteTimeout time.Duration 50 | IdleTimeout time.Duration 51 | 52 | // Server creation callback function 53 | ServerCreator ServerCreator 54 | 55 | // Context for request handlers 56 | context context.Context 57 | } 58 | 59 | // ConfigOption defines a functional option for configuring Config 60 | type ConfigOption func(*Config) 61 | 62 | // WithDrainTimeout sets the drain timeout for graceful shutdown 63 | func WithDrainTimeout(timeout time.Duration) ConfigOption { 64 | return func(c *Config) { 65 | c.DrainTimeout = timeout 66 | } 67 | } 68 | 69 | // WithReadTimeout sets the read timeout for the HTTP server 70 | func WithReadTimeout(timeout time.Duration) ConfigOption { 71 | return func(c *Config) { 72 | c.ReadTimeout = timeout 73 | } 74 | } 75 | 76 | // WithWriteTimeout sets the write timeout for the HTTP server 77 | func WithWriteTimeout(timeout time.Duration) ConfigOption { 78 | return func(c *Config) { 79 | c.WriteTimeout = timeout 80 | } 81 | } 82 | 83 | // WithIdleTimeout sets the idle timeout for the HTTP server 84 | func WithIdleTimeout(timeout time.Duration) ConfigOption { 85 | return func(c *Config) { 86 | c.IdleTimeout = timeout 87 | } 88 | } 89 | 90 | // WithServerCreator sets a custom server creator for the HTTP server 91 | func WithServerCreator(creator ServerCreator) ConfigOption { 92 | return func(c *Config) { 93 | if creator != nil { 94 | c.ServerCreator = creator 95 | } 96 | } 97 | } 98 | 99 | // WithRequestContext sets the context that will be propagated to all request handlers 100 | // via http.Server's BaseContext. This allows handlers to be aware of server shutdown. 101 | func WithRequestContext(ctx context.Context) ConfigOption { 102 | return func(c *Config) { 103 | if ctx != nil { 104 | c.context = ctx 105 | } 106 | } 107 | } 108 | 109 | // WithConfigCopy creates a ConfigOption that copies most settings from the source config 110 | // except for ListenAddr and Routes which are provided directly to NewConfig. 111 | func WithConfigCopy(src *Config) ConfigOption { 112 | return func(dst *Config) { 113 | if src == nil { 114 | return 115 | } 116 | 117 | // Copy timeout settings 118 | dst.DrainTimeout = src.DrainTimeout 119 | dst.ReadTimeout = src.ReadTimeout 120 | dst.WriteTimeout = src.WriteTimeout 121 | dst.IdleTimeout = src.IdleTimeout 122 | 123 | // Copy other settings 124 | dst.ServerCreator = src.ServerCreator 125 | dst.context = src.context 126 | } 127 | } 128 | 129 | // NewConfig creates a new Config with the address and routes 130 | // plus any optional configuration via functional options 131 | func NewConfig(addr string, routes Routes, opts ...ConfigOption) (*Config, error) { 132 | if len(routes) == 0 { 133 | return nil, errors.New("routes cannot be empty") 134 | } 135 | 136 | // Use constants for default values 137 | c := &Config{ 138 | ListenAddr: addr, 139 | Routes: routes, 140 | DrainTimeout: defaultDrainTimeout, 141 | ReadTimeout: defaultReadTimeout, 142 | WriteTimeout: defaultWriteTimeout, 143 | IdleTimeout: defaultIdleTimeout, 144 | ServerCreator: DefaultServerCreator, 145 | context: context.Background(), 146 | } 147 | 148 | // Apply overrides from the functional options 149 | for _, opt := range opts { 150 | opt(c) 151 | } 152 | 153 | return c, nil 154 | } 155 | 156 | // String returns a human-readable representation of the Config 157 | func (c *Config) String() string { 158 | return fmt.Sprintf( 159 | "Config", 160 | c.ListenAddr, 161 | c.DrainTimeout, 162 | c.Routes, 163 | c.ReadTimeout, 164 | c.WriteTimeout, 165 | c.IdleTimeout, 166 | ) 167 | } 168 | 169 | // Equal compares this Config with another and returns true if they are equivalent. 170 | func (c *Config) Equal(other *Config) bool { 171 | if other == nil { 172 | return false 173 | } 174 | 175 | if c.ListenAddr != other.ListenAddr { 176 | return false 177 | } 178 | 179 | if c.DrainTimeout != other.DrainTimeout { 180 | return false 181 | } 182 | 183 | if !c.Routes.Equal(other.Routes) { 184 | return false 185 | } 186 | 187 | // Compare server settings 188 | if c.ReadTimeout != other.ReadTimeout { 189 | return false 190 | } 191 | 192 | if c.WriteTimeout != other.WriteTimeout { 193 | return false 194 | } 195 | 196 | if c.IdleTimeout != other.IdleTimeout { 197 | return false 198 | } 199 | 200 | // Note: We don't compare ServerCreator functions as they're not directly comparable 201 | 202 | return true 203 | } 204 | 205 | // getMux creates and returns a new http.ServeMux with all configured routes registered. 206 | // Each route's Path is mapped to its handler chain via ServeHTTP. 207 | func (c *Config) getMux() *http.ServeMux { 208 | mux := http.NewServeMux() 209 | for _, route := range c.Routes { 210 | mux.Handle(route.Path, &route) 211 | } 212 | return mux 213 | } 214 | 215 | // createServer creates an HTTP server using the configuration's settings 216 | func (c *Config) createServer() HttpServer { 217 | addr := c.ListenAddr 218 | mux := c.getMux() 219 | creator := c.ServerCreator 220 | if creator == nil { 221 | creator = DefaultServerCreator 222 | } 223 | 224 | return creator(addr, mux, c) 225 | } 226 | --------------------------------------------------------------------------------