├── internal ├── middleware │ ├── context.go │ ├── rate_limiter.go │ ├── cors.go │ ├── security.go │ ├── circuit_breaker.go │ ├── compression.go │ ├── logging.go │ └── middleware.go ├── crypto │ ├── ciphers.go │ └── certmanager.go ├── pool │ ├── buffer_pool.go │ ├── options.go │ ├── transport.go │ ├── headers.go │ ├── backend.go │ └── url_rewriter.go ├── svcache │ └── cache.go ├── admin │ ├── middleware │ │ ├── host.go │ │ ├── log.go │ │ └── ip.go │ ├── validatior.go │ └── api.go ├── auth │ ├── errors.go │ ├── models │ │ └── models.go │ ├── middleware │ │ └── auth_middleware.go │ ├── validation │ │ └── password.go │ └── handlers │ │ └── auth_handler.go ├── config │ └── api.config.go ├── stls │ └── stls.go ├── server │ ├── util.go │ └── vhost.go ├── cerr │ └── backend.go ├── health │ └── checker.go └── service │ └── manager.go ├── .gitignore ├── docker-compose.yml ├── api.config.yaml ├── Dockerfile ├── pkg ├── plugin │ ├── action.go │ ├── handler.go │ ├── result.go │ └── manager.go ├── algorithm │ ├── least_connections.go │ ├── round_robin.go │ ├── ip_hash.go │ ├── bounded_least_conn.go │ ├── algorithm.go │ ├── weighted_round_robin.go │ ├── least_response_time.go │ ├── adaptive.go │ └── session_affinity.go ├── trace │ └── request_id.go ├── shutdown │ └── shutdown.go ├── logger │ ├── adapter.go │ ├── sanitizer.go │ ├── defaults.go │ ├── manager.go │ ├── async_core.go │ └── logger.go └── proxy │ └── websocket.go ├── CHANGELOG.md ├── .github └── workflows │ └── go.yml ├── cmd └── terraster │ ├── config.go │ ├── server.go │ └── main.go ├── LICENSE ├── go.mod ├── log.config.json ├── tools └── benchmark │ └── main.go ├── config.yaml ├── scripts └── database │ └── api_util.go └── plugins ├── README.md └── example └── main.go /internal/middleware/context.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | type contextKey int 4 | 5 | const ( 6 | BackendKey contextKey = iota 7 | RetryKey 8 | ) 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.exe 2 | *.exe~ 3 | *.dll 4 | *.so 5 | *.dylib 6 | 7 | *.test 8 | *.out 9 | 10 | go.work 11 | go.work.sum 12 | 13 | .env 14 | 15 | load-balancer 16 | terraster 17 | ./tmp 18 | config.yaml 19 | ./certs 20 | ./plugins 21 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | terraster: 5 | build: . 6 | ports: 7 | - "8080:8080" 8 | - "8081:8081" 9 | - "9090:9090" 10 | volumes: 11 | - ./config.yaml:/root/config.yaml 12 | - ./certs:/etc/certs 13 | restart: unless-stopped 14 | -------------------------------------------------------------------------------- /api.config.yaml: -------------------------------------------------------------------------------- 1 | api: 2 | enabled: true 3 | host: admin.domain.com 4 | port: 8085 5 | insecure: true # this have to be enabled if you want to run api on HTTP (NOT RECOMMENDED) 6 | 7 | database: 8 | path: "./auth.db" 9 | 10 | auth: 11 | jwt_secret: "HelloFormTheOtherSide" 12 | token_cleanup_interval: "7h" 13 | password_expiry_days: 3 14 | 15 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.21-alpine AS builder 2 | 3 | WORKDIR /app 4 | COPY . . 5 | RUN go mod download 6 | RUN go build -o terraster cmd/main.go 7 | 8 | FROM alpine:latest 9 | RUN apk --no-cache add ca-certificates 10 | WORKDIR /root/ 11 | COPY --from=builder /app/terraster . 12 | COPY config.yaml . 13 | 14 | EXPOSE 8080 8081 9090 15 | CMD ["./terraster", "--config", "config.yaml"] 16 | -------------------------------------------------------------------------------- /pkg/plugin/action.go: -------------------------------------------------------------------------------- 1 | package plugin 2 | 3 | // Action represents what the proxy should do with the request/response 4 | type Action int 5 | 6 | const ( 7 | // Continue indicates that the request/response should proceed normally 8 | Continue Action = iota 9 | // Stop indicates that the request/response should be stopped 10 | Stop 11 | // Modify indicates that the request/response has been modified and should proceed 12 | Modify 13 | ) 14 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. See [standard-version](https://github.com/conventional-changelog/standard-version) for commit guidelines. 4 | 5 | ### [0.4.7](https://github.com/unkn0wn-root/terraster/compare/v0.4.6...v0.4.7) (2025-09-11) 6 | 7 | 8 | ### Bug Fixes 9 | 10 | * **pool:** rewrite struct name ([4f458b1](https://github.com/unkn0wn-root/terraster/commit/4f458b127785d6bd7a50768cd06daff874db0d40)) 11 | -------------------------------------------------------------------------------- /pkg/plugin/handler.go: -------------------------------------------------------------------------------- 1 | package plugin 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | // Handler defines the interface that all plugins must implement 9 | type Handler interface { 10 | // ProcessRequest processes the request before it's sent to the backend 11 | ProcessRequest(ctx context.Context, req *http.Request) *Result 12 | 13 | // ProcessResponse processes the response before it's sent back to the client 14 | ProcessResponse(ctx context.Context, resp *http.Response) *Result 15 | 16 | Name() string 17 | Priority() int 18 | Cleanup() error 19 | } 20 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v4 21 | with: 22 | go-version: '1.22.5' 23 | 24 | - name: Build 25 | run: go build -v ./... 26 | 27 | - name: Test 28 | run: go test -v ./... 29 | -------------------------------------------------------------------------------- /internal/crypto/ciphers.go: -------------------------------------------------------------------------------- 1 | package certmanager 2 | 3 | import "crypto/tls" 4 | 5 | // default ciphers for terraster 6 | var TerrasterCiphers = []uint16{ 7 | // TLS 1.3 ciphers 8 | tls.TLS_AES_256_GCM_SHA384, 9 | tls.TLS_AES_128_GCM_SHA256, 10 | tls.TLS_CHACHA20_POLY1305_SHA256, 11 | 12 | // ECDSA ciphers (TLS 1.2) 13 | tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 14 | tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 15 | tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, 16 | 17 | // RSA ciphers (TLS 1.2) 18 | tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 19 | tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 20 | tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, 21 | 22 | // Prevent downgrade attacks 23 | tls.TLS_FALLBACK_SCSV, 24 | } 25 | -------------------------------------------------------------------------------- /internal/pool/buffer_pool.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import "sync" 4 | 5 | // BufferPool is a wrapper around sync.Pool that provides a pool of reusable byte slices. 6 | type BufferPool struct { 7 | sync.Pool 8 | } 9 | 10 | // NewBufferPool initializes and returns a new instance of BufferPool. 11 | func NewBufferPool() *BufferPool { 12 | return &BufferPool{ 13 | Pool: sync.Pool{ 14 | New: func() interface{} { 15 | return make([]byte, 32*1024) // 32KB default size 16 | }, 17 | }, 18 | } 19 | } 20 | 21 | // Get retrieves a byte slice from the BufferPool. 22 | func (b *BufferPool) Get() []byte { 23 | return b.Pool.Get().([]byte) 24 | } 25 | 26 | // Put returns a byte slice back to the BufferPool for reuse. 27 | func (b *BufferPool) Put(buf []byte) { 28 | b.Pool.Put(buf) 29 | } 30 | -------------------------------------------------------------------------------- /pkg/algorithm/least_connections.go: -------------------------------------------------------------------------------- 1 | package algorithm 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type LeastConnections struct{} 8 | 9 | func (lc *LeastConnections) Name() string { 10 | return "least-connections" 11 | } 12 | 13 | func (lc *LeastConnections) NextServer(pool ServerPool, _ *http.Request, w *http.ResponseWriter) *Server { 14 | servers := pool.GetBackends() 15 | if len(servers) == 0 { 16 | return nil 17 | } 18 | 19 | var selectedServer *Server 20 | var minConn int32 = -1 21 | 22 | for _, server := range servers { 23 | if !server.Alive.Load() || !server.CanAcceptConnection() { 24 | continue 25 | } 26 | 27 | if minConn == -1 || server.ConnectionCount < minConn { 28 | minConn = server.ConnectionCount 29 | selectedServer = server 30 | } 31 | } 32 | 33 | return selectedServer 34 | } 35 | -------------------------------------------------------------------------------- /pkg/algorithm/round_robin.go: -------------------------------------------------------------------------------- 1 | package algorithm 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type RoundRobin struct{} 8 | 9 | func (rr *RoundRobin) Name() string { 10 | return "round-robin" 11 | } 12 | 13 | func (rr *RoundRobin) NextServer(pool ServerPool, _ *http.Request, w *http.ResponseWriter) *Server { 14 | servers := pool.GetBackends() 15 | if len(servers) == 0 { 16 | return nil 17 | } 18 | 19 | currentIdx := pool.GetCurrentIndex() 20 | next := currentIdx + 1 21 | pool.SetCurrentIndex(next) 22 | 23 | idx := next % uint64(len(servers)) 24 | l := uint64(len(servers)) 25 | 26 | for i := uint64(0); i < l; i++ { 27 | serverIdx := (idx + i) % l 28 | server := servers[serverIdx] 29 | if server.Alive.Load() && server.CanAcceptConnection() { 30 | return server 31 | } 32 | 33 | } 34 | 35 | return nil 36 | } 37 | -------------------------------------------------------------------------------- /pkg/algorithm/ip_hash.go: -------------------------------------------------------------------------------- 1 | package algorithm 2 | 3 | import ( 4 | "hash/fnv" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | type IPHash struct{} 10 | 11 | func (ih *IPHash) Name() string { 12 | return "ip-hash" 13 | } 14 | 15 | func (ih *IPHash) NextServer(pool ServerPool, r *http.Request, w *http.ResponseWriter) *Server { 16 | servers := pool.GetBackends() 17 | if len(servers) == 0 { 18 | return nil 19 | } 20 | 21 | ip := strings.Split(r.RemoteAddr, ":")[0] 22 | 23 | h := fnv.New32a() 24 | h.Write([]byte(ip)) 25 | hash := h.Sum32() 26 | 27 | available := make([]*Server, 0) 28 | for _, server := range servers { 29 | if server.Alive.Load() { 30 | available = append(available, server) 31 | } 32 | } 33 | 34 | if len(available) == 0 { 35 | return nil 36 | } 37 | 38 | return available[hash%uint32(len(available))] 39 | } 40 | -------------------------------------------------------------------------------- /cmd/terraster/config.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/unkn0wn-root/terraster/internal/config" 5 | "go.uber.org/zap" 6 | ) 7 | 8 | // ConfigManager handles configuration loading and provides defaults 9 | type ConfigManager struct { 10 | logger *zap.Logger 11 | } 12 | 13 | func NewConfigManager(logger *zap.Logger) *ConfigManager { 14 | return &ConfigManager{ 15 | logger: logger, 16 | } 17 | } 18 | 19 | // LoadAPIConfig loads the API configuration with graceful fallback to set admin api as disabled 20 | func (cm *ConfigManager) LoadAPIConfig(path string) *config.APIConfig { 21 | cfg, err := config.LoadAPIConfig(path) 22 | if err != nil { 23 | cm.logger.Warn("Failed to load Admin API configuration file. Admin API is disabled", 24 | zap.Error(err), 25 | zap.String("path", path)) 26 | 27 | return &config.APIConfig{ 28 | API: config.API{ 29 | Enabled: false, 30 | }, 31 | } 32 | } 33 | 34 | return cfg 35 | } 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 david0 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pkg/trace/request_id.go: -------------------------------------------------------------------------------- 1 | package trace 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "github.com/google/uuid" 8 | ) 9 | 10 | // ContextKey is a custom type for context keys to avoid collisions. 11 | type ContextKey string 12 | 13 | const RequestIDKey ContextKey = "request_id" 14 | 15 | type RequestID struct{} 16 | 17 | func WithRequestID() *RequestID { 18 | return &RequestID{} 19 | } 20 | 21 | // Middleware generates a unique request ID for each incoming HTTP request, 22 | // stores it in the context, and sets it in the response headers. 23 | func (r *RequestID) Middleware(next http.Handler) http.Handler { 24 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 25 | requestID := uuid.New().String() 26 | w.Header().Set("X-Request-ID", requestID) 27 | ctx := context.WithValue(r.Context(), RequestIDKey, requestID) 28 | next.ServeHTTP(w, r.WithContext(ctx)) 29 | }) 30 | } 31 | 32 | // GetRequestID retrieves the request ID from the context. 33 | func GetRequestID(ctx context.Context) string { 34 | if ctx == nil { 35 | return "" 36 | } 37 | if reqID, ok := ctx.Value(RequestIDKey).(string); ok { 38 | return reqID 39 | } 40 | return "" 41 | } 42 | -------------------------------------------------------------------------------- /pkg/algorithm/bounded_least_conn.go: -------------------------------------------------------------------------------- 1 | package algorithm 2 | 3 | import ( 4 | "math/rand" 5 | "net/http" 6 | "sync/atomic" 7 | ) 8 | 9 | type BoundedLeastConnections struct { 10 | sampleSize int 11 | } 12 | 13 | func NewBoundedLeastConnections(sampleSize int) *BoundedLeastConnections { 14 | return &BoundedLeastConnections{ 15 | sampleSize: sampleSize, 16 | } 17 | } 18 | 19 | func (blc *BoundedLeastConnections) Name() string { 20 | return "bounded-least-connections" 21 | } 22 | 23 | func (blc *BoundedLeastConnections) NextServer(pool ServerPool, _ *http.Request, w http.ResponseWriter) *Server { 24 | backends := pool.GetBackends() 25 | if len(backends) == 0 { 26 | return nil 27 | } 28 | 29 | // Get sample of servers 30 | sampleSize := min(blc.sampleSize, len(backends)) 31 | indices := rand.Perm(len(backends))[:sampleSize] 32 | 33 | var selectedServer *Server 34 | minConn := int32(-1) 35 | 36 | for _, idx := range indices { 37 | server := backends[idx] 38 | if !server.Alive.Load() { 39 | continue 40 | } 41 | 42 | connections := atomic.LoadInt32(&server.ConnectionCount) 43 | if minConn == -1 || connections < minConn { 44 | minConn = connections 45 | selectedServer = server 46 | } 47 | } 48 | 49 | return selectedServer 50 | } 51 | -------------------------------------------------------------------------------- /pkg/algorithm/algorithm.go: -------------------------------------------------------------------------------- 1 | package algorithm 2 | 3 | import ( 4 | "net/http" 5 | "sync/atomic" 6 | "time" 7 | ) 8 | 9 | type Algorithm interface { 10 | NextServer(pool ServerPool, r *http.Request, w *http.ResponseWriter) *Server 11 | Name() string 12 | } 13 | 14 | type ServerPool interface { 15 | GetBackends() []*Server 16 | GetCurrentIndex() uint64 17 | SetCurrentIndex(idx uint64) 18 | } 19 | 20 | type Server struct { 21 | URL string 22 | Weight int 23 | CurrentWeight atomic.Int32 24 | ConnectionCount int32 25 | MaxConnections int32 26 | Alive atomic.Bool 27 | LastResponseTime time.Duration 28 | } 29 | 30 | func CreateAlgorithm(name string) Algorithm { 31 | switch name { 32 | case "round-robin": 33 | return &RoundRobin{} 34 | case "weighted-round-robin": 35 | return &WeightedRoundRobin{} 36 | case "least-connections": 37 | return &LeastConnections{} 38 | case "ip-hash": 39 | return &IPHash{} 40 | case "least-response-time": 41 | return NewLeastResponseTime() 42 | case "sticky-session": 43 | return NewSessionAffinity() 44 | default: 45 | return &RoundRobin{} // default algorithm 46 | } 47 | } 48 | 49 | func (b *Server) CanAcceptConnection() bool { 50 | return atomic.LoadInt32(&b.ConnectionCount) < int32(b.MaxConnections) 51 | } 52 | -------------------------------------------------------------------------------- /internal/svcache/cache.go: -------------------------------------------------------------------------------- 1 | package svcache 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/unkn0wn-root/terraster/internal/service" 7 | ) 8 | 9 | // ServiceKey uniquely identifies a service within the Terraster application. 10 | // It is composed of the service's host, port, and protocol, ensuring that each service 11 | // can be distinctly referenced and managed within the system. 12 | type ServiceKey struct { 13 | Host string // The hostname where the service is accessible, e.g., "api.example.com". 14 | Port int // The port number on which the service listens, e.g., 80 for HTTP or 443 for HTTPS. 15 | Protocol service.ServiceType // The protocol used by the service, either HTTP or HTTPS, determining how requests are handled. 16 | } 17 | 18 | // String generates a standardized string representation of the ServiceKey. 19 | // This format concatenates the Host, Port, and Protocol fields separated by pipes ("|"), 20 | // resulting in a string like "api.example.com|443|https". This representation is particularly 21 | // useful for creating unique keys for maps or caches, ensuring that each service can be 22 | // efficiently and accurately retrieved based on its unique identifier. 23 | func (k ServiceKey) String() string { 24 | return fmt.Sprintf("%s|%d|%s", k.Host, k.Port, k.Protocol) 25 | } 26 | -------------------------------------------------------------------------------- /internal/admin/middleware/host.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | 7 | "github.com/unkn0wn-root/terraster/internal/middleware" 8 | "go.uber.org/zap" 9 | ) 10 | 11 | // HostnameMiddleware validates incoming requests against a configured hostname 12 | type HostnameMiddleware struct { 13 | hostname string // Expected hostname to validate against 14 | logger *zap.Logger 15 | } 16 | 17 | func NewHostnameMiddleware(hostname string, logger *zap.Logger) middleware.Middleware { 18 | return &HostnameMiddleware{ 19 | hostname: hostname, 20 | logger: logger, 21 | } 22 | } 23 | 24 | func (m *HostnameMiddleware) Middleware(next http.Handler) http.Handler { 25 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 26 | host, _, err := net.SplitHostPort(r.Host) 27 | if err != nil { 28 | m.logger.Error("Could not split host", zap.Error(err)) 29 | http.Error(w, "Could not verify target host", http.StatusInternalServerError) 30 | return 31 | } 32 | 33 | if host != m.hostname { 34 | m.logger.Warn("Invalid hostname", 35 | zap.String("expected", m.hostname), 36 | zap.String("received", host), 37 | zap.String("ip", r.RemoteAddr), 38 | ) 39 | http.Error(w, "Invalid host", http.StatusForbidden) 40 | return 41 | } 42 | 43 | next.ServeHTTP(w, r) 44 | }) 45 | } 46 | -------------------------------------------------------------------------------- /internal/auth/errors.go: -------------------------------------------------------------------------------- 1 | package apierr 2 | 3 | import "errors" 4 | 5 | var ( 6 | // ErrUserLocked is returned when a user's account is temporarily locked due to excessive failed login attempts. 7 | ErrUserLocked = errors.New("account is temporarily locked") 8 | // ErrInvalidToken is returned when a provided token is invalid or has expired. 9 | ErrInvalidToken = errors.New("invalid or expired token") 10 | // ErrRevokedToken is returned when a token has been explicitly revoked. 11 | ErrRevokedToken = errors.New("token has been revoked") 12 | // ErrMaxTokensReached is returned when a user has reached the maximum number of active tokens allowed. 13 | ErrMaxTokensReached = errors.New("maximum number of active tokens reached") 14 | // ErrInvalidCredentials is returned when a user provides incorrect authentication credentials. 15 | ErrInvalidCredentials = errors.New("invalid credentials") 16 | // ErrUsernameTaken is returned when attempting to create a user with a username that already exists. 17 | ErrUsernameTaken = errors.New("username already exists") 18 | // ErrPasswordExpired is returned when a user's password has expired and needs to be changed. 19 | ErrPasswordExpired = errors.New("password has expired") 20 | // ErrUserNotFound is returned when a user record is not found in the database. 21 | ErrUserNotFound = errors.New("user not found") 22 | ) 23 | -------------------------------------------------------------------------------- /pkg/shutdown/shutdown.go: -------------------------------------------------------------------------------- 1 | package shutdown 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "sync" 8 | ) 9 | 10 | type Manager struct { 11 | handlers []func(context.Context) error 12 | mu sync.Mutex 13 | } 14 | 15 | func NewManager() *Manager { 16 | return &Manager{ 17 | handlers: make([]func(context.Context) error, 0), 18 | } 19 | } 20 | 21 | func (sh *Manager) AddHandler(handler func(context.Context) error) { 22 | sh.mu.Lock() 23 | defer sh.mu.Unlock() 24 | sh.handlers = append(sh.handlers, handler) 25 | } 26 | 27 | func (sh *Manager) Shutdown(ctx context.Context) error { 28 | var wg sync.WaitGroup 29 | for _, handler := range sh.handlers { 30 | wg.Add(1) 31 | go func(h func(context.Context) error) { 32 | defer wg.Done() 33 | if err := h(ctx); err != nil { 34 | log.Printf("Error during shutdown: %v", err) 35 | } 36 | }(handler) 37 | } 38 | 39 | done := make(chan struct{}) 40 | go func() { 41 | wg.Wait() 42 | close(done) 43 | }() 44 | 45 | select { 46 | case <-ctx.Done(): 47 | return ctx.Err() 48 | case <-done: 49 | return nil 50 | } 51 | } 52 | 53 | func (sh *Manager) RegisterShutdown(name string, shutdown func(context.Context) error) { 54 | sh.AddHandler(func(ctx context.Context) error { 55 | if err := shutdown(ctx); err != nil { 56 | return fmt.Errorf("%s shutdown: %w", name, err) 57 | } 58 | return nil 59 | }) 60 | } 61 | -------------------------------------------------------------------------------- /pkg/algorithm/weighted_round_robin.go: -------------------------------------------------------------------------------- 1 | package algorithm 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | type WeightedRoundRobin struct { 8 | currentWeight int 9 | } 10 | 11 | func (wrr *WeightedRoundRobin) Name() string { 12 | return "weighted-round-robin" 13 | } 14 | 15 | func (wrr *WeightedRoundRobin) NextServer(pool ServerPool, _ *http.Request, w *http.ResponseWriter) *Server { 16 | servers := pool.GetBackends() 17 | if len(servers) == 0 { 18 | return nil 19 | } 20 | 21 | var totalWeight int32 = 0 22 | var maxWeight int32 = -1 23 | var selectedServer *Server 24 | 25 | // First pass: calculate total weight and find max weight server 26 | for _, server := range servers { 27 | if !server.Alive.Load() || !server.CanAcceptConnection() { 28 | continue 29 | } 30 | 31 | sw := int32(server.Weight) 32 | 33 | currentWeight := server.CurrentWeight.Load() 34 | newWeight := currentWeight + sw 35 | server.CurrentWeight.Store(newWeight) 36 | 37 | totalWeight += sw 38 | 39 | if selectedServer == nil || newWeight > maxWeight { 40 | selectedServer = server 41 | maxWeight = newWeight 42 | } 43 | } 44 | 45 | // If we found a server, decrease its current_weight 46 | if selectedServer != nil { 47 | newWeight := selectedServer.CurrentWeight.Load() - totalWeight 48 | selectedServer.CurrentWeight.Store(newWeight) 49 | return selectedServer 50 | } 51 | 52 | return nil 53 | } 54 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/unkn0wn-root/terraster 2 | 3 | go 1.22.5 4 | 5 | require ( 6 | github.com/golang-jwt/jwt/v4 v4.5.1 7 | github.com/google/uuid v1.6.0 8 | github.com/gorilla/websocket v1.5.3 9 | github.com/natefinch/lumberjack v2.0.0+incompatible 10 | github.com/wneessen/go-mail v0.5.2 11 | go.uber.org/zap v1.27.0 12 | golang.org/x/crypto v0.31.0 13 | golang.org/x/time v0.5.0 14 | gopkg.in/yaml.v2 v2.4.0 15 | modernc.org/sqlite v1.33.1 16 | ) 17 | 18 | require ( 19 | github.com/BurntSushi/toml v1.4.0 // indirect 20 | github.com/dustin/go-humanize v1.0.1 // indirect 21 | github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect 22 | github.com/mattn/go-isatty v0.0.20 // indirect 23 | github.com/ncruces/go-strftime v0.1.9 // indirect 24 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect 25 | github.com/stretchr/testify v1.8.4 // indirect 26 | go.uber.org/multierr v1.10.0 // indirect 27 | golang.org/x/net v0.33.0 // indirect 28 | golang.org/x/sys v0.28.0 // indirect 29 | golang.org/x/text v0.21.0 // indirect 30 | gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect 31 | modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect 32 | modernc.org/libc v1.55.3 // indirect 33 | modernc.org/mathutil v1.6.0 // indirect 34 | modernc.org/memory v1.8.0 // indirect 35 | modernc.org/strutil v1.2.0 // indirect 36 | modernc.org/token v1.1.0 // indirect 37 | ) 38 | -------------------------------------------------------------------------------- /internal/pool/options.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "net/url" 5 | 6 | "github.com/unkn0wn-root/terraster/internal/config" 7 | "github.com/unkn0wn-root/terraster/pkg/plugin" 8 | "go.uber.org/zap" 9 | ) 10 | 11 | // WithURLRewriter is configuring the URLRewriteProxy. 12 | // It sets up a URL rewriter based on the provided Route and backend URL. 13 | // This allows the proxy to modify incoming request URLs according to the specified rewrite rules, 14 | // ensuring that requests are correctly routed to the intended backend services. 15 | func WithURLRewriter(config Route, backendURL *url.URL) ProxyOption { 16 | return func(p *URLRewriteProxy) { 17 | p.urlRewriter = NewURLRewriter(p.rConfig, backendURL) 18 | } 19 | } 20 | 21 | // WithLogger is configuring the URLRewriteProxy with a custom logger. 22 | func WithLogger(logger *zap.Logger) ProxyOption { 23 | return func(p *URLRewriteProxy) { 24 | p.logger = logger 25 | } 26 | } 27 | 28 | // WithHeaderConfig sets custom req/res headers 29 | func WithHeaderConfig(cfg *config.Header) ProxyOption { 30 | return func(p *URLRewriteProxy) { 31 | if cfg == nil { 32 | return 33 | } 34 | 35 | p.headerHandler = NewHeaderHandler(*cfg) 36 | } 37 | } 38 | 39 | func WithPluginManager(pm *plugin.Manager) ProxyOption { 40 | return func(p *URLRewriteProxy) { 41 | p.pluginManager = pm 42 | p.pluginEnabled = pm != nil && pm.IsEnabled() 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /internal/admin/middleware/log.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/unkn0wn-root/terraster/internal/middleware" 7 | "go.uber.org/zap" 8 | ) 9 | 10 | type AccessLogMiddleware struct { 11 | logger *zap.Logger 12 | } 13 | 14 | func NewAccessLogMiddleware(logger *zap.Logger) middleware.Middleware { 15 | return &AccessLogMiddleware{ 16 | logger: logger, 17 | } 18 | } 19 | 20 | func (m *AccessLogMiddleware) Middleware(next http.Handler) http.Handler { 21 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 22 | m.logger.Info("Request to Admin API", 23 | zap.String("method", r.Method), 24 | zap.String("request url", r.URL.Path), 25 | zap.String("request addr.", r.RemoteAddr)) 26 | 27 | // Wrap response writer to capture status code 28 | sw := &statusResponseWriter{ResponseWriter: w} 29 | next.ServeHTTP(sw, r) 30 | m.logger.Info("Response from Admin API", 31 | zap.Int("status", sw.status), 32 | zap.String("method", r.Method), 33 | zap.String("request path", r.URL.Path)) 34 | }) 35 | } 36 | 37 | type statusResponseWriter struct { 38 | http.ResponseWriter 39 | status int 40 | } 41 | 42 | func (w *statusResponseWriter) WriteHeader(status int) { 43 | w.status = status 44 | w.ResponseWriter.WriteHeader(status) 45 | } 46 | 47 | func (w *statusResponseWriter) Write(b []byte) (int, error) { 48 | if w.status == 0 { 49 | w.status = http.StatusOK 50 | } 51 | return w.ResponseWriter.Write(b) 52 | } 53 | -------------------------------------------------------------------------------- /internal/config/api.config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | 6 | "gopkg.in/yaml.v2" 7 | ) 8 | 9 | // APIAuthConfig defines the authentication configuration for the API. 10 | // It includes JWT secrets, database paths, token management settings, and password policies. 11 | type APIConfig struct { 12 | API API `yaml:"api"` 13 | AdminDatabase Database `yaml:"database"` 14 | AdminAuth Auth `yaml:"auth"` 15 | } 16 | 17 | type API struct { 18 | Enabled bool `yaml:"enabled"` 19 | Host string `yaml:"host"` 20 | Port int `yaml:"port"` 21 | TLS *TLS `yaml:"tls"` 22 | Insecure bool `yaml:"insecure"` 23 | AllowedIPs []string `yaml:"allowed_ips"` 24 | Debug bool `yaml:"debug"` 25 | } 26 | 27 | type Database struct { 28 | Path string `yaml:"path"` 29 | } 30 | 31 | type Auth struct { 32 | JWTSecret string `yaml:"jwt_secret"` 33 | PasswordMinLength int `yaml:"password_min_length"` 34 | PasswordExpiryDays int `yaml:"password_expiry_days"` 35 | PasswordHistoryLimit int `yaml:"password_history_limit"` 36 | TokenCleanupInterval string `yaml:"token_cleanup_interval"` 37 | } 38 | 39 | func LoadAPIConfig(path string) (*APIConfig, error) { 40 | data, err := os.ReadFile(path) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | var config APIConfig 46 | if err := yaml.UnmarshalStrict(data, &config); err != nil { 47 | return nil, err 48 | } 49 | 50 | return &config, nil 51 | } 52 | -------------------------------------------------------------------------------- /pkg/logger/adapter.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "strings" 5 | 6 | "go.uber.org/zap" 7 | "go.uber.org/zap/zapcore" 8 | ) 9 | 10 | // adapter implements io.Writer and writes to Zap logger. 11 | type ZapWriter struct { 12 | logger *zap.Logger 13 | level zapcore.Level 14 | prefix string 15 | } 16 | 17 | // - logger: the Zap structured logger. 18 | // - level: the log level at which messages should be logged. 19 | // - prefix: the prefix to include as a separate field (optional). 20 | func NewZapWriter(logger *zap.Logger, level zapcore.Level, prefix string) *ZapWriter { 21 | return &ZapWriter{ 22 | logger: logger, 23 | level: level, 24 | prefix: prefix, 25 | } 26 | } 27 | 28 | // Write implements the io.Writer interface. 29 | func (w *ZapWriter) Write(p []byte) (n int, err error) { 30 | msg := strings.TrimSpace(string(p)) 31 | if msg == "" { 32 | return len(p), nil 33 | } 34 | 35 | fields := []zap.Field{} 36 | 37 | if w.prefix != "" { 38 | fields = append(fields, zap.String("prefix", w.prefix)) 39 | } 40 | 41 | switch w.level { 42 | case zapcore.DebugLevel: 43 | w.logger.Debug(msg, fields...) 44 | case zapcore.InfoLevel: 45 | w.logger.Info(msg, fields...) 46 | case zapcore.WarnLevel: 47 | w.logger.Warn(msg, fields...) 48 | case zapcore.ErrorLevel: 49 | w.logger.Error(msg, fields...) 50 | case zapcore.DPanicLevel: 51 | w.logger.DPanic(msg, fields...) 52 | case zapcore.PanicLevel: 53 | w.logger.Panic(msg, fields...) 54 | case zapcore.FatalLevel: 55 | w.logger.Fatal(msg, fields...) 56 | default: 57 | w.logger.Info(msg, fields...) 58 | } 59 | 60 | return len(p), nil 61 | } 62 | -------------------------------------------------------------------------------- /internal/stls/stls.go: -------------------------------------------------------------------------------- 1 | package stls 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "sync" 7 | ) 8 | 9 | // TLSManager manages TLS configurations for virtual services sharing the same port 10 | type TLSManager struct { 11 | mu sync.RWMutex 12 | configs map[string]*tls.Config 13 | defaults *tls.Config 14 | } 15 | 16 | // NewTLSManager creates a new TLS configuration manager 17 | func NewTLSManager(defaultConfig *tls.Config) *TLSManager { 18 | return &TLSManager{ 19 | configs: make(map[string]*tls.Config), 20 | defaults: defaultConfig, 21 | } 22 | } 23 | 24 | // AddConfig adds or updates TLS configuration for a specific host 25 | func (tm *TLSManager) AddConfig(host string, config *tls.Config) { 26 | tm.mu.Lock() 27 | defer tm.mu.Unlock() 28 | tm.configs[host] = config 29 | } 30 | 31 | // GetConfig retrieves TLS configuration for a given host 32 | func (tm *TLSManager) GetConfig(host string) *tls.Config { 33 | tm.mu.RLock() 34 | defer tm.mu.RUnlock() 35 | 36 | if config, exists := tm.configs[host]; exists { 37 | return config 38 | } 39 | return tm.defaults 40 | } 41 | 42 | // GetCertificate is a callback function for TLS config that selects the appropriate 43 | // certificate based on the ClientHelloInfo 44 | func (tm *TLSManager) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { 45 | if clientHello == nil || clientHello.ServerName == "" { 46 | return nil, fmt.Errorf("no SNI information available") 47 | } 48 | 49 | config := tm.GetConfig(clientHello.ServerName) 50 | if config == nil || config.GetCertificate == nil { 51 | return nil, fmt.Errorf("no certificate provider for host: %s", clientHello.ServerName) 52 | } 53 | 54 | return config.GetCertificate(clientHello) 55 | } 56 | -------------------------------------------------------------------------------- /internal/middleware/rate_limiter.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | 6 | "golang.org/x/time/rate" 7 | ) 8 | 9 | // RateLimiterMiddleware provides rate limiting functionality to HTTP handlers. 10 | // It ensures that incoming requests are processed at a controlled rate, 11 | // preventing abuse and ensuring fair usage of server resources. 12 | type RateLimiterMiddleware struct { 13 | limiter *rate.Limiter 14 | } 15 | 16 | // NewRateLimiterMiddleware initializes and returns a new RateLimiterMiddleware. 17 | // Sets up the rate limiter with the specified requests per second (rps) and burst size. 18 | // If the burst size or rps are not provided (i.e., zero), default values are used. 19 | func NewRateLimiterMiddleware(rps float64, burst int) Middleware { 20 | if burst == 0 { 21 | burst = 50 22 | } 23 | if rps == 0 { 24 | rps = 20 25 | } 26 | return &RateLimiterMiddleware{ 27 | limiter: rate.NewLimiter(rate.Limit(rps), burst), 28 | } 29 | } 30 | 31 | // Middleware is the core function that applies the rate limiting to incoming HTTP requests. 32 | // It wraps the next handler in the chain, allowing controlled access based on the rate limiter's state. 33 | // For each incoming request, the middleware checks if the request is allowed by the rate limiter. 34 | // If the request exceeds the rate limit, it responds with a "Too Many Requests" error. 35 | // Otherwise, it forwards the request to the next handler in the chain. 36 | func (m *RateLimiterMiddleware) Middleware(next http.Handler) http.Handler { 37 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 38 | if !m.limiter.Allow() { 39 | http.Error(w, "Too Many Requests", http.StatusTooManyRequests) 40 | return 41 | } 42 | 43 | next.ServeHTTP(w, r) 44 | }) 45 | } 46 | -------------------------------------------------------------------------------- /pkg/proxy/websocket.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "strings" 7 | 8 | "github.com/gorilla/websocket" 9 | ) 10 | 11 | type WebSocketProxy struct { 12 | upgrader websocket.Upgrader 13 | backend *url.URL 14 | onConnect func(string) 15 | onClose func(string) 16 | } 17 | 18 | func NewWebSocketProxy(backend *url.URL) *WebSocketProxy { 19 | return &WebSocketProxy{ 20 | backend: backend, 21 | upgrader: websocket.Upgrader{ 22 | CheckOrigin: func(r *http.Request) bool { 23 | return true 24 | }, 25 | }, 26 | } 27 | } 28 | 29 | func (wp *WebSocketProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 30 | backendURL := *wp.backend 31 | backendURL.Scheme = strings.Replace(backendURL.Scheme, "http", "ws", 1) 32 | backendConn, _, err := websocket.DefaultDialer.Dial(backendURL.String(), nil) 33 | if err != nil { 34 | http.Error(w, "Could not connect to backend", http.StatusServiceUnavailable) 35 | return 36 | } 37 | defer backendConn.Close() 38 | 39 | clientConn, err := wp.upgrader.Upgrade(w, r, nil) 40 | if err != nil { 41 | return 42 | } 43 | defer clientConn.Close() 44 | 45 | if wp.onConnect != nil { 46 | wp.onConnect(r.RemoteAddr) 47 | } 48 | defer func() { 49 | if wp.onClose != nil { 50 | wp.onClose(r.RemoteAddr) 51 | } 52 | }() 53 | 54 | errChan := make(chan error, 2) 55 | go wp.proxy(clientConn, backendConn, errChan) 56 | go wp.proxy(backendConn, clientConn, errChan) 57 | 58 | <-errChan 59 | } 60 | 61 | func (wp *WebSocketProxy) proxy(dst, src *websocket.Conn, errChan chan error) { 62 | for { 63 | messageType, message, err := src.ReadMessage() 64 | if err != nil { 65 | errChan <- err 66 | return 67 | } 68 | 69 | err = dst.WriteMessage(messageType, message) 70 | if err != nil { 71 | errChan <- err 72 | return 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /log.config.json: -------------------------------------------------------------------------------- 1 | { 2 | "loggers": { 3 | "terraster": { 4 | "level": "debug", 5 | "outputPaths": ["terraster.log"], 6 | "errorOutputPaths": ["stderr"], 7 | "development": false, 8 | "logToConsole": true, 9 | "sampling": { 10 | "initial": 100, 11 | "thereafter": 100 12 | }, 13 | "encodingConfig": { 14 | "timeKey": "time", 15 | "levelKey": "level", 16 | "nameKey": "logger", 17 | "callerKey": "caller", 18 | "messageKey": "msg", 19 | "stacktraceKey": "stacktrace", 20 | "lineEnding": "\n", 21 | "levelEncoder": "lowercase", 22 | "timeEncoder": "iso8601", 23 | "durationEncoder": "string", 24 | "callerEncoder": "short" 25 | } 26 | }, 27 | "service_default": { 28 | "level": "info", 29 | "outputPaths": ["service_default.log"], 30 | "errorOutputPaths": ["service_default_error.log"], 31 | "development": false, 32 | "logToConsole": false, 33 | "sampling": { 34 | "initial": 100, 35 | "thereafter": 100 36 | }, 37 | "encodingConfig": { 38 | "timeKey": "time", 39 | "levelKey": "level", 40 | "nameKey": "logger", 41 | "callerKey": "caller", 42 | "messageKey": "msg", 43 | "stacktraceKey": "stacktrace", 44 | "lineEnding": "\n", 45 | "levelEncoder": "lowercase", 46 | "timeEncoder": "iso8601", 47 | "durationEncoder": "string", 48 | "callerEncoder": "short" 49 | }, 50 | "logRotation": { 51 | "enabled": true, 52 | "maxSizeMB": 200, 53 | "maxBackups": 5, 54 | "maxAgeDays": 15, 55 | "compress": true 56 | }, 57 | "sanitization": { 58 | "sensitiveFields": ["password", "token", "access_token", "refresh_token"], 59 | "mask": "****" 60 | } 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /internal/server/util.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "crypto/tls" 5 | "net" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | // hostNameNoPort extracts the hostname from a given host string by removing the port. 11 | // If the host string does not contain a port, it returns an empty string. 12 | func (s *Server) hostNameNoPort(host string) string { 13 | h, _, err := net.SplitHostPort(host) 14 | if err != nil { 15 | return "" 16 | } 17 | 18 | return h 19 | } 20 | 21 | // servicePort determines the port number to use for a service. 22 | // If a specific port is provided (non-zero), it returns that port. 23 | // Otherwise, it defaults to the standard HTTP port. 24 | func (s *Server) servicePort(port int) int { 25 | if port != 0 { 26 | return port 27 | } 28 | 29 | return DefaultHTTPPort 30 | } 31 | 32 | // hasHTTPSRedirects checks if any of the configured services require HTTP to HTTPS redirection. 33 | // Returns true if at least one service has HTTP redirects enabled, otherwise false. 34 | func (s *Server) hasHTTPSRedirects() bool { 35 | services := s.serviceManager.GetServices() 36 | for _, service := range services { 37 | if service.HTTPRedirect { 38 | return true 39 | } 40 | } 41 | return false 42 | } 43 | 44 | // parseHostPort parses a combined host and port string and determines the appropriate port based on TLS state. 45 | // If the host string does not contain a port, it assigns a default port based on whether TLS is enabled. 46 | func parseHostPort(hostPort string, tlsState *tls.ConnectionState) (host string, port int, err error) { 47 | if !strings.Contains(hostPort, ":") { 48 | if tlsState != nil { 49 | return hostPort, DefaultHTTPSPort, nil 50 | } 51 | return hostPort, DefaultHTTPPort, nil 52 | } 53 | 54 | host, portStr, err := net.SplitHostPort(hostPort) 55 | if err != nil { 56 | return "", 0, err 57 | } 58 | 59 | port, err = strconv.Atoi(portStr) 60 | if err != nil { 61 | return "", 0, err 62 | } 63 | 64 | return host, port, nil 65 | } 66 | -------------------------------------------------------------------------------- /pkg/logger/sanitizer.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "strings" 5 | 6 | "go.uber.org/zap" 7 | "go.uber.org/zap/zapcore" 8 | ) 9 | 10 | // wraps a zapcore.Core and sanitizes log entries. 11 | type SanitizerCore struct { 12 | zapcore.Core 13 | sensitiveFields []string 14 | mask string 15 | } 16 | 17 | func NewSanitizerCore(core zapcore.Core, sensitiveFields []string, mask string) *SanitizerCore { 18 | return &SanitizerCore{ 19 | Core: core, 20 | sensitiveFields: sensitiveFields, 21 | mask: mask, 22 | } 23 | } 24 | 25 | // adds structured context to the core. 26 | func (s *SanitizerCore) With(fields []zapcore.Field) zapcore.Core { 27 | return &SanitizerCore{ 28 | Core: s.Core.With(fields), 29 | sensitiveFields: s.sensitiveFields, 30 | mask: s.mask, 31 | } 32 | } 33 | 34 | // determines whether the supplied entry should be logged. 35 | func (s *SanitizerCore) Check(entry zapcore.Entry, checkedEntry *zapcore.CheckedEntry) *zapcore.CheckedEntry { 36 | if s.Enabled(entry.Level) { 37 | return checkedEntry.AddCore(entry, s) 38 | } 39 | return checkedEntry 40 | } 41 | 42 | func (s *SanitizerCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { 43 | sanitizedFields := sanitizeFields(fields, s.sensitiveFields, s.mask) 44 | return s.Core.Write(entry, sanitizedFields) 45 | } 46 | 47 | // flushe buffered logs (if any). 48 | func (s *SanitizerCore) Sync() error { 49 | return s.Core.Sync() 50 | } 51 | 52 | // processes fields and masks sensitive data. 53 | func sanitizeFields(fields []zapcore.Field, sensitiveFields []string, mask string) []zapcore.Field { 54 | maskedFields := make([]zapcore.Field, len(fields)) 55 | copy(maskedFields, fields) 56 | 57 | for i, field := range maskedFields { 58 | for _, sensitive := range sensitiveFields { 59 | if strings.EqualFold(field.Key, sensitive) { 60 | // Replace the value with the mask 61 | maskedFields[i] = zap.String(field.Key, mask) 62 | break 63 | } 64 | } 65 | } 66 | 67 | return maskedFields 68 | } 69 | -------------------------------------------------------------------------------- /internal/auth/models/models.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type Role string 8 | 9 | const ( 10 | RoleAdmin Role = "admin" 11 | RoleReader Role = "reader" 12 | ) 13 | 14 | type User struct { 15 | ID int64 `json:"id"` 16 | Username string `json:"username"` 17 | Password string `json:"-"` 18 | Role Role `json:"role"` 19 | LastLoginAt *time.Time `json:"last_login_at"` 20 | LastLoginIP string `json:"last_login_ip"` 21 | FailedAttempts int `json:"-"` 22 | LockedUntil *time.Time `json:"-"` 23 | CreatedAt time.Time `json:"created_at"` 24 | UpdatedAt time.Time `json:"updated_at"` 25 | PasswordChangedAt time.Time `json:"password_changed_at"` 26 | } 27 | 28 | type Token struct { 29 | ID int64 `json:"id"` 30 | UserID int64 `json:"user_id"` 31 | Token string `json:"token"` 32 | RefreshToken string `json:"refresh_token,omitempty"` 33 | ExpiresAt time.Time `json:"expires_at"` 34 | CreatedAt time.Time `json:"created_at"` 35 | LastUsedAt time.Time `json:"last_used_at"` 36 | RevokedAt *time.Time `json:"revoked_at,omitempty"` 37 | ClientIP string `json:"client_ip"` 38 | UserAgent string `json:"user_agent"` 39 | JTI string `json:"-"` // JWT ID for tracking 40 | Role Role `json:"role"` 41 | } 42 | 43 | type AuditLog struct { 44 | ID int64 `json:"id"` 45 | UserID int64 `json:"user_id"` 46 | Action string `json:"action"` 47 | Resource string `json:"resource"` 48 | Status string `json:"status"` 49 | IP string `json:"ip"` 50 | UserAgent string `json:"user_agent"` 51 | Details string `json:"details"` 52 | CreatedAt time.Time `json:"created_at"` 53 | } 54 | 55 | type Session struct { 56 | Token *Token `json:"token"` 57 | LastUsed time.Time `json:"last_used"` 58 | ClientInfo string `json:"client_info"` 59 | Active bool `json:"active"` 60 | } 61 | -------------------------------------------------------------------------------- /pkg/plugin/result.go: -------------------------------------------------------------------------------- 1 | package plugin 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "sync" 7 | ) 8 | 9 | // Result represents the outcome of plugin processing 10 | type Result struct { 11 | action Action 12 | StatusCode int 13 | ResponseBody []byte 14 | Headers http.Header 15 | } 16 | 17 | var ( 18 | ResultContinue = &Result{action: Continue} 19 | ResultModify = &Result{action: Modify} 20 | ) 21 | 22 | var resultPool = sync.Pool{ 23 | New: func() interface{} { 24 | return &Result{ 25 | Headers: make(http.Header), 26 | } 27 | }, 28 | } 29 | 30 | // NewResult creates a new result with options 31 | func NewResult(action Action, opts ...ResultOption) *Result { 32 | if action == Continue && len(opts) == 0 { 33 | return ResultContinue 34 | } 35 | if action == Modify && len(opts) == 0 { 36 | return ResultModify 37 | } 38 | 39 | r := resultPool.Get().(*Result) 40 | r.action = action 41 | 42 | for k := range r.Headers { 43 | delete(r.Headers, k) 44 | } 45 | 46 | r.StatusCode = 0 47 | r.ResponseBody = nil 48 | 49 | for _, opt := range opts { 50 | opt(r) 51 | } 52 | return r 53 | } 54 | 55 | // Action returns the result action 56 | func (r *Result) Action() Action { 57 | return r.action 58 | } 59 | 60 | // Release returns the result to the pool 61 | func (r *Result) Release() { 62 | if r == ResultContinue || r == ResultModify { 63 | return 64 | } 65 | resultPool.Put(r) 66 | } 67 | 68 | // ResultOption is a function that modifies a result 69 | type ResultOption func(*Result) 70 | 71 | // WithStatus sets the status code 72 | func WithStatus(code int) ResultOption { 73 | return func(r *Result) { 74 | r.StatusCode = code 75 | } 76 | } 77 | 78 | // WithHeaders adds headers 79 | func WithHeaders(headers http.Header) ResultOption { 80 | return func(r *Result) { 81 | for k, vals := range headers { 82 | r.Headers[k] = vals 83 | } 84 | } 85 | } 86 | 87 | // WithJSONResponse sets a JSON response body 88 | func WithJSONResponse(v interface{}) ResultOption { 89 | return func(r *Result) { 90 | if data, err := json.Marshal(v); err == nil { 91 | r.ResponseBody = data 92 | r.Headers.Set("Content-Type", "application/json") 93 | } 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /tools/benchmark/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "net/http" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | func main() { 12 | url := flag.String("url", "http://localhost:8080", "URL to benchmark") 13 | concurrency := flag.Int("c", 10, "Number of concurrent requests") 14 | requests := flag.Int("n", 1000, "Total number of requests") 15 | duration := flag.Duration("d", 0, "Duration of the test") 16 | flag.Parse() 17 | 18 | results := make(chan time.Duration, *requests) 19 | errors := make(chan error, *requests) 20 | var wg sync.WaitGroup 21 | 22 | start := time.Now() 23 | client := &http.Client{ 24 | Timeout: time.Second * 10, 25 | } 26 | 27 | if *duration > 0 { 28 | timer := time.NewTimer(*duration) 29 | go func() { 30 | <-timer.C 31 | fmt.Println("Duration reached, stopping...") 32 | *requests = 0 33 | }() 34 | } 35 | 36 | // Start workers 37 | for i := 0; i < *concurrency; i++ { 38 | wg.Add(1) 39 | go func() { 40 | defer wg.Done() 41 | for i := 0; i < *requests / *concurrency; i++ { 42 | requestStart := time.Now() 43 | resp, err := client.Get(*url) 44 | if err != nil { 45 | errors <- err 46 | continue 47 | } 48 | resp.Body.Close() 49 | results <- time.Since(requestStart) 50 | } 51 | }() 52 | } 53 | 54 | // Wait for completion 55 | wg.Wait() 56 | close(results) 57 | close(errors) 58 | 59 | // Process results 60 | var total time.Duration 61 | var count int 62 | var min, max time.Duration 63 | errCount := 0 64 | 65 | for d := range results { 66 | if min == 0 || d < min { 67 | min = d 68 | } 69 | if d > max { 70 | max = d 71 | } 72 | total += d 73 | count++ 74 | } 75 | 76 | for range errors { 77 | errCount++ 78 | } 79 | 80 | // Print results 81 | fmt.Printf("\nBenchmark Results:\n") 82 | fmt.Printf("URL: %s\n", *url) 83 | fmt.Printf("Concurrency Level: %d\n", *concurrency) 84 | fmt.Printf("Time taken: %v\n", time.Since(start)) 85 | fmt.Printf("Complete requests: %d\n", count) 86 | fmt.Printf("Failed requests: %d\n", errCount) 87 | fmt.Printf("Requests per second: %.2f\n", float64(count)/time.Since(start).Seconds()) 88 | fmt.Printf("Mean latency: %v\n", total/time.Duration(count)) 89 | fmt.Printf("Min latency: %v\n", min) 90 | fmt.Printf("Max latency: %v\n", max) 91 | } 92 | -------------------------------------------------------------------------------- /pkg/logger/defaults.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | // Provides fallback logging settings for any logger not specified in log.config.json. 4 | var DefaultConfig = Config{ 5 | Level: "info", 6 | OutputPaths: []string{"stdout"}, 7 | ErrorOutputPaths: []string{ 8 | "stderr", 9 | }, 10 | Development: false, 11 | LogToConsole: false, 12 | Sampling: Sampling{ 13 | Initial: 100, 14 | Thereafter: 100, 15 | }, 16 | Encoding: Encoding{ 17 | TimeKey: "time", 18 | LevelKey: "level", 19 | NameKey: "logger", 20 | CallerKey: "caller", 21 | MessageKey: "msg", 22 | StacktraceKey: "stacktrace", 23 | LineEnding: "\n", 24 | LevelEncoder: "lowercase", 25 | TimeEncoder: "iso8601", 26 | DurationEncoder: "string", 27 | CallerEncoder: "short", 28 | }, 29 | LogRotation: LogRotation{ 30 | Enabled: true, 31 | MaxSizeMB: 100, 32 | MaxBackups: 7, 33 | MaxAgeDays: 30, 34 | Compress: true, 35 | }, 36 | Sanitization: Sanitization{ 37 | SensitiveFields: []string{ 38 | "password", 39 | "token", 40 | "access_token", 41 | "refresh_token", 42 | }, 43 | Mask: "****", 44 | }, 45 | } 46 | 47 | func assignDefaultValues(cfg *Config) { 48 | if cfg.Level == "" { 49 | cfg.Level = DefaultConfig.Level 50 | } 51 | if len(cfg.OutputPaths) == 0 { 52 | cfg.OutputPaths = DefaultConfig.OutputPaths 53 | } 54 | if len(cfg.ErrorOutputPaths) == 0 { 55 | cfg.ErrorOutputPaths = DefaultConfig.ErrorOutputPaths 56 | } 57 | if cfg.Encoding.LevelEncoder == "" { 58 | cfg.Encoding.LevelEncoder = DefaultConfig.Encoding.LevelEncoder 59 | } 60 | if cfg.Encoding.TimeEncoder == "" { 61 | cfg.Encoding.TimeEncoder = DefaultConfig.Encoding.TimeEncoder 62 | } 63 | if cfg.Encoding.DurationEncoder == "" { 64 | cfg.Encoding.DurationEncoder = DefaultConfig.Encoding.DurationEncoder 65 | } 66 | if cfg.Encoding.CallerEncoder == "" { 67 | cfg.Encoding.CallerEncoder = DefaultConfig.Encoding.CallerEncoder 68 | } 69 | if cfg.LogRotation.MaxSizeMB == 0 { 70 | cfg.LogRotation.MaxSizeMB = DefaultConfig.LogRotation.MaxSizeMB 71 | } 72 | if cfg.LogRotation.MaxBackups == 0 { 73 | cfg.LogRotation.MaxBackups = DefaultConfig.LogRotation.MaxBackups 74 | } 75 | if cfg.LogRotation.MaxAgeDays == 0 { 76 | cfg.LogRotation.MaxAgeDays = DefaultConfig.LogRotation.MaxAgeDays 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | port: 443 2 | 3 | services: 4 | - name: backend-api # service name 5 | host: internal-api1.local.com # service listener hostname 6 | port: 8455 # service listener port 7 | tls: # service tls configuration 8 | cert_file: "/path/to/api-cert.pem" 9 | key_file: "/path/to/api-key.pem" 10 | health_check: # service health check configuration - will be used by each location 11 | type: "http" 12 | path: "/" 13 | interval: "5s" 14 | timeout: "3s" 15 | thresholds: 16 | healthy: 2 17 | unhealthy: 3 18 | locations: 19 | - path: "/api/" # served path suffix so "https://internal-api1.local.com/api/" 20 | lb_policy: round-robin # load balancing policy 21 | http_redirect: true # http to https redirect 22 | redirect: "/" # redirect e.q. from "/" to "/api/" 23 | backends: 24 | - url: http://internal-api1.local.com:8455 25 | weight: 5 26 | max_connections: 1000 27 | health_check: # or have separate health check for each backend and override service health check 28 | type: "http" 29 | path: "/api_health" 30 | interval: "4s" 31 | timeout: "3s" 32 | thresholds: 33 | healthy: 1 34 | unhealthy: 2 35 | - url: http://internal-api2.local.com:8455 36 | weight: 3 37 | max_connections: 800 38 | 39 | - name: frontend 40 | host: frontend.local.com 41 | locations: 42 | - path: "/" 43 | lb_policy: least_connections 44 | http_redirect: false 45 | rewrite: "/frontend/" # rewrite e.q. from "/" to "/frontend/" in the backend service 46 | backends: 47 | - url: http://frontend-1.local.com:3000 48 | weight: 5 49 | max_connections: 1000 50 | 51 | - url: http://frontend-2.local.com:3000 52 | weight: 3 53 | max_connections: 800 54 | 55 | # global health check will be used by every service that don't have health_check configuration 56 | health_check: 57 | interval: 10s 58 | timeout: 2s 59 | path: /health 60 | 61 | rate_limit: # global rate limit for each service if not defined in the service 62 | requests_per_second: 100 63 | burst: 150 64 | 65 | connection_pool: 66 | max_idle: 100 67 | max_open: 1000 68 | idle_timeout: 90s 69 | -------------------------------------------------------------------------------- /pkg/algorithm/least_response_time.go: -------------------------------------------------------------------------------- 1 | package algorithm 2 | 3 | import ( 4 | "net/http" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | type LeastResponseTime struct { 10 | mu sync.RWMutex 11 | responseTimes map[string]time.Duration 12 | decay float64 13 | updateInterval time.Duration 14 | } 15 | 16 | func NewLeastResponseTime() *LeastResponseTime { 17 | lrt := &LeastResponseTime{ 18 | responseTimes: make(map[string]time.Duration), 19 | decay: 0.8, 20 | updateInterval: time.Second * 10, 21 | } 22 | go lrt.periodicCleanup() 23 | return lrt 24 | } 25 | 26 | func (lrt *LeastResponseTime) Name() string { 27 | return "least-response-time" 28 | } 29 | 30 | func (lrt *LeastResponseTime) NextServer( 31 | pool ServerPool, 32 | _ *http.Request, 33 | w *http.ResponseWriter, 34 | ) *Server { 35 | backends := pool.GetBackends() 36 | if len(backends) == 0 { 37 | return nil 38 | } 39 | 40 | var selectedServer *Server 41 | minTime := time.Duration(-1) 42 | 43 | lrt.mu.RLock() 44 | defer lrt.mu.RUnlock() 45 | 46 | for _, server := range backends { 47 | if !server.Alive.Load() { 48 | continue 49 | } 50 | 51 | responseTime, exists := lrt.responseTimes[server.URL] 52 | if !exists { 53 | return server // Prefer untested servers 54 | } 55 | 56 | // Consider both response time and current connections 57 | adjustedTime := responseTime * time.Duration(server.ConnectionCount+1) 58 | if minTime == -1 || adjustedTime < minTime { 59 | minTime = adjustedTime 60 | selectedServer = server 61 | } 62 | } 63 | 64 | return selectedServer 65 | } 66 | 67 | func (lrt *LeastResponseTime) UpdateResponseTime(serverURL string, duration time.Duration) { 68 | lrt.mu.Lock() 69 | defer lrt.mu.Unlock() 70 | 71 | current, exists := lrt.responseTimes[serverURL] 72 | if !exists { 73 | lrt.responseTimes[serverURL] = duration 74 | return 75 | } 76 | 77 | // Exponential moving average 78 | lrt.responseTimes[serverURL] = time.Duration(float64(current)*lrt.decay + float64(duration)*(1-lrt.decay)) 79 | } 80 | 81 | func (lrt *LeastResponseTime) periodicCleanup() { 82 | ticker := time.NewTicker(lrt.updateInterval) 83 | for range ticker.C { 84 | lrt.mu.Lock() 85 | for server := range lrt.responseTimes { 86 | // Remove stale entries 87 | if _, exists := lrt.responseTimes[server]; !exists { 88 | delete(lrt.responseTimes, server) 89 | } 90 | } 91 | lrt.mu.Unlock() 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /internal/middleware/cors.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | "strings" 7 | 8 | "github.com/unkn0wn-root/terraster/internal/config" 9 | ) 10 | 11 | type CORS struct { 12 | AllowedOrigins []string 13 | AllowedMethods []string 14 | AllowedHeaders []string 15 | ExposedHeaders []string 16 | AllowCredentials bool 17 | MaxAge int 18 | } 19 | 20 | // Initializes and returns a new CORS instance based on the provided configuration. 21 | func NewCORSMiddleware(cfg *config.Terraster) *CORS { 22 | var config *config.CORS 23 | for _, mw := range cfg.Middleware { 24 | if mw.CORS != nil { 25 | config = mw.CORS 26 | break 27 | } 28 | } 29 | 30 | if config == nil { 31 | return nil 32 | } 33 | 34 | return &CORS{ 35 | AllowedOrigins: config.AllowedOrigins, 36 | AllowedMethods: config.AllowedMethods, 37 | AllowedHeaders: config.AllowedHeaders, 38 | ExposedHeaders: config.ExposedHeaders, 39 | AllowCredentials: config.AllowCredentials, 40 | MaxAge: config.MaxAge, 41 | } 42 | } 43 | 44 | // Middleware is an HTTP middleware that sets CORS headers on incoming HTTP responses. 45 | // Manages Cross-Origin Resource Sharing settings based on the CORS configuration. 46 | func (c *CORS) Middleware(next http.Handler) http.Handler { 47 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 48 | if len(c.AllowedOrigins) > 0 { 49 | if len(c.AllowedOrigins) == 1 && c.AllowedOrigins[0] == "*" { 50 | w.Header().Set("Access-Control-Allow-Origin", "*") 51 | } else { 52 | w.Header().Set("Access-Control-Allow-Origin", strings.Join(c.AllowedOrigins, ",")) 53 | } 54 | } 55 | if len(c.AllowedMethods) > 0 { 56 | w.Header().Set("Access-Control-Allow-Methods", strings.Join(c.AllowedMethods, ",")) 57 | } 58 | if len(c.AllowedHeaders) > 0 { 59 | w.Header().Set("Access-Control-Allow-Headers", strings.Join(c.AllowedHeaders, ",")) 60 | } 61 | if len(c.ExposedHeaders) > 0 { 62 | w.Header().Set("Access-Control-Expose-Headers", strings.Join(c.ExposedHeaders, ",")) 63 | } 64 | if c.AllowCredentials { 65 | w.Header().Set("Access-Control-Allow-Credentials", "true") 66 | } 67 | if c.MaxAge > 0 { 68 | w.Header().Set("Access-Control-Max-Age", strconv.Itoa(c.MaxAge)) 69 | } 70 | if r.Method == http.MethodOptions { 71 | w.WriteHeader(http.StatusOK) 72 | return 73 | } 74 | next.ServeHTTP(w, r) 75 | }) 76 | } 77 | -------------------------------------------------------------------------------- /internal/pool/transport.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "crypto/tls" 5 | "net/http" 6 | "time" 7 | 8 | pxErr "github.com/unkn0wn-root/terraster/internal/cerr" 9 | ) 10 | 11 | const ( 12 | // DefaultMaxIdleConnsPerHost is the maximum number of idle connections to keep per-host 13 | DefaultMaxIdleConnsPerHost = 32 14 | 15 | // DefaultIdleConnTimeout is the maximum amount of time an idle connection will remain idle before closing 16 | DefaultIdleConnTimeout = 30 * time.Second 17 | ) 18 | 19 | // Transport provides a custom implementation of http.RoundTripper that wraps 20 | // the standard http.Transport with additional configuration options for TLS, 21 | // connection pooling, and HTTP/2 support. 22 | type Transport struct { 23 | transport *http.Transport 24 | } 25 | 26 | // NewTransport creates and returns a new Transport instance with default configuration. 27 | // The default configuration includes: 28 | // - Connection pooling with DefaultMaxIdleConnsPerHost idle connections per host 29 | // - Idle connection timeout set to DefaultIdleConnTimeout 30 | // - Initialized TLS configuration 31 | func NewTransport(tr *http.Transport) *Transport { 32 | tr.MaxIdleConnsPerHost = DefaultMaxIdleConnsPerHost 33 | tr.IdleConnTimeout = DefaultIdleConnTimeout 34 | 35 | if tr.TLSClientConfig == nil { 36 | tr.TLSClientConfig = &tls.Config{} 37 | } 38 | 39 | return &Transport{ 40 | transport: tr, 41 | } 42 | } 43 | 44 | // ConfigureTransport sets up TLS and HTTP/2 settings for the transport. 45 | // It configures SNI (Server Name Indication), TLS verification, and HTTP/2 support. 46 | // When h2 is false, the transport will be configured to use HTTP/1.1 exclusively 47 | func (t *Transport) ConfigureTransport(serverName string, skipTLSVerify bool, h2 bool) { 48 | t.transport.TLSClientConfig.InsecureSkipVerify = skipTLSVerify 49 | t.transport.TLSClientConfig.ServerName = serverName 50 | 51 | if !h2 { 52 | t.transport.ForceAttemptHTTP2 = false 53 | t.transport.TLSClientConfig.NextProtos = []string{"http/1.1"} 54 | t.transport.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper) 55 | } else { 56 | // Enable HTTP/2 explicitly even if passed transporter can be DefaultTransporter which already has this enabled 57 | t.transport.ForceAttemptHTTP2 = true 58 | } 59 | } 60 | 61 | // RoundTrip implements the RoundTripper interface for the Transport type. 62 | func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { 63 | r, err := t.transport.RoundTrip(req) 64 | if err != nil { 65 | return nil, pxErr.NewProxyError("round_trip", err) 66 | } 67 | 68 | return r, nil 69 | } 70 | 71 | // GetTransport returns the underlying http.Transport. 72 | func (t *Transport) GetTransport() http.RoundTripper { 73 | return t.transport 74 | } 75 | -------------------------------------------------------------------------------- /internal/admin/validatior.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/unkn0wn-root/terraster/internal/config" 9 | ) 10 | 11 | type ValidationError struct { 12 | Field string 13 | Error string 14 | } 15 | 16 | func (e ValidationError) String() string { 17 | return fmt.Sprintf("%s: %s", e.Field, e.Error) 18 | } 19 | 20 | type Validator interface { 21 | Validate() []ValidationError 22 | } 23 | 24 | type BackendRequest struct { 25 | URL string `json:"url"` 26 | Weight int `json:"weight"` 27 | MaxConnections int32 `json:"maxConnections"` 28 | SkipTLSVerify bool `json:"skipTLSVerify"` 29 | HealthCheck *config.HealthCheck `json:"healthCheck"` 30 | } 31 | 32 | func (r BackendRequest) Validate() []ValidationError { 33 | var errors []ValidationError 34 | 35 | if r.URL == "" { 36 | errors = append(errors, ValidationError{"url", "required"}) 37 | } 38 | if r.Weight <= 0 { 39 | errors = append(errors, ValidationError{"weight", "must be positive"}) 40 | } 41 | if r.MaxConnections <= 0 { 42 | errors = append(errors, ValidationError{"maxConnections", "must be positive"}) 43 | } 44 | 45 | if r.HealthCheck != nil { 46 | if errs := validateHealthCheck(r.HealthCheck); len(errs) > 0 { 47 | errors = append(errors, errs...) 48 | } 49 | } 50 | 51 | return errors 52 | } 53 | 54 | func validateHealthCheck(hc *config.HealthCheck) []ValidationError { 55 | var errors []ValidationError 56 | 57 | if hc.Type != "http" && hc.Type != "tcp" { 58 | errors = append(errors, ValidationError{"healthCheck.type", "must be 'http' or 'tcp'"}) 59 | } 60 | if hc.Interval <= 0 { 61 | errors = append(errors, ValidationError{"healthCheck.interval", "must be positive"}) 62 | } 63 | if hc.Timeout <= 0 { 64 | errors = append(errors, ValidationError{"healthCheck.timeout", "must be positive"}) 65 | } 66 | if hc.Thresholds.Healthy <= 0 { 67 | errors = append(errors, ValidationError{"healthCheck.thresholds.healthy", "must be positive"}) 68 | } 69 | if hc.Thresholds.Unhealthy <= 0 { 70 | errors = append(errors, ValidationError{"healthCheck.thresholds.unhealthy", "must be positive"}) 71 | } 72 | 73 | return errors 74 | } 75 | 76 | // HTTP handler helper 77 | func DecodeAndValidate(w http.ResponseWriter, r *http.Request, v Validator) error { 78 | if err := json.NewDecoder(r.Body).Decode(v); err != nil { 79 | http.Error(w, "Invalid request payload: "+err.Error(), http.StatusBadRequest) 80 | return err 81 | } 82 | 83 | if errors := v.Validate(); len(errors) > 0 { 84 | msg := "Validation failed:" 85 | for _, err := range errors { 86 | msg += "\n" + err.String() 87 | } 88 | http.Error(w, msg, http.StatusBadRequest) 89 | return fmt.Errorf(msg) 90 | } 91 | 92 | return nil 93 | } 94 | -------------------------------------------------------------------------------- /pkg/logger/manager.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | 7 | "go.uber.org/zap" 8 | ) 9 | 10 | type LoggerManager struct { 11 | loggers map[string]*zap.Logger 12 | mu sync.RWMutex 13 | defaultConfig *Config 14 | } 15 | 16 | func NewLoggerManager(logsConfigPaths []string) (*LoggerManager, error) { 17 | lm := &LoggerManager{ 18 | loggers: make(map[string]*zap.Logger), 19 | defaultConfig: &DefaultConfig, 20 | } 21 | 22 | // init logger but since we need logger - panic if it fails 23 | if err := Init(logsConfigPaths, lm); err != nil { 24 | return nil, err 25 | } 26 | 27 | return lm, nil 28 | } 29 | 30 | // adds a new logger to the manager. 31 | // Returns error if a logger with the same name already exists. 32 | func (lm *LoggerManager) AddLogger(name string, logger *zap.Logger) error { 33 | if logger == nil { 34 | return fmt.Errorf("logger cannot be nil") 35 | } 36 | 37 | lm.mu.Lock() 38 | defer lm.mu.Unlock() 39 | 40 | if _, exists := lm.loggers[name]; exists { 41 | return fmt.Errorf("logger '%s' already exists", name) 42 | } 43 | 44 | lm.loggers[name] = logger 45 | return nil 46 | } 47 | 48 | // Retrieves a logger by name. Returns error if logger doesn't exist. 49 | func (lm *LoggerManager) GetLogger(name string) (*zap.Logger, error) { 50 | lm.mu.RLock() 51 | logger, exists := lm.loggers[name] 52 | lm.mu.RUnlock() 53 | if exists { 54 | return logger, nil 55 | } 56 | 57 | return nil, fmt.Errorf("logger '%s' not found", name) 58 | } 59 | 60 | // Returns a copy of the map containing all loggers. 61 | func (lm *LoggerManager) GetAllLoggers() map[string]*zap.Logger { 62 | lm.mu.RLock() 63 | defer lm.mu.RUnlock() 64 | 65 | copyMap := make(map[string]*zap.Logger, len(lm.loggers)) 66 | for k, v := range lm.loggers { 67 | copyMap[k] = v 68 | } 69 | return copyMap 70 | } 71 | 72 | // Sync flushes all loggers managed by LoggerManager. 73 | // Returns a multi-error containing all sync errors encountered. 74 | func (lm *LoggerManager) Sync() error { 75 | lm.mu.RLock() 76 | defer lm.mu.RUnlock() 77 | 78 | var errs []error 79 | for name, logger := range lm.loggers { 80 | if err := logger.Sync(); err != nil { 81 | errs = append(errs, fmt.Errorf("failed to sync logger '%s': %w", name, err)) 82 | } 83 | } 84 | 85 | if len(errs) > 0 { 86 | return fmt.Errorf("sync errors: %v", errs) 87 | } 88 | return nil 89 | } 90 | 91 | // Removes a logger from the manager. 92 | func (lm *LoggerManager) RemoveLogger(name string) error { 93 | lm.mu.Lock() 94 | defer lm.mu.Unlock() 95 | 96 | if _, exists := lm.loggers[name]; !exists { 97 | return fmt.Errorf("logger '%s' not found", name) 98 | } 99 | 100 | delete(lm.loggers, name) 101 | return nil 102 | } 103 | -------------------------------------------------------------------------------- /internal/admin/middleware/ip.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/unkn0wn-root/terraster/internal/middleware" 9 | "go.uber.org/zap" 10 | ) 11 | 12 | // IPRestrictionMiddleware validates incoming requests against configured allowed IPs 13 | type IPRestrictionMiddleware struct { 14 | allowedIPs []string 15 | logger *zap.Logger 16 | } 17 | 18 | // NewIPRestrictionMiddleware creates a new middleware for IP-based access control 19 | func NewIPRestrictionMiddleware(allowedIPs []string, logger *zap.Logger) middleware.Middleware { 20 | return &IPRestrictionMiddleware{ 21 | allowedIPs: allowedIPs, 22 | logger: logger, 23 | } 24 | } 25 | 26 | // This middleware provides IP-based access control. 27 | // It validates the client's IP address against a configured list of allowed IPs. 28 | // 29 | // The middleware follows these rules: 30 | // - If no IPs are configured (allowedIPs is empty), all requests are allowed 31 | // - If IPs are configured, only requests from those IPs are allowed 32 | // - Client IP is extracted from X-Forwarded-For header first, then X-Real-IP, finally falling back to RemoteAddr 33 | // 34 | // The function will return an HTTP 403 Forbidden status if the IP is not allowed, 35 | // or HTTP 500 Internal Server Error if the client IP cannot be determined. 36 | func (m *IPRestrictionMiddleware) Middleware(next http.Handler) http.Handler { 37 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 38 | // if no ip configured - assume allow all 39 | if len(m.allowedIPs) == 0 { 40 | next.ServeHTTP(w, r) 41 | return 42 | } 43 | 44 | clientIP := extractIP(r) 45 | if clientIP == "" { 46 | http.Error(w, "Could not verify client IP", http.StatusInternalServerError) 47 | return 48 | } 49 | for _, allowedIP := range m.allowedIPs { 50 | if clientIP == allowedIP { 51 | next.ServeHTTP(w, r) 52 | return 53 | } 54 | } 55 | 56 | m.logger.Warn("Access denied: IP not allowed", 57 | zap.String("client_ip", clientIP), 58 | zap.Strings("allowed_ips", m.allowedIPs), 59 | ) 60 | http.Error(w, "Access denied", http.StatusForbidden) 61 | }) 62 | } 63 | 64 | // extractIP gets the real client IP, taking into account X-Forwarded-For and X-Real-IP headers 65 | func extractIP(r *http.Request) string { 66 | forwardedFor := r.Header.Get("X-Forwarded-For") 67 | if forwardedFor != "" { 68 | // X-Forwarded-For can contain multiple IPs; take the first one 69 | ips := strings.Split(forwardedFor, ",") 70 | if len(ips) > 0 { 71 | return strings.TrimSpace(ips[0]) 72 | } 73 | } 74 | 75 | realIP := r.Header.Get("X-Real-IP") 76 | if realIP != "" { 77 | return realIP 78 | } 79 | 80 | ip, _, err := net.SplitHostPort(r.RemoteAddr) 81 | if err != nil { 82 | return r.RemoteAddr 83 | } 84 | return ip 85 | } 86 | -------------------------------------------------------------------------------- /internal/middleware/security.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/unkn0wn-root/terraster/internal/config" 8 | ) 9 | 10 | type ServerSecurity struct { 11 | HSTS bool // Enables HTTP Strict Transport Security (HSTS). 12 | HSTSMaxAge int // Specifies the duration (in seconds) for which the browser should remember that the site is only to be accessed using HTTPS. 13 | HSTSIncludeSubDomains bool // If true, applies HSTS policy to all subdomains. 14 | HSTSPreload bool // If true, includes the site in browsers' HSTS preload lists. 15 | FrameOptions string // Specifies the X-Frame-Options header value to control whether the site can be embedded in frames. 16 | ContentTypeOptions bool // Enables the X-Content-Type-Options header to prevent MIME type sniffing. 17 | XSSProtection bool // Enables the X-XSS-Protection header to activate the browser's built-in XSS protection. 18 | } 19 | 20 | // NewSecurityMiddleware initializes and returns a new ServerSecurity instance based on the provided configuration. 21 | // Reads security-related settings from the configuration and sets up the corresponding fields. 22 | func NewSecurityMiddleware(cfg *config.Terraster) *ServerSecurity { 23 | var config *config.Security 24 | for _, mw := range cfg.Middleware { 25 | if mw.Security != nil { 26 | config = mw.Security 27 | break 28 | } 29 | } 30 | return &ServerSecurity{ 31 | HSTS: config.HSTS, 32 | HSTSMaxAge: config.HSTSMaxAge, 33 | HSTSIncludeSubDomains: config.HSTSIncludeSubDomains, 34 | HSTSPreload: config.HSTSPreload, 35 | FrameOptions: config.FrameOptions, 36 | ContentTypeOptions: config.ContentTypeOptions, 37 | XSSProtection: config.XSSProtection, 38 | } 39 | } 40 | 41 | // Middleware is an HTTP middleware that sets various security headers on incoming HTTP responses. 42 | // It enhances the security posture of the server by configuring headers like HSTS, X-Frame-Options, 43 | // X-Content-Type-Options, and X-XSS-Protection based on the ServerSecurity settings. 44 | func (s *ServerSecurity) Middleware(next http.Handler) http.Handler { 45 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 46 | if s.HSTS { 47 | value := fmt.Sprintf("max-age=%d", s.HSTSMaxAge) 48 | if s.HSTSIncludeSubDomains { 49 | value += "; includeSubDomains" 50 | } 51 | if s.HSTSPreload { 52 | value += "; preload" 53 | } 54 | w.Header().Set("Strict-Transport-Security", value) 55 | } 56 | 57 | if s.FrameOptions != "" { 58 | w.Header().Set("X-Frame-Options", s.FrameOptions) 59 | } 60 | if s.ContentTypeOptions { 61 | w.Header().Set("X-Content-Type-Options", "nosniff") 62 | } 63 | if s.XSSProtection { 64 | w.Header().Set("X-XSS-Protection", "1; mode=block") 65 | } 66 | next.ServeHTTP(w, r) 67 | }) 68 | } 69 | -------------------------------------------------------------------------------- /internal/auth/middleware/auth_middleware.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strings" 7 | "sync" 8 | 9 | apierr "github.com/unkn0wn-root/terraster/internal/auth" 10 | "github.com/unkn0wn-root/terraster/internal/auth/service" 11 | "golang.org/x/time/rate" 12 | ) 13 | 14 | type AuthMiddleware struct { 15 | authService *service.AuthService 16 | rateLimiter *RateLimiter 17 | } 18 | 19 | func NewAuthMiddleware(authService *service.AuthService) *AuthMiddleware { 20 | return &AuthMiddleware{ 21 | authService: authService, 22 | rateLimiter: NewRateLimiter(10, 30), // 10 requests per second, burst of 30 23 | } 24 | } 25 | 26 | func (m *AuthMiddleware) Authenticate(next http.Handler) http.Handler { 27 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 28 | if !m.rateLimiter.Allow(r.RemoteAddr) { 29 | http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) 30 | return 31 | } 32 | 33 | authHeader := r.Header.Get("Authorization") 34 | if authHeader == "" { 35 | http.Error(w, "No token provided", http.StatusUnauthorized) 36 | return 37 | } 38 | 39 | tokenParts := strings.Split(authHeader, " ") 40 | if len(tokenParts) != 2 || tokenParts[0] != "Bearer" { 41 | http.Error(w, "Invalid authorization header", http.StatusUnauthorized) 42 | return 43 | } 44 | 45 | claims, err := m.authService.ValidateToken(tokenParts[1]) 46 | if err != nil { 47 | switch err { 48 | case apierr.ErrRevokedToken: 49 | http.Error(w, "Token has been revoked", http.StatusUnauthorized) 50 | case apierr.ErrInvalidToken: 51 | http.Error(w, "Invalid token", http.StatusUnauthorized) 52 | default: 53 | http.Error(w, "Authentication failed", http.StatusUnauthorized) 54 | } 55 | return 56 | } 57 | 58 | ctx := context.WithValue(r.Context(), "user_claims", claims) 59 | 60 | w.Header().Set("X-Content-Type-Options", "nosniff") 61 | w.Header().Set("X-Frame-Options", "DENY") 62 | w.Header().Set("X-XSS-Protection", "1; mode=block") 63 | w.Header().Set("Content-Security-Policy", "default-src 'self'") 64 | w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") 65 | 66 | next.ServeHTTP(w, r.WithContext(ctx)) 67 | }) 68 | } 69 | 70 | type RateLimiter struct { 71 | limiters map[string]*rate.Limiter 72 | mu sync.RWMutex 73 | rate rate.Limit 74 | burst int 75 | } 76 | 77 | func NewRateLimiter(requestsPerSecond, burst int) *RateLimiter { 78 | return &RateLimiter{ 79 | limiters: make(map[string]*rate.Limiter), 80 | rate: rate.Limit(requestsPerSecond), 81 | burst: burst, 82 | } 83 | } 84 | 85 | func (rl *RateLimiter) Allow(key string) bool { 86 | rl.mu.Lock() 87 | limiter, exists := rl.limiters[key] 88 | if !exists { 89 | limiter = rate.NewLimiter(rl.rate, rl.burst) 90 | rl.limiters[key] = limiter 91 | } 92 | rl.mu.Unlock() 93 | 94 | return limiter.Allow() 95 | } 96 | -------------------------------------------------------------------------------- /internal/middleware/circuit_breaker.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "sync" 6 | "sync/atomic" 7 | "time" 8 | ) 9 | 10 | // BackendState holds the state and failure information for a single backend. 11 | type BackendState struct { 12 | mu sync.RWMutex 13 | failures int 14 | lastFailure time.Time 15 | state atomic.Value // string: "closed", "open", "half-open" 16 | } 17 | 18 | type CircuitBreaker struct { 19 | failureThreshold int 20 | resetTimeout time.Duration 21 | backends sync.Map // map[string]*BackendState 22 | } 23 | 24 | func NewCircuitBreaker(threshold int, timeout time.Duration) *CircuitBreaker { 25 | return &CircuitBreaker{ 26 | failureThreshold: threshold, 27 | resetTimeout: timeout, 28 | backends: sync.Map{}, 29 | } 30 | } 31 | 32 | // Middleware wraps the HTTP handler with circuit breaker logic. 33 | func (cb *CircuitBreaker) Middleware(next http.Handler) http.Handler { 34 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 35 | backend := r.URL.Host 36 | if backend == "" { 37 | backend = r.Host // fallback if URL.Host is empty 38 | } 39 | 40 | bsIface, _ := cb.backends.LoadOrStore(backend, &BackendState{}) 41 | bs := bsIface.(*BackendState) 42 | if bs.state.Load() == nil { 43 | bs.state.Store("closed") 44 | } 45 | 46 | currentState := bs.state.Load().(string) 47 | if currentState == "open" { 48 | lastFailure := bs.lastFailure 49 | if time.Since(lastFailure) > cb.resetTimeout { 50 | bs.state.Store("half-open") 51 | } else { 52 | http.Error(w, "Service temporarily unavailable", http.StatusServiceUnavailable) 53 | return 54 | } 55 | } 56 | 57 | sw := &statusWriter{ResponseWriter: w} 58 | next.ServeHTTP(sw, r) 59 | 60 | if sw.status >= 500 { 61 | cb.recordFailure(backend, bs) 62 | } else if sw.status > 0 { 63 | cb.recordSuccess(backend, bs) 64 | } 65 | }) 66 | } 67 | 68 | // increments the failure count and updates the state if necessary. 69 | func (cb *CircuitBreaker) recordFailure(backend string, bs *BackendState) { 70 | bs.mu.Lock() 71 | defer bs.mu.Unlock() 72 | 73 | // Reset failure count if last failure was before resetTimeout 74 | if time.Since(bs.lastFailure) > cb.resetTimeout { 75 | bs.failures = 0 76 | } 77 | 78 | bs.failures++ 79 | bs.lastFailure = time.Now() 80 | 81 | if bs.failures >= cb.failureThreshold { 82 | bs.state.Store("open") 83 | } 84 | } 85 | 86 | // decrements the failure count or resets the state based on current state. 87 | func (cb *CircuitBreaker) recordSuccess(backend string, bs *BackendState) { 88 | bs.mu.Lock() 89 | defer bs.mu.Unlock() 90 | 91 | currentState := bs.state.Load().(string) 92 | 93 | if currentState == "half-open" { 94 | bs.state.Store("closed") 95 | bs.failures = 0 96 | } else if currentState == "closed" && bs.failures > 0 { 97 | bs.failures-- 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /internal/middleware/compression.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bufio" 5 | "compress/gzip" 6 | "fmt" 7 | "io" 8 | "net" 9 | "net/http" 10 | "strings" 11 | ) 12 | 13 | // CompressionMiddleware provides response compression functionality to HTTP handlers. 14 | type CompressionMiddleware struct{} 15 | 16 | func NewCompressionMiddleware() Middleware { 17 | return &CompressionMiddleware{} 18 | } 19 | 20 | // Middleware is the core function that applies response compression to HTTP responses. 21 | // It wraps the next handler in the chain, enabling gzip compression for eligible responses. 22 | // For each incoming request, the middleware checks if the client accepts gzip encoding by inspecting 23 | // the "Accept-Encoding" header. 24 | // If gzip is supported, it wraps the ResponseWriter with a gzip.Writer to compress the response. 25 | // It sets the "Content-Encoding" header to "gzip" and removes the "Content-Length" header since 26 | // the length of the compressed response is not known in advance. 27 | // If gzip is not supported, it forwards the request to the next handler without modifying the response. 28 | func (c *CompressionMiddleware) Middleware(next http.Handler) http.Handler { 29 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 30 | if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { 31 | next.ServeHTTP(w, r) 32 | return 33 | } 34 | 35 | gz := gzip.NewWriter(w) 36 | defer gz.Close() 37 | 38 | w.Header().Set("Content-Encoding", "gzip") 39 | w.Header().Del("Content-Length") 40 | 41 | compressedWriter := compressionWriter{ 42 | Writer: gz, 43 | ResponseWriter: w, 44 | } 45 | 46 | next.ServeHTTP(compressedWriter, r) 47 | }) 48 | } 49 | 50 | // compressionWriter is a custom ResponseWriter that wraps the original ResponseWriter 51 | // and an io.Writer (specifically a gzip.Writer). It overrides the Write method to ensure 52 | // that data is written through the gzip.Writer, enabling response compression. 53 | type compressionWriter struct { 54 | io.Writer 55 | http.ResponseWriter 56 | } 57 | 58 | // Write overrides the default Write method to write compressed data. 59 | // It writes the byte slice 'b' to the embedded io.Writer, which compresses the data 60 | // before sending it to the client. 61 | func (c compressionWriter) Write(b []byte) (int, error) { 62 | return c.Writer.Write(b) // Delegate the write operation to the embedded io.Writer. 63 | } 64 | 65 | // Flush allows the compressionWriter to support flushing of the response. 66 | func (c compressionWriter) Flush() { 67 | if flusher, ok := c.ResponseWriter.(http.Flusher); ok { 68 | flusher.Flush() 69 | } 70 | } 71 | 72 | // Hijack allows the compressionWriter to support connection hijacking. 73 | func (c compressionWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 74 | if hijacker, ok := c.ResponseWriter.(http.Hijacker); ok { 75 | return hijacker.Hijack() 76 | } 77 | return nil, nil, fmt.Errorf("upstream ResponseWriter does not implement http.Hijacker") 78 | } 79 | -------------------------------------------------------------------------------- /cmd/terraster/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/unkn0wn-root/terraster/internal/auth/database" 9 | "github.com/unkn0wn-root/terraster/internal/auth/service" 10 | "github.com/unkn0wn-root/terraster/internal/config" 11 | "github.com/unkn0wn-root/terraster/internal/server" 12 | "github.com/unkn0wn-root/terraster/pkg/logger" 13 | "go.uber.org/zap" 14 | ) 15 | 16 | type ServerBuilder struct { 17 | config *config.Terraster 18 | apiConfig *config.APIConfig 19 | logger *zap.Logger 20 | logManager *logger.LoggerManager 21 | } 22 | 23 | func NewServerBuilder( 24 | cfg *config.Terraster, 25 | apiCfg *config.APIConfig, 26 | logger *zap.Logger, 27 | logManager *logger.LoggerManager, 28 | ) *ServerBuilder { 29 | return &ServerBuilder{ 30 | config: cfg, 31 | apiConfig: apiCfg, 32 | logger: logger, 33 | logManager: logManager, 34 | } 35 | } 36 | 37 | // BuildServer constructs the server with all necessary components 38 | func (sb *ServerBuilder) BuildServer(ctx context.Context, errChan chan<- error) (*server.Server, error) { 39 | var db *database.SQLiteDB 40 | var authService *service.AuthService 41 | 42 | if sb.apiConfig.API.Enabled { 43 | var err error 44 | db, err = database.NewSQLiteDB(sb.apiConfig.AdminDatabase.Path) 45 | if err != nil { 46 | return nil, fmt.Errorf("failed to initialize database: %w", err) 47 | } 48 | 49 | authService = service.NewAuthService(db, sb.buildAuthConfig()) 50 | } 51 | 52 | srv, err := server.NewServer( 53 | ctx, 54 | errChan, 55 | sb.config, 56 | sb.apiConfig, 57 | authService, 58 | sb.logger, 59 | sb.logManager, 60 | ) 61 | 62 | if err != nil { 63 | if authService != nil { 64 | authService.Close() 65 | } 66 | return nil, err 67 | } 68 | 69 | return srv, nil 70 | } 71 | 72 | func initializeServer( 73 | ctx context.Context, 74 | cfg *config.Terraster, 75 | apiConfig *config.APIConfig, 76 | errChan chan error, 77 | logger *zap.Logger, 78 | logManager *logger.LoggerManager, 79 | ) *server.Server { 80 | builder := NewServerBuilder(cfg, apiConfig, logger, logManager) 81 | srv, err := builder.BuildServer(ctx, errChan) 82 | if err != nil { 83 | logger.Fatal("Failed to initialize server", zap.Error(err)) 84 | } 85 | 86 | return srv 87 | } 88 | 89 | func (sb *ServerBuilder) buildAuthConfig() service.AuthConfig { 90 | return service.AuthConfig{ 91 | JWTSecret: []byte(sb.apiConfig.AdminAuth.JWTSecret), 92 | TokenExpiry: 15 * time.Minute, 93 | RefreshTokenExpiry: 7 * 24 * time.Hour, 94 | MaxLoginAttempts: 5, 95 | LockDuration: 15 * time.Minute, 96 | MaxActiveTokens: 5, 97 | TokenCleanupInterval: 7 * time.Hour, 98 | PasswordMinLength: 12, 99 | RequireUppercase: true, 100 | RequireNumber: true, 101 | RequireSpecialChar: true, 102 | PasswordExpiryDays: sb.apiConfig.AdminAuth.PasswordExpiryDays, 103 | PasswordHistoryLimit: sb.apiConfig.AdminAuth.PasswordHistoryLimit, 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /pkg/algorithm/adaptive.go: -------------------------------------------------------------------------------- 1 | package algorithm 2 | 3 | import ( 4 | "net/http" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | type AdaptiveLoadBalancer struct { 10 | mu sync.RWMutex 11 | algorithms map[string]Algorithm 12 | stats map[string]*AlgorithmStats 13 | currentAlgo string 14 | evaluationInterval time.Duration 15 | } 16 | 17 | type AlgorithmStats struct { 18 | ResponseTimes []time.Duration 19 | ErrorCount int 20 | RequestCount int 21 | } 22 | 23 | func NewAdaptiveLoadBalancer() *AdaptiveLoadBalancer { 24 | alb := &AdaptiveLoadBalancer{ 25 | algorithms: map[string]Algorithm{ 26 | "round-robin": &RoundRobin{}, 27 | "least-conn": &LeastConnections{}, 28 | "response-time": NewLeastResponseTime(), 29 | }, 30 | stats: make(map[string]*AlgorithmStats), 31 | currentAlgo: "round-robin", 32 | evaluationInterval: time.Minute, 33 | } 34 | 35 | go alb.periodicEvaluation() 36 | return alb 37 | } 38 | 39 | func (alb *AdaptiveLoadBalancer) NextServer(pool ServerPool, r *http.Request, w *http.ResponseWriter) *Server { 40 | alb.mu.RLock() 41 | algo := alb.algorithms[alb.currentAlgo] 42 | alb.mu.RUnlock() 43 | 44 | return algo.NextServer(pool, r, w) 45 | } 46 | 47 | func (alb *AdaptiveLoadBalancer) RecordMetrics(algorithm string, responseTime time.Duration, isError bool) { 48 | alb.mu.Lock() 49 | defer alb.mu.Unlock() 50 | 51 | if _, exists := alb.stats[algorithm]; !exists { 52 | alb.stats[algorithm] = &AlgorithmStats{} 53 | } 54 | 55 | stats := alb.stats[algorithm] 56 | stats.ResponseTimes = append(stats.ResponseTimes, responseTime) 57 | stats.RequestCount++ 58 | if isError { 59 | stats.ErrorCount++ 60 | } 61 | } 62 | 63 | func (alb *AdaptiveLoadBalancer) periodicEvaluation() { 64 | ticker := time.NewTicker(alb.evaluationInterval) 65 | for range ticker.C { 66 | alb.evaluateAlgorithms() 67 | } 68 | } 69 | 70 | func (alb *AdaptiveLoadBalancer) evaluateAlgorithms() { 71 | alb.mu.Lock() 72 | defer alb.mu.Unlock() 73 | 74 | var bestAlgo string 75 | bestScore := 0.0 76 | 77 | for algo, stats := range alb.stats { 78 | score := alb.calculateScore(stats) 79 | if score > bestScore { 80 | bestScore = score 81 | bestAlgo = algo 82 | } 83 | } 84 | 85 | if bestAlgo != "" { 86 | alb.currentAlgo = bestAlgo 87 | } 88 | 89 | // Reset stats 90 | alb.stats = make(map[string]*AlgorithmStats) 91 | } 92 | 93 | func (alb *AdaptiveLoadBalancer) calculateScore(stats *AlgorithmStats) float64 { 94 | if stats.RequestCount == 0 { 95 | return 0 96 | } 97 | 98 | // Calculate average response time 99 | var totalTime time.Duration 100 | for _, rt := range stats.ResponseTimes { 101 | totalTime += rt 102 | } 103 | avgResponseTime := totalTime / time.Duration(len(stats.ResponseTimes)) 104 | 105 | // Calculate error rate 106 | errorRate := float64(stats.ErrorCount) / float64(stats.RequestCount) 107 | 108 | // Score formula: higher is better 109 | // We want low response times and low error rates 110 | responseTimeScore := 1.0 / float64(avgResponseTime) 111 | errorScore := 1.0 - errorRate 112 | 113 | return responseTimeScore*0.7 + errorScore*0.3 114 | } 115 | -------------------------------------------------------------------------------- /internal/pool/headers.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "net/http" 5 | "path/filepath" 6 | "strings" 7 | 8 | "github.com/unkn0wn-root/terraster/internal/config" 9 | ) 10 | 11 | // HeaderHandler manages request and response header modifications 12 | type HeaderHandler struct { 13 | headerConfig config.Header 14 | placeholders map[string]func(*http.Request) string 15 | } 16 | 17 | // NewHeaderHandler creates a new HeaderHandler 18 | func NewHeaderHandler(cfg config.Header) *HeaderHandler { 19 | return &HeaderHandler{ 20 | headerConfig: cfg, 21 | placeholders: map[string]func(*http.Request) string{ 22 | "${remote_addr}": func(r *http.Request) string { return r.RemoteAddr }, 23 | "${host}": func(r *http.Request) string { return r.Host }, 24 | "${uri}": func(r *http.Request) string { return r.RequestURI }, 25 | "${method}": func(r *http.Request) string { return r.Method }, 26 | }, 27 | } 28 | } 29 | 30 | // ProcessRequestHeaders modifies the request headers 31 | func (h *HeaderHandler) ProcessRequestHeaders(req *http.Request) { 32 | for _, header := range h.headerConfig.RemoveRequestHeaders { 33 | req.Header.Del(header) 34 | } 35 | 36 | for key, value := range h.headerConfig.RequestHeaders { 37 | processedValue := h.processPlaceholders(value, req) 38 | req.Header.Set(key, processedValue) 39 | } 40 | } 41 | 42 | // ProcessResponseHeaders modifies the response headers 43 | func (h *HeaderHandler) ProcessResponseHeaders(resp *http.Response) { 44 | for _, header := range h.headerConfig.RemoveResponseHeaders { 45 | resp.Header.Del(header) 46 | } 47 | 48 | for key, value := range h.headerConfig.ResponseHeaders { 49 | processedValue := h.processPlaceholders(value, resp.Request) 50 | resp.Header.Set(key, processedValue) 51 | } 52 | } 53 | 54 | // processPlaceholders replaces placeholder values with actual request values 55 | func (h *HeaderHandler) processPlaceholders(value string, req *http.Request) string { 56 | if req == nil { 57 | return value 58 | } 59 | 60 | result := value 61 | for placeholder, getter := range h.placeholders { 62 | if strings.Contains(value, placeholder) { 63 | result = strings.ReplaceAll(result, placeholder, getter(req)) 64 | } 65 | } 66 | 67 | return result 68 | } 69 | 70 | // TypeByURLPath checks if provided URL path (image123.jpg) is in whitelised extensions 71 | func TypeByURLPath(path string) string { 72 | ext := filepath.Ext(path) 73 | switch ext { 74 | case ".html", ".htm": 75 | return "text/html; charset=utf-8" 76 | case ".css": 77 | return "text/css; charset=utf-8" 78 | case ".js": 79 | return "application/javascript" 80 | case ".jpg", ".jpeg": 81 | return "image/jpeg" 82 | case ".png": 83 | return "image/png" 84 | case ".gif": 85 | return "image/gif" 86 | case ".pdf": 87 | return "application/pdf" 88 | case ".doc": 89 | return "application/msword" 90 | case ".docx": 91 | return "application/vnd.openxmlformats-officedocument.wordprocessingml.document" 92 | case ".xls": 93 | return "application/vnd.ms-excel" 94 | case ".xlsx": 95 | return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" 96 | default: 97 | return "application/octet-stream" // fallback to octet-stream if we can't determinate content type 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /internal/pool/backend.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "net/url" 5 | "sync/atomic" 6 | 7 | "github.com/unkn0wn-root/terraster/internal/config" 8 | ) 9 | 10 | type Backend struct { 11 | URL *url.URL // The URL of the backend server, including scheme, host, and port. 12 | Host string // The hostname extracted from the URL, used for logging and identification. 13 | Alive atomic.Bool // Atomic flag indicating whether the backend is currently alive and reachable. 14 | Weight int // The weight assigned to the backend for load balancing purposes. 15 | CurrentWeight atomic.Int32 // The current weight used in certain load balancing algorithms (e.g., weighted round-robin). 16 | Proxy *URLRewriteProxy // The proxy instance responsible for handling HTTP requests to this backend. 17 | ConnectionCount int32 // The current number of active connections to this backend. 18 | MaxConnections int32 // The maximum number of concurrent connections allowed to this backend. 19 | SuccessCount int32 // The total number of successful requests processed by this backend. 20 | FailureCount int32 // The total number of failed requests processed by this backend. 21 | HealthCheckCfg *config.HealthCheck // Configuration settings for health checks specific to this backend. 22 | } 23 | 24 | // GetURL returns the string representation of the backend's URL. 25 | func (b *Backend) GetURL() string { 26 | return b.URL.String() 27 | } 28 | 29 | // GetWeight retrieves the current weight assigned to the backend. 30 | // The weight influences the load balancing decision, determining the proportion of traffic this backend receives. 31 | func (b *Backend) GetWeight() int { 32 | return b.Weight 33 | } 34 | 35 | // GetCurrentWeight fetches the current weight of the backend. 36 | func (b *Backend) GetCurrentWeight() int { 37 | return int(b.CurrentWeight.Load()) 38 | } 39 | 40 | // SetCurrentWeight sets the current weight of the backend to the specified value. 41 | func (b *Backend) SetCurrentWeight(weight int) { 42 | b.CurrentWeight.Store(int32(weight)) 43 | } 44 | 45 | // GetConnectionCount returns the current number of active connections to the backend. 46 | func (b *Backend) GetConnectionCount() int { 47 | return int(atomic.LoadInt32(&b.ConnectionCount)) 48 | } 49 | 50 | // IsAlive checks whether the backend is currently marked as alive. 51 | func (b *Backend) IsAlive() bool { 52 | return b.Alive.Load() 53 | } 54 | 55 | // SetAlive updates the alive status of the backend. 56 | func (b *Backend) SetAlive(alive bool) { 57 | b.Alive.Store(alive) 58 | } 59 | 60 | // IncrementConnections attempts to increment the active connection count for the backend. 61 | // It ensures that the connection count does not exceed the maximum allowed. 62 | func (b *Backend) IncrementConnections() bool { 63 | for { 64 | current := atomic.LoadInt32(&b.ConnectionCount) 65 | if current >= int32(b.MaxConnections) { 66 | return false 67 | } 68 | 69 | if atomic.CompareAndSwapInt32(&b.ConnectionCount, current, current+1) { 70 | return true 71 | } 72 | } 73 | } 74 | 75 | // DecrementConnections decrements the active connection count for the backend. 76 | // This should be called when a connection to the backend is closed or terminated. 77 | // It ensures that the connection count accurately reflects the current load. 78 | func (b *Backend) DecrementConnections() { 79 | atomic.AddInt32(&b.ConnectionCount, -1) 80 | } 81 | -------------------------------------------------------------------------------- /cmd/terraster/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "log" 7 | "os" 8 | "os/signal" 9 | "strings" 10 | "syscall" 11 | "time" 12 | 13 | "github.com/unkn0wn-root/terraster/internal/config" 14 | "github.com/unkn0wn-root/terraster/internal/server" 15 | "github.com/unkn0wn-root/terraster/pkg/logger" 16 | "go.uber.org/zap" 17 | ) 18 | 19 | func main() { 20 | configPath := flag.String("config", "config.yaml", "path to main config file") 21 | servicesDir := flag.String("services", "", "optional directory containing services configurations") 22 | apiConfigPath := flag.String("api-config", "api.config.yaml", "path to API config file") 23 | customLogConfigs := flag.String("log-config", "", "comma-separated paths to custom provided log config files") 24 | flag.Parse() 25 | 26 | logManager, logger := initializeLogging(customLogConfigs) 27 | defer syncLoggers(logManager) 28 | 29 | cfg, apiConfig := loadConfigs(configPath, servicesDir, apiConfigPath, logger) 30 | 31 | errChan := make(chan error, 1) 32 | ctx, cancel := context.WithCancel(context.Background()) 33 | defer cancel() 34 | 35 | srv := initializeServer(ctx, cfg, apiConfig, errChan, logger, logManager) 36 | runServer(ctx, cancel, srv, errChan, logger) 37 | 38 | } 39 | 40 | // initializeLogging initializes the logger manager and retrieves the main logger. 41 | func initializeLogging(customLogConfigs *string) (*logger.LoggerManager, *zap.Logger) { 42 | logConfigPaths := []string{"log.config.json"} 43 | if *customLogConfigs != "" { 44 | customConfigs := strings.Split(*customLogConfigs, ",") 45 | for _, customConfig := range customConfigs { 46 | if tp := strings.TrimSpace(customConfig); tp != "" { 47 | logConfigPaths = append(logConfigPaths, tp) 48 | } 49 | } 50 | } 51 | 52 | logManager, err := logger.NewLoggerManager(logConfigPaths) 53 | if err != nil { 54 | log.Fatalf("Failed to initialize logger: %v", err) 55 | } 56 | 57 | logger, err := logManager.GetLogger("terraster") 58 | if err != nil { 59 | log.Fatalf("Failed to get logger: %v", err) 60 | } 61 | return logManager, logger 62 | } 63 | 64 | // Ensure that all logger buffers are flushed before the application exits. 65 | func syncLoggers(logManager *logger.LoggerManager) { 66 | if err := logManager.Sync(); err != nil { 67 | log.Fatalf("Failed to sync loggers: %s", err) 68 | } 69 | } 70 | 71 | // Load and merge all configuration files 72 | func loadConfigs(configPath, servicesDir, apiConfigPath *string, logger *zap.Logger) (*config.Terraster, *config.APIConfig) { 73 | cfg, err := config.MergeConfigs(*configPath, *servicesDir, logger) 74 | if err != nil { 75 | logger.Fatal("Failed to load and merge configs", zap.Error(err)) 76 | } 77 | 78 | configManager := NewConfigManager(logger) 79 | apiConfig := configManager.LoadAPIConfig(*apiConfigPath) 80 | return cfg, apiConfig 81 | } 82 | 83 | // runServerWithGracefulShutdown starts the server and listens for shutdown signals 84 | func runServer( 85 | ctx context.Context, 86 | cancel context.CancelFunc, 87 | srv *server.Server, 88 | errChan chan error, 89 | logger *zap.Logger, 90 | ) { 91 | go func() { 92 | if err := srv.Start(); err != nil { 93 | errChan <- err 94 | } 95 | }() 96 | 97 | sigChan := make(chan os.Signal, 1) 98 | signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) 99 | select { 100 | case <-sigChan: 101 | logger.Warn("Shutdown signal received. Initializing graceful shutdown") 102 | cancel() 103 | case err := <-errChan: 104 | logger.Fatal("Server error triggered shutdown", zap.Error(err)) 105 | case <-ctx.Done(): 106 | return 107 | } 108 | 109 | shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second) 110 | defer shutdownCancel() 111 | if err := srv.Shutdown(shutdownCtx); err != nil && err != context.Canceled { 112 | logger.Fatal("Error during shutdown", zap.Error(err)) 113 | } 114 | logger.Info("Server shutdown completed") 115 | } 116 | -------------------------------------------------------------------------------- /internal/admin/api.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "net/http" 5 | "net/http/pprof" 6 | 7 | "github.com/golang-jwt/jwt/v4" 8 | "github.com/unkn0wn-root/terraster/internal/auth/handlers" 9 | "github.com/unkn0wn-root/terraster/internal/auth/models" 10 | auth_service "github.com/unkn0wn-root/terraster/internal/auth/service" 11 | "github.com/unkn0wn-root/terraster/internal/config" 12 | "github.com/unkn0wn-root/terraster/internal/service" 13 | "go.uber.org/zap" 14 | ) 15 | 16 | // AdminAPI represents the administrative API for managing the load balancer. 17 | type AdminAPI struct { 18 | enabled bool 19 | serviceManager *service.Manager 20 | mux *http.ServeMux 21 | config *config.APIConfig 22 | authService *auth_service.AuthService 23 | authHandler *handlers.AuthHandler 24 | logger *zap.Logger 25 | } 26 | 27 | // NewAdminAPI creates a new instance of AdminAPI with the provided service manager and configuration. 28 | func NewAdminAPI( 29 | manager *service.Manager, 30 | cfg *config.APIConfig, 31 | authService *auth_service.AuthService, 32 | logger *zap.Logger, 33 | ) *AdminAPI { 34 | api := &AdminAPI{ 35 | enabled: cfg.API.Enabled, 36 | serviceManager: manager, 37 | mux: http.NewServeMux(), 38 | config: cfg, 39 | authService: authService, 40 | authHandler: handlers.NewAuthHandler(authService), 41 | logger: logger, 42 | } 43 | api.registerRoutes() 44 | return api 45 | } 46 | 47 | // registerRoutes sets up the HTTP handlers for various administrative endpoints. 48 | func (a *AdminAPI) registerRoutes() { 49 | // Auth routes 50 | a.mux.HandleFunc("/api/auth/login", a.authHandler.Login) 51 | a.mux.HandleFunc("/api/auth/refresh", a.authHandler.RefreshToken) 52 | 53 | // Protected routes 54 | a.mux.Handle("/api/auth/change-password", 55 | a.requireAuth(http.HandlerFunc(a.authHandler.ChangePassword))) 56 | 57 | // Admin-only routes 58 | a.mux.Handle("/api/backends", a.registerMiddleware(models.RoleAdmin, http.HandlerFunc(a.handleBackends))) 59 | a.mux.Handle("/api/config", a.registerMiddleware(models.RoleAdmin, http.HandlerFunc(a.handleConfig))) 60 | 61 | // Admin-only debug route if enabled in config 62 | if a.config.API.Debug { 63 | a.mux.Handle("/debug/pprof/", a.registerMiddleware(models.RoleAdmin, http.HandlerFunc(pprof.Index))) 64 | a.mux.Handle("/debug/pprof/cmdline", a.registerMiddleware(models.RoleAdmin, http.HandlerFunc(pprof.Cmdline))) 65 | a.mux.Handle("/debug/pprof/profile", a.registerMiddleware(models.RoleAdmin, http.HandlerFunc(pprof.Profile))) 66 | a.mux.Handle("/debug/pprof/symbol", a.registerMiddleware(models.RoleAdmin, http.HandlerFunc(pprof.Symbol))) 67 | a.mux.Handle("/debug/pprof/trace", a.registerMiddleware(models.RoleAdmin, http.HandlerFunc(pprof.Trace))) 68 | a.mux.Handle("/debug/pprof/heap", a.registerMiddleware(models.RoleAdmin, pprof.Handler("heap"))) 69 | a.mux.Handle("/debug/pprof/goroutine", a.registerMiddleware(models.RoleAdmin, pprof.Handler("goroutine"))) 70 | } 71 | 72 | // Reader routes 73 | a.mux.Handle("/api/services", a.registerMiddleware(models.RoleReader, http.HandlerFunc(a.handleServices))) 74 | a.mux.Handle("/api/health", a.registerMiddleware(models.RoleReader, http.HandlerFunc(a.handleHealth))) 75 | a.mux.Handle("/api/stats", a.registerMiddleware(models.RoleReader, http.HandlerFunc(a.handleStats))) 76 | a.mux.Handle("/api/locations", a.registerMiddleware(models.RoleReader, http.HandlerFunc(a.handleLocations))) 77 | } 78 | 79 | func (a *AdminAPI) requireRole(role models.Role, next http.Handler) http.Handler { 80 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 81 | claims := r.Context().Value("user_claims").(*jwt.MapClaims) 82 | userRole := models.Role((*claims)["role"].(string)) 83 | 84 | if userRole != role && userRole != models.RoleAdmin { 85 | http.Error(w, "Forbidden", http.StatusForbidden) 86 | return 87 | } 88 | 89 | next.ServeHTTP(w, r) 90 | }) 91 | } 92 | 93 | func (a *AdminAPI) registerMiddleware(role models.Role, h http.Handler) http.Handler { 94 | return a.requireAuthStrict(a.requireRole(role, h)) 95 | } 96 | -------------------------------------------------------------------------------- /scripts/database/api_util.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "os" 8 | "time" 9 | 10 | "github.com/unkn0wn-root/terraster/internal/auth/database" 11 | "github.com/unkn0wn-root/terraster/internal/auth/models" 12 | "github.com/unkn0wn-root/terraster/internal/auth/service" 13 | "github.com/unkn0wn-root/terraster/internal/config" 14 | ) 15 | 16 | type Config struct { 17 | DBPath string 18 | JWTSecret string 19 | PasswordMinLength int 20 | TokenCleanupInterval string 21 | RequireUppercase bool 22 | RequireNumber bool 23 | RequireSpecialChar bool 24 | } 25 | 26 | func main() { 27 | var ( 28 | username = flag.String("username", "", "Username for the new user") 29 | password = flag.String("password", "", "Password for the new user") 30 | role = flag.String("role", "reader", "Role for the new user (admin or reader)") 31 | listUsers = flag.Bool("list", false, "List all users") 32 | configPath = flag.String("config", "./api.config.yaml", "Path to configuration file") 33 | ) 34 | flag.Parse() 35 | if *configPath == "" { 36 | log.Fatalf("Config file path is required") 37 | } 38 | 39 | cfg, err := config.LoadAPIConfig(*configPath) 40 | if err != nil { 41 | log.Fatalf("Failed to load configuration: %v", err) 42 | } 43 | 44 | // Initialize configuration 45 | apiCfg := Config{ 46 | DBPath: cfg.AdminDatabase.Path, 47 | JWTSecret: cfg.AdminAuth.JWTSecret, 48 | PasswordMinLength: cfg.AdminAuth.PasswordMinLength, 49 | TokenCleanupInterval: cfg.AdminAuth.TokenCleanupInterval, 50 | RequireUppercase: true, 51 | RequireNumber: true, 52 | RequireSpecialChar: true, 53 | } 54 | 55 | // Initialize database 56 | db, err := database.NewSQLiteDB(apiCfg.DBPath) 57 | if err != nil { 58 | log.Fatalf("Failed to initialize database: %v", err) 59 | } 60 | 61 | tokenDuration, err := time.ParseDuration(apiCfg.TokenCleanupInterval) 62 | if err != nil { 63 | tokenDuration = 24 * time.Hour 64 | } 65 | 66 | // Initialize auth service 67 | authService := service.NewAuthService(db, service.AuthConfig{ 68 | JWTSecret: []byte(apiCfg.JWTSecret), 69 | TokenExpiry: 7 * 24 * 60 * 60, // 7 days 70 | TokenCleanupInterval: tokenDuration, 71 | RefreshTokenExpiry: 7 * 24 * time.Hour, // 7-day refresh token 72 | MaxLoginAttempts: 5, 73 | LockDuration: 15 * 60, // 15 minutes 74 | MaxActiveTokens: 5, 75 | PasswordMinLength: apiCfg.PasswordMinLength, 76 | RequireUppercase: apiCfg.RequireUppercase, 77 | RequireNumber: apiCfg.RequireNumber, 78 | RequireSpecialChar: apiCfg.RequireSpecialChar, 79 | }) 80 | 81 | // Handle list users command 82 | if *listUsers { 83 | if err := listAllUsers(db); err != nil { 84 | log.Fatalf("Failed to list users: %v", err) 85 | } 86 | return 87 | } 88 | 89 | // Validate inputs for user creation 90 | if *username == "" || *password == "" { 91 | flag.Usage() 92 | os.Exit(1) 93 | } 94 | 95 | // Validate role 96 | userRole := models.Role(*role) 97 | if userRole != models.RoleAdmin && userRole != models.RoleReader { 98 | log.Fatalf("Invalid role. Must be 'admin' or 'reader'") 99 | } 100 | 101 | // Create user 102 | err = authService.CreateUser(*username, *password, userRole) 103 | if err != nil { 104 | log.Fatalf("Failed to create user: %v", err) 105 | } 106 | 107 | fmt.Printf("Successfully created user '%s' with role '%s'\n", *username, *role) 108 | } 109 | 110 | func listAllUsers(db *database.SQLiteDB) error { 111 | users, err := db.ListUsers() 112 | if err != nil { 113 | return err 114 | } 115 | 116 | if len(users) == 0 { 117 | fmt.Println("No users found in database") 118 | return nil 119 | } 120 | 121 | fmt.Println("\nUser List:") 122 | fmt.Println("----------------------------------------") 123 | fmt.Printf("%-5s %-20s %-10s %-20s\n", "ID", "Username", "Role", "Created At") 124 | fmt.Println("----------------------------------------") 125 | 126 | for _, user := range users { 127 | fmt.Printf("%-5d %-20s %-10s %-20s\n", 128 | user.ID, 129 | user.Username, 130 | user.Role, 131 | user.CreatedAt.Format("2006-01-02 15:04:05"), 132 | ) 133 | } 134 | fmt.Println("----------------------------------------") 135 | return nil 136 | } 137 | -------------------------------------------------------------------------------- /pkg/algorithm/session_affinity.go: -------------------------------------------------------------------------------- 1 | package algorithm 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/binary" 6 | "hash/fnv" 7 | "net/http" 8 | "strconv" 9 | "time" 10 | ) 11 | 12 | const ( 13 | // sessionCookie is the cookie identifier for client session tracking 14 | sessionCookie = "t_px__SESSION_ID" 15 | // defaultSessionTTL defines the default cookie lifetime 16 | defaultSessionTTL = 24 * time.Hour 17 | ) 18 | 19 | type SessionAffinity struct { 20 | fallbackAlgo Algorithm // alternative algorithm when sticky session fails 21 | sessionTTL time.Duration // session cookie lifetime 22 | useSecure bool // secure cookie settings 23 | } 24 | 25 | // NewSessionAffinity creates a new sticky session manager with default settings 26 | func NewSessionAffinity() *SessionAffinity { 27 | return &SessionAffinity{ 28 | fallbackAlgo: &RoundRobin{}, 29 | sessionTTL: defaultSessionTTL, 30 | useSecure: true, 31 | } 32 | } 33 | 34 | // Name returns the identifier of the session affinity 35 | func (ss *SessionAffinity) Name() string { 36 | return "sticky-session" 37 | } 38 | 39 | // NextServer selects the appropriate backend server based on session stickiness 40 | func (ss *SessionAffinity) NextServer(pool ServerPool, r *http.Request, w *http.ResponseWriter) *Server { 41 | servers := pool.GetBackends() 42 | if len(servers) == 0 { 43 | return nil 44 | } 45 | 46 | cookie, err := r.Cookie(sessionCookie) 47 | if err == http.ErrNoCookie { 48 | return ss.selectNewServer(pool, r, w) 49 | } 50 | 51 | return ss.selectServerFromSession(cookie, servers, pool, r, w) 52 | } 53 | 54 | // selectNewServer handles new client connections without existing session 55 | func (ss *SessionAffinity) selectNewServer(pool ServerPool, r *http.Request, w *http.ResponseWriter) *Server { 56 | server := ss.fallbackAlgo.NextServer(pool, r, w) 57 | if server == nil { 58 | return nil 59 | } 60 | 61 | sessionID := ss.newSessionID(server.URL) 62 | http.SetCookie(*w, ss.newSessionCookie(sessionID)) 63 | return server 64 | } 65 | 66 | // selectServerFromSession handles clients with existing session cookies 67 | func (ss *SessionAffinity) selectServerFromSession( 68 | cookie *http.Cookie, 69 | servers []*Server, 70 | pool ServerPool, 71 | r *http.Request, 72 | w *http.ResponseWriter, 73 | ) *Server { 74 | sessionID, err := strconv.ParseUint(cookie.Value, 10, 64) 75 | if err != nil { 76 | return ss.selectNewServer(pool, r, w) 77 | } 78 | 79 | serverHash := uint32(sessionID >> 32) 80 | for _, s := range servers { 81 | if computeURLHash(s.URL) == serverHash { 82 | if s.Alive.Load() && s.CanAcceptConnection() { 83 | return s 84 | } 85 | break 86 | } 87 | } 88 | 89 | return ss.selectFallbackServer(pool, r, w) 90 | } 91 | 92 | // selectFallbackServer manages failover when the original server is unavailable 93 | func (ss *SessionAffinity) selectFallbackServer(pool ServerPool, r *http.Request, w *http.ResponseWriter) *Server { 94 | newServer := ss.fallbackAlgo.NextServer(pool, r, w) 95 | if newServer != nil { 96 | newSessionID := ss.newSessionID(newServer.URL) 97 | http.SetCookie(*w, ss.newSessionCookie(newSessionID)) 98 | } 99 | return newServer 100 | } 101 | 102 | // newSessionCookie creates an HTTP cookie with the session information 103 | func (ss *SessionAffinity) newSessionCookie(sessionID uint64) *http.Cookie { 104 | return &http.Cookie{ 105 | Name: sessionCookie, 106 | Value: strconv.FormatUint(sessionID, 10), 107 | Path: "/", 108 | HttpOnly: true, 109 | Secure: ss.useSecure, 110 | SameSite: http.SameSiteStrictMode, 111 | MaxAge: int(ss.sessionTTL.Seconds()), 112 | } 113 | } 114 | 115 | // newSessionID generates a unique session identifier for a server 116 | func (ss *SessionAffinity) newSessionID(serverURL string) uint64 { 117 | serverHash := computeURLHash(serverURL) 118 | nonce := generateNonce() 119 | return (uint64(serverHash) << 32) | uint64(nonce) 120 | } 121 | 122 | // computeURLHash creates a hash of the server URL for consistent mapping 123 | func computeURLHash(url string) uint32 { 124 | h := fnv.New32a() 125 | h.Write([]byte(url)) 126 | return h.Sum32() 127 | } 128 | 129 | // generateNonce creates a random 32-bit value for session uniqueness 130 | func generateNonce() uint32 { 131 | b := make([]byte, 4) 132 | _, _ = rand.Read(b) 133 | return binary.BigEndian.Uint32(b) 134 | } 135 | -------------------------------------------------------------------------------- /internal/middleware/logging.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | "time" 7 | 8 | "go.uber.org/zap" 9 | "go.uber.org/zap/zapcore" 10 | ) 11 | 12 | type LoggingMiddleware struct { 13 | logger *zap.Logger 14 | logLevel zapcore.Level 15 | includeHeaders bool 16 | includeQuery bool 17 | excludePaths []string 18 | } 19 | 20 | type LoggingOption func(*LoggingMiddleware) 21 | 22 | func WithLogLevel(level zapcore.Level) LoggingOption { 23 | return func(l *LoggingMiddleware) { 24 | l.logLevel = level 25 | } 26 | } 27 | 28 | // enables logging of request headers. 29 | func WithHeaders(enabled bool) LoggingOption { 30 | return func(l *LoggingMiddleware) { 31 | if enabled { 32 | l.includeHeaders = true 33 | } 34 | } 35 | } 36 | 37 | // enables logging of query parameters. 38 | func WithQueryParams(enabled bool) LoggingOption { 39 | return func(l *LoggingMiddleware) { 40 | if enabled { 41 | l.includeQuery = true 42 | } 43 | } 44 | } 45 | 46 | // excludes specified paths from logging. 47 | func WithExcludePaths(paths []string) LoggingOption { 48 | return func(l *LoggingMiddleware) { 49 | l.excludePaths = paths 50 | } 51 | } 52 | 53 | func NewLoggingMiddleware(logger *zap.Logger, opts ...LoggingOption) *LoggingMiddleware { 54 | lm := &LoggingMiddleware{ 55 | logger: logger, 56 | logLevel: zapcore.InfoLevel, 57 | includeHeaders: false, 58 | includeQuery: false, 59 | excludePaths: []string{}, 60 | } 61 | 62 | for _, opt := range opts { 63 | opt(lm) 64 | } 65 | 66 | return lm 67 | } 68 | 69 | // wraps http.ResponseWriter to capture status code and response size. 70 | type responseWriter struct { 71 | http.ResponseWriter 72 | status int 73 | size int64 74 | } 75 | 76 | // captures the status code. 77 | func (rw *responseWriter) WriteHeader(code int) { 78 | rw.status = code 79 | rw.ResponseWriter.WriteHeader(code) 80 | } 81 | 82 | // captures the response size. 83 | func (rw *responseWriter) Write(b []byte) (int, error) { 84 | size, err := rw.ResponseWriter.Write(b) 85 | rw.size += int64(size) 86 | return size, err 87 | } 88 | 89 | func (l *LoggingMiddleware) shouldExcludePath(path string) bool { 90 | for _, excludePath := range l.excludePaths { 91 | if strings.HasPrefix(path, excludePath) { 92 | return true 93 | } 94 | } 95 | return false 96 | } 97 | 98 | func (l *LoggingMiddleware) Middleware(next http.Handler) http.Handler { 99 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 100 | if l.shouldExcludePath(r.URL.Path) { 101 | next.ServeHTTP(w, r) 102 | return 103 | } 104 | 105 | start := time.Now() 106 | rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} 107 | next.ServeHTTP(rw, r) 108 | duration := time.Since(start) 109 | 110 | fields := make([]zap.Field, 0, 8) 111 | fields = append(fields, 112 | zap.String("method", r.Method), 113 | zap.String("path", r.URL.Path), 114 | zap.Int("status", rw.status), 115 | zap.Duration("duration", duration), 116 | zap.String("ip", getIPAddress(r)), 117 | zap.String("user_agent", r.UserAgent()), 118 | zap.Int64("response_size", rw.size), 119 | ) 120 | 121 | if l.includeQuery && len(r.URL.RawQuery) > 0 { 122 | queryParams := make(map[string]string) 123 | for key, values := range r.URL.Query() { 124 | queryParams[key] = strings.Join(values, ",") 125 | } 126 | fields = append(fields, zap.Any("query_params", queryParams)) 127 | } 128 | 129 | if l.includeHeaders { 130 | headers := make(map[string]string) 131 | for key, values := range r.Header { 132 | headers[key] = strings.Join(values, ",") 133 | } 134 | fields = append(fields, zap.Any("headers", headers)) 135 | } 136 | 137 | switch { 138 | case rw.status >= 500: 139 | l.logger.Error("Server error", fields...) 140 | case rw.status >= 400: 141 | l.logger.Warn("Client error", fields...) 142 | default: 143 | l.logger.Info("Request completed", fields...) 144 | } 145 | }) 146 | } 147 | 148 | // extracts the IP address from the request. 149 | func getIPAddress(r *http.Request) string { 150 | xff := r.Header.Get("X-Forwarded-For") 151 | if xff != "" { 152 | parts := strings.Split(xff, ",") 153 | return strings.TrimSpace(parts[0]) 154 | } 155 | 156 | ip := r.RemoteAddr 157 | if colon := strings.LastIndex(ip, ":"); colon != -1 { 158 | return ip[:colon] 159 | } 160 | return ip 161 | } 162 | -------------------------------------------------------------------------------- /internal/pool/url_rewriter.go: -------------------------------------------------------------------------------- 1 | package pool 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "path" 7 | "strings" 8 | ) 9 | 10 | // URLRewriter is responsible for rewriting incoming request URLs based on predefined rules. 11 | // It handles path prefix stripping, URL rewriting, and redirects to ensure that requests 12 | // are correctly routed to the appropriate backend services. 13 | type URLRewriter struct { 14 | path string // The URL path prefix that should be matched and potentially stripped from incoming requests. 15 | rewriteURL string // The target URL to rewrite the incoming request's path to, if specified. 16 | backendPath string // The base path of the backend service to which requests are being proxied. 17 | shouldStripPath bool // A flag indicating whether the path prefix should be stripped from the incoming request's URL. 18 | redirect string // The URL to which requests should be redirected, if redirection is configured. 19 | } 20 | 21 | // RewriteConfig holds configuration settings for URL rewriting and redirection. 22 | // It defines how incoming request paths should be transformed before being forwarded 23 | // to the backend services. 24 | type Rewrite struct { 25 | ProxyPath string // The path prefix that the proxy should handle and potentially strip from incoming requests. 26 | RewriteURL string // The URL to which the incoming request's path should be rewritten. 27 | Redirect string // The URL to redirect the request to, if redirection is enabled. 28 | } 29 | 30 | // NewURLRewriter determines whether the path prefix should be stripped and sets up the necessary rewrite and redirect rules. 31 | func NewURLRewriter(config Rewrite, backendURL *url.URL) *URLRewriter { 32 | backendPath := backendURL.Path 33 | if backendPath == "" { 34 | backendPath = "/" 35 | } 36 | 37 | normalizedFrontend := path.Clean("/" + config.ProxyPath) 38 | normalizedBackend := path.Clean(backendPath) 39 | 40 | shouldStripPath := true 41 | if normalizedFrontend != "/" && normalizedBackend != "/" { 42 | shouldStripPath = normalizedFrontend != normalizedBackend 43 | } 44 | 45 | return &URLRewriter{ 46 | path: config.ProxyPath, 47 | rewriteURL: config.RewriteURL, 48 | backendPath: backendPath, 49 | shouldStripPath: shouldStripPath, 50 | redirect: config.Redirect, 51 | } 52 | } 53 | 54 | // shouldRedirect determines whether the incoming HTTP request should be redirected based on the URLRewriter's configuration. 55 | // checks if redirection is enabled and if the request matches the criteria for redirection. 56 | func (r *URLRewriter) shouldRedirect(req *http.Request) (bool, string) { 57 | if r.redirect == "" { 58 | return false, "" 59 | } 60 | 61 | if r.path == "/" && req.URL.Path == "/" { 62 | return true, r.redirect 63 | } 64 | 65 | return false, "" 66 | } 67 | 68 | // rewriteRequestURL modifies the incoming HTTP request's URL to target the backend service. 69 | // updates the scheme and host, and conditionally strips the path prefix based on the URLRewriter's settings. 70 | func (r *URLRewriter) rewriteRequestURL(req *http.Request, targetURL *url.URL) { 71 | req.URL.Scheme = targetURL.Scheme 72 | req.URL.Host = targetURL.Host 73 | 74 | if r.shouldStripPath { 75 | r.stripPathPrefix(req) 76 | } 77 | } 78 | 79 | // stripPathPrefix removes the configured path prefix from the incoming HTTP request's URL path. 80 | func (r *URLRewriter) stripPathPrefix(req *http.Request) { 81 | trimmed := strings.TrimPrefix(req.URL.Path, r.path) 82 | if !strings.HasPrefix(trimmed, "/") { 83 | trimmed = "/" + trimmed 84 | } 85 | 86 | if r.path == "/" && req.URL.Path == "/" && r.rewriteURL == "" { 87 | return 88 | } 89 | 90 | if r.rewriteURL == "" { 91 | req.URL.Path = trimmed 92 | } else { 93 | ru := r.rewriteURL 94 | if !strings.HasPrefix(ru, "/") { 95 | ru = "/" + ru 96 | } 97 | 98 | if len(ru) > 1 && strings.HasSuffix(ru, "/") { 99 | ru = strings.TrimSuffix(ru, "/") 100 | } 101 | req.URL.Path = ru + trimmed 102 | } 103 | } 104 | 105 | // rewriteRedirectURL modifies the Location header in HTTP redirect responses to ensure consistency with the original host. 106 | // Updates the host and path of the redirect URL based on the original request and the URLRewriter's configuration. 107 | func (r *URLRewriter) rewriteRedirectURL(locURL *url.URL, originalHost string) { 108 | locURL.Host = originalHost 109 | 110 | // If no rewrite URL is specified and the redirect path does not already start with the proxy path, 111 | // prepend the proxy path to the redirect URL's path. 112 | if r.rewriteURL == "" && !strings.HasPrefix(locURL.Path, r.path) && r.shouldStripPath { 113 | locURL.Path = r.path + locURL.Path 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /pkg/logger/async_core.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "sync/atomic" 7 | "time" 8 | 9 | "go.uber.org/zap/zapcore" 10 | ) 11 | 12 | type LogEntry struct { 13 | Entry zapcore.Entry 14 | Fields []zapcore.Field 15 | } 16 | 17 | // wraps a zapcore.Core and handles asynchronous, batched logging. 18 | type AsyncCore struct { 19 | core zapcore.Core 20 | entryChan chan LogEntry 21 | wg sync.WaitGroup 22 | quit chan struct{} 23 | bufferSize int 24 | batchSize int 25 | flushInterval time.Duration 26 | droppedLogs uint64 27 | batchPool *sync.Pool 28 | } 29 | 30 | // initializes a new AsyncCore with batching and tracking. 31 | // bufferSize: size of the buffered channel 32 | // batchSize: number of log entries per batch 33 | // flushInterval: maximum time to wait before flushing a batch 34 | func NewAsyncCore(core zapcore.Core, bufferSize, batchSize int, flushInterval time.Duration) *AsyncCore { 35 | if bufferSize <= 0 { 36 | bufferSize = 10000 37 | } 38 | if batchSize <= 0 || batchSize > bufferSize { 39 | batchSize = bufferSize / 10 40 | } 41 | if flushInterval <= 0 { 42 | flushInterval = time.Second 43 | } 44 | ac := &AsyncCore{ 45 | core: core, 46 | entryChan: make(chan LogEntry, bufferSize), 47 | quit: make(chan struct{}), 48 | bufferSize: bufferSize, 49 | batchSize: batchSize, 50 | flushInterval: flushInterval, 51 | batchPool: &sync.Pool{ 52 | New: func() interface{} { 53 | batch := make([]LogEntry, 0, batchSize) 54 | return &batch 55 | }, 56 | }, 57 | } 58 | 59 | ac.wg.Add(2) 60 | go ac.processEntries() 61 | go ac.monitorDroppedLogs() 62 | 63 | return ac 64 | } 65 | 66 | // listens to the entry channel and writes logs in batches. 67 | func (ac *AsyncCore) processEntries() { 68 | defer ac.wg.Done() 69 | 70 | ticker := time.NewTicker(ac.flushInterval) 71 | defer ticker.Stop() 72 | 73 | batchPtr := ac.batchPool.Get().(*[]LogEntry) 74 | batch := *batchPtr 75 | batch = batch[:0] 76 | 77 | defer func() { 78 | if len(batch) > 0 { 79 | ac.writeBatch(batch) 80 | } 81 | *batchPtr = batch 82 | ac.batchPool.Put(batchPtr) 83 | }() 84 | 85 | for { 86 | select { 87 | case logEntry := <-ac.entryChan: 88 | batch = append(batch, logEntry) 89 | if len(batch) >= ac.batchSize { 90 | ac.writeBatch(batch) 91 | batch = batch[:0] 92 | } 93 | case <-ticker.C: 94 | if len(batch) > 0 { 95 | ac.writeBatch(batch) 96 | batch = batch[:0] 97 | } 98 | case <-ac.quit: 99 | // Drain remaining log entries before exiting 100 | for { 101 | select { 102 | case logEntry := <-ac.entryChan: 103 | batch = append(batch, logEntry) 104 | if len(batch) >= ac.batchSize { 105 | ac.writeBatch(batch) 106 | batch = batch[:0] 107 | } 108 | default: 109 | if len(batch) > 0 { 110 | ac.writeBatch(batch) 111 | } 112 | return 113 | } 114 | } 115 | } 116 | } 117 | } 118 | 119 | // writes a batch of log entries to the underlying core. 120 | func (ac *AsyncCore) writeBatch(batch []LogEntry) { 121 | for _, logEntry := range batch { 122 | if err := ac.core.Write(logEntry.Entry, logEntry.Fields); err != nil { 123 | fmt.Printf("Failed to write log entry: %v\n", err) 124 | } 125 | } 126 | } 127 | 128 | // periodically logs the number of dropped logs. 129 | func (ac *AsyncCore) monitorDroppedLogs() { 130 | defer ac.wg.Done() 131 | 132 | ticker := time.NewTicker(1 * time.Minute) 133 | defer ticker.Stop() 134 | 135 | for { 136 | select { 137 | case <-ticker.C: 138 | dropped := atomic.SwapUint64(&ac.droppedLogs, 0) 139 | if dropped > 0 { 140 | entry := zapcore.Entry{ 141 | Level: zapcore.WarnLevel, 142 | Message: fmt.Sprintf("Dropped %d log entries due to full buffer", dropped), 143 | Time: time.Now(), 144 | LoggerName: "AsyncCore", 145 | } 146 | ac.core.Write(entry, nil) 147 | } 148 | case <-ac.quit: 149 | return 150 | } 151 | } 152 | } 153 | 154 | func (ac *AsyncCore) Enabled(level zapcore.Level) bool { 155 | return ac.core.Enabled(level) 156 | } 157 | 158 | func (ac *AsyncCore) With(fields []zapcore.Field) zapcore.Core { 159 | return &AsyncCore{ 160 | core: ac.core.With(fields), 161 | entryChan: ac.entryChan, 162 | quit: ac.quit, 163 | bufferSize: ac.bufferSize, 164 | batchSize: ac.batchSize, 165 | flushInterval: ac.flushInterval, 166 | } 167 | } 168 | 169 | // enqueues the entry if enabled. 170 | func (ac *AsyncCore) Check(entry zapcore.Entry, checkedEntry *zapcore.CheckedEntry) *zapcore.CheckedEntry { 171 | if ac.Enabled(entry.Level) { 172 | return checkedEntry.AddCore(entry, ac) 173 | } 174 | return checkedEntry 175 | } 176 | 177 | // enqueues the log entry along with its fields. 178 | func (ac *AsyncCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { 179 | logEntry := LogEntry{ 180 | Entry: entry, 181 | Fields: fields, 182 | } 183 | select { 184 | case ac.entryChan <- logEntry: 185 | return nil 186 | default: 187 | atomic.AddUint64(&ac.droppedLogs, 1) 188 | return nil 189 | } 190 | } 191 | 192 | // flushes all buffered log entries and syncs the underlying core. 193 | func (ac *AsyncCore) Sync() error { 194 | close(ac.quit) 195 | ac.wg.Wait() 196 | return ac.core.Sync() 197 | } 198 | -------------------------------------------------------------------------------- /internal/server/vhost.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | "sync" 7 | 8 | "github.com/unkn0wn-root/terraster/internal/config" 9 | "github.com/unkn0wn-root/terraster/internal/middleware" 10 | "github.com/unkn0wn-root/terraster/internal/service" 11 | "go.uber.org/zap" 12 | ) 13 | 14 | // HostHandler represents a pre-configured handler for a specific host. 15 | // By storing both logger and handler, we avoid context lookups during request processing. 16 | type HostHandler struct { 17 | logger *zap.Logger 18 | handler http.Handler 19 | } 20 | 21 | // VirtualServiceHandler manages multiple service handlers on the same port. 22 | // This is the core component that enables efficient handling of multiple 23 | // services sharing the same port while maintaining separate configurations. 24 | type VirtualServiceHandler struct { 25 | handlers map[string]*HostHandler 26 | mu sync.RWMutex 27 | defaultHandler http.Handler 28 | } 29 | 30 | // NewVirtualServiceHandler creates a new handler manager. 31 | // This is called once per port during server initialization. 32 | func NewVirtualServiceHandler() *VirtualServiceHandler { 33 | return &VirtualServiceHandler{ 34 | handlers: make(map[string]*HostHandler), 35 | } 36 | } 37 | 38 | // AddService configures and stores a complete handler chain for a service. 39 | // This is called during service registration, NOT during request processing. 40 | // The key performance aspect is that this build the entire middleware chain 41 | // once during initialization, rather than per-request. 42 | func (mh *VirtualServiceHandler) AddService(s *Server, svc *service.ServiceInfo) { 43 | mh.mu.Lock() 44 | defer mh.mu.Unlock() 45 | 46 | hostname := strings.ToLower(svc.Host) 47 | if svc.ServiceType() == service.HTTP && svc.HTTPRedirect { 48 | mh.handlers[hostname] = &HostHandler{ 49 | logger: svc.Logger, 50 | handler: s.createRedirectHandler(svc), 51 | } 52 | return 53 | } 54 | 55 | baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 56 | s.handleRequest(w, r) 57 | }) 58 | 59 | chain := middleware.NewMiddlewareChain() 60 | chain.AddConfiguredMiddlewares(s.config, svc.Logger) 61 | 62 | if svc.Middleware != nil { 63 | for _, mw := range svc.Middleware { 64 | switch { 65 | case mw.RateLimit != nil: 66 | rl := middleware.NewRateLimiterMiddleware( 67 | mw.RateLimit.RequestsPerSecond, 68 | mw.RateLimit.Burst, 69 | ) 70 | chain.Replace(rl) 71 | case mw.CircuitBreaker != nil: 72 | cb := middleware.NewCircuitBreaker( 73 | mw.CircuitBreaker.FailureThreshold, 74 | mw.CircuitBreaker.ResetTimeout, 75 | ) 76 | chain.Replace(cb) 77 | case mw.Security != nil: 78 | sec := middleware.NewSecurityMiddleware(s.config) 79 | chain.Replace(sec) 80 | case mw.CORS != nil: 81 | cors := middleware.NewCORSMiddleware(s.config) 82 | chain.Replace(cors) 83 | case mw.Compression: 84 | compressor := middleware.NewCompressionMiddleware() 85 | chain.Replace(compressor) 86 | } 87 | } 88 | } 89 | 90 | logOpts := &config.LogOptions{ 91 | Headers: false, 92 | QueryParams: false, 93 | } 94 | if slop := svc.LogOptions; slop != nil { 95 | logOpts = slop 96 | } 97 | logger := middleware.NewLoggingMiddleware( 98 | svc.Logger, 99 | middleware.WithLogLevel(zap.InfoLevel), 100 | middleware.WithHeaders(logOpts.Headers), 101 | middleware.WithQueryParams(logOpts.Headers), 102 | middleware.WithExcludePaths([]string{"/api/auth/login", "/api/auth/refresh"}), 103 | ) 104 | chain.Use(logger) 105 | 106 | mh.handlers[hostname] = &HostHandler{ 107 | logger: svc.Logger, 108 | handler: chain.Then(baseHandler), 109 | } 110 | } 111 | 112 | // ServeHTTP is the entry point for all requests. 113 | func (mh *VirtualServiceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 114 | mh.mu.RLock() 115 | hostHandler, exists := mh.handlers[mh.hostKey(r.Host)] 116 | mh.mu.RUnlock() 117 | 118 | if !exists { 119 | http.Error(w, "Service not found", http.StatusNotFound) 120 | return 121 | } 122 | // Direct dispatch to pre-configured handler 123 | hostHandler.handler.ServeHTTP(w, r) 124 | } 125 | 126 | // hostKey return host without port. 127 | // Supports both IPv4 and IPv6 128 | // Takes `r.Host` as input which should always cotains valid hostname. 129 | // Does NOT return any error since all invalid hosts should not match any service 130 | func (mh *VirtualServiceHandler) hostKey(host string) string { 131 | // IPv6 with port - [2001:db8::1]:80 132 | if host[0] == '[' { 133 | end := last(host, ']') 134 | if end > 0 { 135 | // If port exists after IPv6 address 136 | if len(host) > end+1 && host[end+1] == ':' { 137 | return strings.ToLower(host[:end+1]) 138 | } 139 | return strings.ToLower(host) 140 | } 141 | } 142 | 143 | // No colons at all - bare hostname 144 | firstColon := strings.IndexByte(host, ':') 145 | if firstColon < 0 { 146 | return strings.ToLower(host) 147 | } 148 | 149 | // Look for second colon - if found, it's bare IPv6 150 | // Only scan after first colon 151 | if strings.IndexByte(host[firstColon+1:], ':') >= 0 { 152 | return strings.ToLower(host) 153 | } 154 | 155 | // hostname:port or IPv4:port 156 | return strings.ToLower(host[:firstColon]) 157 | } 158 | 159 | // Last byte finder 160 | func last(s string, b byte) int { 161 | for i := len(s) - 1; i >= 0; i-- { 162 | if s[i] == b { 163 | return i 164 | } 165 | } 166 | return -1 167 | } 168 | -------------------------------------------------------------------------------- /internal/auth/validation/password.go: -------------------------------------------------------------------------------- 1 | package validation 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | "unicode" 7 | ) 8 | 9 | var ( 10 | ErrPasswordTooShort = errors.New("password is too short") 11 | ErrPasswordTooLong = errors.New("password is too long") 12 | ErrMissingUppercase = errors.New("password must contain at least one uppercase letter") 13 | ErrMissingLowercase = errors.New("password must contain at least one lowercase letter") 14 | ErrMissingNumber = errors.New("password must contain at least one number") 15 | ErrMissingSpecial = errors.New("password must contain at least one special character") 16 | ErrContainsUsername = errors.New("password cannot contain the username") 17 | ErrCommonPassword = errors.New("password is too common") 18 | ErrConsecutiveChars = errors.New("password contains consecutive repeated characters") 19 | ErrSequentialChars = errors.New("password contains sequential characters") 20 | ) 21 | 22 | type PasswordPolicy struct { 23 | MinLength int 24 | MaxLength int 25 | RequireUppercase bool 26 | RequireLowercase bool 27 | RequireNumbers bool 28 | RequireSpecial bool 29 | MaxRepeatingChars int 30 | PreventSequential bool 31 | PreventUsernamePart bool 32 | } 33 | 34 | // DefaultPasswordPolicy returns a recommended password policy 35 | func DefaultPasswordPolicy() PasswordPolicy { 36 | return PasswordPolicy{ 37 | MinLength: 12, 38 | MaxLength: 128, 39 | RequireUppercase: true, 40 | RequireLowercase: true, 41 | RequireNumbers: true, 42 | RequireSpecial: true, 43 | MaxRepeatingChars: 3, 44 | PreventSequential: true, 45 | PreventUsernamePart: true, 46 | } 47 | } 48 | 49 | type PasswordValidator struct { 50 | policy PasswordPolicy 51 | } 52 | 53 | func NewPasswordValidator(policy PasswordPolicy) *PasswordValidator { 54 | return &PasswordValidator{ 55 | policy: policy, 56 | } 57 | } 58 | 59 | func (v *PasswordValidator) ValidatePassword(password string, username string) error { 60 | if len(password) < v.policy.MinLength { 61 | return ErrPasswordTooShort 62 | } 63 | if len(password) > v.policy.MaxLength { 64 | return ErrPasswordTooLong 65 | } 66 | 67 | var ( 68 | hasUpper bool 69 | hasLower bool 70 | hasNumber bool 71 | hasSpecial bool 72 | ) 73 | 74 | for _, char := range password { 75 | switch { 76 | case unicode.IsUpper(char): 77 | hasUpper = true 78 | case unicode.IsLower(char): 79 | hasLower = true 80 | case unicode.IsNumber(char): 81 | hasNumber = true 82 | case unicode.IsPunct(char) || unicode.IsSymbol(char): 83 | hasSpecial = true 84 | } 85 | } 86 | 87 | if v.policy.RequireUppercase && !hasUpper { 88 | return ErrMissingUppercase 89 | } 90 | if v.policy.RequireLowercase && !hasLower { 91 | return ErrMissingLowercase 92 | } 93 | if v.policy.RequireNumbers && !hasNumber { 94 | return ErrMissingNumber 95 | } 96 | if v.policy.RequireSpecial && !hasSpecial { 97 | return ErrMissingSpecial 98 | } 99 | 100 | if v.policy.MaxRepeatingChars > 0 { 101 | if err := v.checkRepeatingChars(password); err != nil { 102 | return err 103 | } 104 | } 105 | 106 | if v.policy.PreventSequential { 107 | if err := v.checkSequentialChars(password); err != nil { 108 | return err 109 | } 110 | } 111 | 112 | if v.policy.PreventUsernamePart && username != "" { 113 | if err := v.checkUsernameInPassword(password, username); err != nil { 114 | return err 115 | } 116 | } 117 | 118 | if err := v.checkCommonPasswords(password); err != nil { 119 | return err 120 | } 121 | 122 | return nil 123 | } 124 | 125 | func (v *PasswordValidator) checkRepeatingChars(password string) error { 126 | var count int 127 | var lastChar rune 128 | 129 | for i, char := range password { 130 | if i == 0 { 131 | lastChar = char 132 | count = 1 133 | continue 134 | } 135 | 136 | if char == lastChar { 137 | count++ 138 | if count > v.policy.MaxRepeatingChars { 139 | return ErrConsecutiveChars 140 | } 141 | } else { 142 | lastChar = char 143 | count = 1 144 | } 145 | } 146 | return nil 147 | } 148 | 149 | func (v *PasswordValidator) checkSequentialChars(password string) error { 150 | sequences := []string{ 151 | "abcdefghijklmnopqrstuvwxyz", 152 | "ABCDEFGHIJKLMNOPQRSTUVWXYZ", 153 | "0123456789", 154 | } 155 | 156 | lowPass := strings.ToLower(password) 157 | for _, seq := range sequences { 158 | for i := 0; i < len(seq)-2; i++ { 159 | if strings.Contains(lowPass, seq[i:i+3]) { 160 | return ErrSequentialChars 161 | } 162 | // Check reverse sequences too 163 | if strings.Contains(lowPass, reverse(seq[i:i+3])) { 164 | return ErrSequentialChars 165 | } 166 | } 167 | } 168 | return nil 169 | } 170 | 171 | func (v *PasswordValidator) checkUsernameInPassword(password, username string) error { 172 | if len(username) < 3 { 173 | return nil 174 | } 175 | 176 | lowPass := strings.ToLower(password) 177 | lowUser := strings.ToLower(username) 178 | 179 | if strings.Contains(lowPass, lowUser) { 180 | return ErrContainsUsername 181 | } 182 | 183 | return nil 184 | } 185 | 186 | func (v *PasswordValidator) checkCommonPasswords(password string) error { 187 | commonPasswords := map[string]bool{ 188 | "password123": true, 189 | "12345678": true, 190 | "qwerty123": true, 191 | "admin123": true, 192 | "letmein": true, 193 | "welcome1": true, 194 | } 195 | 196 | if commonPasswords[strings.ToLower(password)] { 197 | return ErrCommonPassword 198 | } 199 | 200 | return nil 201 | } 202 | 203 | func reverse(s string) string { 204 | runes := []rune(s) 205 | for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 { 206 | runes[i], runes[j] = runes[j], runes[i] 207 | } 208 | return string(runes) 209 | } 210 | -------------------------------------------------------------------------------- /internal/cerr/backend.go: -------------------------------------------------------------------------------- 1 | package cerr 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "net" 9 | "net/http" 10 | "strconv" 11 | "syscall" 12 | ) 13 | 14 | // Common proxy error types 15 | var ( 16 | ErrBackendUnavailable = errors.New("server unavailable") 17 | ErrBackendTimeout = errors.New("server timeout") 18 | ErrInvalidRedirect = errors.New("invalid redirect received from server") 19 | ) 20 | 21 | // Retry header constants define the retry mechanism configuration. 22 | const ( 23 | RetryAfter = "Retry-After" 24 | RetryAfterSec = 5 25 | ) 26 | 27 | // ErrorResponse represents the structure of error responses sent to clients 28 | type ErrorResponse struct { 29 | Status string `json:"status"` 30 | Message string `json:"message"` 31 | RetryAfter int `json:"retry_after,omitempty"` 32 | } 33 | 34 | // ProxyErrorCode represents specific error conditions in the proxy 35 | type ProxyErrorCode int 36 | 37 | const ( 38 | ErrCodeUnknown ProxyErrorCode = iota 39 | ErrCodeBackendConnFailed 40 | ErrCodeBackendTimeout 41 | ErrCodeInvalidResponse 42 | ErrCodeTLSError 43 | ErrCodeClientDisconnect 44 | ) 45 | 46 | // ProxyError represents a detailed error that occurs during proxy operations 47 | type ProxyError struct { 48 | Op string 49 | Code ProxyErrorCode 50 | Message string 51 | Err error 52 | Retryable bool 53 | StatusCode int 54 | } 55 | 56 | func (e *ProxyError) Error() string { 57 | if e.Err != nil { 58 | return fmt.Sprintf("%s: %s: %v", e.Op, e.Message, e.Err) 59 | } 60 | return fmt.Sprintf("%s: %s", e.Op, e.Message) 61 | } 62 | 63 | func (e *ProxyError) Unwrap() error { 64 | return e.Err 65 | } 66 | 67 | // IsTemporaryError determines if an error is temporary and the request can be retried 68 | func IsTemporaryError(err error) bool { 69 | var proxyErr *ProxyError 70 | if errors.As(err, &proxyErr) { 71 | return proxyErr.Retryable 72 | } 73 | 74 | var netErr net.Error 75 | if errors.As(err, &netErr) && netErr.Timeout() { 76 | return true 77 | } 78 | 79 | var opErr *net.OpError 80 | if errors.As(err, &opErr) { 81 | var syscallErr syscall.Errno 82 | if errors.As(opErr.Err, &syscallErr) { 83 | switch syscallErr { 84 | case 85 | syscall.ECONNREFUSED, 86 | syscall.ECONNRESET, 87 | syscall.ETIMEDOUT, 88 | syscall.EPIPE, 89 | syscall.ECONNABORTED, 90 | syscall.EHOSTDOWN, 91 | syscall.ENETUNREACH, 92 | syscall.EHOSTUNREACH: 93 | return true 94 | } 95 | } 96 | 97 | var dnsErr *net.DNSError 98 | if errors.As(opErr.Err, &dnsErr) { 99 | return dnsErr.IsTemporary 100 | } 101 | } 102 | 103 | return false 104 | } 105 | 106 | // NewProxyError creates a new ProxyError with appropriate defaults based on the error type 107 | func NewProxyError(op string, err error) *ProxyError { 108 | pe := &ProxyError{ 109 | Op: op, 110 | Err: err, 111 | Code: ErrCodeUnknown, 112 | StatusCode: http.StatusBadGateway, 113 | Retryable: false, 114 | } 115 | 116 | switch { 117 | case errors.Is(err, context.Canceled): 118 | pe.Code = ErrCodeClientDisconnect 119 | pe.Message = "Request canceled by client" 120 | pe.StatusCode = 499 // Client closed request 121 | pe.Retryable = false 122 | 123 | case errors.Is(err, ErrBackendUnavailable): 124 | pe.Code = ErrCodeBackendConnFailed 125 | pe.Message = "Backend server unavailable" 126 | pe.StatusCode = http.StatusBadGateway 127 | pe.Retryable = true 128 | 129 | case errors.Is(err, ErrBackendTimeout): 130 | pe.Code = ErrCodeBackendTimeout 131 | pe.Message = "Backend server timeout" 132 | pe.StatusCode = http.StatusGatewayTimeout 133 | pe.Retryable = true 134 | 135 | default: 136 | var opErr *net.OpError 137 | if errors.As(err, &opErr) { 138 | pe.Retryable = IsTemporaryError(err) 139 | 140 | var dnsErr *net.DNSError 141 | if errors.As(opErr.Err, &dnsErr) { 142 | pe.Code = ErrCodeBackendConnFailed 143 | pe.Message = fmt.Sprintf("DNS error: %s", dnsErr.Error()) 144 | pe.StatusCode = http.StatusBadGateway 145 | pe.Retryable = dnsErr.IsTemporary 146 | return pe 147 | } 148 | 149 | var syscallErr syscall.Errno 150 | if errors.As(opErr.Err, &syscallErr) { 151 | switch syscallErr { 152 | case syscall.ECONNREFUSED: 153 | pe.Message = "Connection refused by backend" 154 | case syscall.ECONNRESET: 155 | pe.Message = "Connection reset by backend" 156 | case syscall.ETIMEDOUT: 157 | pe.Code = ErrCodeBackendTimeout 158 | pe.Message = "Connection timed out" 159 | pe.StatusCode = http.StatusGatewayTimeout 160 | default: 161 | pe.Message = fmt.Sprintf("Network error: %s", syscallErr.Error()) 162 | } 163 | pe.Code = ErrCodeBackendConnFailed 164 | return pe 165 | } 166 | } 167 | 168 | var netErr net.Error 169 | if errors.As(err, &netErr) { 170 | if netErr.Timeout() { 171 | pe.Code = ErrCodeBackendTimeout 172 | pe.Message = "Network timeout" 173 | pe.StatusCode = http.StatusGatewayTimeout 174 | pe.Retryable = true 175 | } else { 176 | pe.Code = ErrCodeBackendConnFailed 177 | pe.Message = "Network error" 178 | pe.Retryable = IsTemporaryError(err) 179 | } 180 | return pe 181 | } 182 | 183 | pe.Message = fmt.Sprintf("Unexpected error: %v", err) 184 | } 185 | 186 | return pe 187 | } 188 | 189 | // WriteErrorResponse writes a structured error response to the client 190 | func WriteErrorResponse(w http.ResponseWriter, err error) { 191 | var pe *ProxyError 192 | if !errors.As(err, &pe) { 193 | pe = NewProxyError("unknown", err) 194 | } 195 | 196 | response := ErrorResponse{ 197 | Status: "error", 198 | Message: pe.Message, 199 | } 200 | 201 | if pe.Retryable { 202 | response.RetryAfter = RetryAfterSec 203 | w.Header().Set(RetryAfter, strconv.Itoa(RetryAfterSec)) 204 | } 205 | 206 | w.Header().Set("Content-Type", "application/json") 207 | w.WriteHeader(pe.StatusCode) 208 | json.NewEncoder(w).Encode(response) 209 | } 210 | -------------------------------------------------------------------------------- /internal/middleware/middleware.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | "reflect" 9 | "time" 10 | 11 | "github.com/unkn0wn-root/terraster/internal/config" 12 | "go.uber.org/zap" 13 | ) 14 | 15 | // Middleware defines an interface for HTTP middleware. 16 | // Each middleware must implement the Middleware method, which takes the next handler in the chain 17 | // and returns a new handler that wraps additional functionality around it. 18 | type Middleware interface { 19 | Middleware(next http.Handler) http.Handler 20 | } 21 | 22 | // statusWriter is a custom ResponseWriter that captures the HTTP status code and the length of the response. 23 | // Embeds the standard http.ResponseWriter and adds fields to store status and length. 24 | type statusWriter struct { 25 | http.ResponseWriter 26 | status int 27 | length int 28 | } 29 | 30 | // newStatusWriter initializes and returns a new instance of statusWriter. 31 | func newStatusWriter(w http.ResponseWriter) *statusWriter { 32 | return &statusWriter{ 33 | ResponseWriter: w, 34 | status: http.StatusOK, 35 | } 36 | } 37 | 38 | // WriteHeader captures the status code and delegates the call to the embedded ResponseWriter. 39 | func (w *statusWriter) WriteHeader(status int) { 40 | w.status = status 41 | w.ResponseWriter.WriteHeader(status) 42 | } 43 | 44 | // Write captures the length of the response and delegates the write operation. 45 | // Ensures that the status is set to http.StatusOK if not already set. 46 | func (w *statusWriter) Write(b []byte) (int, error) { 47 | if w.status == 0 { 48 | w.status = http.StatusOK 49 | } 50 | n, err := w.ResponseWriter.Write(b) 51 | w.length += n 52 | return n, err 53 | } 54 | 55 | // Status returns the captured HTTP status code. 56 | func (w *statusWriter) Status() int { 57 | return w.status 58 | } 59 | 60 | // Length returns the length of the response body in bytes. 61 | func (w *statusWriter) Length() int { 62 | return w.length 63 | } 64 | 65 | // Hijack allows the middleware to support connection hijacking. 66 | // Delegates the hijacking process to the embedded ResponseWriter if it implements the http.Hijacker interface. 67 | func (w *statusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 68 | if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok { 69 | return hijacker.Hijack() 70 | } 71 | return nil, nil, fmt.Errorf("upstream ResponseWriter does not implement http.Hijacker") 72 | } 73 | 74 | // Flush allows the middleware to support flushing of the response. 75 | // Delegates the flush operation to the embedded ResponseWriter if it implements the http.Flusher interface. 76 | func (w *statusWriter) Flush() { 77 | if flusher, ok := w.ResponseWriter.(http.Flusher); ok { 78 | flusher.Flush() 79 | } 80 | } 81 | 82 | // MiddlewareChain manages a sequence of middleware. 83 | // Allows chaining multiple middleware together and applying them to a final HTTP handler. 84 | type MiddlewareChain struct { 85 | middlewares []Middleware // A slice holding the middleware in the order they should be applied. 86 | } 87 | 88 | // NewMiddlewareChain initializes and returns a new MiddlewareChain with the provided middleware. 89 | func NewMiddlewareChain(middlewares ...Middleware) *MiddlewareChain { 90 | return &MiddlewareChain{ 91 | middlewares: middlewares, 92 | } 93 | } 94 | 95 | // Use adds a new Middleware to the MiddlewareChain. 96 | func (c *MiddlewareChain) Use(middleware Middleware) { 97 | c.middlewares = append(c.middlewares, middleware) 98 | } 99 | 100 | // Iterate over the existing middleware in the chain 101 | // and replace the middleware service configured with the same type 102 | func (c *MiddlewareChain) Replace(middleware Middleware) { 103 | for i, mw := range c.middlewares { 104 | if reflect.TypeOf(mw) == reflect.TypeOf(middleware) { 105 | c.middlewares[i] = middleware 106 | return 107 | } 108 | } 109 | 110 | // If the middleware doesn't exist, add it to the chain 111 | c.Use(middleware) 112 | } 113 | 114 | // Then applies the middleware chain to the final HTTP handler. 115 | // It wraps the final handler with each middleware in reverse order, so that the first middleware added 116 | // is the first to process the request. 117 | func (c *MiddlewareChain) Then(final http.Handler) http.Handler { 118 | if final == nil { 119 | final = http.DefaultServeMux 120 | } 121 | 122 | // Wrap the final handler with each middleware, starting from the last added. 123 | for i := len(c.middlewares) - 1; i >= 0; i-- { 124 | final = c.middlewares[i].Middleware(final) 125 | } 126 | return final 127 | } 128 | 129 | // AddConfiguredMiddlewars adds middleware to the chain based on the provided configuration. 130 | // It checks the configuration for enabled middleware features like Circuit Breaker, Rate Limiting, and Security, 131 | // and adds the corresponding middleware to the chain. 132 | func (c *MiddlewareChain) AddConfiguredMiddlewares(config *config.Terraster, logger *zap.Logger) { 133 | for _, mw := range config.Middleware { 134 | switch { 135 | case mw.CircuitBreaker != nil: 136 | cir := mw.CircuitBreaker 137 | threshold := cir.FailureThreshold 138 | resetTimeout := cir.ResetTimeout 139 | if threshold == 0 { 140 | threshold = 5 141 | } 142 | 143 | if resetTimeout == 0 { 144 | resetTimeout = 30 * time.Second 145 | } 146 | 147 | cb := NewCircuitBreaker(threshold, resetTimeout) 148 | c.Use(cb) 149 | logger.Info("Global Circuit Breaker middleware configured", 150 | zap.Int("failure_threshold", threshold), 151 | zap.Duration("reset_timeout", resetTimeout)) 152 | 153 | case mw.RateLimit != nil: 154 | rml := mw.RateLimit 155 | rl := NewRateLimiterMiddleware(rml.RequestsPerSecond, rml.Burst) 156 | c.Use(rl) 157 | logger.Info("Global Rate Limiter middleware configured", 158 | zap.Float64("requests_per_second", rml.RequestsPerSecond), 159 | zap.Int("burst", rml.Burst)) 160 | 161 | case mw.Security != nil: 162 | sec := NewSecurityMiddleware(config) 163 | c.Use(sec) 164 | logger.Info("Global Security middleware configured") 165 | 166 | case mw.CORS != nil: 167 | cors := NewCORSMiddleware(config) 168 | c.Use(cors) 169 | 170 | logger.Info("Global CORS middleware enabled configured") 171 | } 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /pkg/plugin/manager.go: -------------------------------------------------------------------------------- 1 | package plugin 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "os" 9 | "path/filepath" 10 | "plugin" 11 | "sort" 12 | "sync" 13 | "sync/atomic" 14 | "time" 15 | 16 | "go.uber.org/zap" 17 | ) 18 | 19 | const ( 20 | DefaultTimeout = 5 * time.Second // DefaultTimeout defines the maximum time allowed for plugin processing 21 | ) 22 | 23 | var ( 24 | ErrFailedToRead = errors.New("Failed to read plugin directory") 25 | ErrFailedToLoad = errors.New("Failed to load plugin") 26 | ) 27 | 28 | // Manager handles loading, initialization, and execution of plugins 29 | type Manager struct { 30 | plugins []Handler // Ordered list of plugin handlers 31 | logger *zap.Logger 32 | enabled atomic.Bool 33 | mu sync.RWMutex 34 | } 35 | 36 | // NewManager creates a Manager instance with initial capacity of 10 plugins 37 | func NewManager(logger *zap.Logger) *Manager { 38 | return &Manager{ 39 | plugins: make([]Handler, 0, 10), 40 | logger: logger, 41 | } 42 | } 43 | 44 | // Initialize loads all .so plugin files from the specified directory. 45 | // Plugins are sorted by priority after loading. 46 | // Context is used to cancel the initialization process. 47 | func (pm *Manager) Initialize(ctx context.Context, pluginDir string) error { 48 | if _, err := os.Stat(pluginDir); os.IsNotExist(err) { 49 | pm.logger.Info("Plugins directory not found. Plugin not enabled", zap.String("path", pluginDir)) 50 | return nil 51 | } 52 | 53 | files, err := filepath.Glob(filepath.Join(pluginDir, "*.so")) 54 | if err != nil { 55 | return fmt.Errorf("%w: %v", ErrFailedToRead, err) 56 | } 57 | 58 | if len(files) == 0 { 59 | return nil 60 | } 61 | 62 | plugins := make([]Handler, 0, len(files)) 63 | for _, file := range files { 64 | select { 65 | case <-ctx.Done(): 66 | return ctx.Err() 67 | default: 68 | handler, err := pm.loadPlugin(file) 69 | if err != nil { 70 | pm.logger.Error("Failed to load plugin", 71 | zap.String("file", file), 72 | zap.Error(err)) 73 | return ErrFailedToLoad 74 | } 75 | plugins = append(plugins, handler) 76 | } 77 | } 78 | 79 | // Sort plugins by priority 80 | sort.Slice(plugins, func(i, j int) bool { 81 | return plugins[i].Priority() < plugins[j].Priority() 82 | }) 83 | 84 | pm.mu.Lock() 85 | pm.plugins = plugins 86 | pm.enabled.Store(true) 87 | pm.mu.Unlock() 88 | 89 | pm.logger.Info("Plugin system initialized", 90 | zap.Int("plugins_loaded", len(plugins)), 91 | zap.String("plugin_dir", pluginDir), 92 | ) 93 | 94 | return nil 95 | } 96 | 97 | // loadPlugin loads a single plugin from the given path. 98 | // The plugin must export a "New" function that returns a Handler. 99 | func (pm *Manager) loadPlugin(path string) (Handler, error) { 100 | p, err := plugin.Open(path) 101 | if err != nil { 102 | return nil, fmt.Errorf("failed to open plugin: %w", err) 103 | } 104 | 105 | newFunc, err := p.Lookup("New") 106 | if err != nil { 107 | return nil, fmt.Errorf("plugin does not export 'New' symbol: %w", err) 108 | } 109 | 110 | createPlugin, ok := newFunc.(func() Handler) 111 | if !ok { 112 | return nil, fmt.Errorf("plugin 'New' has wrong signature") 113 | } 114 | 115 | handler := createPlugin() 116 | pm.logger.Info("Loaded plugin", 117 | zap.String("name", handler.Name()), 118 | zap.Int("priority", handler.Priority()), 119 | zap.String("path", path), 120 | ) 121 | 122 | return handler, nil 123 | } 124 | 125 | // ProcessRequest executes plugins in priority order for HTTP requests. 126 | // Returns early on timeout or if a plugin returns Stop action. 127 | func (pm *Manager) ProcessRequest(req *http.Request) *Result { 128 | if !pm.enabled.Load() { 129 | return ResultContinue 130 | } 131 | 132 | ctx := req.Context() 133 | if _, hasDeadline := ctx.Deadline(); !hasDeadline { 134 | var cancel context.CancelFunc 135 | ctx, cancel = context.WithTimeout(ctx, DefaultTimeout) 136 | defer cancel() 137 | } 138 | 139 | plugins := pm.getPluginsNoLock() 140 | for _, p := range plugins { 141 | select { 142 | case <-ctx.Done(): 143 | return NewResult( 144 | Stop, 145 | WithStatus(http.StatusGatewayTimeout), 146 | WithJSONResponse(map[string]string{ 147 | "error": "plugin processing timeout", 148 | }), 149 | ) 150 | default: 151 | result := p.ProcessRequest(ctx, req) 152 | action := result.Action() 153 | if action == Stop { 154 | return result 155 | } 156 | 157 | if result != ResultContinue && result != ResultModify { 158 | result.Release() 159 | } 160 | } 161 | } 162 | 163 | return ResultContinue 164 | } 165 | 166 | // ProcessResponse executes plugins in priority order for HTTP responses. 167 | // Returns early on context cancellation or if a plugin returns Stop action. 168 | func (pm *Manager) ProcessResponse(resp *http.Response) *Result { 169 | if !pm.enabled.Load() { 170 | return ResultContinue 171 | } 172 | 173 | ctx := resp.Request.Context() 174 | plugins := pm.getPluginsNoLock() 175 | 176 | for _, p := range plugins { 177 | select { 178 | case <-ctx.Done(): 179 | return ResultContinue 180 | default: 181 | result := p.ProcessResponse(ctx, resp) 182 | action := result.Action() 183 | 184 | if action == Stop { 185 | return result 186 | } 187 | 188 | if result != ResultContinue && result != ResultModify { 189 | result.Release() 190 | } 191 | } 192 | } 193 | 194 | return ResultContinue 195 | } 196 | 197 | // getPluginsNoLock returns the current plugins slice without locking 198 | func (pm *Manager) getPluginsNoLock() []Handler { 199 | return pm.plugins 200 | } 201 | 202 | // Shutdown cleans up all plugins and disables the manager. 203 | func (pm *Manager) Shutdown(ctx context.Context) error { 204 | pm.mu.Lock() 205 | defer pm.mu.Unlock() 206 | 207 | for _, p := range pm.plugins { 208 | select { 209 | case <-ctx.Done(): 210 | return ctx.Err() 211 | default: 212 | if err := p.Cleanup(); err != nil { 213 | pm.logger.Error("Plugin cleanup failed", 214 | zap.String("plugin", p.Name()), 215 | zap.Error(err), 216 | ) 217 | } 218 | } 219 | } 220 | 221 | pm.enabled.Store(false) 222 | pm.plugins = nil 223 | return nil 224 | } 225 | 226 | // IsEnabled returns whether the plugin manager is currently enabled. 227 | func (pm *Manager) IsEnabled() bool { 228 | return pm.enabled.Load() 229 | } 230 | -------------------------------------------------------------------------------- /plugins/README.md: -------------------------------------------------------------------------------- 1 | # Terraster Plugin Development Guide 2 | 3 | This guide explains how to develop custom plugins for the Terraster. Plugins allow you to extend the load balancer's functionality by intercepting and modifying requests and responses. 4 | 5 | ## CGO Requirements 6 | 7 | ### Important: CGO is Required 8 | 9 | Terraster and all plugins must be compiled with CGO enabled to function properly. This is because the plugin system relies on Go's plugin package, which requires CGO support. 10 | 11 | To compile Terraster with CGO: 12 | ```go 13 | CGO_ENABLED=1 go build 14 | ``` 15 | 16 | To compile your plugins with CGO: 17 | ```go 18 | CGO_ENABLED=1 go build -buildmode=plugin -o your_plugin.so your_plugin.go 19 | ``` 20 | 21 | Note: If CGO is not enabled during compilation, plugins will fail to load and Terraster will not function correctly. 22 | 23 | ## Plugin Interface 24 | 25 | To create a plugin, you need to implement the `Handler` interface: 26 | 27 | ```go 28 | type Handler interface { 29 | ProcessRequest(ctx context.Context, req *http.Request) *Result 30 | ProcessResponse(ctx context.Context, resp *http.Response) *Result 31 | Name() string 32 | Priority() int 33 | Cleanup() error 34 | } 35 | ``` 36 | 37 | ### Required Methods 38 | 39 | - `ProcessRequest`: Processes incoming requests before they reach the backend 40 | - `ProcessResponse`: Processes responses before they're sent back to the client 41 | - `Name`: Returns the plugin's unique identifier 42 | - `Priority`: Determines the plugin's execution order (lower numbers execute first) 43 | - `Cleanup`: Handles plugin cleanup when shutting down 44 | 45 | ## Creating a Plugin 46 | 47 | ### 1. Basic Structure 48 | 49 | Start by creating a new Go package for your plugin: 50 | 51 | ```go 52 | package myplugin 53 | 54 | import ( 55 | "context" 56 | "net/http" 57 | "github.com/unkn0wn-root/terraster/pkg/plugin" 58 | ) 59 | 60 | type MyPlugin struct { 61 | // Plugin configuration and state 62 | } 63 | 64 | func New() plugin.Handler { 65 | return &MyPlugin{} 66 | } 67 | ``` 68 | 69 | ### 2. Configuration 70 | 71 | It's recommended to use a configuration structure to make your plugin configurable: 72 | 73 | ```go 74 | type Config struct { 75 | // Define your configuration fields 76 | Enabled bool `json:"enabled"` 77 | SomeOption string `json:"some_option"` 78 | } 79 | ``` 80 | 81 | ### 3. Implementing the Interface 82 | 83 | Each method in the interface serves a specific purpose: 84 | 85 | #### ProcessRequest 86 | 87 | ```go 88 | func (p *MyPlugin) ProcessRequest(ctx context.Context, req *http.Request) *plugin.Result { 89 | // Modify or validate the request 90 | // Return plugin.ResultModify to continue processing 91 | // Return a custom result with plugin.NewResult to stop processing 92 | return plugin.ResultModify 93 | } 94 | ``` 95 | 96 | #### ProcessResponse 97 | 98 | ```go 99 | func (p *MyPlugin) ProcessResponse(ctx context.Context, resp *http.Response) *plugin.Result { 100 | // Modify or transform the response 101 | return plugin.ResultModify 102 | } 103 | ``` 104 | 105 | #### Additional Required Methods 106 | 107 | ```go 108 | func (p *MyPlugin) Name() string { 109 | return "my_plugin" 110 | } 111 | 112 | func (p *MyPlugin) Priority() int { 113 | return 50 // Middle priority 114 | } 115 | 116 | func (p *MyPlugin) Cleanup() error { 117 | // Clean up resources 118 | return nil 119 | } 120 | ``` 121 | 122 | ## Result Types 123 | 124 | The plugin system uses `Result` objects to control request/response flow: 125 | 126 | ```go 127 | // Continue processing 128 | plugin.ResultModify 129 | 130 | // Stop processing and return custom response 131 | plugin.NewResult( 132 | plugin.Stop, 133 | plugin.WithStatus(http.StatusBadRequest), 134 | plugin.WithJSONResponse(map[string]string{ 135 | "error": "Invalid request", 136 | }), 137 | ) 138 | ``` 139 | 140 | ## Best Practices 141 | 142 | 1. **Error Handling** 143 | - Always handle errors gracefully 144 | - Use appropriate HTTP status codes 145 | - Log errors with context for debugging 146 | 147 | 2. **Performance** 148 | - Minimize memory allocations 149 | - Avoid blocking operations 150 | 151 | 3. **Context Usage** 152 | - Respect context cancellation 153 | - Use context for request-scoped values 154 | - Don't store context values across requests 155 | 156 | 4. **Configuration** 157 | - Make your plugin configurable 158 | - Provide sensible defaults 159 | - Validate configuration on startup 160 | 161 | 5. **Resource Management** 162 | - Clean up resources in the Cleanup method 163 | - Use sync.Pool for frequently allocated objects 164 | - Implement proper connection pooling if needed 165 | 166 | ## Example Plugin 167 | 168 | See the [example plugin](https://github.com/unkn0wn-root/terraster/blob/main/example/main.go) for a complete implementation demonstrating: 169 | 170 | - Request/Response processing 171 | - Rate limiting 172 | - CORS handling 173 | - Authentication 174 | - Error handling 175 | - Configuration management 176 | 177 | ## Integration 178 | 179 | To integrate your plugin with the Terraster, you have two options: 180 | 181 | ### 1. Default Plugins Directory 182 | 183 | Place your compiled plugin (.so file) in the `plugins` directory in the Terraster root: 184 | 185 | ``` 186 | terraster/ 187 | ├── plugins/ 188 | │ └── your_plugin.so 189 | ``` 190 | 191 | Remember to compile your plugin with CGO enabled: 192 | ```go 193 | CGO_ENABLED=1 go build -buildmode=plugin -o plugins/your_plugin.so your_plugin.go 194 | ``` 195 | 196 | ### 2. Custom Plugin Directory 197 | 198 | Specify a custom plugin directory in your Terraster configuration: 199 | 200 | ```yaml 201 | plugin_directory: "/path/to/plugins" # Will default to `./plugins` if not specified 202 | ``` 203 | 204 | The load balancer will automatically discover and load plugins from the configured directory during startup. Make sure: 205 | 1. Your plugin implements the required Handler interface 206 | 2. The plugin is compiled with CGO enabled 207 | 3. The plugin file has a .so extension 208 | 4. The plugin is properly structured as a Go package 209 | 210 | ## Testing 211 | 212 | It's recommended to thoroughly test your plugins: 213 | 214 | ```go 215 | func TestMyPlugin(t *testing.T) { 216 | plugin := New() 217 | 218 | // Test request processing 219 | req := httptest.NewRequest("GET", "/test", nil) 220 | result := plugin.ProcessRequest(context.Background(), req) 221 | 222 | // Assert expected behavior 223 | // ... 224 | } 225 | ``` 226 | -------------------------------------------------------------------------------- /internal/health/checker.go: -------------------------------------------------------------------------------- 1 | package health 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "strings" 10 | "sync" 11 | "sync/atomic" 12 | "time" 13 | 14 | "github.com/unkn0wn-root/terraster/internal/config" 15 | "github.com/unkn0wn-root/terraster/internal/pool" 16 | "go.uber.org/zap" 17 | "go.uber.org/zap/zapcore" 18 | ) 19 | 20 | // Supported health check types 21 | const ( 22 | HealthCheckTypeHTTP = "http" 23 | HealthCheckTypeTCP = "tcp" 24 | ) 25 | 26 | // Checker periodically checks the health of backends in registered ServerPools. 27 | type Checker struct { 28 | interval time.Duration 29 | timeout time.Duration 30 | pools []*pool.ServerPool 31 | mu sync.RWMutex 32 | client *http.Client 33 | logger *zap.Logger 34 | running atomic.Bool 35 | cancel context.CancelFunc 36 | wg sync.WaitGroup 37 | prefix string 38 | } 39 | 40 | // creates a new health checker with the given interval and timeout. 41 | func NewChecker( 42 | config *config.HealthCheck, 43 | logger *zap.Logger, 44 | prefix string, 45 | ) *Checker { 46 | hc := &http.Client{ 47 | Timeout: config.Timeout, 48 | } 49 | 50 | if config.SkipTLSVerify { 51 | hc.Transport = &http.Transport{ 52 | TLSClientConfig: &tls.Config{ 53 | InsecureSkipVerify: true, 54 | }, 55 | } 56 | } 57 | 58 | return &Checker{ 59 | interval: config.Interval, 60 | timeout: config.Timeout, 61 | pools: make([]*pool.ServerPool, 0), 62 | client: hc, 63 | logger: logger, 64 | prefix: prefix, 65 | } 66 | } 67 | 68 | func (c *Checker) RegisterPool(p *pool.ServerPool) { 69 | c.mu.Lock() 70 | defer c.mu.Unlock() 71 | c.pools = append(c.pools, p) 72 | } 73 | 74 | // begins the health checking process. 75 | func (c *Checker) Start(ctx context.Context) { 76 | if !c.running.CompareAndSwap(false, true) { 77 | c.logf(zapcore.InfoLevel, "Health checker already running") 78 | return 79 | } 80 | 81 | ctx, cancel := context.WithCancel(ctx) 82 | c.cancel = cancel 83 | c.wg.Add(1) 84 | go func() { 85 | defer c.wg.Done() 86 | ticker := time.NewTicker(c.interval) 87 | defer ticker.Stop() 88 | 89 | c.logf(zapcore.InfoLevel, "Health checker started") 90 | 91 | for { 92 | select { 93 | case <-ticker.C: 94 | c.checkAllPools() 95 | case <-ctx.Done(): 96 | c.logf(zapcore.InfoLevel, "Health checker stopping") 97 | return 98 | } 99 | } 100 | }() 101 | } 102 | 103 | func (c *Checker) Stop() { 104 | if c.running.Load() { 105 | c.cancel() 106 | c.wg.Wait() 107 | c.running.Store(false) 108 | } 109 | } 110 | 111 | // iterates over all registered server pools and checks their backends. 112 | func (c *Checker) checkAllPools() { 113 | c.mu.RLock() 114 | pools := make([]*pool.ServerPool, len(c.pools)) 115 | copy(pools, c.pools) 116 | c.mu.RUnlock() 117 | 118 | var wg sync.WaitGroup 119 | for _, p := range pools { 120 | wg.Add(1) 121 | go func(pool *pool.ServerPool) { 122 | defer wg.Done() 123 | c.checkPool(pool) 124 | }(p) 125 | } 126 | wg.Wait() 127 | } 128 | 129 | // performs health checks on all backends within a single ServerPool. 130 | func (c *Checker) checkPool(s *pool.ServerPool) { 131 | backends := s.GetAllBackends() 132 | var wg sync.WaitGroup 133 | for _, backend := range backends { 134 | wg.Add(1) 135 | go func(b *pool.Backend) { 136 | defer wg.Done() 137 | c.checkBackend(b) 138 | }(backend) 139 | } 140 | wg.Wait() 141 | } 142 | 143 | // performs a health check on a single backend based on its type. 144 | func (c *Checker) checkBackend(b *pool.Backend) { 145 | switch strings.ToLower(b.HealthCheckCfg.Type) { 146 | case HealthCheckTypeHTTP: 147 | c.performHTTPHealthCheck(b) 148 | case HealthCheckTypeTCP: 149 | c.performTCPHealthCheck(b) 150 | default: 151 | c.logf(zap.WarnLevel, "Unsupported health check type '%s' for backend %s", b.HealthCheckCfg.Type, b.URL) 152 | c.updateBackendHealth(b, false) 153 | } 154 | } 155 | 156 | // http-based health check 157 | func (c *Checker) performHTTPHealthCheck(b *pool.Backend) { 158 | healthPath := b.HealthCheckCfg.Path 159 | if healthPath == "" { 160 | healthPath = "/health" // default health path 161 | } 162 | 163 | healthURL := *b.URL 164 | healthURL.Path = healthPath 165 | req, err := http.NewRequest("GET", healthURL.String(), nil) 166 | if err != nil { 167 | c.logf(zap.WarnLevel, "Failed to create HTTP health check request for %s: %v", b.URL, err) 168 | c.updateBackendHealth(b, false) 169 | return 170 | } 171 | 172 | resp, err := c.client.Do(req) 173 | if err != nil { 174 | c.logf(zap.WarnLevel, "HTTP health check failed for %s: %v", b.URL, err) 175 | c.updateBackendHealth(b, false) 176 | return 177 | } 178 | defer resp.Body.Close() 179 | 180 | if resp.StatusCode >= 200 && resp.StatusCode < 300 { 181 | c.updateBackendHealth(b, true) 182 | } else { 183 | c.logf(zap.WarnLevel, "HTTP health check returned non-2xx for %s: %d", b.URL, resp.StatusCode) 184 | c.updateBackendHealth(b, false) 185 | } 186 | } 187 | 188 | // TCP-based health check. 189 | func (c *Checker) performTCPHealthCheck(b *pool.Backend) { 190 | healthAddress := b.URL.Host 191 | host, port, err := net.SplitHostPort(healthAddress) 192 | if err != nil { 193 | host = healthAddress 194 | if b.URL.Scheme == "https" { 195 | port = "443" 196 | } else { 197 | port = "80" 198 | } 199 | } 200 | 201 | tcpAddress := net.JoinHostPort(host, port) 202 | conn, err := net.DialTimeout("tcp", tcpAddress, c.timeout) 203 | if err != nil { 204 | c.logf(zap.WarnLevel, "TCP health check failed for %s: %v", b.URL, err) 205 | c.updateBackendHealth(b, false) 206 | return 207 | } 208 | 209 | conn.Close() 210 | c.updateBackendHealth(b, true) 211 | } 212 | 213 | // updates the backend's health status based on the check result. 214 | func (c *Checker) updateBackendHealth(b *pool.Backend, healthy bool) { 215 | if healthy { 216 | newSuccess := atomic.AddInt32(&b.SuccessCount, 1) 217 | atomic.StoreInt32(&b.FailureCount, 0) 218 | if newSuccess >= int32(b.HealthCheckCfg.Thresholds.Healthy) { 219 | if !b.Alive.Load() { 220 | c.logf(zap.InfoLevel, "Backend %s marked as healthy", b.URL) 221 | s := findServerPool(c.pools, b) 222 | if s != nil { 223 | s.MarkBackendStatus(b.URL, true) 224 | } 225 | } 226 | } 227 | } else { 228 | newFailure := atomic.AddInt32(&b.FailureCount, 1) 229 | atomic.StoreInt32(&b.SuccessCount, 0) 230 | if newFailure >= int32(b.HealthCheckCfg.Thresholds.Unhealthy) { 231 | if b.Alive.Load() { 232 | c.logf(zap.WarnLevel, "Backend %s marked as unhealthy", b.URL) 233 | s := findServerPool(c.pools, b) 234 | if s != nil { 235 | s.MarkBackendStatus(b.URL, false) 236 | } 237 | } 238 | } 239 | } 240 | } 241 | 242 | // Helper method to log with prefix 243 | func (c *Checker) logf(level zapcore.Level, template string, args ...interface{}) { 244 | msg := fmt.Sprintf(template, args...) 245 | switch level { 246 | case zapcore.InfoLevel: 247 | c.logger.Info(msg, zap.String("prefix", c.prefix)) 248 | case zapcore.WarnLevel: 249 | c.logger.Warn(msg, zap.String("prefix", c.prefix)) 250 | case zapcore.ErrorLevel: 251 | c.logger.Error(msg, zap.String("prefix", c.prefix)) 252 | } 253 | } 254 | 255 | // locates the ServerPool containing the specified backend. 256 | func findServerPool(pools []*pool.ServerPool, backend *pool.Backend) *pool.ServerPool { 257 | for _, s := range pools { 258 | currentBackends := s.GetAllBackends() 259 | for _, b := range currentBackends { 260 | if b == backend { 261 | return s 262 | } 263 | } 264 | } 265 | 266 | return nil 267 | } 268 | -------------------------------------------------------------------------------- /plugins/example/main.go: -------------------------------------------------------------------------------- 1 | package example_plugin 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "os" 9 | "sync" 10 | "time" 11 | 12 | "github.com/unkn0wn-root/terraster/pkg/plugin" 13 | ) 14 | 15 | type Config struct { 16 | AllowedOrigins []string `json:"allowed_origins"` 17 | RateLimit int `json:"rate_limit"` 18 | AuthHeader string `json:"auth_header"` 19 | Debug bool `json:"debug"` 20 | } 21 | 22 | type ExamplePlugin struct { 23 | config Config 24 | rateLimits sync.Map // client -> limit data 25 | logger Logger // interface for logging 26 | } 27 | 28 | // RateLimitData tracks rate limiting data 29 | type RateLimitData struct { 30 | count int 31 | resetAt time.Time 32 | lastAccess time.Time 33 | } 34 | 35 | type Logger interface { 36 | Info(msg string, fields ...interface{}) 37 | Error(msg string, fields ...interface{}) 38 | } 39 | 40 | func New() plugin.Handler { 41 | // Read config from environment or file 42 | config := loadConfig() 43 | 44 | return &ExamplePlugin{ 45 | config: config, 46 | } 47 | } 48 | 49 | func loadConfig() Config { 50 | // Try to load from environment first 51 | if configJSON := os.Getenv("EXAMPLE_PLUGIN_CONFIG"); configJSON != "" { 52 | var config Config 53 | if err := json.Unmarshal([]byte(configJSON), &config); err == nil { 54 | return config 55 | } 56 | } 57 | 58 | // Default config 59 | return Config{ 60 | AllowedOrigins: []string{"*"}, 61 | RateLimit: 100, 62 | AuthHeader: "X-API-Key", 63 | Debug: false, 64 | } 65 | } 66 | 67 | // ProcessRequest handles incoming requests 68 | func (p *ExamplePlugin) ProcessRequest(ctx context.Context, req *http.Request) *plugin.Result { 69 | // 1. CORS Pre-flight handling 70 | if req.Method == http.MethodOptions { 71 | return p.handleCORS(req) 72 | } 73 | 74 | // 2. Rate limiting 75 | if exceeded, resetAt := p.isRateLimitExceeded(req); exceeded { 76 | return plugin.NewResult( 77 | plugin.Stop, 78 | plugin.WithStatus(http.StatusTooManyRequests), 79 | plugin.WithHeaders(http.Header{ 80 | "X-RateLimit-Reset": []string{fmt.Sprint(resetAt.Unix())}, 81 | }), 82 | plugin.WithJSONResponse(map[string]interface{}{ 83 | "error": "rate limit exceeded", 84 | "reset_at": resetAt.Format(time.RFC3339), 85 | "limit": p.config.RateLimit, 86 | "interval": "1 minute", 87 | }), 88 | ) 89 | } 90 | 91 | // 3. Authentication 92 | if !p.isAuthenticated(req) { 93 | return plugin.NewResult( 94 | plugin.Stop, 95 | plugin.WithStatus(http.StatusUnauthorized), 96 | plugin.WithJSONResponse(map[string]string{ 97 | "error": fmt.Sprintf("missing or invalid %s header", p.config.AuthHeader), 98 | }), 99 | ) 100 | } 101 | 102 | // 4. Request modification example 103 | // Add custom headers 104 | req.Header.Set("X-Processed-By", "example-plugin") 105 | req.Header.Set("X-Request-Start", time.Now().Format(time.RFC3339)) 106 | 107 | // Store timing info in context for response processing 108 | *req = *req.WithContext(context.WithValue(req.Context(), "req_start", time.Now())) 109 | 110 | return plugin.ResultModify 111 | } 112 | 113 | // ProcessResponse handles outgoing responses 114 | func (p *ExamplePlugin) ProcessResponse(ctx context.Context, resp *http.Response) *plugin.Result { 115 | // 1. Add security headers 116 | resp.Header.Set("X-Content-Type-Options", "nosniff") 117 | resp.Header.Set("X-Frame-Options", "DENY") 118 | resp.Header.Set("X-XSS-Protection", "1; mode=block") 119 | 120 | // 2. Add timing headers if we have the start time 121 | if startTime, ok := resp.Request.Context().Value("req_start").(time.Time); ok { 122 | processingTime := time.Since(startTime).Milliseconds() 123 | resp.Header.Set("X-Processing-Time", fmt.Sprintf("%dms", processingTime)) 124 | } 125 | 126 | // 3. Add CORS headers for non-OPTIONS requests 127 | if resp.Request.Method != http.MethodOptions { 128 | origin := resp.Request.Header.Get("Origin") 129 | if p.isOriginAllowed(origin) { 130 | resp.Header.Set("Access-Control-Allow-Origin", origin) 131 | resp.Header.Set("Access-Control-Allow-Credentials", "true") 132 | } 133 | } 134 | 135 | // 4. Handle specific error cases (example) 136 | if resp.StatusCode >= 500 { 137 | // Log server errors 138 | p.logError("Backend error", map[string]interface{}{ 139 | "status": resp.StatusCode, 140 | "path": resp.Request.URL.Path, 141 | "method": resp.Request.Method, 142 | }) 143 | 144 | // Optionally modify error response 145 | if p.config.Debug { 146 | return plugin.ResultModify 147 | } 148 | 149 | return plugin.NewResult( 150 | plugin.Stop, 151 | plugin.WithStatus(http.StatusBadGateway), 152 | plugin.WithJSONResponse(map[string]string{ 153 | "error": "service temporarily unavailable", 154 | }), 155 | ) 156 | } 157 | 158 | return plugin.ResultModify 159 | } 160 | 161 | func (p *ExamplePlugin) handleCORS(req *http.Request) *plugin.Result { 162 | origin := req.Header.Get("Origin") 163 | if !p.isOriginAllowed(origin) { 164 | return plugin.NewResult( 165 | plugin.Stop, 166 | plugin.WithStatus(http.StatusForbidden), 167 | plugin.WithJSONResponse(map[string]string{ 168 | "error": "origin not allowed", 169 | }), 170 | ) 171 | } 172 | 173 | return plugin.NewResult( 174 | plugin.Stop, 175 | plugin.WithStatus(http.StatusOK), 176 | plugin.WithHeaders(http.Header{ 177 | "Access-Control-Allow-Origin": []string{origin}, 178 | "Access-Control-Allow-Methods": []string{"GET, POST, PUT, DELETE, OPTIONS"}, 179 | "Access-Control-Allow-Headers": []string{"Content-Type, Authorization, " + p.config.AuthHeader}, 180 | "Access-Control-Allow-Credentials": []string{"true"}, 181 | "Access-Control-Max-Age": []string{"86400"}, 182 | }), 183 | ) 184 | } 185 | 186 | func (p *ExamplePlugin) isRateLimitExceeded(req *http.Request) (bool, time.Time) { 187 | clientID := req.RemoteAddr 188 | 189 | now := time.Now() 190 | if data, exists := p.rateLimits.Load(clientID); exists { 191 | limit := data.(*RateLimitData) 192 | 193 | // Reset if window has passed 194 | if now.Sub(limit.resetAt) >= time.Minute { 195 | limit.count = 0 196 | limit.resetAt = now.Add(time.Minute) 197 | } 198 | 199 | limit.count++ 200 | limit.lastAccess = now 201 | 202 | if limit.count > p.config.RateLimit { 203 | return true, limit.resetAt 204 | } 205 | return false, limit.resetAt 206 | } 207 | 208 | // First request for this client 209 | p.rateLimits.Store(clientID, &RateLimitData{ 210 | count: 1, 211 | resetAt: now.Add(time.Minute), 212 | lastAccess: now, 213 | }) 214 | 215 | return false, now.Add(time.Minute) 216 | } 217 | 218 | func (p *ExamplePlugin) isAuthenticated(req *http.Request) bool { 219 | authToken := req.Header.Get(p.config.AuthHeader) 220 | return authToken != "" // Just example. Shoudl have real auth logic 221 | } 222 | 223 | func (p *ExamplePlugin) isOriginAllowed(origin string) bool { 224 | if len(p.config.AllowedOrigins) == 0 || p.config.AllowedOrigins[0] == "*" { 225 | return true 226 | } 227 | 228 | for _, allowed := range p.config.AllowedOrigins { 229 | if allowed == origin { 230 | return true 231 | } 232 | } 233 | return false 234 | } 235 | 236 | func (p *ExamplePlugin) logError(msg string, fields map[string]interface{}) { 237 | if p.logger != nil { 238 | p.logger.Error(msg, fields) 239 | } 240 | } 241 | 242 | // Required plugin interface methods 243 | func (p *ExamplePlugin) Name() string { 244 | return "example_plugin" 245 | } 246 | 247 | func (p *ExamplePlugin) Priority() int { 248 | return 50 // Middle priority 249 | } 250 | 251 | func (p *ExamplePlugin) Cleanup() error { 252 | // Clean up rate limit data older than 1 hour 253 | now := time.Now() 254 | p.rateLimits.Range(func(key, value interface{}) bool { 255 | data := value.(*RateLimitData) 256 | if now.Sub(data.lastAccess) > time.Hour { 257 | p.rateLimits.Delete(key) 258 | } 259 | return true 260 | }) 261 | return nil 262 | } 263 | -------------------------------------------------------------------------------- /internal/auth/handlers/auth_handler.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/golang-jwt/jwt/v4" 9 | apierr "github.com/unkn0wn-root/terraster/internal/auth" 10 | "github.com/unkn0wn-root/terraster/internal/auth/models" 11 | "github.com/unkn0wn-root/terraster/internal/auth/service" 12 | ) 13 | 14 | type AuthHandler struct { 15 | authService *service.AuthService 16 | } 17 | 18 | func NewAuthHandler(authService *service.AuthService) *AuthHandler { 19 | return &AuthHandler{ 20 | authService: authService, 21 | } 22 | } 23 | 24 | type LoginRequest struct { 25 | Username string `json:"username"` 26 | Password string `json:"password"` 27 | } 28 | 29 | type LoginResponse struct { 30 | Token string `json:"token"` 31 | RefreshToken string `json:"refresh_token"` 32 | Type string `json:"type"` 33 | ExpiresAt time.Time `json:"expires_at"` 34 | Role string `json:"role"` 35 | } 36 | 37 | type CreateUserRequest struct { 38 | Username string `json:"username"` 39 | Password string `json:"password"` 40 | Role string `json:"role"` 41 | } 42 | 43 | type ChangePasswordRequest struct { 44 | OldPassword string `json:"old_password"` 45 | NewPassword string `json:"new_password"` 46 | } 47 | 48 | func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { 49 | if r.Method != http.MethodPost { 50 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 51 | return 52 | } 53 | 54 | var req LoginRequest 55 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 56 | http.Error(w, "Invalid request body", http.StatusBadRequest) 57 | return 58 | } 59 | 60 | token, err := h.authService.AuthenticateUser(req.Username, req.Password, r) 61 | if err != nil { 62 | switch err { 63 | case apierr.ErrUserLocked: 64 | http.Error(w, "Account is temporarily locked", http.StatusForbidden) 65 | case apierr.ErrInvalidCredentials: 66 | http.Error(w, "Invalid credentials", http.StatusUnauthorized) 67 | case apierr.ErrPasswordExpired: 68 | http.Error(w, "Password has expired", http.StatusUnauthorized) 69 | default: 70 | http.Error(w, "Authentication failed", http.StatusInternalServerError) 71 | } 72 | return 73 | } 74 | 75 | response := LoginResponse{ 76 | Token: token.Token, 77 | RefreshToken: token.RefreshToken, 78 | Type: "Bearer", 79 | ExpiresAt: token.ExpiresAt, 80 | Role: string(token.Role), 81 | } 82 | 83 | w.Header().Set("Content-Type", "application/json") 84 | json.NewEncoder(w).Encode(response) 85 | } 86 | 87 | func (h *AuthHandler) CreateUser(w http.ResponseWriter, r *http.Request) { 88 | if r.Method != http.MethodPost { 89 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 90 | return 91 | } 92 | 93 | // Check if requester is admin 94 | claims := r.Context().Value("user_claims").(*jwt.MapClaims) 95 | if role := (*claims)["role"].(string); role != string(models.RoleAdmin) { 96 | http.Error(w, "Forbidden", http.StatusForbidden) 97 | return 98 | } 99 | 100 | var req CreateUserRequest 101 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 102 | http.Error(w, "Invalid request body", http.StatusBadRequest) 103 | return 104 | } 105 | 106 | err := h.authService.CreateUser(req.Username, req.Password, models.Role(req.Role)) 107 | if err != nil { 108 | http.Error(w, err.Error(), http.StatusBadRequest) 109 | return 110 | } 111 | 112 | w.WriteHeader(http.StatusCreated) 113 | } 114 | 115 | func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) { 116 | if r.Method != http.MethodPost { 117 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 118 | return 119 | } 120 | 121 | token := r.Header.Get("Authorization") 122 | if token == "" || len(token) < 7 || token[:7] != "Bearer " { 123 | http.Error(w, "Invalid authorization header", http.StatusBadRequest) 124 | return 125 | } 126 | 127 | if err := h.authService.RevokeToken(token[7:]); err != nil { 128 | http.Error(w, "Error revoking token", http.StatusInternalServerError) 129 | return 130 | } 131 | 132 | w.WriteHeader(http.StatusNoContent) 133 | } 134 | 135 | func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { 136 | if r.Method != http.MethodPost { 137 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 138 | return 139 | } 140 | 141 | var req struct { 142 | RefreshToken string `json:"refresh_token"` 143 | } 144 | 145 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 146 | http.Error(w, "Invalid request body", http.StatusBadRequest) 147 | return 148 | } 149 | 150 | token, err := h.authService.RefreshToken(req.RefreshToken, r) 151 | if err != nil { 152 | switch err { 153 | case apierr.ErrInvalidToken: 154 | http.Error(w, "Invalid refresh token", http.StatusUnauthorized) 155 | case apierr.ErrRevokedToken: 156 | http.Error(w, "Refresh token has been revoked", http.StatusUnauthorized) 157 | default: 158 | http.Error(w, "Error refreshing token", http.StatusInternalServerError) 159 | } 160 | return 161 | } 162 | 163 | response := LoginResponse{ 164 | Token: token.Token, 165 | RefreshToken: token.RefreshToken, 166 | Type: "Bearer", 167 | ExpiresAt: token.ExpiresAt, 168 | Role: string(token.Role), 169 | } 170 | 171 | w.Header().Set("Content-Type", "application/json") 172 | json.NewEncoder(w).Encode(response) 173 | } 174 | 175 | func (h *AuthHandler) ListSessions(w http.ResponseWriter, r *http.Request) { 176 | if r.Method != http.MethodGet { 177 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 178 | return 179 | } 180 | 181 | claims := r.Context().Value("user_claims").(*jwt.MapClaims) 182 | userID := int64((*claims)["user_id"].(float64)) 183 | 184 | sessions, err := h.authService.GetActiveSessions(userID) 185 | if err != nil { 186 | http.Error(w, "Error fetching sessions", http.StatusInternalServerError) 187 | return 188 | } 189 | 190 | w.Header().Set("Content-Type", "application/json") 191 | json.NewEncoder(w).Encode(sessions) 192 | } 193 | 194 | // Add to AuthHandler 195 | func (h *AuthHandler) GetPasswordRequirements(w http.ResponseWriter, r *http.Request) { 196 | if r.Method != http.MethodGet { 197 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 198 | return 199 | } 200 | 201 | config := h.authService.GetConfig() 202 | 203 | requirements := map[string]interface{}{ 204 | "minLength": config.PasswordMinLength, 205 | "maxLength": 128, 206 | "requires": map[string]bool{ 207 | "uppercase": config.RequireUppercase, 208 | "lowercase": true, 209 | "number": config.RequireNumber, 210 | "special": config.RequireSpecialChar, 211 | }, 212 | "preventions": []string{ 213 | "3 or more consecutive identical characters", 214 | "Sequential characters (abc, 123)", 215 | "Username in password", 216 | "Common passwords", 217 | }, 218 | } 219 | 220 | w.Header().Set("Content-Type", "application/json") 221 | json.NewEncoder(w).Encode(requirements) 222 | } 223 | 224 | func (h *AuthHandler) ChangePassword(w http.ResponseWriter, r *http.Request) { 225 | if r.Method != http.MethodPost { 226 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 227 | return 228 | } 229 | 230 | // Get user from context 231 | claims := r.Context().Value("user_claims").(*jwt.MapClaims) 232 | userID := int64((*claims)["user_id"].(float64)) 233 | 234 | var req ChangePasswordRequest 235 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil { 236 | http.Error(w, "Invalid request body", http.StatusBadRequest) 237 | return 238 | } 239 | 240 | if err := h.authService.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil { 241 | switch err { 242 | case apierr.ErrInvalidCredentials: 243 | http.Error(w, "Current password is incorrect", http.StatusUnauthorized) 244 | default: 245 | http.Error(w, err.Error(), http.StatusBadRequest) 246 | } 247 | return 248 | } 249 | 250 | w.WriteHeader(http.StatusCreated) 251 | } 252 | -------------------------------------------------------------------------------- /pkg/logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "os" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | "github.com/natefinch/lumberjack" 12 | "go.uber.org/zap" 13 | "go.uber.org/zap/zapcore" 14 | ) 15 | 16 | var ( 17 | managerOnce sync.Once 18 | initErr error 19 | ) 20 | 21 | type Config struct { 22 | Level string `json:"level"` 23 | OutputPaths []string `json:"outputPaths"` 24 | ErrorOutputPaths []string `json:"errorOutputPaths"` 25 | Development bool `json:"development"` 26 | LogToConsole bool `json:"logToConsole"` 27 | Sampling Sampling `json:"sampling"` 28 | Encoding Encoding `json:"encodingConfig"` 29 | LogRotation LogRotation `json:"logRotation"` 30 | Sanitization Sanitization `json:"sanitization"` 31 | } 32 | 33 | type Sampling struct { 34 | Initial int `json:"initial"` 35 | Thereafter int `json:"thereafter"` 36 | } 37 | 38 | type Encoding struct { 39 | TimeKey string `json:"timeKey"` 40 | LevelKey string `json:"levelKey"` 41 | NameKey string `json:"nameKey"` 42 | CallerKey string `json:"callerKey"` 43 | MessageKey string `json:"messageKey"` 44 | StacktraceKey string `json:"stacktraceKey"` 45 | LineEnding string `json:"lineEnding"` 46 | LevelEncoder string `json:"levelEncoder"` 47 | TimeEncoder string `json:"timeEncoder"` 48 | DurationEncoder string `json:"durationEncoder"` 49 | CallerEncoder string `json:"callerEncoder"` 50 | } 51 | 52 | type LogRotation struct { 53 | Enabled bool `json:"enabled"` 54 | MaxSizeMB int `json:"maxSizeMB"` 55 | MaxBackups int `json:"maxBackups"` 56 | MaxAgeDays int `json:"maxAgeDays"` 57 | Compress bool `json:"compress"` 58 | } 59 | 60 | // Sanitization configures sensitive field sanitization. 61 | type Sanitization struct { 62 | SensitiveFields []string `json:"sensitiveFields"` 63 | Mask string `json:"mask"` 64 | } 65 | 66 | // Init initializes the loggers based on multiple configuration files. 67 | // It should be called once at the start of the application. 68 | func Init(configPaths []string, manager *LoggerManager) error { 69 | managerOnce.Do(func() { 70 | for _, configPath := range configPaths { 71 | var cfgMap map[string]Config 72 | data, err := os.ReadFile(configPath) 73 | if err != nil { 74 | if os.IsNotExist(err) { 75 | fmt.Printf("Configuration file '%s' not found. Skipping.\n", configPath) 76 | continue 77 | } else { 78 | initErr = fmt.Errorf("failed to read configuration file '%s': %w", configPath, err) 79 | return 80 | } 81 | } 82 | 83 | var configWrapper struct { 84 | Loggers map[string]Config `json:"loggers"` 85 | } 86 | if err := json.Unmarshal(data, &configWrapper); err != nil { 87 | initErr = fmt.Errorf("failed to parse configuration file '%s': %w", configPath, err) 88 | return 89 | } 90 | 91 | cfgMap = configWrapper.Loggers 92 | for name, cfg := range cfgMap { 93 | logger, err := buildLogger(name, &cfg) 94 | if err != nil { 95 | initErr = fmt.Errorf("failed to build logger '%s': %w", name, err) 96 | return 97 | } 98 | if err := manager.AddLogger(name, logger); err != nil { 99 | initErr = fmt.Errorf("failed to add logger '%s' from config '%s': %w", name, configPath, err) 100 | return 101 | } 102 | } 103 | } 104 | 105 | // If no loggers were loaded from config files, initialize default logger 106 | if len(manager.loggers) == 0 { 107 | defaultLogger, err := buildLogger("default", manager.defaultConfig) 108 | if err != nil { 109 | initErr = fmt.Errorf("failed to build default logger: %w", err) 110 | return 111 | } 112 | if err := manager.AddLogger("default", defaultLogger); err != nil { 113 | initErr = fmt.Errorf("failed to add default logger: %w", err) 114 | return 115 | } 116 | } 117 | }) 118 | 119 | return initErr 120 | } 121 | 122 | func buildLogger(name string, cfg *Config) (*zap.Logger, error) { 123 | assignDefaultValues(cfg) 124 | 125 | encoderConfig := zapcore.EncoderConfig{ 126 | TimeKey: cfg.Encoding.TimeKey, 127 | LevelKey: cfg.Encoding.LevelKey, 128 | NameKey: cfg.Encoding.NameKey, 129 | CallerKey: cfg.Encoding.CallerKey, 130 | MessageKey: cfg.Encoding.MessageKey, 131 | StacktraceKey: cfg.Encoding.StacktraceKey, 132 | LineEnding: cfg.Encoding.LineEnding, 133 | EncodeLevel: getZapLevelEncoder(cfg.Encoding.LevelEncoder), 134 | EncodeTime: getZapTimeEncoder(cfg.Encoding.TimeEncoder), 135 | EncodeDuration: getZapDurationEncoder(cfg.Encoding.DurationEncoder), 136 | EncodeCaller: getZapCallerEncoder(cfg.Encoding.CallerEncoder), 137 | } 138 | 139 | consoleEncoderConfig := encoderConfig 140 | consoleEncoderConfig.EncodeLevel = coloredLevelEncoder 141 | consoleEncoder := zapcore.NewConsoleEncoder(consoleEncoderConfig) 142 | 143 | jsonEncoder := zapcore.NewJSONEncoder(encoderConfig) 144 | atomicLevel := zap.NewAtomicLevelAt(getZapLevel(cfg.Level)) 145 | 146 | var allCores []zapcore.Core 147 | if cfg.Development || cfg.LogToConsole { 148 | consoleWS := zapcore.Lock(os.Stdout) 149 | consoleCore := zapcore.NewCore(consoleEncoder, consoleWS, atomicLevel) 150 | allCores = append(allCores, consoleCore) 151 | } 152 | 153 | for _, path := range cfg.OutputPaths { 154 | if path == "stdout" || path == "stderr" { 155 | continue 156 | } 157 | 158 | var fileWS zapcore.WriteSyncer 159 | if cfg.LogRotation.Enabled { 160 | lj := ljLogger(path, cfg.LogRotation) 161 | fileWS = zapcore.AddSync(lj) 162 | } else { 163 | file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) 164 | if err != nil { 165 | return nil, fmt.Errorf("Failed to open log file '%s': %v\n", path, err) 166 | } 167 | fileWS = zapcore.AddSync(file) 168 | } 169 | 170 | fileCore := zapcore.NewCore(jsonEncoder, fileWS, atomicLevel) 171 | asyncFileCore := NewAsyncCore(fileCore, 1000, 100, 500*time.Millisecond) // bufferSize, batchSize, flushInterval 172 | allCores = append(allCores, asyncFileCore) 173 | } 174 | 175 | combinedCore := zapcore.NewTee(allCores...) 176 | 177 | if len(cfg.Sanitization.SensitiveFields) > 0 { 178 | combinedCore = NewSanitizerCore(combinedCore, cfg.Sanitization.SensitiveFields, cfg.Sanitization.Mask) 179 | } 180 | 181 | logger := zap.New(combinedCore, 182 | zap.AddCaller(), 183 | zap.AddStacktrace(zap.ErrorLevel), 184 | ).Named(name) 185 | 186 | return logger, nil 187 | } 188 | 189 | // maps string levels to zapcore.Level. 190 | func getZapLevel(level string) zapcore.Level { 191 | switch strings.ToLower(level) { 192 | case "debug": 193 | return zap.DebugLevel 194 | case "info": 195 | return zap.InfoLevel 196 | case "warn", "warning": 197 | return zap.WarnLevel 198 | case "error": 199 | return zap.ErrorLevel 200 | case "dpanic": 201 | return zap.DPanicLevel 202 | case "panic": 203 | return zap.PanicLevel 204 | case "fatal": 205 | return zap.FatalLevel 206 | default: 207 | return zap.InfoLevel 208 | } 209 | } 210 | 211 | // maps string encoders to zapcore.LevelEncoder. 212 | func getZapLevelEncoder(encoder string) zapcore.LevelEncoder { 213 | switch strings.ToLower(encoder) { 214 | case "lowercase": 215 | return zapcore.LowercaseLevelEncoder 216 | case "uppercase": 217 | return zapcore.CapitalLevelEncoder 218 | case "capital": 219 | return zapcore.CapitalLevelEncoder 220 | default: 221 | return zapcore.LowercaseLevelEncoder 222 | } 223 | } 224 | 225 | // maps string encoders to zapcore.TimeEncoder. 226 | func getZapTimeEncoder(encoder string) zapcore.TimeEncoder { 227 | switch strings.ToLower(encoder) { 228 | case "iso8601": 229 | return zapcore.ISO8601TimeEncoder 230 | case "epoch": 231 | return zapcore.EpochTimeEncoder 232 | case "millis": 233 | return zapcore.EpochMillisTimeEncoder 234 | case "nanos": 235 | return zapcore.EpochNanosTimeEncoder 236 | default: 237 | return zapcore.ISO8601TimeEncoder 238 | } 239 | } 240 | 241 | // maps string encoders to zapcore.DurationEncoder. 242 | func getZapDurationEncoder(encoder string) zapcore.DurationEncoder { 243 | switch strings.ToLower(encoder) { 244 | case "string": 245 | return zapcore.StringDurationEncoder 246 | case "seconds": 247 | return zapcore.SecondsDurationEncoder 248 | case "millis": 249 | return zapcore.MillisDurationEncoder 250 | case "nanos": 251 | return zapcore.NanosDurationEncoder 252 | default: 253 | return zapcore.StringDurationEncoder 254 | } 255 | } 256 | 257 | // maps string encoders to zapcore.CallerEncoder. 258 | func getZapCallerEncoder(encoder string) zapcore.CallerEncoder { 259 | switch strings.ToLower(encoder) { 260 | case "full": 261 | return zapcore.FullCallerEncoder 262 | case "short": 263 | return zapcore.ShortCallerEncoder 264 | default: 265 | return zapcore.ShortCallerEncoder 266 | } 267 | } 268 | 269 | // adds color codes to log levels for console output - this is a bit slow so only in dev 270 | func coloredLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) { 271 | var level string 272 | switch l { 273 | case zapcore.DebugLevel: 274 | level = "\x1b[36m" + l.String() + "\x1b[0m" // Cyan 275 | case zapcore.InfoLevel: 276 | level = "\x1b[32m" + l.String() + "\x1b[0m" // Green 277 | case zapcore.WarnLevel: 278 | level = "\x1b[33m" + l.String() + "\x1b[0m" // Yellow 279 | case zapcore.ErrorLevel: 280 | level = "\x1b[31m" + l.String() + "\x1b[0m" // Red 281 | case zapcore.DPanicLevel, zapcore.PanicLevel, zapcore.FatalLevel: 282 | level = "\x1b[35m" + l.String() + "\x1b[0m" // Magenta 283 | default: 284 | level = l.String() 285 | } 286 | enc.AppendString(level) 287 | } 288 | 289 | // creates a new Lumberjack logger with the given path and configuration. 290 | func ljLogger(path string, l LogRotation) *lumberjack.Logger { 291 | return &lumberjack.Logger{ 292 | Filename: path, 293 | MaxSize: l.MaxSizeMB, 294 | MaxBackups: l.MaxBackups, 295 | MaxAge: l.MaxAgeDays, 296 | Compress: l.Compress, 297 | } 298 | } 299 | -------------------------------------------------------------------------------- /internal/service/manager.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | "sync" 8 | 9 | "github.com/unkn0wn-root/terraster/internal/config" 10 | "github.com/unkn0wn-root/terraster/internal/pool" 11 | "github.com/unkn0wn-root/terraster/pkg/algorithm" 12 | "github.com/unkn0wn-root/terraster/pkg/plugin" 13 | "go.uber.org/zap" 14 | ) 15 | 16 | var ( 17 | ErrServiceAlreadyExists = errors.New("service already exists") 18 | ErrDuplicateLocation = errors.New("duplicate location path") 19 | ErrNotDefined = errors.New("service must have either host or name defined") 20 | ) 21 | 22 | // ServiceType represents the type of service protocol, either HTTP or HTTPS. 23 | type ServiceType string 24 | 25 | const ( 26 | HTTP ServiceType = "http" 27 | HTTPS ServiceType = "https" 28 | ) 29 | 30 | // Manager is responsible for managing all the services within the Terraster application. 31 | type Manager struct { 32 | services map[string]*ServiceInfo // A map of service identifiers to their corresponding ServiceInfo. 33 | pluginManager *plugin.Manager 34 | logger *zap.Logger // Logger instance for logging service manager activities. 35 | mu sync.RWMutex // Mutex to ensure thread-safe access to the services map. 36 | } 37 | 38 | // ServiceInfo contains comprehensive information about a service, including its routing and backend configurations. 39 | type ServiceInfo struct { 40 | Name string // The unique name of the service. 41 | Host string // The host address where the service is accessible. 42 | Port int // The port number on which the service listens. 43 | TLS *config.TLS // TLS configuration for the service, if HTTPS is enabled. 44 | HTTPRedirect bool // Indicates whether HTTP requests should be redirected to HTTPS. 45 | RedirectPort int // The port to which HTTP requests are redirected for HTTPS. 46 | HealthCheck *config.HealthCheck // Health check configuration specific to the service. 47 | Locations []*LocationInfo // A slice of LocationInfo representing different routing paths for the service. 48 | Middleware []config.Middleware // Middleware configurations for the service. 49 | LogName string // LogName will be used to get service logger from config. 50 | LogOptions *config.LogOptions // LogOptions define diffrent logger options such as headers or query params logging 51 | Logger *zap.Logger // Logger instance for logging service activities. 52 | Headers *config.Header // Request/Response custom headers 53 | } 54 | 55 | // ServiceType determines the protocol type of the service based on its TLS configuration. 56 | // It returns HTTPS if TLS is enabled, otherwise HTTP. 57 | func (s *ServiceInfo) ServiceType() ServiceType { 58 | if s.TLS != nil && s.TLS.Enabled { 59 | return HTTPS 60 | } 61 | return HTTP 62 | } 63 | 64 | // LocationInfo contains routing and backend information for a specific path within a service. 65 | // Defines how incoming requests matching the path should be handled and which backend servers to proxy to. 66 | type LocationInfo struct { 67 | Path string // The URL path that this location handles. 68 | Rewrite string // The URL rewrite rule applied to incoming requests. 69 | Algorithm algorithm.Algorithm // The load balancing algorithm used to select a backend server. 70 | ServerPool *pool.ServerPool // The pool of backend servers associated with this location. 71 | } 72 | 73 | // NewManager initializes and returns a new instance of Manager. 74 | // It sets up services based on the provided configuration and initializes their respective server pools. 75 | // If no services are defined in the configuration but backends are provided, it creates a default service. 76 | func NewManager(cfg *config.Terraster, logger *zap.Logger, pm *plugin.Manager) (*Manager, error) { 77 | m := &Manager{ 78 | services: make(map[string]*ServiceInfo), 79 | pluginManager: pm, 80 | logger: logger, 81 | } 82 | 83 | if len(cfg.Services) == 0 && len(cfg.Backends) > 0 { 84 | host := cfg.Host 85 | if host == "" { 86 | host = "localhost" 87 | } 88 | 89 | defaultService := config.Service{ 90 | Name: "default", 91 | Host: host, 92 | Port: 8080, 93 | TLS: &cfg.TLS, 94 | Locations: []config.Location{ 95 | { 96 | Path: "", 97 | LoadBalancer: "round-robin", 98 | Backends: cfg.Backends, 99 | }, 100 | }, 101 | } 102 | if err := m.AddService(defaultService, cfg.HealthCheck); err != nil { 103 | return nil, err 104 | } 105 | } else { 106 | for _, svc := range cfg.Services { 107 | hcCfg := svc.HealthCheck 108 | if hcCfg == nil { 109 | hcCfg = cfg.HealthCheck 110 | } 111 | if err := m.AddService(svc, hcCfg); err != nil { 112 | return nil, err 113 | } 114 | } 115 | } 116 | 117 | return m, nil 118 | } 119 | 120 | // AddService adds a new service to the Manager with the provided configuration and health check settings. 121 | // Processes each location within the service, creates corresponding server pools, and ensures no duplicate services or locations exist. 122 | func (m *Manager) AddService(service config.Service, globalHealthCheck *config.HealthCheck) error { 123 | locations := make([]*LocationInfo, 0, len(service.Locations)) 124 | locationPaths := make(map[string]bool) 125 | for _, location := range service.Locations { 126 | if location.Path == "" { 127 | location.Path = "/" 128 | } 129 | if _, exist := locationPaths[location.Path]; exist { 130 | return ErrDuplicateLocation 131 | } 132 | if len(location.Backends) == 0 { 133 | return fmt.Errorf("service %s, location %s: no backends defined", 134 | service.Name, location.Path) 135 | } 136 | 137 | locationPaths[location.Path] = true 138 | serverPool, err := m.createServerPool(service, location, globalHealthCheck) 139 | if err != nil { 140 | return err 141 | } 142 | 143 | locations = append(locations, &LocationInfo{ 144 | Path: location.Path, 145 | Algorithm: algorithm.CreateAlgorithm(location.LoadBalancer), 146 | Rewrite: location.Rewrite, 147 | ServerPool: serverPool, 148 | }) 149 | } 150 | 151 | k := service.Name 152 | if k == "" { 153 | k = service.Host 154 | } 155 | if k == "" { 156 | return ErrNotDefined 157 | } 158 | if _, exist := m.services[k]; exist { 159 | return ErrServiceAlreadyExists 160 | } 161 | 162 | // Determine the health check configuration for the service. 163 | // Use the service-specific configuration if provided; otherwise, fallback to the global configuration. 164 | serviceHealthCheck := globalHealthCheck 165 | if service.HealthCheck != nil && service.HealthCheck.Type != "" { 166 | serviceHealthCheck = service.HealthCheck 167 | } 168 | 169 | m.mu.Lock() 170 | m.services[k] = &ServiceInfo{ 171 | Name: service.Name, 172 | Host: service.Host, 173 | Port: service.Port, 174 | TLS: service.TLS, 175 | HTTPRedirect: service.HTTPRedirect, 176 | RedirectPort: service.RedirectPort, 177 | HealthCheck: serviceHealthCheck, 178 | Locations: locations, 179 | Middleware: service.Middleware, 180 | LogName: service.LogName, 181 | LogOptions: service.LogOptions, 182 | Headers: service.Headers, 183 | } 184 | m.mu.Unlock() 185 | return nil 186 | } 187 | 188 | // GetService retrieves the service and location information based on the provided host, path, and port. 189 | // If hostOnly is true, it returns only the ServiceInfo without matching a specific location. 190 | func (m *Manager) GetService( 191 | host, path string, 192 | port int, 193 | hostOnly bool, 194 | ) (*ServiceInfo, *LocationInfo, error) { 195 | m.mu.RLock() 196 | defer m.mu.RUnlock() 197 | 198 | var matchedService *ServiceInfo 199 | for _, service := range m.services { 200 | if matchHost(service.Host, host) && service.Port == port { 201 | if hostOnly { 202 | return service, nil, nil 203 | } 204 | matchedService = service 205 | break 206 | } 207 | } 208 | if matchedService == nil { 209 | return nil, nil, fmt.Errorf("service not found for host %s", host) 210 | } 211 | 212 | var matchedLocation *LocationInfo 213 | var matchedLen int 214 | for _, location := range matchedService.Locations { 215 | if strings.HasPrefix(path, location.Path) && len(location.Path) > matchedLen { 216 | matchedLocation = location 217 | matchedLen = len(location.Path) 218 | } 219 | } 220 | 221 | if matchedLocation == nil { 222 | return nil, nil, fmt.Errorf("location not found for path %s", path) 223 | } 224 | return matchedService, matchedLocation, nil 225 | } 226 | 227 | // GetServiceByName retrieves a service based on its unique name. 228 | func (m *Manager) GetServiceByName(name string) *ServiceInfo { 229 | m.mu.RLock() 230 | defer m.mu.RUnlock() 231 | 232 | for _, service := range m.services { 233 | if service.Name == name { 234 | return service 235 | } 236 | } 237 | return nil 238 | } 239 | 240 | // GetServices returns a slice of all services managed by the Manager. 241 | func (m *Manager) GetServices() []*ServiceInfo { 242 | m.mu.RLock() 243 | defer m.mu.RUnlock() 244 | 245 | services := make([]*ServiceInfo, 0, len(m.services)) 246 | for _, service := range m.services { 247 | services = append(services, service) 248 | } 249 | return services 250 | } 251 | 252 | // AssignLogger assigns a logger to a specific service based on its name. 253 | func (m *Manager) AssignLogger(serviceName string, logger *zap.Logger) { 254 | m.mu.Lock() 255 | defer m.mu.Unlock() 256 | if svc, exists := m.services[serviceName]; exists { 257 | svc.Logger = logger 258 | } 259 | } 260 | 261 | // createServerPool initializes and configures a ServerPool for a given service location. 262 | // It sets up the load balancing algorithm and adds all backends associated with the location to the pool. 263 | func (m *Manager) createServerPool( 264 | svc config.Service, 265 | lc config.Location, 266 | serviceHealthCheck *config.HealthCheck, 267 | ) (*pool.ServerPool, error) { 268 | pm := m.pluginManager 269 | if pm != nil && svc.DisablePluginLoad { 270 | pm = nil 271 | } 272 | serverPool := pool.NewServerPool(&svc, pm, m.logger) 273 | serverPool.UpdateConfig(pool.PoolConfig{ 274 | Algorithm: lc.LoadBalancer, 275 | }) 276 | 277 | for _, backend := range lc.Backends { 278 | rc := pool.Route{ 279 | Path: lc.Path, // The path associated with the backend. 280 | RewriteURL: lc.Rewrite, // URL rewrite rules for the backend. 281 | Redirect: lc.Redirect, // Redirect settings if applicable. 282 | SkipTLSVerify: backend.SkipTLSVerify, // TLS verification settings for the backend. 283 | SNI: backend.ServerName, // SNI (Server Name Indication name) 284 | // allow http2 since most backends will support that so 285 | // if http2 is not explicitly set i config - http2 is allowed 286 | // if is set then use config value 287 | HTTP2: backend.HTTP2 == nil || *backend.HTTP2, 288 | } 289 | backendHealthCheck := serviceHealthCheck 290 | if backend.HealthCheck != nil { 291 | backendHealthCheck = backend.HealthCheck 292 | } 293 | if err := serverPool.AddBackend(backend, rc, backendHealthCheck); err != nil { 294 | return nil, err 295 | } 296 | } 297 | return serverPool, nil 298 | } 299 | 300 | // matchHost determines if the provided host matches the given pattern. 301 | // Supports wildcard patterns, allowing for flexible host matching. 302 | func matchHost(pattern, host string) bool { 303 | if !strings.Contains(pattern, "*") { 304 | return strings.EqualFold(pattern, host) 305 | } 306 | if pattern == "*" { 307 | return true 308 | } 309 | // Patterns starting with "*." are treated as wildcard subdomains. 310 | if strings.HasPrefix(pattern, "*.") { 311 | suffix := pattern[1:] // Remove the asterisk. 312 | return strings.HasSuffix(strings.ToLower(host), strings.ToLower(suffix)) 313 | } 314 | return false 315 | } 316 | -------------------------------------------------------------------------------- /internal/crypto/certmanager.go: -------------------------------------------------------------------------------- 1 | package certmanager 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "crypto/x509" 7 | "fmt" 8 | "sync" 9 | "time" 10 | 11 | "go.uber.org/zap" 12 | "golang.org/x/crypto/acme/autocert" 13 | 14 | "github.com/unkn0wn-root/terraster/internal/config" 15 | "github.com/wneessen/go-mail" 16 | ) 17 | 18 | type NoopAlerter struct{} 19 | 20 | func (n *NoopAlerter) Alert(domain string, expiry time.Time) error { 21 | return nil 22 | } 23 | 24 | // Alerter defines the interface for certificate expiration alerting 25 | type Alerter interface { 26 | Alert(domain string, expiry time.Time) error 27 | } 28 | 29 | type EmailAlerter struct { 30 | client *mail.Client 31 | fromEmail string 32 | toEmails []string 33 | logger *zap.Logger 34 | } 35 | 36 | func NewEmailAlerter(cfg AlertingConfig, logger *zap.Logger) (*EmailAlerter, error) { 37 | client, err := mail.NewClient(cfg.SMTPHost, 38 | mail.WithPort(cfg.SMTPPort), 39 | mail.WithSMTPAuth(mail.SMTPAuthPlain), 40 | mail.WithUsername(cfg.FromEmail), 41 | mail.WithPassword(cfg.FromPass), 42 | mail.WithTLSPolicy(mail.TLSMandatory), 43 | ) 44 | if err != nil { 45 | return nil, fmt.Errorf("failed to create mail client: %w", err) 46 | } 47 | 48 | return &EmailAlerter{ 49 | client: client, 50 | fromEmail: cfg.FromEmail, 51 | toEmails: cfg.ToEmails, 52 | logger: logger, 53 | }, nil 54 | } 55 | 56 | type CertCache interface { 57 | Get(ctx context.Context, key string) ([]byte, error) 58 | Put(ctx context.Context, key string, data []byte) error 59 | Delete(ctx context.Context, key string) error 60 | } 61 | 62 | type InMemoryCertCache struct { 63 | mu sync.RWMutex 64 | cache map[string][]byte 65 | } 66 | 67 | func NewInMemoryCertCache() *InMemoryCertCache { 68 | return &InMemoryCertCache{ 69 | cache: make(map[string][]byte), 70 | } 71 | } 72 | 73 | func (c *InMemoryCertCache) Get(ctx context.Context, key string) ([]byte, error) { 74 | c.mu.RLock() 75 | defer c.mu.RUnlock() 76 | data, exists := c.cache[key] 77 | if !exists { 78 | return nil, fmt.Errorf("no cache entry for %s", key) 79 | } 80 | return data, nil 81 | } 82 | 83 | func (c *InMemoryCertCache) Put(ctx context.Context, key string, data []byte) error { 84 | c.mu.Lock() 85 | defer c.mu.Unlock() 86 | c.cache[key] = data 87 | return nil 88 | } 89 | 90 | func (c *InMemoryCertCache) Delete(ctx context.Context, key string) error { 91 | c.mu.Lock() 92 | defer c.mu.Unlock() 93 | delete(c.cache, key) 94 | return nil 95 | } 96 | 97 | type certStatus struct { 98 | exists bool 99 | isValid bool 100 | expiresAt time.Time 101 | error error 102 | } 103 | 104 | type CertManager struct { 105 | manager *autocert.Manager 106 | cache CertCache 107 | domains []string 108 | certDir string 109 | certs sync.Map // map[string]*tls.Certificate 110 | logger *zap.Logger 111 | config *config.Terraster 112 | alerter Alerter 113 | checkInterval time.Duration 114 | expirationThresh time.Duration 115 | stopChan chan struct{} 116 | } 117 | 118 | type AlertingConfig struct { 119 | Enabled bool 120 | SMTPHost string 121 | SMTPPort int 122 | FromEmail string 123 | FromPass string 124 | ToEmails []string 125 | } 126 | 127 | func NewCertManager( 128 | domains []string, 129 | certDir string, 130 | cache CertCache, 131 | ctx context.Context, 132 | cfg *config.Terraster, 133 | alerting AlertingConfig, 134 | logger *zap.Logger, 135 | ) (*CertManager, error) { 136 | var alerter Alerter = &NoopAlerter{} 137 | if alerting.Enabled { 138 | emailAlerter, err := NewEmailAlerter(alerting, logger) 139 | if err != nil { 140 | return nil, fmt.Errorf("failed to create email alerter: %w", err) 141 | } 142 | alerter = emailAlerter 143 | } 144 | 145 | checkInterval := cfg.CertManager.CheckInterval 146 | if checkInterval == 0 { 147 | checkInterval = 24 * time.Hour // 24 hours 148 | } 149 | 150 | expirationThresh := cfg.CertManager.ExpirationThresh 151 | if expirationThresh == 0 { 152 | expirationThresh = 30 * 24 * time.Hour // 30 days 153 | } 154 | 155 | cm := &CertManager{ 156 | domains: domains, 157 | certDir: certDir, 158 | cache: cache, 159 | alerter: alerter, 160 | logger: logger, 161 | config: cfg, 162 | stopChan: make(chan struct{}), 163 | } 164 | 165 | cm.checkInterval = checkInterval 166 | cm.expirationThresh = expirationThresh 167 | 168 | cm.manager = &autocert.Manager{ 169 | Cache: cache, 170 | Prompt: autocert.AcceptTOS, 171 | HostPolicy: cm.hostPolicy, 172 | } 173 | 174 | // Load local certificates during initialization 175 | cm.loadLocalCertificates() 176 | 177 | // Start periodic certificate check 178 | go cm.periodicCertCheck(ctx) 179 | 180 | return cm, nil 181 | } 182 | 183 | func NewAlertingConfig(cfg *config.Terraster) AlertingConfig { 184 | return AlertingConfig{ 185 | Enabled: cfg.CertManager.Alerting.Enabled, 186 | SMTPHost: cfg.CertManager.Alerting.SMTPHost, 187 | SMTPPort: cfg.CertManager.Alerting.SMTPPort, 188 | FromEmail: cfg.CertManager.Alerting.FromEmail, 189 | FromPass: cfg.CertManager.Alerting.FromPass, 190 | ToEmails: cfg.CertManager.Alerting.ToEmails, 191 | } 192 | } 193 | 194 | // hostPolicy ensures that only configured domains are allowed. 195 | func (cm *CertManager) hostPolicy(ctx context.Context, host string) error { 196 | for _, domain := range cm.domains { 197 | if host == domain { 198 | return nil 199 | } 200 | } 201 | 202 | return fmt.Errorf("host %q not configured", host) 203 | } 204 | 205 | // loadLocalCertificates loads local TLS certificates and stores them in the certs map. 206 | func (cm *CertManager) loadLocalCertificates() { 207 | for _, svc := range cm.config.Services { 208 | if svc.TLS != nil && svc.TLS.Enabled { 209 | cert, err := tls.LoadX509KeyPair(svc.TLS.CertFile, svc.TLS.KeyFile) 210 | if err != nil { 211 | cm.logger.Warn("Failed to load local certificate. Will use autocert", 212 | zap.String("host", svc.Host), 213 | zap.Error(err)) 214 | continue 215 | } 216 | cm.certs.Store(svc.Host, &cert) 217 | cm.logger.Info("Loaded local certificate", zap.String("host", svc.Host)) 218 | } 219 | } 220 | } 221 | 222 | // GetCertificate retrieves the TLS certificate for the given client hello. 223 | func (cm *CertManager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 224 | // Try local certificate first (if any) 225 | if cert, ok := cm.certs.Load(hello.ServerName); ok { 226 | return cert.(*tls.Certificate), nil 227 | } 228 | 229 | // If not found, fetch using autocert - slow path 230 | // You should own domain and configure Let's Encrypt to accept fetching certs 231 | cert, err := cm.manager.GetCertificate(hello) 232 | if err != nil { 233 | return nil, err 234 | } 235 | 236 | cm.certs.Store(hello.ServerName, cert) 237 | 238 | return cert, nil 239 | } 240 | 241 | // periodicCertCheck periodically checks for certificate expirations. 242 | func (cm *CertManager) periodicCertCheck(ctx context.Context) { 243 | if cm.domains == nil || len(cm.domains) == 0 { 244 | cm.logger.Warn("No domains configured for certificate check. Periodic check will not run.") 245 | return 246 | } 247 | 248 | ticker := time.NewTicker(cm.checkInterval) 249 | defer ticker.Stop() 250 | 251 | for { 252 | select { 253 | case <-ticker.C: 254 | cm.checkCerts() 255 | case <-cm.stopChan: 256 | cm.logger.Info("Periodic certificate check stopped") 257 | return 258 | case <-ctx.Done(): 259 | cm.logger.Info("Context cancelled. Periodic certificate check stopped") 260 | return 261 | } 262 | } 263 | } 264 | 265 | // validateCertificate validates the given certificate. 266 | func (cm *CertManager) validateCertificate(cert *tls.Certificate) certStatus { 267 | if cert == nil { 268 | return certStatus{ 269 | exists: false, 270 | error: fmt.Errorf("certificate is nil"), 271 | isValid: false, 272 | } 273 | } 274 | 275 | // Parse the Leaf certificate if it's not already parsed 276 | if cert.Leaf == nil && len(cert.Certificate) > 0 { 277 | leaf, err := x509.ParseCertificate(cert.Certificate[0]) 278 | if err != nil { 279 | return certStatus{ 280 | exists: true, 281 | error: fmt.Errorf("failed to parse certificate: %w", err), 282 | isValid: false, 283 | } 284 | } 285 | cert.Leaf = leaf 286 | } 287 | 288 | now := time.Now() 289 | status := certStatus{ 290 | exists: true, 291 | isValid: now.Before(cert.Leaf.NotAfter) && now.After(cert.Leaf.NotBefore), 292 | expiresAt: cert.Leaf.NotAfter, 293 | } 294 | 295 | return status 296 | } 297 | 298 | // checkCerts checks the certificates for expiration and sends alerts if necessary. 299 | func (cm *CertManager) checkCerts() { 300 | now := time.Now() 301 | var checkErrors []error 302 | 303 | cm.certs.Range(func(key, value interface{}) bool { 304 | domain := key.(string) 305 | cert := value.(*tls.Certificate) 306 | 307 | status := cm.validateCertificate(cert) 308 | if !status.exists { 309 | cm.logger.Warn("No certificate found for domain", 310 | zap.String("domain", domain), 311 | zap.Error(status.error)) 312 | checkErrors = append(checkErrors, fmt.Errorf("domain %s: %w", domain, status.error)) 313 | return true 314 | } 315 | 316 | if status.error != nil { 317 | cm.logger.Error("Certificate validation failed", 318 | zap.String("domain", domain), 319 | zap.Error(status.error)) 320 | checkErrors = append(checkErrors, fmt.Errorf("domain %s: %w", domain, status.error)) 321 | return true 322 | } 323 | 324 | if !status.isValid { 325 | cm.logger.Error("Invalid certificate", 326 | zap.String("domain", domain), 327 | zap.Time("expires_at", status.expiresAt)) 328 | return true 329 | } 330 | 331 | timeLeft := status.expiresAt.Sub(now) 332 | if timeLeft < cm.expirationThresh { 333 | daysLeft := int(timeLeft.Hours() / 24) 334 | cm.logger.Warn("Certificate approaching expiration", 335 | zap.String("domain", domain), 336 | zap.Time("expires_at", status.expiresAt), 337 | zap.Int("time_left", daysLeft)) 338 | 339 | if err := cm.alerter.Alert(domain, status.expiresAt); err != nil { 340 | cm.logger.Error("Failed to send alert", 341 | zap.String("domain", domain), 342 | zap.Error(err)) 343 | } 344 | 345 | return true 346 | } 347 | 348 | cm.logger.Debug("Certificate valid", 349 | zap.String("domain", domain), 350 | zap.Time("expires_at", status.expiresAt), 351 | zap.String("time_left", formatDuration(timeLeft))) 352 | 353 | return true 354 | }) 355 | 356 | if len(checkErrors) > 0 { 357 | cm.logger.Error("Certificate check completed with errors", 358 | zap.Int("error_count", len(checkErrors)), 359 | zap.Errors("errors", checkErrors)) 360 | } 361 | } 362 | 363 | func (e *EmailAlerter) Alert(domain string, expiry time.Time) error { 364 | msg := mail.NewMsg() 365 | if err := msg.From(e.fromEmail); err != nil { 366 | return fmt.Errorf("failed to set From address: %w", err) 367 | } 368 | 369 | if err := msg.To(e.toEmails...); err != nil { 370 | return fmt.Errorf("failed to set To address: %w", err) 371 | } 372 | 373 | msg.Subject(fmt.Sprintf("Certificate Expiration Warning - %s", domain)) 374 | msg.SetBodyString(mail.TypeTextPlain, fmt.Sprintf( 375 | "The TLS certificate for %s will expire on %s.\n\nPlease renew the certificate before expiration to prevent service interruption.", 376 | domain, 377 | expiry.Format(time.RFC3339), 378 | )) 379 | 380 | if err := e.client.DialAndSend(msg); err != nil { 381 | return fmt.Errorf("failed to send alert email: %w", err) 382 | } 383 | 384 | return nil 385 | } 386 | 387 | func (cm *CertManager) Stop() { 388 | close(cm.stopChan) 389 | } 390 | 391 | // formatDuration formats time to more human readable string 392 | func formatDuration(d time.Duration) string { 393 | d = d.Round(time.Hour) // round to nearest hour 394 | days := int(d.Hours() / 24) 395 | hours := int(d.Hours()) % 24 396 | 397 | switch { 398 | case days > 0: 399 | return fmt.Sprintf("%dd %dh", days, hours) 400 | case hours > 0: 401 | return fmt.Sprintf("%dh", hours) 402 | default: 403 | return "less than 1h" 404 | } 405 | } 406 | --------------------------------------------------------------------------------