├── internal
├── middlewares
│ ├── security
│ │ ├── trackerhistory.go
│ │ ├── routing_anomaly_score_test.go
│ │ └── routing_anomaly_score.go
│ ├── addheader.go
│ ├── middleware.go
│ ├── omit_headers.go
│ ├── gzip.go
│ ├── omit_headers_test.go
│ ├── sizelimiter.go
│ ├── timeout.go
│ ├── logging.go
│ ├── minify.go
│ ├── sizelimiter_test.go
│ ├── timeout_test.go
│ ├── cache.go
│ ├── openapi.go
│ ├── otel.go
│ ├── logging_test.go
│ ├── minify_test.go
│ ├── ratelimit.go
│ ├── gzip_test.go
│ ├── ratelimit_test.go
│ ├── openapi_test.go
│ ├── cache_test.go
│ └── responsecapture.go
├── contextvalues
│ ├── version.go
│ └── tracer.go
├── handlers
│ ├── proxy.go
│ ├── files.go
│ ├── files_test.go
│ ├── balancer.go
│ └── balancer_test.go
└── config
│ ├── config_test.go
│ └── config.go
├── pkg
├── cron
│ ├── macros.go
│ ├── README.md
│ ├── cron.go
│ ├── schedule.go
│ └── cron_test.go
├── multimux
│ ├── multimux.go
│ └── multimux_test.go
├── tracker
│ ├── tracker.go
│ └── tracker_test.go
├── pathgraph
│ └── pathgraph.go
└── monitor
│ ├── monitor.go
│ └── monitor_test.go
├── .github
└── workflows
│ └── go-tests.yml
├── dev
├── server
│ └── main.go
├── otel
│ ├── otel-collector-config.yaml
│ └── docker-compose.yaml
└── client
│ └── main.go
├── cmd
├── main.go
├── openapi.yaml
└── config.yaml
├── checks.go
├── LICENCE
├── gatego.go
├── go.mod
├── server.go
├── handler.go
├── otel.go
├── config-schema.json
├── go.sum
└── README.md
/internal/middlewares/security/trackerhistory.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | type trackerHistory struct {
4 | jumpsCount int
5 | jumpsScoreSum float64
6 | }
7 |
8 | func (th trackerHistory) Avg() float64 {
9 | return th.jumpsScoreSum / float64(th.jumpsCount)
10 | }
11 |
--------------------------------------------------------------------------------
/pkg/cron/macros.go:
--------------------------------------------------------------------------------
1 | package cron
2 |
3 | const (
4 | Yearly = "@yearly"
5 | Annually = "@annually"
6 | Monthly = "@monthly"
7 | Weekly = "@weekly"
8 | Daily = "@daily"
9 | Midnight = "@midnight"
10 | Hourly = "@hourly"
11 | Minutely = "@minutely"
12 | )
13 |
14 | var macros = map[string]string{
15 | Yearly: "0 0 1 1 *",
16 | Annually: "0 0 1 1 *",
17 | Monthly: "0 0 1 * *",
18 | Weekly: "0 0 * * 0",
19 | Daily: "0 0 * * *",
20 | Midnight: "0 0 * * *",
21 | Hourly: "0 * * * *",
22 | Minutely: "* * * * *",
23 | }
24 |
--------------------------------------------------------------------------------
/.github/workflows/go-tests.yml:
--------------------------------------------------------------------------------
1 | # This workflow will build a golang project
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
3 |
4 | name: Tests
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | jobs:
13 |
14 | build:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v4
18 |
19 | - name: Set up Go
20 | uses: actions/setup-go@v4
21 | with:
22 | go-version: '1.22'
23 |
24 | - name: Test
25 | run: go test -v ./...
26 |
--------------------------------------------------------------------------------
/internal/middlewares/addheader.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 |
7 | "go.opentelemetry.io/otel/trace"
8 | )
9 |
10 | func NewAddHeadersMiddleware(headers map[string]string) Middleware {
11 | return func(next http.Handler) http.Handler {
12 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
13 | span := trace.SpanFromContext(r.Context())
14 |
15 | for header, value := range headers {
16 | r.Header.Set(header, value)
17 | span.AddEvent(fmt.Sprintf("Added header %s to request", header))
18 | }
19 | next.ServeHTTP(w, r)
20 | })
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/dev/server/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "net/http"
7 | )
8 |
9 | func main() {
10 | server := http.NewServeMux()
11 |
12 | server.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
13 | fmt.Printf("%s, %s, %s, %v\n", r.Proto, r.Host, r.URL, r.Header)
14 | w.Header().Set("Content-Type", "application/json")
15 | w.WriteHeader(200)
16 | w.Write([]byte(`{ "hello" : 1.5 , "good" : true }`))
17 | })
18 |
19 | fmt.Println("Running server at '127.0.0.1:4007'")
20 |
21 | err := http.ListenAndServe("127.0.0.1:4007", server)
22 | if err != nil {
23 | log.Fatal(err)
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/internal/contextvalues/version.go:
--------------------------------------------------------------------------------
1 | package contextvalues
2 |
3 | import "context"
4 |
5 | // Define a custom type for context keys to avoid collisions
6 | type versionKeyType string
7 |
8 | var versionKey = versionKeyType("version")
9 |
10 | // Add version to context
11 | func AddVersionToContext(ctx context.Context, version string) context.Context {
12 | return context.WithValue(ctx, versionKey, version)
13 | }
14 |
15 | // Retrieve version from context
16 | func VersionFromContext(ctx context.Context) string {
17 | version := ""
18 | if v, ok := ctx.Value(versionKey).(string); ok {
19 | version = v
20 | }
21 | return version
22 | }
23 |
--------------------------------------------------------------------------------
/cmd/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "log"
6 | "os"
7 | "os/signal"
8 |
9 | "github.com/hvuhsg/gatego"
10 | "github.com/hvuhsg/gatego/internal/config"
11 | )
12 |
13 | const version = "0.0.1"
14 |
15 | func main() {
16 | // Handle SIGINT (CTRL+C) gracefully.
17 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
18 | defer stop()
19 |
20 | config, err := config.ParseConfig("config.yaml", version)
21 | if err != nil {
22 | log.Fatal(err)
23 | }
24 |
25 | log.Default().Println("Config loaded successfully")
26 |
27 | server := gatego.New(ctx, config, version)
28 |
29 | err = server.Run()
30 | if err != nil {
31 | log.Fatalln(err)
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/internal/contextvalues/tracer.go:
--------------------------------------------------------------------------------
1 | package contextvalues
2 |
3 | import (
4 | "context"
5 |
6 | "go.opentelemetry.io/otel/trace"
7 | )
8 |
9 | // Define a custom type for context keys to avoid collisions
10 | type tracerKeyType string
11 |
12 | var tracerKey = tracerKeyType("tracer")
13 |
14 | // Add tracer to context
15 | func AddTracerToContext(ctx context.Context, tracer trace.Tracer) context.Context {
16 | return context.WithValue(ctx, tracerKey, tracer)
17 | }
18 |
19 | // Retrieve tracer from context
20 | func TracerFromContext(ctx context.Context) trace.Tracer {
21 | var tracer trace.Tracer = nil
22 | if t, ok := ctx.Value(tracerKey).(trace.Tracer); ok {
23 | tracer = t
24 | }
25 | return tracer
26 | }
27 |
--------------------------------------------------------------------------------
/dev/otel/otel-collector-config.yaml:
--------------------------------------------------------------------------------
1 | receivers:
2 | otlp:
3 | protocols:
4 | grpc:
5 | endpoint: 0.0.0.0:4317
6 | http:
7 | endpoint: 0.0.0.0:4318
8 |
9 | processors:
10 | batch:
11 | timeout: 1s
12 | send_batch_size: 1024
13 |
14 | memory_limiter:
15 | check_interval: 1s
16 | limit_mib: 1000
17 | spike_limit_mib: 200
18 |
19 | exporters:
20 | otlp:
21 | endpoint: "jaeger:4317"
22 | tls:
23 | insecure: true
24 |
25 | debug:
26 | verbosity: detailed
27 |
28 | extensions:
29 | health_check:
30 | endpoint: 0.0.0.0:13133
31 |
32 | service:
33 | extensions: [health_check]
34 | pipelines:
35 | traces:
36 | receivers: [otlp]
37 | processors: [memory_limiter, batch]
38 | exporters: [otlp, debug]
--------------------------------------------------------------------------------
/checks.go:
--------------------------------------------------------------------------------
1 | package gatego
2 |
3 | import (
4 | "github.com/hvuhsg/gatego/internal/config"
5 | "github.com/hvuhsg/gatego/pkg/monitor"
6 | )
7 |
8 | func createMonitorChecks(services []config.Service) []monitor.Check {
9 | checks := make([]monitor.Check, 0)
10 | for _, service := range services {
11 | for _, path := range service.Paths {
12 | for _, checkConfig := range path.Checks {
13 | check := monitor.Check{
14 | Name: checkConfig.Name,
15 | Cron: checkConfig.Cron,
16 | URL: checkConfig.URL,
17 | Method: checkConfig.Method,
18 | Timeout: checkConfig.Timeout,
19 | Headers: checkConfig.Headers,
20 | OnFailure: checkConfig.OnFailure,
21 | }
22 |
23 | checks = append(checks, check)
24 | }
25 | }
26 | }
27 |
28 | return checks
29 | }
30 |
--------------------------------------------------------------------------------
/dev/otel/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | services:
2 | # Jaeger
3 | jaeger:
4 | image: jaegertracing/all-in-one:latest
5 | ports:
6 | - "16686:16686" # Jaeger UI
7 | - "14250:14250" # Model used by collector
8 | environment:
9 | - COLLECTOR_OTLP_ENABLED=true
10 |
11 | # OpenTelemetry Collector
12 | otel-collector:
13 | image: otel/opentelemetry-collector-contrib:latest
14 | command: ["--config=/etc/otel-collector-config.yaml"]
15 | volumes:
16 | - ./otel-collector-config.yaml:/etc/otel-collector-config.yaml
17 | ports:
18 | - "4317:4317" # OTLP gRPC receiver
19 | - "4318:4318" # OTLP http receiver
20 | - "8888:8888" # Prometheus metrics exposed by the collector
21 | - "8889:8889" # Prometheus exporter metrics
22 | - "13133:13133" # Health check extension
23 | depends_on:
24 | - jaeger
--------------------------------------------------------------------------------
/internal/middlewares/middleware.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import "net/http"
4 |
5 | type Middleware func(http.Handler) http.Handler
6 |
7 | type HandlerWithMiddleware struct {
8 | finalHandler http.Handler
9 | middlewares []Middleware
10 | }
11 |
12 | func NewHandlerWithMiddleware(handler http.Handler) *HandlerWithMiddleware {
13 | return &HandlerWithMiddleware{
14 | finalHandler: handler,
15 | middlewares: []Middleware{},
16 | }
17 | }
18 |
19 | func (h *HandlerWithMiddleware) Add(middleware Middleware) {
20 | h.middlewares = append(h.middlewares, middleware)
21 | }
22 |
23 | func (h *HandlerWithMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
24 | // Chain the middlewares around the final handler
25 | handler := h.finalHandler
26 | for i := len(h.middlewares) - 1; i >= 0; i-- {
27 | handler = h.middlewares[i](handler)
28 | }
29 | handler.ServeHTTP(w, r)
30 | }
31 |
--------------------------------------------------------------------------------
/internal/middlewares/omit_headers.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 |
7 | "go.opentelemetry.io/otel/trace"
8 | )
9 |
10 | // OmitHeaders middleware removes specified headers from the response to enhance security.
11 | func NewOmitHeadersMiddleware(headersToOmit []string) func(http.Handler) http.Handler {
12 | return func(next http.Handler) http.Handler {
13 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
14 | span := trace.SpanFromContext(r.Context())
15 |
16 | rc := NewRecorder()
17 | next.ServeHTTP(rc, r)
18 |
19 | // Omit headers from response
20 | for _, header := range headersToOmit {
21 | if rc.Result().Header.Get(header) != "" {
22 | rc.Result().Header.Del(header)
23 | span.AddEvent(fmt.Sprintf("Removed response header %s", header))
24 | }
25 | }
26 |
27 | rc.WriteHeadersTo(w)
28 | w.WriteHeader(rc.Result().StatusCode)
29 | w.Write(rc.Body.Bytes())
30 | })
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/internal/middlewares/gzip.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "compress/gzip"
5 | "net/http"
6 | "strings"
7 | )
8 |
9 | // GzipMiddleware compresses the response using gzip if the client supports it
10 | func GzipMiddleware(next http.Handler) http.Handler {
11 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
12 | // Check if the client accepts gzip encoding
13 | if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
14 | // Client doesn't support gzip, serve the next handler
15 | next.ServeHTTP(w, r)
16 | return
17 | }
18 |
19 | // Create a gzip.Writer
20 | gzipWriter := gzip.NewWriter(w)
21 | defer gzipWriter.Close()
22 |
23 | // Serve the next handler, writing the response into the ResponseCapture
24 | rc := NewRecorder()
25 | next.ServeHTTP(rc, r)
26 |
27 | rc.WriteHeadersTo(w)
28 |
29 | w.Header().Del("Content-Length")
30 | w.Header().Set("Content-Encoding", "gzip") // Set Content-Encoding header
31 |
32 | w.WriteHeader(rc.Result().StatusCode)
33 |
34 | gzipWriter.Write(rc.Body.Bytes())
35 | })
36 | }
37 |
--------------------------------------------------------------------------------
/internal/handlers/proxy.go:
--------------------------------------------------------------------------------
1 | package handlers
2 |
3 | import (
4 | "net/http"
5 | "net/http/httputil"
6 | "net/url"
7 |
8 | "github.com/hvuhsg/gatego/internal/config"
9 | "github.com/hvuhsg/gatego/internal/contextvalues"
10 | semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
11 | )
12 |
13 | type Proxy struct {
14 | proxy *httputil.ReverseProxy
15 | }
16 |
17 | func NewProxy(service config.Service, path config.Path) (Proxy, error) {
18 | serviceURL, err := url.Parse(*path.Destination)
19 | if err != nil {
20 | return Proxy{}, err
21 | }
22 |
23 | proxy := httputil.NewSingleHostReverseProxy(serviceURL)
24 |
25 | server := Proxy{proxy: proxy}
26 | return server, nil
27 | }
28 |
29 | func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
30 | tracer := contextvalues.TracerFromContext(r.Context())
31 | if tracer != nil {
32 | ctx, span := tracer.Start(r.Context(), "request.upstream")
33 | span.SetAttributes(semconv.HTTPServerAttributesFromHTTPRequest(r.Host, r.URL.Path, r)...)
34 | r = r.WithContext(ctx)
35 | defer span.End()
36 | }
37 | p.proxy.ServeHTTP(w, r)
38 | }
39 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 yehoyada
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/cmd/openapi.yaml:
--------------------------------------------------------------------------------
1 | openapi: 3.1.0
2 |
3 | info:
4 | title: Simple API
5 | version: 1.0.0
6 | description: A simple API with one root path and one query parameter
7 |
8 | paths:
9 | /:
10 | post:
11 | summary: Root endpoint
12 | description: Returns a greeting message
13 | parameters:
14 | - in: query
15 | name: name
16 | schema:
17 | type: string
18 | maxLength: 10
19 | required: true
20 | description: Name of the person to greet
21 | responses:
22 | '200':
23 | description: Successful response
24 | content:
25 | application/json:
26 | schema:
27 | type: object
28 | properties:
29 | message:
30 | type: string
31 | example: "Hello, World!"
32 | '400':
33 | description: Bad request
34 | content:
35 | application/json:
36 | schema:
37 | type: object
38 | properties:
39 | error:
40 | type: string
41 | example: "Invalid query parameter"
--------------------------------------------------------------------------------
/pkg/multimux/multimux.go:
--------------------------------------------------------------------------------
1 | // This package implement a mutil-mux an http handler
2 | // that acts as seprate http.ServeMux for each registred host
3 |
4 | package multimux
5 |
6 | import (
7 | "net/http"
8 | "strings"
9 | "sync"
10 | )
11 |
12 | type MultiMux struct {
13 | Hosts sync.Map
14 | }
15 |
16 | func NewMultiMux() *MultiMux {
17 | return &MultiMux{Hosts: sync.Map{}}
18 | }
19 |
20 | func (mm *MultiMux) RegisterHandler(host string, pattern string, handler http.Handler) {
21 | cleanedHost := cleanHost(host)
22 | muxAny, _ := mm.Hosts.LoadOrStore(cleanedHost, http.NewServeMux())
23 | mux := muxAny.(*http.ServeMux)
24 |
25 | cleanedPattern := strings.ToLower(pattern)
26 |
27 | mux.Handle(cleanedPattern, handler)
28 | }
29 |
30 | func (mm *MultiMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
31 | host := r.Host
32 | cleanedHost := cleanHost(host)
33 | muxAny, exists := mm.Hosts.Load(cleanedHost)
34 |
35 | if !exists {
36 | w.WriteHeader(http.StatusNotFound)
37 | return
38 | }
39 |
40 | mux := muxAny.(*http.ServeMux)
41 | mux.ServeHTTP(w, r)
42 | }
43 |
44 | func cleanHost(domain string) string {
45 | return removePort(strings.ToLower(domain))
46 | }
47 |
48 | func removePort(addr string) string {
49 | if i := strings.LastIndex(addr, ":"); i != -1 {
50 | return addr[:i]
51 | }
52 | return addr
53 | }
54 |
--------------------------------------------------------------------------------
/internal/middlewares/omit_headers_test.go:
--------------------------------------------------------------------------------
1 | package middlewares_test
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "testing"
7 |
8 | "github.com/hvuhsg/gatego/internal/middlewares"
9 | )
10 |
11 | // TestOmitHeadersMiddleware_OmitResponseHeaders tests that headers are omitted from the response
12 | func TestOmitHeadersMiddleware_OmitResponseHeaders(t *testing.T) {
13 | nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
14 | w.Header().Set("Authorization", "Bearer some-secret-token")
15 | w.Header().Set("X-API-Key", "secret-api-key")
16 | w.WriteHeader(http.StatusOK)
17 | w.Write([]byte("OK"))
18 | })
19 |
20 | headers := []string{"Authorization", "X-API-Key"}
21 | handler := middlewares.NewOmitHeadersMiddleware(headers)(nextHandler)
22 |
23 | req := httptest.NewRequest(http.MethodGet, "/", nil)
24 |
25 | rr := httptest.NewRecorder()
26 | handler.ServeHTTP(rr, req)
27 |
28 | if rr.Header().Get("Authorization") != "" {
29 | t.Errorf("expected 'Authorization' header to be omitted, got %s", rr.Header().Get("Authorization"))
30 | }
31 | if rr.Header().Get("X-API-Key") != "" {
32 | t.Errorf("expected 'X-API-Key' header to be omitted, got %s", rr.Header().Get("X-API-Key"))
33 | }
34 |
35 | if rr.Body.String() != "OK" {
36 | t.Errorf("expected 'OK', got %s", rr.Body.String())
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/internal/handlers/files.go:
--------------------------------------------------------------------------------
1 | package handlers
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "path"
7 | "strings"
8 | )
9 |
10 | type Files struct {
11 | basePath string
12 | handler http.Handler
13 | }
14 |
15 | func NewFiles(dirPath string, basePath string) Files {
16 | return Files{handler: http.FileServer(http.Dir(dirPath)), basePath: basePath}
17 | }
18 |
19 | func (f Files) ServeHTTP(w http.ResponseWriter, r *http.Request) {
20 | cleanedPath, err := removeBaseURLPath(f.basePath, r.URL.Path)
21 | if err == nil {
22 | r.URL.Path = cleanedPath
23 | }
24 |
25 | f.handler.ServeHTTP(w, r)
26 | }
27 |
28 | func removeBaseURLPath(basePath, fullPath string) (string, error) {
29 | // Ensure paths start with "/"
30 | basePath = "/" + strings.Trim(basePath, "/")
31 | fullPath = "/" + strings.Trim(fullPath, "/")
32 |
33 | // Normalize paths
34 | basePath = path.Clean(basePath)
35 | fullPath = path.Clean(fullPath)
36 |
37 | // Check if the full path starts with the base path
38 | if !strings.HasPrefix(fullPath, basePath) {
39 | return "", fmt.Errorf("full path %s is not in base path %s", fullPath, basePath)
40 | }
41 |
42 | // Remove the base path
43 | relPath := strings.TrimPrefix(fullPath, basePath)
44 |
45 | // Ensure the relative path starts with "/"
46 | relPath = "/" + strings.TrimPrefix(relPath, "/")
47 |
48 | return relPath, nil
49 | }
50 |
--------------------------------------------------------------------------------
/internal/middlewares/sizelimiter.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io"
7 | "net/http"
8 | )
9 |
10 | // NewRequestSizeLimitMiddleware limits the size of the request body to the specified limit in bytes.
11 | func NewRequestSizeLimitMiddleware(maxSize uint64) func(http.Handler) http.Handler {
12 | return func(next http.Handler) http.Handler {
13 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
14 | // Create a buffer to read the body
15 | buf := new(bytes.Buffer)
16 |
17 | // Use io.LimitReader to limit the size of the request body
18 | limitedReader := io.LimitReader(r.Body, int64(maxSize+1)) // Allow one extra byte for overflow detection
19 | _, err := io.Copy(buf, limitedReader) // Copy the limited input into the buffer
20 |
21 | // Check for errors
22 | if err != nil {
23 | http.Error(w, "Error reading request body", http.StatusInternalServerError)
24 | return
25 | }
26 |
27 | // Check if we exceeded the maximum size
28 | if buf.Len() > int(maxSize) {
29 | http.Error(w, fmt.Sprintf("Request body too large. Maximum allowed size is %d bytes.", maxSize), http.StatusRequestEntityTooLarge)
30 | return
31 | }
32 |
33 | // Restore the request body for further processing
34 | r.Body = io.NopCloser(bytes.NewReader(buf.Bytes()))
35 |
36 | // Proceed to the next handler
37 | next.ServeHTTP(w, r)
38 | })
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/dev/client/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "fmt"
7 | "io"
8 | "log"
9 | "net/http"
10 | )
11 |
12 | func sendRequest() *http.Response {
13 | // Sample data to send in JSON format
14 | data := map[string]interface{}{
15 | "key1": "value1",
16 | "key2": "value2",
17 | "key3": 123,
18 | }
19 |
20 | // Convert the data to JSON
21 | jsonData, err := json.Marshal(data)
22 | if err != nil {
23 | log.Fatal("Error marshaling JSON:", err)
24 | }
25 |
26 | // Create a new POST request with the JSON payload
27 | req, err := http.NewRequest(http.MethodPost, "http://localhost:8004/?name=yoyo", bytes.NewBuffer(jsonData))
28 | if err != nil {
29 | log.Fatal("Error creating request:", err)
30 | }
31 |
32 | // Set the appropriate Content-Type header for JSON
33 | req.Header.Set("Content-Type", "application/json")
34 |
35 | // Send the POST request
36 | client := http.DefaultClient
37 | response, err := client.Do(req)
38 | if err != nil {
39 | log.Fatal("Error sending request:", err)
40 | }
41 |
42 | return response
43 | }
44 | func main() {
45 | resp := sendRequest()
46 | defer resp.Body.Close() // Always defer closing the response body
47 |
48 | // Check the response status code
49 | if resp.StatusCode > 299 {
50 | log.Printf("Error: received status code %d", resp.StatusCode)
51 | }
52 |
53 | fmt.Println(resp)
54 |
55 | // Read the response body
56 | data, err := io.ReadAll(resp.Body)
57 | if err != nil {
58 | log.Fatal(err)
59 | }
60 |
61 | // Print the response body
62 | fmt.Println(string(data))
63 | }
64 |
--------------------------------------------------------------------------------
/internal/middlewares/timeout.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "context"
5 | "net/http"
6 | "time"
7 |
8 | "go.opentelemetry.io/otel/trace"
9 | )
10 |
11 | // NewTimeoutMiddleware returns an HTTP handler that wraps the provided handler with a timeout.
12 | // If the processing takes longer than the specified timeout, it returns a 503 Service Unavailable error.
13 | func NewTimeoutMiddleware(timeout time.Duration) func(next http.Handler) http.Handler {
14 | return func(next http.Handler) http.Handler {
15 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
16 | span := trace.SpanFromContext(r.Context())
17 |
18 | // Create a context with the specified timeout
19 | ctx, cancel := context.WithTimeout(r.Context(), timeout)
20 | defer cancel() // Make sure to cancel the context when done
21 |
22 | // Create a new request with the timeout context
23 | r = r.WithContext(ctx)
24 |
25 | // Channel to capture when the request processing finishes
26 | done := make(chan struct{})
27 |
28 | go func() {
29 | // Serve the request
30 | next.ServeHTTP(w, r)
31 | // Signal that the request processing is done
32 | close(done)
33 | }()
34 |
35 | select {
36 | case <-ctx.Done():
37 | // If the context is canceled (due to timeout), return an error response
38 | if ctx.Err() == context.DeadlineExceeded {
39 | span.AddEvent("Request timed out")
40 | http.Error(w, "Request timed out", http.StatusGatewayTimeout)
41 | }
42 | case <-done:
43 | // If the request finished within the timeout, return the result
44 | return
45 | }
46 | })
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/cmd/config.yaml:
--------------------------------------------------------------------------------
1 | # yaml-language-server: $schema=https://raw.githubusercontent.com/hvuhsg/gatego/refs/heads/main/config-schema.json
2 |
3 | version: '0.0.1'
4 | host: localhost
5 | port: 8004
6 |
7 | # open_telemetry:
8 | # endpoint: "localhost:4317"
9 | # sample_ratio: 1
10 |
11 | services:
12 | - domain: localhost
13 |
14 | anomaly_detection:
15 | active: true
16 |
17 | endpoints:
18 | - path: /
19 | # directory: /home/yoyo/ # Instead of destination
20 | destination: http://127.0.0.1:4007/
21 | # backend:
22 | # balance_policy: 'least-latency' # Can be 'round-robin', 'random', or 'least-latency'
23 | # servers:
24 | # - url: http://127.0.0.1:4007/
25 | # weight: 1
26 | # - url: http://127.0.0.1:4008/
27 | # weight: 2
28 |
29 | minify: [js, html, css, json, xml, svg]
30 |
31 | gzip: true
32 |
33 | timeout: 3s # Default (30s)
34 | max_size: 1024 # Default (10MB)
35 |
36 | ratelimits:
37 | - ip-60/m # Limit requests from the same IP to 6 requests per minute.
38 | - ip-100/d
39 |
40 | openapi: openapi.yaml
41 |
42 | checks:
43 | - name: "DB Health"
44 | cron: "* * * * *"
45 | method: GET
46 | url: "http://127.0.0.1:4007/check_db"
47 | timeout: 5s
48 | headers:
49 | Host: domain.org
50 | Authorization: "Bearer abc123"
51 | on_failure: |
52 | echo Health check '$check_name' failed at $date with error: $error
53 |
54 | omit_headers: [Authorization, X-API-Key, X-Secret-Token]
55 |
56 | cache: true
57 |
--------------------------------------------------------------------------------
/internal/middlewares/logging.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "net/http"
7 | "time"
8 | )
9 |
10 | func formatDuration(ms int64) string {
11 | if ms < 1000 {
12 | return fmt.Sprintf("%dms", ms)
13 | }
14 | return fmt.Sprintf("%.1fs", float64(ms)/1000)
15 | }
16 |
17 | // Logging middleware log the request / response with the log style of nginx
18 | func NewLoggingMiddleware(out io.Writer) func(http.Handler) http.Handler {
19 | return func(next http.Handler) http.Handler {
20 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21 | start := time.Now().UnixMilli()
22 |
23 | rh := &responseHook{ResponseWriter: w, respSize: 0}
24 | next.ServeHTTP(rh, r)
25 |
26 | end := time.Now().UnixMilli()
27 |
28 | scheme := "http"
29 | if r.TLS != nil {
30 | scheme = "https"
31 | }
32 | fullURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.String())
33 |
34 | method := r.Method
35 | path := r.URL.Path
36 | responseSize := rh.respSize
37 | remoteAddr := r.RemoteAddr
38 | date := time.Now().Format("2006-01-02 15:04:05")
39 | userAgent := r.UserAgent()
40 | statusCode := rh.statusCode
41 | duration := formatDuration(end - start)
42 |
43 | fmt.Fprintf(out, "%s - - [%s] \"%s %s %s\" %d %d %s \"%s\" \"%s\"\n", remoteAddr, date, method, path, r.Proto, statusCode, responseSize, duration, fullURL, userAgent)
44 | })
45 | }
46 | }
47 |
48 | type responseHook struct {
49 | http.ResponseWriter
50 | respSize int
51 | statusCode int
52 | }
53 |
54 | func (rh *responseHook) Write(b []byte) (int, error) {
55 | // Save the length of the response
56 | rh.respSize += len(b)
57 |
58 | return rh.ResponseWriter.Write(b)
59 | }
60 |
61 | func (rh *responseHook) WriteHeader(statusCode int) {
62 | // Save status code
63 | rh.statusCode = statusCode
64 |
65 | rh.ResponseWriter.WriteHeader(statusCode)
66 | }
67 |
--------------------------------------------------------------------------------
/gatego.go:
--------------------------------------------------------------------------------
1 | package gatego
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "time"
7 |
8 | "github.com/hvuhsg/gatego/internal/config"
9 | "github.com/hvuhsg/gatego/internal/contextvalues"
10 | "github.com/hvuhsg/gatego/pkg/monitor"
11 | )
12 |
13 | const serviceName = "gatego"
14 |
15 | type GateGo struct {
16 | config config.Config
17 | monitor *monitor.Monitor
18 | ctx context.Context
19 | }
20 |
21 | func New(ctx context.Context, config config.Config, version string) *GateGo {
22 | ctx = contextvalues.AddVersionToContext(ctx, version)
23 | return &GateGo{config: config, ctx: ctx}
24 | }
25 |
26 | func (gg GateGo) Run() error {
27 | useOtel := gg.config.OTEL != nil
28 | if useOtel {
29 | otelConfig := otelConfig{
30 | ServiceName: serviceName,
31 | SampleRatio: gg.config.OTEL.SampleRatio,
32 | CollectorTimeout: time.Second * 5, // TODO: Add to config
33 | TraceCollectorEndpoint: gg.config.OTEL.Endpoint,
34 | MetricCollectorEndpoint: gg.config.OTEL.Endpoint,
35 | LogsCollectorEndpoint: gg.config.OTEL.Endpoint,
36 | }
37 | shutdown, err := setupOTelSDK(gg.ctx, otelConfig)
38 | if err != nil {
39 | return err
40 | }
41 | defer shutdown(context.Background())
42 | }
43 |
44 | // Create checks start monitoring
45 | healthChecks := createMonitorChecks(gg.config.Services)
46 | gg.monitor = monitor.New(time.Second*5, healthChecks...)
47 | gg.monitor.Start()
48 |
49 | server, err := newServer(gg.ctx, gg.config, useOtel)
50 | if err != nil {
51 | return err
52 | }
53 | defer server.Shutdown(gg.ctx)
54 |
55 | serveErrChan, err := server.serve(gg.config.TLS.CertFile, gg.config.TLS.KeyFile)
56 | if err != nil {
57 | return err
58 | }
59 |
60 | // Wait for interruption.
61 | select {
62 | case err = <-serveErrChan:
63 | return err
64 | case <-gg.ctx.Done():
65 | fmt.Println("\nShutting down...")
66 | return server.Shutdown(context.Background())
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/pkg/cron/README.md:
--------------------------------------------------------------------------------
1 | # Cron
2 |
3 | A Go package that implements a crontab-like service to execute and schedule repetitive tasks/jobs.
4 |
5 | ## Features
6 |
7 | - Supports cron expressions for flexible scheduling
8 | - Allows registering and managing multiple jobs
9 | - Provides macros for common schedule patterns
10 | - Supports custom timezones
11 | - Allows setting custom tick intervals
12 | - Supports starting and stopping the cron service
13 |
14 | ## Installation
15 |
16 | ```sh
17 | go get github.com/hvuhsg/gatego/pkg/cron
18 | ```
19 |
20 | ## Usage
21 |
22 | ```go
23 | package main
24 |
25 | import (
26 | "fmt"
27 | "time"
28 |
29 | "github.com/hvuhsg/gatego/pkg/cron"
30 | )
31 |
32 | func main() {
33 | c := cron.New()
34 |
35 | // Register a job
36 | c.MustAdd("job1", "*/5 * * * *", func() {
37 | fmt.Println("Running job1...")
38 | })
39 |
40 | // Set a custom timezone
41 | loc, _ := time.LoadLocation("Asia/Tokyo")
42 | c.SetTimezone(loc)
43 |
44 | // Set a custom tick interval
45 | c.SetInterval(5 * time.Second)
46 |
47 | // Start the cron service
48 | c.Start()
49 |
50 | // Stop the cron service after 30 seconds
51 | time.Sleep(30 * time.Second)
52 | c.Stop()
53 | }
54 | ```
55 |
56 | ## Cron Expression Format
57 |
58 | The package supports the following cron expression format:
59 |
60 | ```
61 | * * * * *
62 | │ │ │ │ │
63 | │ │ │ │ └── Day of Week (0-6)
64 | │ │ │ └──── Month (1-12)
65 | │ │ └────── Day of Month (1-31)
66 | │ └──────── Hour (0-23)
67 | └────────── Minute (0-59)
68 | ```
69 |
70 | It also supports the following macros:
71 |
72 | - `@yearly` or `@annually`: Run once a year at midnight on the first day of the year
73 | - `@monthly`: Run once a month at midnight on the first day of the month
74 | - `@weekly`: Run once a week at midnight on Sunday
75 | - `@daily` or `@midnight`: Run once a day at midnight
76 | - `@hourly`: Run once an hour at the beginning of the hour
77 | - `@minutely`: Run once a minute at the beginning of the minute
78 |
--------------------------------------------------------------------------------
/pkg/tracker/tracker.go:
--------------------------------------------------------------------------------
1 | package tracker
2 |
3 | import (
4 | "crypto/rand"
5 | "encoding/hex"
6 | "net/http"
7 | )
8 |
9 | type Tracker interface {
10 | GetTrackerID(*http.Request) string
11 | SetTracker(http.ResponseWriter) (string, error)
12 | RemoveTracker(*http.Request)
13 | }
14 |
15 | type cookieTracker struct {
16 | cookieName string
17 | trackerMaxAge int
18 | secureCookie bool
19 | }
20 |
21 | func NewCookieTracker(cookieName string, maxAge int, isSecure bool) cookieTracker {
22 | return cookieTracker{cookieName: cookieName, trackerMaxAge: maxAge, secureCookie: isSecure}
23 | }
24 |
25 | // Get the tracker id from request or return empty string if not found
26 | func (ct cookieTracker) GetTrackerID(r *http.Request) string {
27 | cookie, err := r.Cookie(ct.cookieName)
28 |
29 | if err != nil {
30 | return ""
31 | }
32 |
33 | return cookie.Value
34 | }
35 |
36 | // Set tracer into response and return the tracker id
37 | func (ct cookieTracker) SetTracker(w http.ResponseWriter) (string, error) {
38 | traceID, err := generateTraceID()
39 | if err != nil {
40 | return "", err
41 | }
42 |
43 | http.SetCookie(w, &http.Cookie{
44 | Name: ct.cookieName,
45 | Value: traceID,
46 | Path: "/",
47 | MaxAge: ct.trackerMaxAge,
48 | HttpOnly: true,
49 | Secure: ct.secureCookie,
50 | SameSite: http.SameSiteLaxMode,
51 | })
52 |
53 | return traceID, nil
54 | }
55 |
56 | func (ct cookieTracker) RemoveTracker(r *http.Request) {
57 | // Get existing cookies
58 | oldCookies := r.Cookies()
59 |
60 | // Create new headers without the cookie we want to remove
61 | r.Header.Del("Cookie")
62 |
63 | // Add back all cookies except the one we want to remove
64 | for _, cookie := range oldCookies {
65 | if cookie.Name != ct.cookieName {
66 | r.AddCookie(cookie)
67 | }
68 | }
69 | }
70 |
71 | func generateTraceID() (string, error) {
72 | bytes := make([]byte, 16)
73 | if _, err := rand.Read(bytes); err != nil {
74 | return "", err
75 | }
76 | return hex.EncodeToString(bytes), nil
77 | }
78 |
--------------------------------------------------------------------------------
/internal/middlewares/minify.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "strconv"
7 |
8 | "github.com/tdewolff/minify/v2"
9 | "github.com/tdewolff/minify/v2/css"
10 | "github.com/tdewolff/minify/v2/html"
11 | "github.com/tdewolff/minify/v2/js"
12 | "github.com/tdewolff/minify/v2/json"
13 | "github.com/tdewolff/minify/v2/svg"
14 | "github.com/tdewolff/minify/v2/xml"
15 | "go.opentelemetry.io/otel/trace"
16 | )
17 |
18 | type MinifyConfig struct {
19 | ALL bool
20 | JS bool
21 | CSS bool
22 | HTML bool
23 | JSON bool
24 | SVG bool
25 | XML bool
26 | }
27 |
28 | func NewMinifyMiddleware(config MinifyConfig) Middleware {
29 | m := minify.New()
30 |
31 | // Add minifiers for the different content types
32 | if config.HTML || config.ALL {
33 | m.AddFunc("text/html", html.Minify)
34 | }
35 | if config.CSS || config.ALL {
36 | m.AddFunc("text/css", css.Minify)
37 | }
38 | if config.JS || config.ALL {
39 | m.AddFunc("application/javascript", js.Minify)
40 | }
41 | if config.JSON || config.ALL {
42 | m.AddFunc("application/json", json.Minify)
43 | }
44 | if config.SVG || config.ALL {
45 | m.AddFunc("image/svg+xml", svg.Minify)
46 | }
47 | if config.XML || config.ALL {
48 | m.AddFunc("application/xml", xml.Minify)
49 | }
50 |
51 | return func(next http.Handler) http.Handler {
52 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
53 | span := trace.SpanFromContext(r.Context())
54 |
55 | // Create a custom ResponseWriter to capture the response
56 | rc := NewRecorder()
57 |
58 | // Serve the next handler
59 | next.ServeHTTP(rc, r)
60 |
61 | // Get the content type of the response
62 | contentType := rc.Header().Get("Content-Type")
63 |
64 | minifiedContent, err := m.Bytes(contentType, rc.Body.Bytes())
65 | if err != nil {
66 | rc.WriteTo(w) // Return the original response
67 | return
68 | }
69 |
70 | span.AddEvent(fmt.Sprintf("Minified response content, content-type = %s", contentType))
71 |
72 | // Write the minified content to the response
73 | w.Header().Set("Content-Length", strconv.Itoa(len(minifiedContent)))
74 | rc.WriteHeadersTo(w)
75 | w.WriteHeader(rc.Result().StatusCode)
76 |
77 | w.Write(minifiedContent)
78 | })
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/internal/middlewares/sizelimiter_test.go:
--------------------------------------------------------------------------------
1 | package middlewares_test
2 |
3 | import (
4 | "bytes"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 |
9 | "github.com/hvuhsg/gatego/internal/middlewares"
10 | )
11 |
12 | // TestRequestSizeLimitMiddleware tests the RequestSizeLimitMiddleware
13 | func TestRequestSizeLimitMiddleware(t *testing.T) {
14 | tests := []struct {
15 | name string
16 | body []byte
17 | maxSize uint64
18 | expectedCode int
19 | expectedBody string
20 | }{
21 | {
22 | name: "Within limit",
23 | body: []byte("This is within the limit."),
24 | maxSize: 30,
25 | expectedCode: http.StatusOK,
26 | },
27 | {
28 | name: "Exactly at limit",
29 | body: bytes.Repeat([]byte("A"), 30), // 30 bytes
30 | maxSize: 30,
31 | expectedCode: http.StatusOK,
32 | },
33 | {
34 | name: "Exceeds limit",
35 | body: bytes.Repeat([]byte("A"), 31), // 31 bytes
36 | maxSize: 30,
37 | expectedCode: http.StatusRequestEntityTooLarge,
38 | expectedBody: "Request body too large. Maximum allowed size is 30 bytes.\n",
39 | },
40 | }
41 |
42 | for _, tt := range tests {
43 | t.Run(tt.name, func(t *testing.T) {
44 | // Create a request with the test body
45 | req := httptest.NewRequest("POST", "http://example.com", bytes.NewReader(tt.body))
46 | rr := httptest.NewRecorder()
47 |
48 | // Use a simple handler that just returns OK
49 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
50 | w.WriteHeader(http.StatusOK)
51 | })
52 |
53 | // Create the middleware with the specified maxSize
54 | middleware := middlewares.NewRequestSizeLimitMiddleware(tt.maxSize)
55 |
56 | // Serve the request through the middleware
57 | middleware(handler).ServeHTTP(rr, req)
58 |
59 | // Check the response code
60 | if rr.Code != tt.expectedCode {
61 | t.Errorf("expected status code %d, got %d", tt.expectedCode, rr.Code)
62 | }
63 |
64 | // Check the response body if applicable
65 | if tt.expectedBody != "" {
66 | if rr.Body.String() != tt.expectedBody {
67 | t.Errorf("expected response body %q, got %q", tt.expectedBody, rr.Body.String())
68 | }
69 | }
70 | })
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/hvuhsg/gatego
2 |
3 | go 1.22.0
4 |
5 | require gopkg.in/yaml.v3 v3.0.1
6 |
7 | require (
8 | github.com/hashicorp/go-version v1.7.0
9 | github.com/tdewolff/minify/v2 v2.21.0
10 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0
11 | go.opentelemetry.io/otel/log v0.7.0
12 | go.opentelemetry.io/otel/sdk/log v0.7.0
13 | go.opentelemetry.io/otel/trace v1.31.0
14 | )
15 |
16 | require (
17 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect
18 | github.com/go-logr/logr v1.4.2 // indirect
19 | github.com/go-logr/stdr v1.2.2 // indirect
20 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect
21 | go.opentelemetry.io/otel/metric v1.31.0 // indirect
22 | go.opentelemetry.io/proto/otlp v1.3.1 // indirect
23 | golang.org/x/sys v0.26.0 // indirect
24 | golang.org/x/text v0.19.0 // indirect
25 | google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect
26 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 // indirect
27 | google.golang.org/grpc v1.67.1 // indirect
28 | google.golang.org/protobuf v1.35.1 // indirect
29 | )
30 |
31 | require (
32 | github.com/davecgh/go-spew v1.1.1 // indirect
33 | github.com/getkin/kin-openapi v0.128.0
34 | github.com/go-openapi/jsonpointer v0.21.0 // indirect
35 | github.com/go-openapi/swag v0.23.0 // indirect
36 | github.com/google/uuid v1.6.0
37 | github.com/gorilla/mux v1.8.0 // indirect
38 | github.com/invopop/yaml v0.3.1 // indirect
39 | github.com/josharian/intern v1.0.0 // indirect
40 | github.com/mailru/easyjson v0.7.7 // indirect
41 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
42 | github.com/patrickmn/go-cache v2.1.0+incompatible
43 | github.com/perimeterx/marshmallow v1.1.5 // indirect
44 | github.com/pmezard/go-difflib v1.0.0 // indirect
45 | github.com/stretchr/testify v1.9.0
46 | github.com/tdewolff/parse/v2 v2.7.17 // indirect
47 | go.opentelemetry.io/otel v1.31.0
48 | go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.7.0
49 | go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0
50 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0
51 | go.opentelemetry.io/otel/sdk v1.31.0
52 | go.opentelemetry.io/otel/sdk/metric v1.31.0
53 | golang.org/x/net v0.30.0
54 | golang.org/x/time v0.7.0
55 | )
56 |
--------------------------------------------------------------------------------
/internal/middlewares/timeout_test.go:
--------------------------------------------------------------------------------
1 | package middlewares_test
2 |
3 | import (
4 | "context"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 | "time"
9 |
10 | "github.com/hvuhsg/gatego/internal/middlewares"
11 | )
12 |
13 | func TestTimeoutMiddleware(t *testing.T) {
14 | tests := []struct {
15 | name string
16 | timeout time.Duration
17 | handlerSleep time.Duration
18 | expectedStatus int
19 | }{
20 | {
21 | name: "Request completes before timeout",
22 | timeout: 100 * time.Millisecond,
23 | handlerSleep: 50 * time.Millisecond,
24 | expectedStatus: http.StatusOK,
25 | },
26 | {
27 | name: "Request times out",
28 | timeout: 50 * time.Millisecond,
29 | handlerSleep: 100 * time.Millisecond,
30 | expectedStatus: http.StatusGatewayTimeout,
31 | },
32 | }
33 |
34 | for _, tt := range tests {
35 | t.Run(tt.name, func(t *testing.T) {
36 | // Create a test handler that sleeps for the specified duration
37 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38 | time.Sleep(tt.handlerSleep)
39 | w.WriteHeader(http.StatusOK)
40 | })
41 |
42 | // Wrap the handler with our middleware
43 | wrappedHandler := middlewares.NewTimeoutMiddleware(tt.timeout)(handler)
44 |
45 | // Create a test request
46 | req, err := http.NewRequest("GET", "/test", nil)
47 | if err != nil {
48 | t.Fatal(err)
49 | }
50 |
51 | // Create a ResponseRecorder to record the response
52 | rr := httptest.NewRecorder()
53 |
54 | // Serve the request using our wrapped handler
55 | wrappedHandler.ServeHTTP(rr, req)
56 |
57 | // Check the status code
58 | if status := rr.Code; status != tt.expectedStatus {
59 | t.Errorf("handler returned wrong status code: got %v want %v",
60 | status, tt.expectedStatus)
61 | }
62 | })
63 | }
64 | }
65 |
66 | func TestTimeoutMiddlewareCancelContext(t *testing.T) {
67 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
68 | // Sleep until the context is canceled
69 | <-r.Context().Done()
70 | // Check if the context was canceled due to a timeout
71 | if r.Context().Err() == context.DeadlineExceeded {
72 | w.WriteHeader(http.StatusGatewayTimeout)
73 | }
74 |
75 | w.WriteHeader(http.StatusOK)
76 | })
77 |
78 | wrappedHandler := middlewares.NewTimeoutMiddleware(50 * time.Millisecond)(handler)
79 |
80 | req, err := http.NewRequest("GET", "/test", nil)
81 | if err != nil {
82 | t.Fatal(err)
83 | }
84 |
85 | rr := httptest.NewRecorder()
86 |
87 | wrappedHandler.ServeHTTP(rr, req)
88 |
89 | if status := rr.Code; status != http.StatusGatewayTimeout {
90 | t.Errorf("handler returned wrong status code: got %v want %v",
91 | status, http.StatusGatewayTimeout)
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/internal/middlewares/cache.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "net/http"
5 | "strconv"
6 | "strings"
7 | "time"
8 |
9 | "github.com/patrickmn/go-cache"
10 | "go.opentelemetry.io/otel/trace"
11 | )
12 |
13 | const DEFAULT_CACHE_TTL = time.Minute * 1
14 | const CLEANUP_CACHE_INTERVAL = time.Minute * 10
15 |
16 | var responseCache = cache.New(DEFAULT_CACHE_TTL, CLEANUP_CACHE_INTERVAL) // Default cache with a placeholder TTL
17 |
18 | type CachedResponse struct {
19 | statusCode int
20 | body []byte
21 | headers http.Header
22 | }
23 |
24 | func NewCacheMiddleware() Middleware {
25 | return func(next http.Handler) http.Handler {
26 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27 | span := trace.SpanFromContext(r.Context())
28 |
29 | // Check if response response is already cached
30 | cachedResponse, found := responseCache.Get(r.URL.String())
31 | if found {
32 | span.AddEvent("Cache hit")
33 | response := cachedResponse.(CachedResponse)
34 | for header := range response.headers {
35 | w.Header().Set(header, response.headers.Get(header))
36 | }
37 | w.WriteHeader(response.statusCode)
38 | w.Write(response.body)
39 | return
40 | }
41 |
42 | // Serve the next handler and capture the response
43 | rc := NewRecorder()
44 | next.ServeHTTP(rc, r)
45 |
46 | // Get cache control headers
47 | cacheControl := rc.Header().Get("Cache-Control")
48 | maxAge := getCacheMaxAge(cacheControl)
49 | expires := getCacheExpires(rc.Header().Get("Expires"))
50 |
51 | // Determine TTL based on cache headers
52 | ttl := time.Second * 0
53 | if maxAge > 0 {
54 | ttl = time.Duration(maxAge) * time.Second
55 | } else if !expires.IsZero() {
56 | ttl = time.Until(expires)
57 | }
58 |
59 | // Cache the response if it's cacheable
60 | if ttl > 0 {
61 | cachedResponse := CachedResponse{statusCode: rc.Result().StatusCode, body: rc.Body.Bytes(), headers: rc.Result().Header}
62 | responseCache.Set(r.URL.String(), cachedResponse, ttl)
63 | span.AddEvent("Response stored in cache")
64 | }
65 |
66 | // Write the captured response (original or cached)
67 | rc.WriteTo(w)
68 | })
69 | }
70 | }
71 |
72 | func getCacheMaxAge(cacheControl string) int {
73 | for _, directive := range strings.Split(cacheControl, ",") {
74 | directive = strings.TrimSpace(directive)
75 | if strings.HasPrefix(directive, "max-age=") {
76 | maxAge, err := strconv.Atoi(strings.TrimPrefix(directive, "max-age="))
77 | if err == nil {
78 | return maxAge
79 | }
80 | }
81 | }
82 | return 0
83 | }
84 |
85 | func getCacheExpires(expiresHeader string) time.Time {
86 | expires, err := time.Parse(time.RFC1123, expiresHeader)
87 | if err != nil {
88 | return time.Time{}
89 | }
90 | return expires
91 | }
92 |
--------------------------------------------------------------------------------
/internal/middlewares/openapi.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 |
7 | "github.com/getkin/kin-openapi/openapi3"
8 | "github.com/getkin/kin-openapi/openapi3filter"
9 | "github.com/getkin/kin-openapi/routers/gorillamux"
10 | "go.opentelemetry.io/otel/trace"
11 | )
12 |
13 | func NewOpenAPIValidationMiddleware(specPath string) (Middleware, error) {
14 | loader := &openapi3.Loader{IsExternalRefsAllowed: true}
15 | doc, err := loader.LoadFromFile(specPath)
16 | if err != nil {
17 | return nil, fmt.Errorf("error loading OpenAPI spec: %w", err)
18 | }
19 |
20 | if err := doc.Validate(loader.Context); err != nil {
21 | return nil, fmt.Errorf("error validating OpenAPI spec: %w", err)
22 | }
23 |
24 | router, err := gorillamux.NewRouter(doc)
25 | if err != nil {
26 | return nil, fmt.Errorf("error creating router: %w", err)
27 | }
28 |
29 | return func(next http.Handler) http.Handler {
30 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
31 | span := trace.SpanFromContext(r.Context())
32 |
33 | route, pathParams, err := router.FindRoute(r)
34 | if err != nil {
35 | span.AddEvent("Request path not found in openapi spec")
36 | http.Error(w, fmt.Sprintf("Error finding route: %v", err), http.StatusBadRequest)
37 | return
38 | }
39 |
40 | requestValidationInput := &openapi3filter.RequestValidationInput{
41 | Request: r,
42 | PathParams: pathParams,
43 | Route: route,
44 | }
45 |
46 | if err := openapi3filter.ValidateRequest(r.Context(), requestValidationInput); err != nil {
47 | span.AddEvent(fmt.Sprintf("Error while validating request with openapi spec. err = %v", err))
48 | http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
49 | return
50 | }
51 |
52 | span.AddEvent("Request validated by openapi spec")
53 |
54 | rc := NewRecorder()
55 | next.ServeHTTP(rc, r)
56 |
57 | responseValidationInput := &openapi3filter.ResponseValidationInput{
58 | RequestValidationInput: requestValidationInput,
59 | Status: rc.Result().StatusCode,
60 | Header: rc.Header(),
61 | }
62 |
63 | if rc.Body.Bytes() != nil {
64 | responseValidationInput.SetBodyBytes(rc.Body.Bytes())
65 | }
66 |
67 | if err := openapi3filter.ValidateResponse(r.Context(), responseValidationInput); err != nil {
68 | span.AddEvent(fmt.Sprintf("Error while validating response with openapi spec. err = %v", err))
69 | http.Error(w, fmt.Sprintf("Invalid response: %v", err), http.StatusInternalServerError)
70 | return
71 | }
72 |
73 | span.AddEvent("Response validated by openapi spec")
74 |
75 | rc.WriteHeadersTo(w)
76 | w.WriteHeader(rc.Result().StatusCode)
77 | if rc.Body.Bytes() != nil {
78 | w.Write(rc.Body.Bytes())
79 | }
80 | })
81 | }, nil
82 | }
83 |
--------------------------------------------------------------------------------
/server.go:
--------------------------------------------------------------------------------
1 | package gatego
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 | "net"
8 | "net/http"
9 | "os"
10 | "time"
11 |
12 | "github.com/hvuhsg/gatego/internal/config"
13 | "github.com/hvuhsg/gatego/pkg/multimux"
14 | )
15 |
16 | type gategoServer struct {
17 | *http.Server
18 | }
19 |
20 | func newServer(ctx context.Context, config config.Config, useOtel bool) (*gategoServer, error) {
21 | multimuxer, err := createMultiMuxer(ctx, config.Services, useOtel)
22 | if err != nil {
23 | return nil, err
24 | }
25 |
26 | addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
27 |
28 | // Start HTTP server.
29 | server := &http.Server{
30 | Addr: addr,
31 | BaseContext: func(_ net.Listener) context.Context { return ctx },
32 | ReadTimeout: time.Second,
33 | WriteTimeout: 10 * time.Second,
34 | Handler: multimuxer,
35 | }
36 |
37 | return &gategoServer{Server: server}, nil
38 | }
39 |
40 | func createMultiMuxer(ctx context.Context, services []config.Service, useOtel bool) (*multimux.MultiMux, error) {
41 | mm := multimux.NewMultiMux()
42 |
43 | for _, service := range services {
44 | for _, path := range service.Paths {
45 | handler, err := NewHandler(ctx, useOtel, service, path)
46 | if err != nil {
47 | return nil, err
48 | }
49 |
50 | mm.RegisterHandler(service.Domain, path.Path, handler)
51 | }
52 | }
53 |
54 | return mm, nil
55 | }
56 |
57 | func (gs *gategoServer) serve(certfile *string, keyfile *string) (chan error, error) {
58 | supportTLS, err := checkTLSConfig(certfile, keyfile)
59 | if err != nil {
60 | return nil, err
61 | }
62 |
63 | serveErr := make(chan error, 1)
64 |
65 | go func() {
66 | if supportTLS {
67 | log.Default().Printf("Serving proxy with TLS %s\n", gs.Addr)
68 | serveErr <- gs.ListenAndServeTLS(*certfile, *keyfile)
69 | } else {
70 | log.Default().Printf("Serving proxy %s\n", gs.Addr)
71 | serveErr <- gs.ListenAndServe()
72 | }
73 | }()
74 |
75 | return serveErr, nil
76 | }
77 |
78 | func checkTLSConfig(certfile *string, keyfile *string) (bool, error) {
79 | if keyfile == nil || certfile == nil || *keyfile == "" || *certfile == "" {
80 | return false, nil
81 | }
82 |
83 | if !fileExists(*keyfile) {
84 | return false, fmt.Errorf("can't find keyfile at '%s'", *keyfile)
85 | }
86 |
87 | if !fileExists(*certfile) {
88 | return false, fmt.Errorf("can't find certfile at '%s'", *certfile)
89 | }
90 |
91 | return true, nil
92 | }
93 |
94 | func fileExists(filepath string) bool {
95 | _, err := os.Stat(filepath)
96 |
97 | if os.IsNotExist(err) {
98 | return false
99 | }
100 |
101 | // If we cant check the file info we probably can't open the file
102 | if err != nil {
103 | return false
104 | }
105 |
106 | return true
107 | }
108 |
--------------------------------------------------------------------------------
/pkg/pathgraph/pathgraph.go:
--------------------------------------------------------------------------------
1 | package pathgraph
2 |
3 | import "strings"
4 |
5 | const incRate = 1
6 | const baseWeight = 0
7 |
8 | // PathVertex represents a vertex in the graph
9 | type PathVertex struct {
10 | Path string
11 | Weight float64
12 | }
13 |
14 | // PathGraph represents a weighted directed graph of navigation paths
15 | type PathGraph struct {
16 | // Map of source path to map of destination paths and their weights
17 | adjacencyList map[string]map[string]*PathVertex
18 | }
19 |
20 | // NewPathGraph creates a new instance of PathGraph
21 | func NewPathGraph() *PathGraph {
22 | return &PathGraph{
23 | adjacencyList: make(map[string]map[string]*PathVertex),
24 | }
25 | }
26 |
27 | // AddJump adds or updates a path transition in the graph
28 | func (g *PathGraph) AddJump(sourcePath, destPath string) float64 {
29 | sourcePath = normalizePath(sourcePath)
30 | destPath = normalizePath(destPath)
31 |
32 | // Initialize source path if it doesn't exist
33 | if _, exists := g.adjacencyList[sourcePath]; !exists {
34 | g.adjacencyList[sourcePath] = make(map[string]*PathVertex)
35 | }
36 |
37 | // Get or create destination node
38 | vertex, exists := g.adjacencyList[sourcePath][destPath]
39 | if !exists {
40 | vertex = &PathVertex{
41 | Path: destPath,
42 | Weight: baseWeight,
43 | }
44 | g.adjacencyList[sourcePath][destPath] = vertex
45 | }
46 |
47 | // Increment weight
48 | vertex.Weight += incRate
49 |
50 | return vertex.Weight - 1 // The original weight (before the jump)
51 | }
52 |
53 | // GetDestinations returns all destinations and their weights for a given source path
54 | func (g *PathGraph) GetDestinations(sourcePath string) map[string]float64 {
55 | sourcePath = normalizePath(sourcePath)
56 |
57 | result := make(map[string]float64)
58 |
59 | if vertexs, exists := g.adjacencyList[sourcePath]; exists {
60 | for path, vertex := range vertexs {
61 | result[path] = vertex.Weight
62 | }
63 | }
64 |
65 | return result
66 | }
67 |
68 | // GetAllPaths returns all unique paths in the graph
69 | func (g *PathGraph) GetAllPaths() []string {
70 | pathSet := make(map[string]bool)
71 |
72 | // Add all source paths
73 | for sourcePath := range g.adjacencyList {
74 | pathSet[sourcePath] = true
75 |
76 | // Add all destination paths
77 | for destPath := range g.adjacencyList[sourcePath] {
78 | pathSet[destPath] = true
79 | }
80 | }
81 |
82 | // Convert set to slice
83 | paths := make([]string, 0, len(pathSet))
84 | for path := range pathSet {
85 | paths = append(paths, path)
86 | }
87 |
88 | return paths
89 | }
90 |
91 | func normalizePath(path string) string {
92 | if len(path) == 0 || path[0] != '/' {
93 | path = "/" + path
94 | }
95 |
96 | path = strings.ToLower(path)
97 |
98 | return path
99 | }
100 |
--------------------------------------------------------------------------------
/internal/middlewares/security/routing_anomaly_score_test.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | import (
4 | "math"
5 | "testing"
6 | )
7 |
8 | func TestCalcAnomalyRating(t *testing.T) {
9 | tests := []struct {
10 | name string
11 | detector *RoutingAnomalyDetector
12 | trackerHistory *trackerHistory
13 | want float64
14 | }{
15 | {
16 | name: "Below threshold returns 0",
17 | detector: &RoutingAnomalyDetector{
18 | numberOfJumps: 99,
19 | scoreSum: 100,
20 | avgDiviation: 10,
21 | tresholdForRating: 100,
22 | minScore: 100,
23 | maxScore: 200,
24 | anomalyHeaderName: "test",
25 | },
26 | trackerHistory: &trackerHistory{jumpsScoreSum: 50, jumpsCount: 1},
27 | want: 0,
28 | },
29 | {
30 | name: "Score below minScore returns 0",
31 | detector: &RoutingAnomalyDetector{
32 | numberOfJumps: 101,
33 | scoreSum: 500,
34 | avgDiviation: 100,
35 | tresholdForRating: 100,
36 | minScore: 100,
37 | maxScore: 200,
38 | anomalyHeaderName: "test",
39 | },
40 | trackerHistory: &trackerHistory{jumpsScoreSum: 95, jumpsCount: 1},
41 | want: 0,
42 | },
43 | {
44 | name: "Score above maxScore returns 1",
45 | detector: &RoutingAnomalyDetector{
46 | numberOfJumps: 101,
47 | scoreSum: 1000,
48 | avgDiviation: 1,
49 | tresholdForRating: 100,
50 | minScore: 100,
51 | maxScore: 200,
52 | anomalyHeaderName: "test",
53 | },
54 | trackerHistory: &trackerHistory{jumpsScoreSum: 50, jumpsCount: 1},
55 | want: 1,
56 | },
57 | {
58 | name: "Normal score calculation",
59 | detector: &RoutingAnomalyDetector{
60 | numberOfJumps: 100,
61 | scoreSum: 500,
62 | avgDiviation: 50,
63 | tresholdForRating: 100,
64 | minScore: 100,
65 | maxScore: 200,
66 | anomalyHeaderName: "test",
67 | },
68 | trackerHistory: &trackerHistory{jumpsScoreSum: 90, jumpsCount: 1},
69 | want: 0.6, // assuming avg score of 10 units
70 | },
71 | {
72 | name: "Zero avgDiviation handling",
73 | detector: &RoutingAnomalyDetector{
74 | numberOfJumps: 101,
75 | scoreSum: 100,
76 | avgDiviation: 0,
77 | tresholdForRating: 100,
78 | minScore: 100,
79 | maxScore: 200,
80 | anomalyHeaderName: "test",
81 | },
82 | trackerHistory: &trackerHistory{jumpsScoreSum: 100, jumpsCount: 1},
83 | want: 1, // should return max score due to division by zero protection
84 | },
85 | }
86 |
87 | for _, tt := range tests {
88 | t.Run(tt.name, func(t *testing.T) {
89 | got := tt.detector.calcAnomalyRating(tt.trackerHistory)
90 |
91 | if math.Abs(got-tt.want) > 0.0001 { // Using small epsilon for float comparison
92 | t.Errorf("calcAnomalyRating() = %v, want %v", got, tt.want)
93 | }
94 | })
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/internal/middlewares/otel.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "net/http"
7 |
8 | "github.com/hvuhsg/gatego/internal/contextvalues"
9 | "go.opentelemetry.io/otel"
10 | "go.opentelemetry.io/otel/attribute"
11 | "go.opentelemetry.io/otel/codes"
12 | "go.opentelemetry.io/otel/propagation"
13 | semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
14 | "go.opentelemetry.io/otel/trace"
15 | )
16 |
17 | const tracerName = "request"
18 | const spanName = "middlewares"
19 |
20 | type OTELConfig struct {
21 | ServiceDomain string
22 | BasePath string
23 | }
24 |
25 | func NewOpenTelemetryMiddleware(ctx context.Context, config OTELConfig) (Middleware, error) {
26 | tp := otel.GetTracerProvider()
27 | tracer := tp.Tracer(tracerName)
28 |
29 | return func(next http.Handler) http.Handler {
30 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
31 | // Add tracer to request context
32 | r = r.WithContext(contextvalues.AddTracerToContext(r.Context(), tracer))
33 |
34 | // Create span for request
35 | ctx, span := tracer.Start(
36 | r.Context(),
37 | spanName,
38 | trace.WithAttributes(semconv.NetAttributesFromHTTPRequest("", r)...),
39 | trace.WithSpanKind(trace.SpanKindServer),
40 | )
41 | defer span.End()
42 |
43 | // Add request-specific attributes
44 | attrs := make([]attribute.KeyValue, 0)
45 | attrs = append(attrs, semconv.HTTPUserAgentKey.String(r.UserAgent()))
46 | attrs = append(attrs, semconv.HTTPServerAttributesFromHTTPRequest(config.ServiceDomain, config.BasePath, r)...)
47 | span.SetAttributes(
48 | attrs...,
49 | )
50 |
51 | // Handle panic recovery
52 | defer func() {
53 | if err := recover(); err != nil {
54 | span.SetStatus(codes.Error, fmt.Sprintf("panic: %v", err))
55 | span.RecordError(fmt.Errorf("%v", err))
56 | panic(err) // Re-panic after recording error
57 | }
58 | }()
59 |
60 | // Propegate open telemetry context via the request to the upstream service
61 | otel.GetTextMapPropagator().Inject(r.Context(), propagation.HeaderCarrier(r.Header))
62 |
63 | // Add span to request context
64 | rc := NewRecorder()
65 | next.ServeHTTP(rc, r.WithContext(ctx))
66 |
67 | // Set status and attributes based on response code
68 | statusCode := rc.Result().StatusCode
69 | span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(statusCode)...)
70 | if statusCode >= 400 {
71 | span.SetStatus(codes.Error, http.StatusText(statusCode))
72 | if statusCode >= 500 {
73 | span.RecordError(fmt.Errorf("server error: %d", statusCode))
74 | }
75 | } else {
76 | span.SetStatus(codes.Ok, "")
77 | }
78 |
79 | // Add response information
80 | span.SetAttributes(
81 | attribute.Int64("http.response_size", rc.Result().ContentLength),
82 | attribute.String("http.response_content_type", rc.Result().Header.Get("Content-Type")),
83 | )
84 |
85 | // Return response
86 | rc.WriteTo(w)
87 | })
88 | }, nil
89 | }
90 |
--------------------------------------------------------------------------------
/pkg/monitor/monitor.go:
--------------------------------------------------------------------------------
1 | package monitor
2 |
3 | import (
4 | "fmt"
5 | "log"
6 | "net/http"
7 | "os/exec"
8 | "strings"
9 | "time"
10 |
11 | "github.com/google/uuid"
12 | "github.com/hvuhsg/gatego/pkg/cron"
13 | )
14 |
15 | type Check struct {
16 | Name string
17 | Cron string
18 | URL string
19 | Method string
20 | Timeout time.Duration
21 | Headers map[string]string
22 | OnFailure string
23 | }
24 |
25 | func (c Check) run(onFailure func(error)) func() {
26 | return func() {
27 | // Create a client with timeout
28 | client := &http.Client{
29 | Timeout: c.Timeout,
30 | }
31 |
32 | // Create new request
33 | req, err := http.NewRequest(c.Method, c.URL, nil)
34 | if err != nil {
35 | log.Default().Printf("Check <%s> error creating check request URL=%s Method=%s\n", c.Name, c.URL, c.Method)
36 | onFailure(err)
37 | return
38 | }
39 |
40 | // Add headers
41 | for key, value := range c.Headers {
42 | req.Header.Add(key, value)
43 | }
44 |
45 | // Send request
46 | resp, err := client.Do(req)
47 | if err != nil {
48 | log.Default().Printf("Check <%s> error sending request Error=%s\n", c.Name, err.Error())
49 | onFailure(err)
50 | return
51 | }
52 | defer resp.Body.Close()
53 |
54 | // Check status code
55 | if resp.StatusCode != http.StatusOK {
56 | log.Default().Printf("Check <%s> failed. Expected status code 200 got %d\n", c.Name, resp.StatusCode)
57 | onFailure(fmt.Errorf("expected status code 200 got %d", resp.StatusCode))
58 | return
59 | }
60 | }
61 | }
62 |
63 | func handleFailure(check Check, err error) error {
64 | // Expand command
65 | command := check.OnFailure
66 | date := time.Now().UTC().Format("2006-01-02 15:04:05")
67 | command = strings.ReplaceAll(command, "$date", date)
68 | command = strings.ReplaceAll(command, "$error", err.Error())
69 | command = strings.ReplaceAll(command, "$check_name", check.Name)
70 |
71 | // Run it
72 | args := strings.Split(command, " ")
73 | cmd := exec.Command(args[0], args[1:]...)
74 | if err := cmd.Start(); err != nil {
75 | return err
76 | }
77 | return nil
78 | }
79 |
80 | type Monitor struct {
81 | Delay time.Duration
82 | Checks []Check
83 | scheduler *cron.Cron
84 | }
85 |
86 | func New(delay time.Duration, checks ...Check) *Monitor {
87 | return &Monitor{Delay: delay, Checks: checks, scheduler: cron.New()}
88 | }
89 |
90 | func (m Monitor) Start() error {
91 | m.scheduler = cron.New()
92 |
93 | for _, check := range m.Checks {
94 | err := m.scheduler.Add(uuid.NewString(), check.Cron, check.run(func(err error) {
95 | if check.OnFailure != "" {
96 | if err := handleFailure(check, err); err != nil {
97 | log.Default().Printf("Failed to spawn on_failure command: %s\n", err)
98 | }
99 | }
100 | }))
101 | if err != nil {
102 | return err
103 | }
104 | }
105 |
106 | go func() {
107 | time.Sleep(m.Delay)
108 | m.scheduler.Start()
109 | log.Default().Println("Started running automated checks.")
110 | }()
111 |
112 | return nil
113 | }
114 |
--------------------------------------------------------------------------------
/internal/middlewares/logging_test.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "bytes"
5 | "net/http"
6 | "net/http/httptest"
7 | "regexp"
8 | "strings"
9 | "testing"
10 | )
11 |
12 | func TestLoggingMiddleware(t *testing.T) {
13 | tests := []struct {
14 | name string
15 | method string
16 | path string
17 | statusCode int
18 | responseBody string
19 | expectedLogParts []string
20 | }{
21 | {
22 | name: "GET request",
23 | method: "GET",
24 | path: "/test",
25 | statusCode: 200,
26 | responseBody: "OK",
27 | expectedLogParts: []string{
28 | "GET",
29 | "/test",
30 | "HTTP/1.1",
31 | "200",
32 | "2",
33 | "http://example.com/test",
34 | },
35 | },
36 | {
37 | name: "POST request with 404",
38 | method: "POST",
39 | path: "/notfound",
40 | statusCode: 404,
41 | responseBody: "Not Found",
42 | expectedLogParts: []string{
43 | "POST",
44 | "/notfound",
45 | "HTTP/1.1",
46 | "404",
47 | "9",
48 | "http://example.com/notfound",
49 | },
50 | },
51 | }
52 |
53 | for _, tt := range tests {
54 | t.Run(tt.name, func(t *testing.T) {
55 | // Create a buffer to capture the log output
56 | buf := &bytes.Buffer{}
57 |
58 | // Create a test handler that returns the specified status code and body
59 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
60 | w.WriteHeader(tt.statusCode)
61 | w.Write([]byte(tt.responseBody))
62 | })
63 |
64 | // Create the logging middleware
65 | loggingMiddleware := NewLoggingMiddleware(buf)
66 |
67 | // Create a test server with the logging middleware
68 | ts := httptest.NewServer(loggingMiddleware(testHandler))
69 | defer ts.Close()
70 |
71 | // Create and send the request
72 | req, _ := http.NewRequest(tt.method, ts.URL+tt.path, nil)
73 | req.Host = "example.com" // Set a consistent host for testing
74 | resp, err := http.DefaultClient.Do(req)
75 | if err != nil {
76 | t.Fatalf("Error making request: %v", err)
77 | }
78 | defer resp.Body.Close()
79 |
80 | // Check the response
81 | if resp.StatusCode != tt.statusCode {
82 | t.Errorf("Expected status code %d, got %d", tt.statusCode, resp.StatusCode)
83 | }
84 |
85 | // Check the log output
86 | logOutput := buf.String()
87 | for _, expectedPart := range tt.expectedLogParts {
88 | if !strings.Contains(logOutput, expectedPart) {
89 | t.Errorf("Expected log to contain '%s', but it didn't. Log: %s", expectedPart, logOutput)
90 | }
91 | }
92 |
93 | // Check for the presence of a timestamp in the expected format
94 | timeStampFormat := "[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}"
95 | if matched, _ := regexp.MatchString(timeStampFormat, logOutput); !matched {
96 | t.Errorf("Expected log to contain a timestamp in format YYYY-MM-DD HH:MM:SS, but it didn't. Log: %s", logOutput)
97 | }
98 |
99 | // Check for the presence of a duration
100 | if !strings.Contains(logOutput, "ms") && !strings.Contains(logOutput, "s") {
101 | t.Errorf("Expected log to contain a duration, but it didn't. Log: %s", logOutput)
102 | }
103 | })
104 | }
105 | }
106 |
107 | func TestFormatDuration(t *testing.T) {
108 | tests := []struct {
109 | name string
110 | duration int64
111 | expected string
112 | }{
113 | {"Less than a second", 500, "500ms"},
114 | {"Exactly one second", 1000, "1.0s"},
115 | {"More than a second", 1500, "1.5s"},
116 | {"Multiple seconds", 3750, "3.8s"},
117 | }
118 |
119 | for _, tt := range tests {
120 | t.Run(tt.name, func(t *testing.T) {
121 | result := formatDuration(tt.duration)
122 | if result != tt.expected {
123 | t.Errorf("formatDuration(%d) = %s; want %s", tt.duration, result, tt.expected)
124 | }
125 | })
126 | }
127 | }
128 |
--------------------------------------------------------------------------------
/pkg/tracker/tracker_test.go:
--------------------------------------------------------------------------------
1 | package tracker
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "strings"
7 | "testing"
8 | )
9 |
10 | func TestNewCookieTracker(t *testing.T) {
11 | tracker := NewCookieTracker("testTracker", 3600, true)
12 |
13 | if tracker.cookieName != "testTracker" {
14 | t.Errorf("Expected cookieName to be 'testTracker', got %s", tracker.cookieName)
15 | }
16 | if tracker.trackerMaxAge != 3600 {
17 | t.Errorf("Expected trackerMaxAge to be 3600, got %d", tracker.trackerMaxAge)
18 | }
19 | if !tracker.secureCookie {
20 | t.Errorf("Expected secureCookie to be true")
21 | }
22 | }
23 |
24 | func TestGenerateTraceID(t *testing.T) {
25 | traceID1, err1 := generateTraceID()
26 | if err1 != nil {
27 | t.Fatalf("Unexpected error generating trace ID: %v", err1)
28 | }
29 |
30 | traceID2, err2 := generateTraceID()
31 | if err2 != nil {
32 | t.Fatalf("Unexpected error generating trace ID: %v", err2)
33 | }
34 |
35 | if len(traceID1) != 32 {
36 | t.Errorf("Expected trace ID length to be 32, got %d", len(traceID1))
37 | }
38 |
39 | if traceID1 == traceID2 {
40 | t.Error("Generated trace IDs should be unique")
41 | }
42 | }
43 |
44 | func TestSetTracker(t *testing.T) {
45 | tracker := NewCookieTracker("testTracker", 3600, true)
46 |
47 | // Create a test response writer
48 | w := httptest.NewRecorder()
49 |
50 | // Set tracker
51 | traceID, err := tracker.SetTracker(w)
52 | if err != nil {
53 | t.Fatalf("Unexpected error setting tracker: %v", err)
54 | }
55 |
56 | // Check response headers
57 | cookies := w.Result().Cookies()
58 | if len(cookies) != 1 {
59 | t.Fatalf("Expected 1 cookie, got %d", len(cookies))
60 | }
61 |
62 | cookie := cookies[0]
63 | if cookie.Name != "testTracker" {
64 | t.Errorf("Expected cookie name 'testTracker', got %s", cookie.Name)
65 | }
66 | if cookie.Value != traceID {
67 | t.Errorf("Cookie value does not match returned trace ID")
68 | }
69 | if cookie.Path != "/" {
70 | t.Errorf("Expected cookie path '/', got %s", cookie.Path)
71 | }
72 | if cookie.MaxAge != 3600 {
73 | t.Errorf("Expected MaxAge 3600, got %d", cookie.MaxAge)
74 | }
75 | if !cookie.HttpOnly {
76 | t.Errorf("Expected HttpOnly to be true")
77 | }
78 | }
79 |
80 | func TestGetTrackerID(t *testing.T) {
81 | tracker := NewCookieTracker("testTracker", 3600, true)
82 |
83 | // Test request without cookie
84 | req1 := httptest.NewRequest(http.MethodGet, "/", nil)
85 | trackerID1 := tracker.GetTrackerID(req1)
86 | if trackerID1 != "" {
87 | t.Errorf("Expected empty string when no cookie exists, got %s", trackerID1)
88 | }
89 |
90 | // Test request with cookie
91 | req2 := httptest.NewRequest(http.MethodGet, "/", nil)
92 | req2.AddCookie(&http.Cookie{
93 | Name: "testTracker",
94 | Value: "test-trace-id",
95 | })
96 |
97 | trackerID2 := tracker.GetTrackerID(req2)
98 | if trackerID2 != "test-trace-id" {
99 | t.Errorf("Expected 'test-trace-id', got %s", trackerID2)
100 | }
101 | }
102 |
103 | func TestRemoveTracker(t *testing.T) {
104 | tracker := NewCookieTracker("testTracker", 3600, true)
105 |
106 | // Create a request with multiple cookies
107 | req := httptest.NewRequest(http.MethodGet, "/", nil)
108 | req.AddCookie(&http.Cookie{Name: "testTracker", Value: "remove-me"})
109 | req.AddCookie(&http.Cookie{Name: "otherCookie", Value: "keep-me"})
110 |
111 | // Remove the specific tracker cookie
112 | tracker.RemoveTracker(req)
113 |
114 | // Check that the cookie header has been modified
115 | cookieHeader := req.Header.Get("Cookie")
116 | if strings.Contains(cookieHeader, "testTracker") {
117 | t.Errorf("testTracker cookie should have been removed")
118 | }
119 | if !strings.Contains(cookieHeader, "otherCookie=keep-me") {
120 | t.Errorf("Other cookies should be preserved")
121 | }
122 | }
123 |
124 | // Benchmark trace ID generation
125 | func BenchmarkGenerateTraceID(b *testing.B) {
126 | for i := 0; i < b.N; i++ {
127 | generateTraceID()
128 | }
129 | }
130 |
--------------------------------------------------------------------------------
/internal/middlewares/minify_test.go:
--------------------------------------------------------------------------------
1 | package middlewares_test
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "testing"
7 |
8 | "github.com/hvuhsg/gatego/internal/middlewares"
9 | )
10 |
11 | // Helper function to create a basic next handler that returns content with a specific content type
12 | func createHandler(contentType, content string) http.Handler {
13 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
14 | w.Header().Set("Content-Type", contentType)
15 | w.WriteHeader(http.StatusOK)
16 | w.Write([]byte(content))
17 | })
18 | }
19 |
20 | // TestMinifyMiddleware_HTML tests HTML minification
21 | func TestMinifyMiddleware_HTML(t *testing.T) {
22 | handler := createHandler("text/html", "
Hello World
")
23 |
24 | config := middlewares.MinifyConfig{HTML: true}
25 | middleware := middlewares.NewMinifyMiddleware(config)
26 | minifiedHandler := middleware(handler)
27 |
28 | req := httptest.NewRequest(http.MethodGet, "/", nil)
29 | rr := httptest.NewRecorder()
30 |
31 | minifiedHandler.ServeHTTP(rr, req)
32 |
33 | expected := "Hello World
"
34 | if rr.Body.String() != expected {
35 | t.Errorf("expected '%s', got '%s'", expected, rr.Body.String())
36 | }
37 | }
38 |
39 | // TestMinifyMiddleware_CSS tests CSS minification
40 | func TestMinifyMiddleware_CSS(t *testing.T) {
41 | handler := createHandler("text/css", "body { color: red; }")
42 |
43 | config := middlewares.MinifyConfig{CSS: true}
44 | middleware := middlewares.NewMinifyMiddleware(config)
45 | minifiedHandler := middleware(handler)
46 |
47 | req := httptest.NewRequest(http.MethodGet, "/", nil)
48 | rr := httptest.NewRecorder()
49 |
50 | minifiedHandler.ServeHTTP(rr, req)
51 |
52 | expected := "body{color:red}"
53 | if rr.Body.String() != expected {
54 | t.Errorf("expected '%s', got '%s'", expected, rr.Body.String())
55 | }
56 | }
57 |
58 | // TestMinifyMiddleware_JS tests JS minification
59 | func TestMinifyMiddleware_JS(t *testing.T) {
60 | handler := createHandler("application/javascript", "function test() { return 1; }")
61 |
62 | config := middlewares.MinifyConfig{JS: true}
63 | middleware := middlewares.NewMinifyMiddleware(config)
64 | minifiedHandler := middleware(handler)
65 |
66 | req := httptest.NewRequest(http.MethodGet, "/", nil)
67 | rr := httptest.NewRecorder()
68 |
69 | minifiedHandler.ServeHTTP(rr, req)
70 |
71 | expected := "function test(){return 1}"
72 | if rr.Body.String() != expected {
73 | t.Errorf("expected '%s', got '%s'", expected, rr.Body.String())
74 | }
75 | }
76 |
77 | // TestMinifyMiddleware_JSON tests JSON minification
78 | func TestMinifyMiddleware_JSON(t *testing.T) {
79 | handler := createHandler("application/json", `{
80 | "name": "John",
81 | "age": 30
82 | }`)
83 |
84 | config := middlewares.MinifyConfig{JSON: true}
85 | middleware := middlewares.NewMinifyMiddleware(config)
86 | minifiedHandler := middleware(handler)
87 |
88 | req := httptest.NewRequest(http.MethodGet, "/", nil)
89 | rr := httptest.NewRecorder()
90 |
91 | minifiedHandler.ServeHTTP(rr, req)
92 |
93 | expected := `{"name":"John","age":30}`
94 | if rr.Body.String() != expected {
95 | t.Errorf("expected '%s', got '%s'", expected, rr.Body.String())
96 | }
97 | }
98 |
99 | // TestMinifyMiddleware_SkipUnsupported tests that unsupported content types are not minified
100 | func TestMinifyMiddleware_SkipUnsupported(t *testing.T) {
101 | handler := createHandler("text/plain", "This is a plain text file.")
102 |
103 | config := middlewares.MinifyConfig{HTML: true, CSS: true, JS: true}
104 | middleware := middlewares.NewMinifyMiddleware(config)
105 | minifiedHandler := middleware(handler)
106 |
107 | req := httptest.NewRequest(http.MethodGet, "/", nil)
108 | rr := httptest.NewRecorder()
109 |
110 | minifiedHandler.ServeHTTP(rr, req)
111 |
112 | expected := "This is a plain text file."
113 | if rr.Body.String() != expected {
114 | t.Errorf("expected '%s', got '%s'", expected, rr.Body.String())
115 | }
116 | }
117 |
--------------------------------------------------------------------------------
/internal/handlers/files_test.go:
--------------------------------------------------------------------------------
1 | package handlers
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "os"
7 | "path/filepath"
8 | "testing"
9 | )
10 |
11 | func TestRemoveBaseURLPath(t *testing.T) {
12 | tests := []struct {
13 | name string
14 | basePath string
15 | fullPath string
16 | want string
17 | wantErr bool
18 | }{
19 | {
20 | name: "simple path",
21 | basePath: "/api",
22 | fullPath: "/api/file.txt",
23 | want: "/file.txt",
24 | wantErr: false,
25 | },
26 | {
27 | name: "path with multiple segments",
28 | basePath: "/api/v1",
29 | fullPath: "/api/v1/docs/file.txt",
30 | want: "/docs/file.txt",
31 | wantErr: false,
32 | },
33 | {
34 | name: "paths with trailing slashes",
35 | basePath: "/api/",
36 | fullPath: "/api/file.txt/",
37 | want: "/file.txt",
38 | wantErr: false,
39 | },
40 | {
41 | name: "paths without leading slashes",
42 | basePath: "api",
43 | fullPath: "api/file.txt",
44 | want: "/file.txt",
45 | wantErr: false,
46 | },
47 | {
48 | name: "path not in base path",
49 | basePath: "/api",
50 | fullPath: "/other/file.txt",
51 | want: "",
52 | wantErr: true,
53 | },
54 | {
55 | name: "empty paths",
56 | basePath: "",
57 | fullPath: "/file.txt",
58 | want: "/file.txt",
59 | wantErr: false,
60 | },
61 | {
62 | name: "identical paths",
63 | basePath: "/api",
64 | fullPath: "/api",
65 | want: "/",
66 | wantErr: false,
67 | },
68 | }
69 |
70 | for _, tt := range tests {
71 | t.Run(tt.name, func(t *testing.T) {
72 | got, err := removeBaseURLPath(tt.basePath, tt.fullPath)
73 | if (err != nil) != tt.wantErr {
74 | t.Errorf("removeBaseURLPath() error = %v, wantErr %v", err, tt.wantErr)
75 | return
76 | }
77 | if got != tt.want {
78 | t.Errorf("removeBaseURLPath() = %v, want %v", got, tt.want)
79 | }
80 | })
81 | }
82 | }
83 |
84 | func TestFiles_ServeHTTP(t *testing.T) {
85 | // Create a temporary directory for test files
86 | tmpDir, err := os.MkdirTemp("", "files_test")
87 | if err != nil {
88 | t.Fatal(err)
89 | }
90 | defer os.RemoveAll(tmpDir)
91 |
92 | // Create a test file
93 | testContent := []byte("test file content")
94 | testFilePath := filepath.Join(tmpDir, "test.txt")
95 | if err := os.WriteFile(testFilePath, testContent, 0644); err != nil {
96 | t.Fatal(err)
97 | }
98 |
99 | tests := []struct {
100 | name string
101 | basePath string
102 | requestPath string
103 | expectedStatus int
104 | expectedBody string
105 | }{
106 | {
107 | name: "valid file request",
108 | basePath: "/files",
109 | requestPath: "/files/test.txt",
110 | expectedStatus: http.StatusOK,
111 | expectedBody: "test file content",
112 | },
113 | {
114 | name: "file not found",
115 | basePath: "/files",
116 | requestPath: "/files/nonexistent.txt",
117 | expectedStatus: http.StatusNotFound,
118 | expectedBody: "404 page not found\n",
119 | },
120 | {
121 | name: "path outside base path",
122 | basePath: "/files",
123 | requestPath: "/other/test.txt",
124 | expectedStatus: http.StatusNotFound,
125 | expectedBody: "404 page not found\n",
126 | },
127 | }
128 |
129 | for _, tt := range tests {
130 | t.Run(tt.name, func(t *testing.T) {
131 | // Create a new Files handler
132 | files := NewFiles(tmpDir, tt.basePath)
133 |
134 | // Create a test request
135 | req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
136 | w := httptest.NewRecorder()
137 |
138 | // Serve the request
139 | files.ServeHTTP(w, req)
140 |
141 | // Check status code
142 | if w.Code != tt.expectedStatus {
143 | t.Errorf("ServeHTTP() status = %v, want %v", w.Code, tt.expectedStatus)
144 | }
145 |
146 | // Check response body
147 | if w.Body.String() != tt.expectedBody {
148 | t.Errorf("ServeHTTP() body = %v, want %v", w.Body.String(), tt.expectedBody)
149 | }
150 | })
151 | }
152 | }
153 |
--------------------------------------------------------------------------------
/handler.go:
--------------------------------------------------------------------------------
1 | package gatego
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "net/http"
7 | "os"
8 | "slices"
9 |
10 | "github.com/hvuhsg/gatego/internal/config"
11 | "github.com/hvuhsg/gatego/internal/handlers"
12 | "github.com/hvuhsg/gatego/internal/middlewares"
13 | "github.com/hvuhsg/gatego/internal/middlewares/security"
14 | )
15 |
16 | var ErrUnsupportedBaseHandler = errors.New("base handler unsupported")
17 |
18 | func GetBaseHandler(service config.Service, path config.Path) (http.Handler, error) {
19 | if path.Destination != nil && *path.Destination != "" {
20 | return handlers.NewProxy(service, path)
21 | } else if path.Directory != nil && *path.Directory != "" {
22 | handler := handlers.NewFiles(*path.Directory, path.Path)
23 | return handler, nil
24 | } else if path.Backend != nil {
25 | return handlers.NewBalancer(service, path)
26 | } else {
27 | // Should not be reached (early validation should prevent it)
28 | return nil, ErrUnsupportedBaseHandler
29 | }
30 | }
31 |
32 | func NewHandler(ctx context.Context, useOtel bool, service config.Service, path config.Path) (http.Handler, error) {
33 | handler, err := GetBaseHandler(service, path)
34 | if err != nil {
35 | return nil, err
36 | }
37 |
38 | handlerWithMiddlewares := middlewares.NewHandlerWithMiddleware(handler)
39 |
40 | handlerWithMiddlewares.Add(middlewares.NewLoggingMiddleware(os.Stdout))
41 |
42 | // Open Telemetry
43 | if useOtel {
44 | otelMiddleware, err := middlewares.NewOpenTelemetryMiddleware(
45 | ctx,
46 | middlewares.OTELConfig{
47 | ServiceDomain: service.Domain,
48 | BasePath: path.Path,
49 | },
50 | )
51 | if err != nil {
52 | return nil, err
53 | }
54 | handlerWithMiddlewares.Add(otelMiddleware)
55 | }
56 |
57 | // Timeout
58 | if path.Timeout == 0 {
59 | path.Timeout = config.DefaultTimeout
60 | }
61 | handlerWithMiddlewares.Add(middlewares.NewTimeoutMiddleware(path.Timeout))
62 |
63 | // Max request size
64 | if path.MaxSize == 0 {
65 | path.MaxSize = config.DefaultMaxRequestSize
66 | }
67 | handlerWithMiddlewares.Add(middlewares.NewRequestSizeLimitMiddleware(path.MaxSize))
68 |
69 | // Rate limits
70 | if len(path.RateLimits) > 0 {
71 | ratelimiter, err := middlewares.NewRateLimitMiddleware(path.RateLimits)
72 | if err != nil {
73 | return nil, err
74 | }
75 | handlerWithMiddlewares.Add(ratelimiter)
76 | }
77 |
78 | // Add anomaly detector
79 | if service.AnomalyDetection != nil {
80 | handlerWithMiddlewares.Add(
81 | security.NewRoutingAnomalyDetector(
82 | service.AnomalyDetection.HeaderName,
83 | service.AnomalyDetection.TresholdForRating,
84 | service.AnomalyDetection.MinScore,
85 | service.AnomalyDetection.MaxScore).AddAnomalyScore,
86 | )
87 | }
88 |
89 | // Add headers
90 | if path.Headers != nil {
91 | handlerWithMiddlewares.Add(middlewares.NewAddHeadersMiddleware(*path.Headers))
92 | }
93 |
94 | // GZIP compression
95 | if path.Gzip != nil && *path.Gzip {
96 | handlerWithMiddlewares.Add(middlewares.GzipMiddleware)
97 | }
98 |
99 | // Remove response headers
100 | if len(path.OmitHeaders) > 0 {
101 | handlerWithMiddlewares.Add(middlewares.NewOmitHeadersMiddleware(path.OmitHeaders))
102 | }
103 |
104 | // Minify files
105 | minifyConfig := middlewares.MinifyConfig{
106 | ALL: slices.Contains(path.Minify, "all"),
107 | JS: slices.Contains(path.Minify, "js"),
108 | HTML: slices.Contains(path.Minify, "html"),
109 | CSS: slices.Contains(path.Minify, "css"),
110 | JSON: slices.Contains(path.Minify, "json"),
111 | SVG: slices.Contains(path.Minify, "svg"),
112 | XML: slices.Contains(path.Minify, "xml"),
113 | }
114 | handlerWithMiddlewares.Add(middlewares.NewMinifyMiddleware(minifyConfig))
115 |
116 | // OpenAPI validation
117 | if path.OpenAPI != nil {
118 | openapiMiddleware, err := middlewares.NewOpenAPIValidationMiddleware(*path.OpenAPI)
119 | if err != nil {
120 | return nil, err
121 | }
122 | handlerWithMiddlewares.Add(openapiMiddleware)
123 | }
124 |
125 | // Response cache
126 | if path.Cache {
127 | handlerWithMiddlewares.Add(middlewares.NewCacheMiddleware())
128 | }
129 |
130 | return handlerWithMiddlewares, nil
131 | }
132 |
--------------------------------------------------------------------------------
/pkg/multimux/multimux_test.go:
--------------------------------------------------------------------------------
1 | package multimux
2 |
3 | import (
4 | "fmt"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 | )
9 |
10 | func TestRegisterHandler(t *testing.T) {
11 | tests := []struct {
12 | name string
13 | host string
14 | pattern string
15 | }{
16 | {"basic registration", "example.com", "/path"},
17 | {"with port", "example.com:8080", "/path"},
18 | {"uppercase host", "EXAMPLE.COM", "/path"},
19 | {"uppercase pattern", "/PATH", "/path"},
20 | {"with subdomain", "sub.example.com", "/path"},
21 | }
22 |
23 | for _, tt := range tests {
24 | t.Run(tt.name, func(t *testing.T) {
25 | mm := NewMultiMux()
26 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
27 | mm.RegisterHandler(tt.host, tt.pattern, handler)
28 |
29 | cleanedHost := cleanHost(tt.host)
30 | mux, exists := mm.Hosts.Load(cleanedHost)
31 | if !exists {
32 | t.Errorf("Host %s was not registered", cleanedHost)
33 | }
34 | if mux == nil {
35 | t.Errorf("ServeMux for host %s is nil", cleanedHost)
36 | }
37 | })
38 | }
39 | }
40 |
41 | func TestServeHTTP(t *testing.T) {
42 | tests := []struct {
43 | name string
44 | host string
45 | path string
46 | expectedStatus int
47 | expectedBody string
48 | }{
49 | {
50 | name: "existing host and path",
51 | host: "example.com",
52 | path: "/test",
53 | expectedStatus: http.StatusOK,
54 | expectedBody: "handler1",
55 | },
56 | {
57 | name: "existing host with port",
58 | host: "example.com:8080",
59 | path: "/test",
60 | expectedStatus: http.StatusOK,
61 | expectedBody: "handler1",
62 | },
63 | {
64 | name: "non-existing host",
65 | host: "unknown.com",
66 | path: "/test",
67 | expectedStatus: http.StatusNotFound,
68 | expectedBody: "",
69 | },
70 | }
71 |
72 | for _, tt := range tests {
73 | t.Run(tt.name, func(t *testing.T) {
74 | mm := NewMultiMux()
75 |
76 | // Register a test handler
77 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
78 | fmt.Fprint(w, "handler1")
79 | })
80 | mm.RegisterHandler("example.com", "/test", handler)
81 |
82 | // Create test request
83 | req := httptest.NewRequest("GET", "http://"+tt.host+tt.path, nil)
84 | req.Host = tt.host
85 | w := httptest.NewRecorder()
86 |
87 | // Serve the request
88 | mm.ServeHTTP(w, req)
89 |
90 | // Check status code
91 | if w.Code != tt.expectedStatus {
92 | t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
93 | }
94 |
95 | // Check response body if expected
96 | if tt.expectedBody != "" && w.Body.String() != tt.expectedBody {
97 | t.Errorf("expected body %q, got %q", tt.expectedBody, w.Body.String())
98 | }
99 | })
100 | }
101 | }
102 |
103 | func TestCleanHost(t *testing.T) {
104 | tests := []struct {
105 | input string
106 | expected string
107 | }{
108 | {"example.com", "example.com"},
109 | {"EXAMPLE.COM", "example.com"},
110 | {"example.com:8080", "example.com"},
111 | {"EXAMPLE.COM:8080", "example.com"},
112 | {"sub.example.com:8080", "sub.example.com"},
113 | {"localhost", "localhost"},
114 | {"localhost:8080", "localhost"},
115 | }
116 |
117 | for _, tt := range tests {
118 | t.Run(tt.input, func(t *testing.T) {
119 | result := cleanHost(tt.input)
120 | if result != tt.expected {
121 | t.Errorf("cleanHost(%q) = %q; want %q", tt.input, result, tt.expected)
122 | }
123 | })
124 | }
125 | }
126 |
127 | func TestRemovePort(t *testing.T) {
128 | tests := []struct {
129 | input string
130 | expected string
131 | }{
132 | {"example.com", "example.com"},
133 | {"example.com:8080", "example.com"},
134 | {"example.com:80", "example.com"},
135 | {"localhost:8080", "localhost"},
136 | {"127.0.0.1:8080", "127.0.0.1"},
137 | {"[::1]:8080", "[::1]"},
138 | }
139 |
140 | for _, tt := range tests {
141 | t.Run(tt.input, func(t *testing.T) {
142 | result := removePort(tt.input)
143 | if result != tt.expected {
144 | t.Errorf("removePort(%q) = %q; want %q", tt.input, result, tt.expected)
145 | }
146 | })
147 | }
148 | }
149 |
--------------------------------------------------------------------------------
/internal/middlewares/ratelimit.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "net/http"
7 | "slices"
8 | "strconv"
9 | "strings"
10 | "sync"
11 | "time"
12 |
13 | "golang.org/x/time/rate"
14 | )
15 |
16 | var SupportedZones = []string{"ip"}
17 | var ErrZoneNotSupported = errors.New("rate limit zone is not supported")
18 |
19 | type RateLimiter struct {
20 | limiters sync.Map
21 | }
22 |
23 | type LimitConfig struct {
24 | Zone string
25 | Requests int
26 | Per time.Duration
27 | }
28 |
29 | func (lc LimitConfig) GetKey(r *http.Request) (key string, err error) {
30 | err = nil
31 | switch lc.Zone {
32 | case "ip":
33 | parts := strings.Split(r.RemoteAddr, ":")
34 | ip := parts[0]
35 | key = "ip:" + ip
36 | default:
37 | err = errors.New("rate limit zone is not supported")
38 | }
39 | key = strconv.Itoa(int(lc.Per.Seconds())) + "|" + strconv.Itoa(lc.Requests) + "!" + key
40 | return
41 | }
42 |
43 | func NewRateLimiter() *RateLimiter {
44 | return &RateLimiter{}
45 | }
46 |
47 | func (rl *RateLimiter) addLimiter(key string, limit rate.Limit, burst int) {
48 | limiter := rate.NewLimiter(limit, burst)
49 | rl.limiters.Store(key, limiter)
50 | }
51 |
52 | func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
53 | if limiter, ok := rl.limiters.Load(key); ok {
54 | return limiter.(*rate.Limiter)
55 | }
56 | return nil
57 | }
58 |
59 | func ParseLimitConfig(config string) (LimitConfig, error) {
60 | parts := strings.Split(config, "-")
61 | if len(parts) != 2 {
62 | return LimitConfig{}, fmt.Errorf("invalid limit config: %s", config)
63 | }
64 | zone := parts[0]
65 | if !slices.Contains(SupportedZones, strings.ToLower(zone)) {
66 | return LimitConfig{}, ErrZoneNotSupported
67 | }
68 |
69 | limitParts := strings.Split(parts[1], "/")
70 | if len(limitParts) != 2 {
71 | return LimitConfig{}, fmt.Errorf("invalid limit config: %s", config)
72 | }
73 |
74 | requests, err := strconv.Atoi(limitParts[0])
75 | if err != nil {
76 | return LimitConfig{}, fmt.Errorf("invalid requests number: %s", limitParts[0])
77 | }
78 |
79 | var duration time.Duration
80 | switch limitParts[1] {
81 | case "s":
82 | duration = time.Second
83 | case "m":
84 | duration = time.Minute
85 | case "h":
86 | duration = time.Hour
87 | case "d":
88 | duration = time.Hour * 24
89 | default:
90 | return LimitConfig{}, fmt.Errorf("invalid time unit: %s", limitParts[1])
91 | }
92 |
93 | return LimitConfig{
94 | Zone: zone,
95 | Requests: requests,
96 | Per: duration,
97 | }, nil
98 | }
99 |
100 | func NewRateLimitMiddleware(limits []string) (func(http.Handler) http.Handler, error) {
101 | rateLimiter := NewRateLimiter()
102 |
103 | // Pre-process ratelimit configs
104 | parsedLimits := make([]LimitConfig, 0, len(limits))
105 | for _, limit := range limits {
106 | parsedLimit, err := ParseLimitConfig(limit)
107 | if err != nil {
108 | return nil, err
109 | }
110 | parsedLimits = append(parsedLimits, parsedLimit)
111 | }
112 |
113 | return func(next http.Handler) http.Handler {
114 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
115 | for _, config := range parsedLimits {
116 | key, err := config.GetKey(r)
117 | if err != nil {
118 | // Should never reach here (validation should prevent it)
119 | http.Error(w, err.Error(), http.StatusInternalServerError)
120 | }
121 |
122 | limiter := rateLimiter.getLimiter(key)
123 | if limiter == nil {
124 | rateLimiter.addLimiter(key, rate.Every(config.Per), config.Requests)
125 | limiter = rateLimiter.getLimiter(key)
126 | }
127 |
128 | if !limiter.Allow() {
129 | setRateLimitHeaders(w, limiter, config)
130 | http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
131 | return
132 | }
133 | }
134 | next.ServeHTTP(w, r)
135 | })
136 | }, nil
137 | }
138 |
139 | func setRateLimitHeaders(w http.ResponseWriter, limiter *rate.Limiter, config LimitConfig) {
140 | now := time.Now()
141 | limit := config.Requests
142 | remaining := int(limiter.Tokens())
143 | reset := now.Add(config.Per).Unix()
144 |
145 | w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", limit))
146 | w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
147 | w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", reset))
148 | }
149 |
--------------------------------------------------------------------------------
/internal/middlewares/gzip_test.go:
--------------------------------------------------------------------------------
1 | package middlewares_test
2 |
3 | import (
4 | "bytes"
5 | "compress/gzip"
6 | "io"
7 | "net/http"
8 | "net/http/httptest"
9 | "testing"
10 |
11 | "github.com/hvuhsg/gatego/internal/middlewares"
12 | )
13 |
14 | // Helper to decode gzip data
15 | func decodeGzip(t *testing.T, gzippedBody []byte) string {
16 | gzipReader, err := gzip.NewReader(bytes.NewReader(gzippedBody))
17 | if err != nil {
18 | t.Fatalf("failed to create gzip reader: %v", err)
19 | }
20 | defer gzipReader.Close()
21 |
22 | var decodedBody bytes.Buffer
23 | if _, err := io.Copy(&decodedBody, gzipReader); err != nil {
24 | t.Fatalf("failed to decode gzip body: %v", err)
25 | }
26 |
27 | return decodedBody.String()
28 | }
29 |
30 | // TestGzipMiddleware_NoGzipSupport tests the middleware when the client does not support gzip
31 | func TestGzipMiddleware_NoGzipSupport(t *testing.T) {
32 | // Create a test handler to be wrapped by the middleware
33 | nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
34 | w.Write([]byte("Hello, World"))
35 | })
36 |
37 | // Wrap the handler with GzipMiddleware
38 | handler := middlewares.GzipMiddleware(nextHandler)
39 |
40 | // Create a new HTTP request without gzip support
41 | req := httptest.NewRequest(http.MethodGet, "/", nil)
42 | req.Header.Set("Accept-Encoding", "identity")
43 |
44 | // Record the response
45 | rr := httptest.NewRecorder()
46 | handler.ServeHTTP(rr, req)
47 |
48 | // Check that the response is not gzipped
49 | if encoding := rr.Header().Get("Content-Encoding"); encoding != "" {
50 | t.Errorf("expected no gzip encoding, got %s", encoding)
51 | }
52 |
53 | // Check the body content
54 | if rr.Body.String() != "Hello, World" {
55 | t.Errorf("expected 'Hello, World', got %s", rr.Body.String())
56 | }
57 | }
58 |
59 | // TestGzipMiddleware_WithGzipSupport tests the middleware when the client supports gzip
60 | func TestGzipMiddleware_WithGzipSupport(t *testing.T) {
61 | // Create a test handler to be wrapped by the middleware
62 | nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
63 | w.WriteHeader(http.StatusOK)
64 | w.Write([]byte("Hello, World"))
65 | })
66 |
67 | // Wrap the handler with GzipMiddleware
68 | handler := middlewares.GzipMiddleware(nextHandler)
69 |
70 | // Create a new HTTP request with gzip support
71 | req := httptest.NewRequest(http.MethodGet, "/", nil)
72 | req.Header.Set("Accept-Encoding", "gzip")
73 |
74 | // Record the response
75 | rr := httptest.NewRecorder()
76 | handler.ServeHTTP(rr, req)
77 |
78 | // Check that the response is gzipped
79 | if encoding := rr.Header().Get("Content-Encoding"); encoding != "gzip" {
80 | t.Errorf("expected gzip encoding, got %s", encoding)
81 | }
82 |
83 | // Decode the gzipped response body
84 | gzippedBody := rr.Body.Bytes()
85 | decodedBody := decodeGzip(t, gzippedBody)
86 |
87 | // Check the body content
88 | if decodedBody != "Hello, World" {
89 | t.Errorf("expected 'Hello, World', got %s", decodedBody)
90 | }
91 | }
92 |
93 | // TestGzipMiddleware_StatusCode tests that the middleware preserves status codes
94 | func TestGzipMiddleware_StatusCode(t *testing.T) {
95 | // Create a test handler to be wrapped by the middleware
96 | nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
97 | w.WriteHeader(http.StatusCreated)
98 | w.Write([]byte("Created"))
99 | })
100 |
101 | // Wrap the handler with GzipMiddleware
102 | handler := middlewares.GzipMiddleware(nextHandler)
103 |
104 | // Create a new HTTP request with gzip support
105 | req := httptest.NewRequest(http.MethodGet, "/", nil)
106 | req.Header.Set("Accept-Encoding", "gzip")
107 |
108 | // Record the response
109 | rr := httptest.NewRecorder()
110 | handler.ServeHTTP(rr, req)
111 |
112 | // Check that the response is gzipped
113 | if encoding := rr.Header().Get("Content-Encoding"); encoding != "gzip" {
114 | t.Errorf("expected gzip encoding, got %s", encoding)
115 | }
116 |
117 | // Check that the status code is preserved
118 | if status := rr.Result().StatusCode; status != http.StatusCreated {
119 | t.Errorf("expected status code %d, got %d", http.StatusCreated, status)
120 | }
121 |
122 | // Decode the gzipped response body
123 | gzippedBody := rr.Body.Bytes()
124 | decodedBody := decodeGzip(t, gzippedBody)
125 |
126 | // Check the body content
127 | if decodedBody != "Created" {
128 | t.Errorf("expected 'Created', got %s", decodedBody)
129 | }
130 | }
131 |
--------------------------------------------------------------------------------
/pkg/cron/cron.go:
--------------------------------------------------------------------------------
1 | // Package cron implements a crontab-like service to execute and schedule
2 | // repeative tasks/jobs.
3 | //
4 | // Example:
5 | //
6 | // c := cron.New()
7 | // c.MustAdd("dailyReport", "0 0 * * *", func() { ... })
8 | // c.Start()
9 | package cron
10 |
11 | import (
12 | "errors"
13 | "fmt"
14 | "sync"
15 | "time"
16 | )
17 |
18 | type job struct {
19 | schedule *Schedule
20 | run func()
21 | }
22 |
23 | // Cron is a crontab-like struct for tasks/jobs scheduling.
24 | type Cron struct {
25 | timezone *time.Location
26 | ticker *time.Ticker
27 | startTimer *time.Timer
28 | jobs map[string]*job
29 | interval time.Duration
30 |
31 | sync.RWMutex
32 | }
33 |
34 | // New create a new Cron struct with default tick interval of 1 minute
35 | // and timezone in UTC.
36 | //
37 | // You can change the default tick interval with Cron.SetInterval().
38 | // You can change the default timezone with Cron.SetTimezone().
39 | func New() *Cron {
40 | return &Cron{
41 | interval: 1 * time.Minute,
42 | timezone: time.UTC,
43 | jobs: map[string]*job{},
44 | }
45 | }
46 |
47 | // SetInterval changes the current cron tick interval
48 | // (it usually should be >= 1 minute).
49 | func (c *Cron) SetInterval(d time.Duration) {
50 | // update interval
51 | c.Lock()
52 | wasStarted := c.ticker != nil
53 | c.interval = d
54 | c.Unlock()
55 |
56 | // restart the ticker
57 | if wasStarted {
58 | c.Start()
59 | }
60 | }
61 |
62 | // SetTimezone changes the current cron tick timezone.
63 | func (c *Cron) SetTimezone(l *time.Location) {
64 | c.Lock()
65 | defer c.Unlock()
66 |
67 | c.timezone = l
68 | }
69 |
70 | // MustAdd is similar to Add() but panic on failure.
71 | func (c *Cron) MustAdd(jobId string, cronExpr string, run func()) {
72 | if err := c.Add(jobId, cronExpr, run); err != nil {
73 | panic(err)
74 | }
75 | }
76 |
77 | // Add registers a single cron job.
78 | //
79 | // If there is already a job with the provided id, then the old job
80 | // will be replaced with the new one.
81 | //
82 | // cronExpr is a regular cron expression, eg. "0 */3 * * *" (aka. at minute 0 past every 3rd hour).
83 | // Check cron.NewSchedule() for the supported tokens.
84 | func (c *Cron) Add(jobId string, cronExpr string, run func()) error {
85 | if run == nil {
86 | return errors.New("failed to add new cron job: run must be non-nil function")
87 | }
88 |
89 | c.Lock()
90 | defer c.Unlock()
91 |
92 | schedule, err := NewSchedule(cronExpr)
93 | if err != nil {
94 | return fmt.Errorf("failed to add new cron job: %w", err)
95 | }
96 |
97 | c.jobs[jobId] = &job{
98 | schedule: schedule,
99 | run: run,
100 | }
101 |
102 | return nil
103 | }
104 |
105 | // Remove removes a single cron job by its id.
106 | func (c *Cron) Remove(jobId string) {
107 | c.Lock()
108 | defer c.Unlock()
109 |
110 | delete(c.jobs, jobId)
111 | }
112 |
113 | // RemoveAll removes all registered cron jobs.
114 | func (c *Cron) RemoveAll() {
115 | c.Lock()
116 | defer c.Unlock()
117 |
118 | c.jobs = map[string]*job{}
119 | }
120 |
121 | // Total returns the current total number of registered cron jobs.
122 | func (c *Cron) Total() int {
123 | c.RLock()
124 | defer c.RUnlock()
125 |
126 | return len(c.jobs)
127 | }
128 |
129 | // Stop stops the current cron ticker (if not already).
130 | //
131 | // You can resume the ticker by calling Start().
132 | func (c *Cron) Stop() {
133 | c.Lock()
134 | defer c.Unlock()
135 |
136 | if c.startTimer != nil {
137 | c.startTimer.Stop()
138 | c.startTimer = nil
139 | }
140 |
141 | if c.ticker == nil {
142 | return // already stopped
143 | }
144 |
145 | c.ticker.Stop()
146 | c.ticker = nil
147 | }
148 |
149 | // Start starts the cron ticker.
150 | //
151 | // Calling Start() on already started cron will restart the ticker.
152 | func (c *Cron) Start() {
153 | c.Stop()
154 |
155 | // delay the ticker to start at 00 of 1 c.interval duration
156 | now := time.Now()
157 | next := now.Add(c.interval).Truncate(c.interval)
158 | delay := next.Sub(now)
159 |
160 | c.Lock()
161 | c.startTimer = time.AfterFunc(delay, func() {
162 | c.Lock()
163 | c.ticker = time.NewTicker(c.interval)
164 | c.Unlock()
165 |
166 | // run immediately at 00
167 | c.runDue(time.Now())
168 |
169 | // run after each tick
170 | go func() {
171 | for t := range c.ticker.C {
172 | c.runDue(t)
173 | }
174 | }()
175 | })
176 | c.Unlock()
177 | }
178 |
179 | // HasStarted checks whether the current Cron ticker has been started.
180 | func (c *Cron) HasStarted() bool {
181 | c.RLock()
182 | defer c.RUnlock()
183 |
184 | return c.ticker != nil
185 | }
186 |
187 | // runDue runs all registered jobs that are scheduled for the provided time.
188 | func (c *Cron) runDue(t time.Time) {
189 | c.RLock()
190 | defer c.RUnlock()
191 |
192 | moment := NewMoment(t.In(c.timezone))
193 |
194 | for _, j := range c.jobs {
195 | if j.schedule.IsDue(moment) {
196 | go j.run()
197 | }
198 | }
199 | }
200 |
--------------------------------------------------------------------------------
/internal/middlewares/ratelimit_test.go:
--------------------------------------------------------------------------------
1 | package middlewares_test
2 |
3 | import (
4 | "errors"
5 | "net/http"
6 | "net/http/httptest"
7 | "sync"
8 | "testing"
9 |
10 | "github.com/hvuhsg/gatego/internal/middlewares"
11 | )
12 |
13 | func TestRateLimitExceeded(t *testing.T) {
14 | limits := []string{"ip-1/s"}
15 | rateLimitMiddleware, err := middlewares.NewRateLimitMiddleware(limits)
16 | if err != nil {
17 | t.Fatalf("Error creating middleware: %v", err)
18 | }
19 |
20 | handler := rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21 | w.WriteHeader(http.StatusOK)
22 | }))
23 |
24 | // Create a test server
25 | req := httptest.NewRequest("GET", "http://example.com", nil)
26 | req.RemoteAddr = "192.168.1.1:12345"
27 | rr := httptest.NewRecorder()
28 |
29 | // First request should pass
30 | handler.ServeHTTP(rr, req)
31 | if rr.Code != http.StatusOK {
32 | t.Errorf("expected status OK, got %v", rr.Code)
33 | }
34 |
35 | // Second request should fail (rate limit exceeded)
36 | rr = httptest.NewRecorder()
37 | handler.ServeHTTP(rr, req)
38 | if rr.Code != http.StatusTooManyRequests {
39 | t.Errorf("expected status TooManyRequests, got %v", rr.Code)
40 | }
41 | }
42 |
43 | func TestRateLimitHeaders(t *testing.T) {
44 | limits := []string{"ip-1/s"}
45 | rateLimitMiddleware, err := middlewares.NewRateLimitMiddleware(limits)
46 | if err != nil {
47 | t.Fatalf("Error creating middleware: %v", err)
48 | }
49 |
50 | handler := rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
51 | w.WriteHeader(http.StatusOK)
52 | }))
53 |
54 | req := httptest.NewRequest("GET", "http://example.com", nil)
55 | req.RemoteAddr = "192.168.1.1:12345"
56 | rr := httptest.NewRecorder()
57 |
58 | // First request should pass
59 | handler.ServeHTTP(rr, req)
60 |
61 | // Second request should fail
62 | handler.ServeHTTP(rr, req)
63 |
64 | if rr.Header().Get("X-RateLimit-Limit") != "1" {
65 | t.Errorf("expected X-RateLimit-Limit 1, got %s", rr.Header().Get("X-RateLimit-Limit"))
66 | }
67 | if rr.Header().Get("X-RateLimit-Remaining") != "0" {
68 | t.Errorf("expected X-RateLimit-Remaining 0, got %s", rr.Header().Get("X-RateLimit-Remaining"))
69 | }
70 | }
71 |
72 | func TestInvalidRateLimitConfig(t *testing.T) {
73 | limits := []string{"ip-invalid/s"}
74 | _, err := middlewares.NewRateLimitMiddleware(limits)
75 | if err == nil {
76 | t.Errorf("expected error for invalid config, got none")
77 | }
78 | }
79 |
80 | func TestUnsupportedZone(t *testing.T) {
81 | limits := []string{"unsupported-10/s"}
82 | _, err := middlewares.NewRateLimitMiddleware(limits)
83 | if !errors.Is(err, middlewares.ErrZoneNotSupported) {
84 | t.Errorf("expected ErrZoneNotSupported, got %v", err)
85 | }
86 | }
87 |
88 | func TestRateLimitConcurrentRequests(t *testing.T) {
89 | limits := []string{"ip-3/s"}
90 | rateLimitMiddleware, err := middlewares.NewRateLimitMiddleware(limits)
91 | if err != nil {
92 | t.Fatalf("Error creating middleware: %v", err)
93 | }
94 |
95 | handler := rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
96 | w.WriteHeader(http.StatusOK)
97 | }))
98 |
99 | req := httptest.NewRequest("GET", "http://example.com", nil)
100 | req.RemoteAddr = "192.168.1.1:12345"
101 |
102 | var wg sync.WaitGroup
103 | var rateLimitedCount int
104 | var mu sync.Mutex
105 |
106 | for i := 0; i < 5; i++ {
107 | wg.Add(1)
108 | go func() {
109 | defer wg.Done()
110 | rr := httptest.NewRecorder()
111 | handler.ServeHTTP(rr, req)
112 |
113 | mu.Lock()
114 | if rr.Code == http.StatusTooManyRequests {
115 | rateLimitedCount++
116 | }
117 | mu.Unlock()
118 | }()
119 | }
120 |
121 | wg.Wait()
122 |
123 | if rateLimitedCount != 2 {
124 | t.Errorf("expected 2 rate limited requests, got %d", rateLimitedCount)
125 | }
126 | }
127 |
128 | func TestRateLimitDifferentTimeWindows(t *testing.T) {
129 | limits := []string{"ip-2/s", "ip-5/m"}
130 | rateLimitMiddleware, err := middlewares.NewRateLimitMiddleware(limits)
131 | if err != nil {
132 | t.Fatalf("Error creating middleware: %v", err)
133 | }
134 |
135 | handler := rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
136 | w.WriteHeader(http.StatusOK)
137 | }))
138 |
139 | req := httptest.NewRequest("GET", "http://example.com", nil)
140 | req.RemoteAddr = "192.168.1.1:12345"
141 | rr := httptest.NewRecorder()
142 |
143 | // First and second request should pass
144 | handler.ServeHTTP(rr, req)
145 | if rr.Code != http.StatusOK {
146 | t.Errorf("expected status OK, got %v", rr.Code)
147 | }
148 |
149 | rr = httptest.NewRecorder()
150 | handler.ServeHTTP(rr, req)
151 | if rr.Code != http.StatusOK {
152 | t.Errorf("expected status OK, got %v", rr.Code)
153 | }
154 |
155 | // Third request should fail due to 1-second window limit
156 | rr = httptest.NewRecorder()
157 | handler.ServeHTTP(rr, req)
158 | if rr.Code != http.StatusTooManyRequests {
159 | t.Errorf("expected status TooManyRequests, got %v", rr.Code)
160 | }
161 | }
162 |
--------------------------------------------------------------------------------
/internal/handlers/balancer.go:
--------------------------------------------------------------------------------
1 | package handlers
2 |
3 | import (
4 | "math"
5 | "math/rand"
6 | "net/http"
7 | "net/http/httputil"
8 | "net/url"
9 | "time"
10 |
11 | "github.com/hvuhsg/gatego/internal/config"
12 | "github.com/hvuhsg/gatego/internal/contextvalues"
13 | semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
14 | )
15 |
16 | type ServerAndWeight struct {
17 | server *httputil.ReverseProxy
18 | weight int
19 | url string
20 | }
21 |
22 | type BalancePolicy interface {
23 | GetNext() *httputil.ReverseProxy
24 | }
25 |
26 | type Balancer struct {
27 | policy BalancePolicy
28 | }
29 |
30 | func NewBalancer(service config.Service, path config.Path) (*Balancer, error) {
31 | serversConfig := path.Backend.Servers
32 |
33 | serversAndWeights := make([]ServerAndWeight, 0, len(serversConfig))
34 | for _, serverConfig := range serversConfig {
35 | serverURL, err := url.Parse(serverConfig.URL)
36 | if err != nil {
37 | return &Balancer{}, err
38 | }
39 |
40 | server := httputil.NewSingleHostReverseProxy(serverURL)
41 |
42 | serverWeight := int(serverConfig.Weight)
43 | if serverWeight < 1 {
44 | serverWeight = 1
45 | }
46 | serversAndWeights = append(serversAndWeights, ServerAndWeight{server: server, weight: serverWeight, url: serverConfig.URL})
47 | }
48 |
49 | var policy BalancePolicy
50 | switch path.Backend.BalancePolicy {
51 | case "round-robin":
52 | policy = NewRoundRobinPolicy(serversAndWeights)
53 | case "random":
54 | policy = NewRandomPolicy(serversAndWeights)
55 | case "least-latency":
56 | policy = NewLeastLatencyPolicy(serversAndWeights)
57 | }
58 |
59 | balancer := Balancer{policy: policy}
60 |
61 | return &balancer, nil
62 | }
63 |
64 | func (b *Balancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
65 | proxy := b.policy.GetNext()
66 |
67 | tracer := contextvalues.TracerFromContext(r.Context())
68 | if tracer != nil {
69 | ctx, span := tracer.Start(r.Context(), "request.upstream")
70 | span.SetAttributes(semconv.HTTPServerAttributesFromHTTPRequest(r.Host, r.URL.Path, r)...)
71 | r = r.WithContext(ctx)
72 | defer span.End()
73 | }
74 |
75 | proxy.ServeHTTP(w, r)
76 | }
77 |
78 | type RoundRobinPolicy struct {
79 | current int
80 | weightsSum int
81 | servers []ServerAndWeight
82 | }
83 |
84 | func NewRoundRobinPolicy(servers []ServerAndWeight) *RoundRobinPolicy {
85 | weightsSum := 0
86 | for _, server := range servers {
87 | weightsSum += server.weight
88 | }
89 |
90 | policy := &RoundRobinPolicy{current: 0, weightsSum: weightsSum, servers: servers}
91 | return policy
92 | }
93 |
94 | // The servers provided must be provided in the same order for accurate results
95 | func (rrp *RoundRobinPolicy) GetNext() *httputil.ReverseProxy {
96 | serverIndex := rrp.current
97 |
98 | for _, server := range rrp.servers {
99 | serverIndex -= server.weight
100 | if serverIndex < 0 {
101 | rrp.current += 1
102 | return server.server
103 | }
104 | }
105 |
106 | rrp.current = (rrp.current % rrp.weightsSum) + 1
107 | return rrp.servers[0].server
108 | }
109 |
110 | type RandomPolicy struct {
111 | weightsSum int
112 | servers []ServerAndWeight
113 | }
114 |
115 | func NewRandomPolicy(servers []ServerAndWeight) *RandomPolicy {
116 | weightsSum := 0
117 | for _, server := range servers {
118 | weightsSum += server.weight
119 | }
120 |
121 | return &RandomPolicy{weightsSum: weightsSum, servers: servers}
122 | }
123 |
124 | func (rp *RandomPolicy) GetNext() *httputil.ReverseProxy {
125 | randomServerIndex := rand.Intn(rp.weightsSum)
126 |
127 | for _, server := range rp.servers {
128 | randomServerIndex -= server.weight
129 | if randomServerIndex <= 0 {
130 | return server.server
131 | }
132 | }
133 |
134 | return rp.servers[0].server
135 | }
136 |
137 | type LeastLatencyPolicy struct {
138 | serversLatency map[string]int64
139 | servers []ServerAndWeight
140 | }
141 |
142 | func NewLeastLatencyPolicy(serversAndURLs []ServerAndWeight) *LeastLatencyPolicy {
143 | serversLatency := make(map[string]int64, len(serversAndURLs))
144 |
145 | for _, serverAndWeight := range serversAndURLs {
146 | serversLatency[serverAndWeight.url] = 0
147 | }
148 |
149 | return &LeastLatencyPolicy{servers: serversAndURLs, serversLatency: serversLatency}
150 | }
151 |
152 | func (llp *LeastLatencyPolicy) GetNext() *httputil.ReverseProxy {
153 |
154 | bestServerURL := llp.servers[0].url
155 | var bestLatency int64 = math.MaxInt64
156 |
157 | for url, latency := range llp.serversLatency {
158 | if latency < bestLatency {
159 | bestServerURL = url
160 | bestLatency = latency
161 | }
162 | }
163 |
164 | var chosenServer ServerAndWeight
165 | for _, server := range llp.servers {
166 | if server.url == bestServerURL {
167 | chosenServer = server
168 | break
169 | }
170 | }
171 |
172 | // TODO: use decaing latency for extream latency conditions
173 |
174 | startTime := time.Now().UnixMicro()
175 | chosenServer.server.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
176 | llp.serversLatency[chosenServer.url] = time.Now().UnixMicro() - startTime
177 | }
178 | chosenServer.server.ModifyResponse = func(r *http.Response) error {
179 | llp.serversLatency[chosenServer.url] = time.Now().UnixMicro() - startTime
180 | return nil
181 | }
182 |
183 | return chosenServer.server
184 | }
185 |
--------------------------------------------------------------------------------
/otel.go:
--------------------------------------------------------------------------------
1 | package gatego
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "time"
7 |
8 | "github.com/hvuhsg/gatego/internal/contextvalues"
9 | "go.opentelemetry.io/otel"
10 | "go.opentelemetry.io/otel/attribute"
11 | "go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc"
12 | "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
13 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace"
14 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
15 | "go.opentelemetry.io/otel/log/global"
16 | "go.opentelemetry.io/otel/propagation"
17 | "go.opentelemetry.io/otel/sdk/log"
18 | "go.opentelemetry.io/otel/sdk/metric"
19 | "go.opentelemetry.io/otel/sdk/resource"
20 | "go.opentelemetry.io/otel/sdk/trace"
21 | semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
22 | )
23 |
24 | type otelConfig struct {
25 | TraceCollectorEndpoint string
26 | MetricCollectorEndpoint string
27 | LogsCollectorEndpoint string
28 | CollectorTimeout time.Duration
29 | ServiceName string
30 | SampleRatio float64
31 | }
32 |
33 | // setupOTelSDK bootstraps the OpenTelemetry pipeline.
34 | // If it does not return an error, make sure to call shutdown for proper cleanup.
35 | func setupOTelSDK(ctx context.Context, conf otelConfig) (func(context.Context) error, error) {
36 | var shutdownFuncs []func(context.Context) error
37 |
38 | // shutdown calls cleanup functions registered via shutdownFuncs.
39 | // The errors from the calls are joined.
40 | // Each registered cleanup will be invoked once.
41 | shutdown := func(ctx context.Context) error {
42 | var err error
43 | for _, fn := range shutdownFuncs {
44 | err = errors.Join(err, fn(ctx))
45 | }
46 | shutdownFuncs = nil
47 | return err
48 | }
49 |
50 | // handleErr calls shutdown for cleanup and makes sure that all errors are returned.
51 | handleErr := func(inErr error) (func(context.Context) error, error) {
52 | err := errors.Join(inErr, shutdown(ctx))
53 | return nil, err
54 | }
55 |
56 | // Set up propagator.
57 | prop := newPropagator()
58 | otel.SetTextMapPropagator(prop)
59 |
60 | // Set up resource
61 | resource := resource.NewWithAttributes(
62 | semconv.SchemaURL,
63 | semconv.ServiceNameKey.String(conf.ServiceName),
64 | semconv.TelemetrySDKLanguageGo,
65 | attribute.String("version", contextvalues.VersionFromContext(ctx)),
66 | )
67 |
68 | // Set up trace provider.
69 | tracerProvider, err := newTraceProvider(ctx, resource, conf.TraceCollectorEndpoint, conf.CollectorTimeout, conf.SampleRatio)
70 | if err != nil {
71 | return handleErr(err)
72 | }
73 | shutdownFuncs = append(shutdownFuncs, tracerProvider.Shutdown)
74 | otel.SetTracerProvider(tracerProvider)
75 |
76 | // Set up meter provider.
77 | meterProvider, err := newMeterProvider(ctx, resource, conf.TraceCollectorEndpoint, conf.CollectorTimeout)
78 | if err != nil {
79 | return handleErr(err)
80 | }
81 | shutdownFuncs = append(shutdownFuncs, meterProvider.Shutdown)
82 | otel.SetMeterProvider(meterProvider)
83 |
84 | // Set up logger provider.
85 | loggerProvider, err := newLoggerProvider(ctx, resource, conf.TraceCollectorEndpoint, conf.CollectorTimeout)
86 | if err != nil {
87 | return handleErr(err)
88 | }
89 | shutdownFuncs = append(shutdownFuncs, loggerProvider.Shutdown)
90 | global.SetLoggerProvider(loggerProvider)
91 |
92 | return shutdown, err
93 | }
94 |
95 | func newPropagator() propagation.TextMapPropagator {
96 | return propagation.NewCompositeTextMapPropagator(
97 | propagation.TraceContext{},
98 | propagation.Baggage{},
99 | )
100 | }
101 |
102 | func newTraceProvider(ctx context.Context, resource *resource.Resource, endpoint string, timeout time.Duration, sampleRatio float64) (*trace.TracerProvider, error) {
103 | exporter, err := otlptrace.New(
104 | ctx,
105 | otlptracegrpc.NewClient(
106 | otlptracegrpc.WithEndpoint(endpoint), // OTLP gRPC endpoint
107 | otlptracegrpc.WithTimeout(timeout),
108 | otlptracegrpc.WithInsecure(),
109 | ),
110 | )
111 | if err != nil {
112 | return nil, err
113 | }
114 |
115 | traceProvider := trace.NewTracerProvider(
116 | trace.WithResource(resource),
117 | trace.WithBatcher(exporter),
118 | trace.WithSampler(trace.TraceIDRatioBased(sampleRatio)),
119 | )
120 | return traceProvider, nil
121 | }
122 |
123 | func newMeterProvider(ctx context.Context, resource *resource.Resource, endpoint string, timeout time.Duration) (*metric.MeterProvider, error) {
124 | exporter, err := otlpmetricgrpc.New(
125 | ctx,
126 | otlpmetricgrpc.WithEndpoint(endpoint), // OTLP gRPC endpoint
127 | otlpmetricgrpc.WithTimeout(timeout),
128 | otlpmetricgrpc.WithInsecure(),
129 | )
130 | if err != nil {
131 | return nil, err
132 | }
133 |
134 | meterProvider := metric.NewMeterProvider(
135 | metric.WithResource(resource),
136 | metric.WithReader(
137 | metric.NewPeriodicReader(
138 | exporter,
139 | ),
140 | ),
141 | )
142 |
143 | return meterProvider, nil
144 | }
145 |
146 | func newLoggerProvider(ctx context.Context, resource *resource.Resource, endpoint string, timeout time.Duration) (*log.LoggerProvider, error) {
147 | exporter, err := otlploggrpc.New(
148 | ctx,
149 | otlploggrpc.WithEndpoint(endpoint), // OTLP gRPC endpoint
150 | otlploggrpc.WithTimeout(timeout),
151 | otlploggrpc.WithInsecure(),
152 | )
153 | if err != nil {
154 | return nil, err
155 | }
156 |
157 | loggerProvider := log.NewLoggerProvider(
158 | log.WithResource(resource),
159 | log.WithProcessor(log.NewBatchProcessor(exporter)),
160 | )
161 | return loggerProvider, nil
162 | }
163 |
--------------------------------------------------------------------------------
/internal/middlewares/security/routing_anomaly_score.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | import (
4 | "math"
5 | "net/http"
6 | "net/url"
7 | "strconv"
8 | "sync"
9 |
10 | "github.com/hvuhsg/gatego/pkg/pathgraph"
11 | "github.com/hvuhsg/gatego/pkg/tracker"
12 | "go.opentelemetry.io/otel/attribute"
13 | "go.opentelemetry.io/otel/trace"
14 | )
15 |
16 | const (
17 | tracingCookieName = "sad-trc"
18 | cookieMaxAge = 24 * 60 * 60 // 24 hours in seconds
19 | refererHeaderName = "Referer"
20 | )
21 |
22 | // RoutingAnomalyDetector handles path tracking logic and manages user sessions
23 | type RoutingAnomalyDetector struct {
24 | graph *pathgraph.PathGraph
25 | numberOfJumps int
26 | scoreSum float64
27 | avgDiviation float64
28 | lastPaths sync.Map // Maps trace_id to last path
29 | trackerRoutingHistory sync.Map
30 | tracker tracker.Tracker
31 |
32 | tresholdForRating int // The number of requests before starting to calculate anomaly score
33 | minScore int // If the diviation form the avg diviation is lower then this then the session is not suspicuse
34 | maxScore int // If the diviation form the avg diviation is larger then this then the session is fully suspicuse
35 | anomalyHeaderName string
36 | }
37 |
38 | func NewRoutingAnomalyDetector(headerName string, tresholdForRating, minScore, maxScore int) *RoutingAnomalyDetector {
39 | return &RoutingAnomalyDetector{
40 | graph: pathgraph.NewPathGraph(),
41 | tracker: tracker.NewCookieTracker(tracingCookieName, cookieMaxAge, false),
42 | anomalyHeaderName: headerName,
43 | minScore: minScore,
44 | maxScore: maxScore,
45 | tresholdForRating: tresholdForRating,
46 | }
47 | }
48 |
49 | // NewPathTracker creates a new PathTracker instance
50 | func NewPathTracker(graph *pathgraph.PathGraph) *RoutingAnomalyDetector {
51 | return &RoutingAnomalyDetector{
52 | graph: graph,
53 | lastPaths: sync.Map{},
54 | }
55 | }
56 |
57 | // Claculate anomaly score based on global avg routing and tracker routing
58 | // This middleware uses a graph to represent every path called by users
59 | // Eeach source, destination path has a vertex with the score of how many requests jumpt it,
60 | // We save tracker (session) jumps history and calculate an anomaly score, and add it as header to the request.
61 | func (pt *RoutingAnomalyDetector) AddAnomalyScore(next http.Handler) http.Handler {
62 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
63 | span := trace.SpanFromContext(r.Context())
64 |
65 | // Get or create trace ID
66 | traceID := pt.tracker.GetTrackerID(r)
67 |
68 | if traceID != "" {
69 | // We do not want the tracker to be sent to the downstream server
70 | pt.tracker.RemoveTracker(r)
71 | } else { // Create new tracker if not found
72 | var err error
73 | traceID, err = pt.tracker.SetTracker(w)
74 | if err != nil {
75 | // Log error but continue serving
76 | next.ServeHTTP(w, r)
77 | return
78 | }
79 |
80 | // Create tracker history
81 | trackerH := &trackerHistory{jumpsCount: 0, jumpsScoreSum: 0}
82 | pt.trackerRoutingHistory.Store(traceID, trackerH)
83 | }
84 |
85 | currentPath := r.URL.Path
86 |
87 | // Get last path for this trace ID
88 | lastPath, exists := pt.getLastPath(traceID, r)
89 | if !exists {
90 | lastPath = "" // empty path means the user has entered the site for the first time
91 | }
92 |
93 | jumpScore := pt.graph.AddJump(lastPath, currentPath)
94 | value, ok := pt.trackerRoutingHistory.Load(traceID)
95 |
96 | var trackerH *trackerHistory
97 | if ok {
98 | trackerH = value.(*trackerHistory)
99 | }
100 |
101 | // update tracker history with jump score
102 | trackerH.jumpsCount++
103 | trackerH.jumpsScoreSum += jumpScore
104 |
105 | // update global stats
106 | pt.numberOfJumps++
107 | pt.scoreSum += jumpScore
108 |
109 | pt.lastPaths.Store(traceID, currentPath)
110 |
111 | anomalyScore := pt.calcAnomalyRating(trackerH)
112 | span.SetAttributes(attribute.Float64("RoutingAnomalyScore", anomalyScore))
113 |
114 | r.Header.Set(pt.anomalyHeaderName, strconv.FormatFloat(anomalyScore, 'f', 2, 64))
115 |
116 | // Call the next handler
117 | next.ServeHTTP(w, r)
118 | })
119 | }
120 |
121 | // GetLastPath retrieves the last path for a given trace ID from storage or referer header (in this order)
122 | func (pt *RoutingAnomalyDetector) getLastPath(traceID string, r *http.Request) (string, bool) {
123 | path, exists := pt.lastPaths.Load(traceID)
124 |
125 | if !exists {
126 | u := r.Header.Get(refererHeaderName)
127 | url, err := url.Parse(u)
128 | if err == nil {
129 | path = url.Path
130 | }
131 | }
132 |
133 | return path.(string), exists
134 | }
135 |
136 | // 0 - is fully normal, 1 - fully suspicuse
137 | func (pt *RoutingAnomalyDetector) calcAnomalyRating(trackerH *trackerHistory) float64 {
138 | avgGlobalScore := (pt.scoreSum / float64(pt.numberOfJumps)) * 2
139 | avgTrackerScore := trackerH.Avg()
140 |
141 | diviation := math.Abs(avgGlobalScore - avgTrackerScore)
142 |
143 | // If avg diviation is 0 it will return +Inf and get the correct result
144 | anomalyScore := (diviation / (pt.avgDiviation / 100))
145 |
146 | // Update avgDiviation with new diviation
147 | pt.avgDiviation = ((pt.avgDiviation * float64(pt.numberOfJumps)) + diviation) / float64(pt.numberOfJumps)
148 |
149 | // Only return 0 until useage data is collected
150 | if pt.numberOfJumps < pt.tresholdForRating {
151 | return 0
152 | }
153 |
154 | if anomalyScore < float64(pt.minScore) {
155 | return 0
156 | }
157 |
158 | if anomalyScore > float64(pt.maxScore) {
159 | return 1
160 | }
161 |
162 | return (anomalyScore - float64(pt.minScore)) / 100
163 | }
164 |
--------------------------------------------------------------------------------
/pkg/cron/schedule.go:
--------------------------------------------------------------------------------
1 | package cron
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "strconv"
7 | "strings"
8 | "time"
9 | )
10 |
11 | // Moment represents a parsed single time moment.
12 | type Moment struct {
13 | Minute int `json:"minute"`
14 | Hour int `json:"hour"`
15 | Day int `json:"day"`
16 | Month int `json:"month"`
17 | DayOfWeek int `json:"dayOfWeek"`
18 | }
19 |
20 | // NewMoment creates a new Moment from the specified time.
21 | func NewMoment(t time.Time) *Moment {
22 | return &Moment{
23 | Minute: t.Minute(),
24 | Hour: t.Hour(),
25 | Day: t.Day(),
26 | Month: int(t.Month()),
27 | DayOfWeek: int(t.Weekday()),
28 | }
29 | }
30 |
31 | // Schedule stores parsed information for each time component when a cron job should run.
32 | type Schedule struct {
33 | Minutes map[int]struct{} `json:"minutes"`
34 | Hours map[int]struct{} `json:"hours"`
35 | Days map[int]struct{} `json:"days"`
36 | Months map[int]struct{} `json:"months"`
37 | DaysOfWeek map[int]struct{} `json:"daysOfWeek"`
38 | }
39 |
40 | // IsDue checks whether the provided Moment satisfies the current Schedule.
41 | func (s *Schedule) IsDue(m *Moment) bool {
42 | if _, ok := s.Minutes[m.Minute]; !ok {
43 | return false
44 | }
45 |
46 | if _, ok := s.Hours[m.Hour]; !ok {
47 | return false
48 | }
49 |
50 | if _, ok := s.Days[m.Day]; !ok {
51 | return false
52 | }
53 |
54 | if _, ok := s.DaysOfWeek[m.DayOfWeek]; !ok {
55 | return false
56 | }
57 |
58 | if _, ok := s.Months[m.Month]; !ok {
59 | return false
60 | }
61 |
62 | return true
63 | }
64 |
65 | // NewSchedule creates a new Schedule from a cron expression.
66 | //
67 | // A cron expression could be a macro OR 5 segments separated by space,
68 | // representing: minute, hour, day of the month, month and day of the week.
69 | //
70 | // The following segment formats are supported:
71 | // - wildcard: *
72 | // - range: 1-30
73 | // - step: */n or 1-30/n
74 | // - list: 1,2,3,10-20/n
75 | //
76 | // The following macros are supported:
77 | // - @yearly (or @annually)
78 | // - @monthly
79 | // - @weekly
80 | // - @daily (or @midnight)
81 | // - @hourly
82 | // - @minutely
83 | func NewSchedule(cronExpr string) (*Schedule, error) {
84 | if v, ok := macros[cronExpr]; ok {
85 | cronExpr = v
86 | }
87 |
88 | segments := strings.Split(cronExpr, " ")
89 | if len(segments) != 5 {
90 | return nil, errors.New("invalid cron expression - must be a valid macro or to have exactly 5 space separated segments")
91 | }
92 |
93 | minutes, err := parseCronSegment(segments[0], 0, 59)
94 | if err != nil {
95 | return nil, err
96 | }
97 |
98 | hours, err := parseCronSegment(segments[1], 0, 23)
99 | if err != nil {
100 | return nil, err
101 | }
102 |
103 | days, err := parseCronSegment(segments[2], 1, 31)
104 | if err != nil {
105 | return nil, err
106 | }
107 |
108 | months, err := parseCronSegment(segments[3], 1, 12)
109 | if err != nil {
110 | return nil, err
111 | }
112 |
113 | daysOfWeek, err := parseCronSegment(segments[4], 0, 6)
114 | if err != nil {
115 | return nil, err
116 | }
117 |
118 | return &Schedule{
119 | Minutes: minutes,
120 | Hours: hours,
121 | Days: days,
122 | Months: months,
123 | DaysOfWeek: daysOfWeek,
124 | }, nil
125 | }
126 |
127 | // parseCronSegment parses a single cron expression segment and
128 | // returns its time schedule slots.
129 | func parseCronSegment(segment string, min int, max int) (map[int]struct{}, error) {
130 | slots := map[int]struct{}{}
131 |
132 | list := strings.Split(segment, ",")
133 | for _, p := range list {
134 | stepParts := strings.Split(p, "/")
135 |
136 | // step (*/n, 1-30/n)
137 | var step int
138 | switch len(stepParts) {
139 | case 1:
140 | step = 1
141 | case 2:
142 | parsedStep, err := strconv.Atoi(stepParts[1])
143 | if err != nil {
144 | return nil, err
145 | }
146 | if parsedStep < 1 || parsedStep > max {
147 | return nil, fmt.Errorf("invalid segment step boundary - the step must be between 1 and the %d", max)
148 | }
149 | step = parsedStep
150 | default:
151 | return nil, errors.New("invalid segment step format - must be in the format */n or 1-30/n")
152 | }
153 |
154 | // find the min and max range of the segment part
155 | var rangeMin, rangeMax int
156 | if stepParts[0] == "*" {
157 | rangeMin = min
158 | rangeMax = max
159 | } else {
160 | // single digit (1) or range (1-30)
161 | rangeParts := strings.Split(stepParts[0], "-")
162 | switch len(rangeParts) {
163 | case 1:
164 | if step != 1 {
165 | return nil, errors.New("invalid segement step - step > 1 could be used only with the wildcard or range format")
166 | }
167 | parsed, err := strconv.Atoi(rangeParts[0])
168 | if err != nil {
169 | return nil, err
170 | }
171 | if parsed < min || parsed > max {
172 | return nil, errors.New("invalid segment value - must be between the min and max of the segment")
173 | }
174 | rangeMin = parsed
175 | rangeMax = rangeMin
176 | case 2:
177 | parsedMin, err := strconv.Atoi(rangeParts[0])
178 | if err != nil {
179 | return nil, err
180 | }
181 | if parsedMin < min || parsedMin > max {
182 | return nil, fmt.Errorf("invalid segment range minimum - must be between %d and %d", min, max)
183 | }
184 | rangeMin = parsedMin
185 |
186 | parsedMax, err := strconv.Atoi(rangeParts[1])
187 | if err != nil {
188 | return nil, err
189 | }
190 | if parsedMax < parsedMin || parsedMax > max {
191 | return nil, fmt.Errorf("invalid segment range maximum - must be between %d and %d", rangeMin, max)
192 | }
193 | rangeMax = parsedMax
194 | default:
195 | return nil, errors.New("invalid segment range format - the range must have 1 or 2 parts")
196 | }
197 | }
198 |
199 | // fill the slots
200 | for i := rangeMin; i <= rangeMax; i += step {
201 | slots[i] = struct{}{}
202 | }
203 | }
204 |
205 | return slots, nil
206 | }
207 |
--------------------------------------------------------------------------------
/internal/handlers/balancer_test.go:
--------------------------------------------------------------------------------
1 | package handlers
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "net/http/httputil"
7 | "net/url"
8 | "strings"
9 | "testing"
10 | "time"
11 |
12 | "github.com/hvuhsg/gatego/internal/config"
13 | )
14 |
15 | func TestNewBalancer(t *testing.T) {
16 | service := config.Service{}
17 | path := config.Path{
18 | Backend: &config.Backend{
19 | BalancePolicy: "round-robin",
20 | Servers: []struct {
21 | URL string "yaml:\"url\""
22 | Weight uint "yaml:\"weight\""
23 | }{
24 | {URL: "http://localhost:8001", Weight: 1},
25 | {URL: "http://localhost:8002", Weight: 2},
26 | },
27 | },
28 | }
29 |
30 | balancer, err := NewBalancer(service, path)
31 | if err != nil {
32 | t.Fatalf("Failed to create balancer: %v", err)
33 | }
34 |
35 | if balancer == nil {
36 | t.Fatal("Balancer is nil")
37 | }
38 | }
39 |
40 | func TestRoundRobinPolicy(t *testing.T) {
41 | servers := []ServerAndWeight{
42 | {server: createDummyProxy("http://localhost:8001/"), weight: 1, url: "http://localhost:8001/"},
43 | {server: createDummyProxy("http://localhost:8002/"), weight: 1, url: "http://localhost:8002/"},
44 | }
45 |
46 | policy := NewRoundRobinPolicy(servers)
47 |
48 | // Test the round-robin behavior
49 | expectedOrder := []string{"http://localhost:8001/", "http://localhost:8002/", "http://localhost:8001/", "http://localhost:8002/"}
50 | for i, expected := range expectedOrder {
51 | server := policy.GetNext()
52 | if server.Director == nil {
53 | t.Fatalf("Server %d is nil", i)
54 | }
55 | serverURL := getProxyURL(server)
56 | if serverURL != expected {
57 | t.Errorf("index = %d Expected server %s, got %s", i, expected, serverURL)
58 | }
59 | }
60 | }
61 |
62 | func TestRandomPolicy(t *testing.T) {
63 | servers := []ServerAndWeight{
64 | {server: createDummyProxy("http://localhost:8001"), weight: 1, url: "http://localhost:8001"},
65 | {server: createDummyProxy("http://localhost:8002"), weight: 1, url: "http://localhost:8002"},
66 | }
67 |
68 | policy := NewRandomPolicy(servers)
69 |
70 | // Test that we get a valid server (we can't test randomness easily)
71 | for i := 0; i < 10; i++ {
72 | server := policy.GetNext()
73 | if server == nil {
74 | t.Fatal("Got nil server from RandomPolicy")
75 | }
76 | }
77 | }
78 |
79 | func TestLeastLatencyPolicy(t *testing.T) {
80 | // Create mock servers
81 | server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
82 | time.Sleep(20 * time.Millisecond)
83 | w.WriteHeader(http.StatusOK)
84 | w.Write([]byte("Slow response from server 1"))
85 | }))
86 | defer server1.Close()
87 |
88 | server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
89 | w.WriteHeader(http.StatusOK)
90 | w.Write([]byte("Fast response from server 2"))
91 | }))
92 | defer server2.Close()
93 |
94 | servers := []ServerAndWeight{
95 | {server: httputil.NewSingleHostReverseProxy(mustParseURL(server1.URL)), weight: 1, url: server1.URL},
96 | {server: httputil.NewSingleHostReverseProxy(mustParseURL(server2.URL)), weight: 1, url: server2.URL},
97 | }
98 |
99 | policy := NewLeastLatencyPolicy(servers)
100 |
101 | // Initially, all servers should have 0 latency
102 | server := policy.GetNext()
103 | if server == nil {
104 | t.Fatal("Got nil server from LeastLatencyPolicy")
105 | }
106 |
107 | // Simulate a request and update latency
108 | w := httptest.NewRecorder()
109 | r, _ := http.NewRequest("GET", server1.URL, nil)
110 | server.ServeHTTP(w, r)
111 |
112 | // The policy should now prefer the fast second server
113 | server = policy.GetNext()
114 | serverURL := strings.TrimSuffix(getProxyURL(server), "/")
115 | if serverURL != strings.TrimSuffix(server2.URL, "/") {
116 | t.Errorf("LeastLatencyPolicy did not choose the server with least latency Got %s Want %s", serverURL, server2.URL)
117 | }
118 | }
119 |
120 | func TestBalancerServeHTTP(t *testing.T) {
121 | // Create mock servers
122 | server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
123 | w.WriteHeader(http.StatusOK)
124 | w.Write([]byte("Response from server 1"))
125 | }))
126 | defer server1.Close()
127 |
128 | server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
129 | w.WriteHeader(http.StatusOK)
130 | w.Write([]byte("Response from server 2"))
131 | }))
132 | defer server2.Close()
133 |
134 | // Create ServerAndWeight structs using the mock servers
135 | servers := []ServerAndWeight{
136 | {server: httputil.NewSingleHostReverseProxy(mustParseURL(server1.URL)), weight: 1, url: server1.URL},
137 | {server: httputil.NewSingleHostReverseProxy(mustParseURL(server2.URL)), weight: 1, url: server2.URL},
138 | }
139 |
140 | policy := NewRoundRobinPolicy(servers)
141 | balancer := &Balancer{policy: policy}
142 |
143 | w := httptest.NewRecorder()
144 | r, _ := http.NewRequest("GET", "http://example.com", nil)
145 |
146 | balancer.ServeHTTP(w, r)
147 |
148 | if w.Code != http.StatusOK {
149 | t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
150 | }
151 |
152 | w = httptest.NewRecorder()
153 | r, _ = http.NewRequest("GET", "http://example.com", nil)
154 |
155 | balancer.ServeHTTP(w, r)
156 |
157 | if w.Code != http.StatusOK {
158 | t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
159 | }
160 | }
161 |
162 | // Helper function to create a dummy reverse proxy
163 | func createDummyProxy(targetURL string) *httputil.ReverseProxy {
164 | url, _ := url.Parse(targetURL)
165 | return httputil.NewSingleHostReverseProxy(url)
166 | }
167 |
168 | // Helper function to get the target URL of a reverse proxy
169 | func getProxyURL(proxy *httputil.ReverseProxy) string {
170 | req, _ := http.NewRequest("GET", "http://example.com", nil)
171 | proxy.Director(req)
172 | return req.URL.String()
173 | }
174 |
175 | // Helper function to parse URL and panic on error
176 | func mustParseURL(rawURL string) *url.URL {
177 | u, err := url.Parse(rawURL)
178 | if err != nil {
179 | panic(err)
180 | }
181 | return u
182 | }
183 |
--------------------------------------------------------------------------------
/internal/config/config_test.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "os"
5 | "path/filepath"
6 | "testing"
7 | )
8 |
9 | func TestPathValidate(t *testing.T) {
10 | tests := []struct {
11 | name string
12 | path Path
13 | wantErr bool
14 | }{
15 | {"Valid path with destination", Path{Path: "/api", Destination: ptr("http://example.com")}, false},
16 | {"Valid path with directory", Path{Path: "/static", Directory: ptr("/var")}, false},
17 | {"Invalid path without leading slash", Path{Path: "api", Destination: ptr("http://example.com")}, true},
18 | {"Invalid destination URL", Path{Path: "/api", Destination: ptr("not-a-url")}, true},
19 | {"Invalid with both destination and directory", Path{Path: "/both", Destination: ptr("http://example.com"), Directory: ptr("/var/www")}, true},
20 | {"Invalid with neither destination nor directory", Path{Path: "/empty"}, true},
21 | }
22 |
23 | for _, tt := range tests {
24 | t.Run(tt.name, func(t *testing.T) {
25 | err := tt.path.validate()
26 | if (err != nil) != tt.wantErr {
27 | t.Errorf("Path.validate() error = %v, wantErr %v", err, tt.wantErr)
28 | }
29 | })
30 | }
31 | }
32 |
33 | func TestServiceValidate(t *testing.T) {
34 | tests := []struct {
35 | name string
36 | service Service
37 | wantErr bool
38 | }{
39 | {"Valid service", Service{Domain: "example.com", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}, false},
40 | {"Invalid domain", Service{Domain: "not a domain", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}, true},
41 | {"Invalid path", Service{Domain: "example.com", Paths: []Path{{Path: "invalid", Destination: ptr("http://api.example.com")}}}, true},
42 | }
43 |
44 | for _, tt := range tests {
45 | t.Run(tt.name, func(t *testing.T) {
46 | err := tt.service.validate()
47 | if (err != nil) != tt.wantErr {
48 | t.Errorf("Service.validate() service = %v, error = %v, wantErr %v", err, tt.service, tt.wantErr)
49 | }
50 | })
51 | }
52 | }
53 |
54 | func TestConfigValidate(t *testing.T) {
55 | tests := []struct {
56 | name string
57 | config Config
58 | currentVersion string
59 | wantErr bool
60 | }{
61 | {"Valid config", Config{Version: "1.0.0", Host: "localhost", Port: 80, Services: []Service{{Domain: "example.com", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}}}, "1.0.0", false},
62 | {"AutoTLS with port != 443", Config{Version: "1.0.0", Host: "localhost", Port: 80, TLS: TLS{Auto: true, Domains: []string{"example.com"}}, Services: []Service{{Domain: "example.com", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}}}, "1.0.0", true},
63 | {"Missing version", Config{Host: "localhost"}, "1.0.0", true},
64 | {"Invalid version", Config{Version: "invalid", Host: "localhost"}, "1.0.0", true},
65 | {"Future version", Config{Version: "2.0.0", Host: "localhost"}, "1.0.0", true},
66 | {"Missing host", Config{Version: "1.0.0"}, "1.0.0", true},
67 | }
68 |
69 | for _, tt := range tests {
70 | t.Run(tt.name, func(t *testing.T) {
71 | err := tt.config.Validate(tt.currentVersion)
72 | if (err != nil) != tt.wantErr {
73 | t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr)
74 | }
75 | })
76 | }
77 | }
78 |
79 | func TestParseConfig(t *testing.T) {
80 | // Create a temporary directory for test files
81 | tempDir, err := os.MkdirTemp("", "config_test")
82 | if err != nil {
83 | t.Fatalf("Failed to create temp dir: %v", err)
84 | }
85 | defer os.RemoveAll(tempDir)
86 |
87 | // Create a valid config file
88 | validConfig := `
89 | version: "1.0.0"
90 | host: "localhost"
91 | port: 8080
92 | services:
93 | - domain: "example.com"
94 | endpoints:
95 | - path: "/api"
96 | destination: "http://api.example.com"
97 | `
98 | validConfigPath := filepath.Join(tempDir, "valid_config.yaml")
99 | err = os.WriteFile(validConfigPath, []byte(validConfig), 0644)
100 | if err != nil {
101 | t.Fatalf("Failed to write valid config file: %v", err)
102 | }
103 |
104 | // Create an invalid config file
105 | invalidConfig := `
106 | version: "invalid"
107 | host: "localhost"
108 | `
109 | invalidConfigPath := filepath.Join(tempDir, "invalid_config.yaml")
110 | err = os.WriteFile(invalidConfigPath, []byte(invalidConfig), 0644)
111 | if err != nil {
112 | t.Fatalf("Failed to write invalid config file: %v", err)
113 | }
114 |
115 | tests := []struct {
116 | name string
117 | filepath string
118 | currentVersion string
119 | wantErr bool
120 | }{
121 | {"Valid config", validConfigPath, "1.0.0", false},
122 | {"Invalid config", invalidConfigPath, "1.0.0", true},
123 | {"Non-existent file", filepath.Join(tempDir, "non_existent.yaml"), "1.0.0", true},
124 | }
125 |
126 | for _, tt := range tests {
127 | t.Run(tt.name, func(t *testing.T) {
128 | _, err := ParseConfig(tt.filepath, tt.currentVersion)
129 | if (err != nil) != tt.wantErr {
130 | t.Errorf("ParseConfig() error = %v, wantErr %v", err, tt.wantErr)
131 | }
132 | })
133 | }
134 | }
135 |
136 | func TestIsValidURL(t *testing.T) {
137 | tests := []struct {
138 | name string
139 | url string
140 | want bool
141 | }{
142 | {"Valid URL", "http://example.com", true},
143 | {"Valid URL with path", "https://example.com/path", true},
144 | {"Invalid URL", "not-a-url", false},
145 | {"Invalid URL", "not a domain", false},
146 | {"Missing scheme", "example.com", false},
147 | }
148 |
149 | for _, tt := range tests {
150 | t.Run(tt.name, func(t *testing.T) {
151 | if got := isValidURL(tt.url); got != tt.want {
152 | t.Errorf("isValidURL() = %v, want %v", got, tt.want)
153 | }
154 | })
155 | }
156 | }
157 |
158 | func TestIsValidDir(t *testing.T) {
159 | // Create a temporary directory for the test
160 | tempDir, err := os.MkdirTemp("", "dir_test")
161 | if err != nil {
162 | t.Fatalf("Failed to create temp dir: %v", err)
163 | }
164 | defer os.RemoveAll(tempDir)
165 |
166 | tests := []struct {
167 | name string
168 | path string
169 | want bool
170 | }{
171 | {"Valid directory", tempDir, true},
172 | {"Non-existent directory", filepath.Join(tempDir, "non_existent"), false},
173 | {"Empty path", "", false},
174 | }
175 |
176 | for _, tt := range tests {
177 | t.Run(tt.name, func(t *testing.T) {
178 | if got := isValidDir(tt.path); got != tt.want {
179 | t.Errorf("isValidDir() = %v, want %v", got, tt.want)
180 | }
181 | })
182 | }
183 | }
184 |
185 | // Helper function to create string pointers
186 | func ptr(s string) *string {
187 | return &s
188 | }
189 |
--------------------------------------------------------------------------------
/pkg/monitor/monitor_test.go:
--------------------------------------------------------------------------------
1 | package monitor
2 |
3 | import (
4 | "errors"
5 | "net/http"
6 | "net/http/httptest"
7 | "testing"
8 | "time"
9 | )
10 |
11 | func TestCheck_run(t *testing.T) {
12 | tests := []struct {
13 | name string
14 | server *httptest.Server
15 | check Check
16 | expectedError bool
17 | serverResponse int
18 | }{
19 | {
20 | name: "successful check",
21 | server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22 | w.WriteHeader(http.StatusOK)
23 | })),
24 | check: Check{
25 | Name: "test-check",
26 | Method: "GET",
27 | Timeout: 5 * time.Second,
28 | Headers: map[string]string{"X-Test": "test-value"},
29 | },
30 | expectedError: false,
31 | serverResponse: http.StatusOK,
32 | },
33 | {
34 | name: "failed check - wrong status code",
35 | server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
36 | w.WriteHeader(http.StatusInternalServerError)
37 | })),
38 | check: Check{
39 | Name: "test-check-fail",
40 | Method: "GET",
41 | Timeout: 5 * time.Second,
42 | },
43 | expectedError: true,
44 | serverResponse: http.StatusInternalServerError,
45 | },
46 | {
47 | name: "check with timeout",
48 | server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49 | time.Sleep(2 * time.Second)
50 | w.WriteHeader(http.StatusOK)
51 | })),
52 | check: Check{
53 | Name: "test-check-timeout",
54 | Method: "GET",
55 | Timeout: 1 * time.Second,
56 | },
57 | expectedError: true,
58 | },
59 | {
60 | name: "check with custom headers",
61 | server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
62 | if r.Header.Get("X-Custom") != "custom-value" {
63 | w.WriteHeader(http.StatusBadRequest)
64 | return
65 | }
66 | w.WriteHeader(http.StatusOK)
67 | })),
68 | check: Check{
69 | Name: "test-check-headers",
70 | Method: "GET",
71 | Timeout: 5 * time.Second,
72 | Headers: map[string]string{"X-Custom": "custom-value"},
73 | },
74 | expectedError: false,
75 | serverResponse: http.StatusOK,
76 | },
77 | }
78 |
79 | for _, tt := range tests {
80 | t.Run(tt.name, func(t *testing.T) {
81 | defer tt.server.Close()
82 |
83 | tt.check.URL = tt.server.URL
84 | tt.check.run(func(error) {})
85 | })
86 | }
87 | }
88 |
89 | func TestChecker_Start(t *testing.T) {
90 | tests := []struct {
91 | name string
92 | checker Monitor
93 | expectedError bool
94 | }{
95 | {
96 | name: "successful start",
97 | checker: Monitor{
98 | Delay: 1 * time.Second,
99 | Checks: []Check{
100 | {
101 | Name: "test-check",
102 | Cron: "* * * * *",
103 | Method: "GET",
104 | URL: "http://example.com",
105 | Timeout: 5 * time.Second,
106 | },
107 | },
108 | },
109 | expectedError: false,
110 | },
111 | {
112 | name: "invalid cron expression",
113 | checker: Monitor{
114 | Delay: 1 * time.Second,
115 | Checks: []Check{
116 | {
117 | Name: "test-check-invalid-cron",
118 | Cron: "invalid",
119 | Method: "GET",
120 | URL: "http://example.com",
121 | Timeout: 5 * time.Second,
122 | },
123 | },
124 | },
125 | expectedError: true,
126 | },
127 | }
128 |
129 | for _, tt := range tests {
130 | t.Run(tt.name, func(t *testing.T) {
131 | err := tt.checker.Start()
132 | if (err != nil) != tt.expectedError {
133 | t.Errorf("Checker.Start() error = %v, expectedError %v", err, tt.expectedError)
134 | }
135 |
136 | // Clean up scheduler if it was created
137 | if tt.checker.scheduler != nil {
138 | tt.checker.scheduler.Stop()
139 | }
140 | })
141 | }
142 | }
143 |
144 | func TestChecker_OnFailure(t *testing.T) {
145 | tests := []struct {
146 | name string
147 | checker Monitor
148 | expectedError bool
149 | }{
150 | {
151 | name: "on failure command with valid command",
152 | checker: Monitor{
153 | Delay: 1 * time.Second,
154 | Checks: []Check{
155 | {
156 | Name: "test-check-failure",
157 | Cron: "* * * * *",
158 | Method: "GET",
159 | URL: "http://example.com",
160 | Timeout: 5 * time.Second,
161 | OnFailure: "echo check '$check_name' failed at $date: $error",
162 | },
163 | },
164 | },
165 | expectedError: false,
166 | },
167 | {
168 | name: "on failure command with invalid command",
169 | checker: Monitor{
170 | Delay: 1 * time.Second,
171 | Checks: []Check{
172 | {
173 | Name: "test-check-failure",
174 | Cron: "* * * * *",
175 | Method: "GET",
176 | URL: "http://example.com",
177 | Timeout: 5 * time.Second,
178 | OnFailure: "invalidCommand $error",
179 | },
180 | },
181 | },
182 | expectedError: true,
183 | },
184 | }
185 |
186 | for _, tt := range tests {
187 | t.Run(tt.name, func(t *testing.T) {
188 | // Simulate a failure scenario by injecting an error
189 | err := errors.New("Connection timeout")
190 | err = handleFailure(tt.checker.Checks[0], err)
191 |
192 | // Check if an error was returned and if it matches the expected result
193 | if (err != nil) != tt.expectedError {
194 | t.Errorf("handleFailure() error = %v, expectedError %v", err, tt.expectedError)
195 | }
196 |
197 | // Clean up scheduler if it was created
198 | if tt.checker.scheduler != nil {
199 | tt.checker.scheduler.Stop()
200 | }
201 | })
202 | }
203 | }
204 |
205 | // TestCheckWithMockServer tests the Check struct with a mock HTTP server
206 | func TestCheckWithMockServer(t *testing.T) {
207 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
208 | // Verify method
209 | if r.Method != http.MethodGet {
210 | t.Errorf("Expected method %s, got %s", http.MethodGet, r.Method)
211 | }
212 |
213 | // Verify headers
214 | if r.Header.Get("X-Test") != "test-value" {
215 | t.Errorf("Expected header X-Test: test-value, got %s", r.Header.Get("X-Test"))
216 | }
217 |
218 | w.WriteHeader(http.StatusOK)
219 | })
220 |
221 | server := httptest.NewServer(handler)
222 | defer server.Close()
223 |
224 | check := Check{
225 | Name: "test-check",
226 | Method: http.MethodGet,
227 | URL: server.URL,
228 | Timeout: 5 * time.Second,
229 | Headers: map[string]string{"X-Test": "test-value"},
230 | }
231 |
232 | check.run(func(error) {})
233 | }
234 |
--------------------------------------------------------------------------------
/internal/middlewares/openapi_test.go:
--------------------------------------------------------------------------------
1 | package middlewares_test
2 |
3 | import (
4 | "encoding/json"
5 | "net/http"
6 | "net/http/httptest"
7 | "os"
8 | "testing"
9 |
10 | "github.com/hvuhsg/gatego/internal/middlewares"
11 | "github.com/stretchr/testify/assert"
12 | "github.com/stretchr/testify/require"
13 | )
14 |
15 | // Helper function to compare JSON
16 | func assertJSONEqual(t *testing.T, expected, actual string) {
17 | var expectedMap, actualMap map[string]interface{}
18 | err := json.Unmarshal([]byte(expected), &expectedMap)
19 | require.NoError(t, err, "Error unmarshaling expected JSON")
20 | err = json.Unmarshal([]byte(actual), &actualMap)
21 | require.NoError(t, err, "Error unmarshaling actual JSON")
22 | assert.Equal(t, expectedMap, actualMap)
23 | }
24 |
25 | func TestOpenAPIValidationMiddleware(t *testing.T) {
26 | // Create a temporary OpenAPI spec file for testing
27 | specFile, err := os.CreateTemp("", "openapi-spec-*.yaml")
28 | require.NoError(t, err)
29 | defer os.Remove(specFile.Name())
30 |
31 | // Write a simple OpenAPI spec to the file
32 | specContent := `
33 | openapi: 3.0.0
34 | info:
35 | title: Test API
36 | version: 1.0.0
37 | paths:
38 | /test:
39 | get:
40 | parameters:
41 | - name: param
42 | in: query
43 | required: true
44 | schema:
45 | type: string
46 | responses:
47 | '200':
48 | description: OK
49 | content:
50 | application/json:
51 | schema:
52 | type: object
53 | required:
54 | - message
55 | properties:
56 | message:
57 | type: string
58 | status:
59 | type: string
60 | `
61 | _, err = specFile.Write([]byte(specContent))
62 | require.NoError(t, err)
63 | specFile.Close()
64 |
65 | // Create the middleware
66 | middleware, err := middlewares.NewOpenAPIValidationMiddleware(specFile.Name())
67 | require.NoError(t, err)
68 |
69 | tests := []struct {
70 | name string
71 | url string
72 | handler http.HandlerFunc
73 | expectedStatus int
74 | expectedBody string
75 | isJSON bool
76 | }{
77 | {
78 | name: "Valid request and response",
79 | url: "/test?param=value",
80 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
81 | w.Header().Set("Content-Type", "application/json")
82 | w.WriteHeader(http.StatusOK)
83 | json.NewEncoder(w).Encode(map[string]string{"message": "Hello, World!", "status": "ok"})
84 | }),
85 | expectedStatus: http.StatusOK,
86 | expectedBody: `{"message":"Hello, World!","status":"ok"}`,
87 | isJSON: true,
88 | },
89 | {
90 | name: "Valid request but invalid response (missing required field)",
91 | url: "/test?param=value",
92 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
93 | w.Header().Set("Content-Type", "application/json")
94 | w.WriteHeader(http.StatusOK)
95 | json.NewEncoder(w).Encode(map[string]string{"status": "error"})
96 | }),
97 | expectedStatus: http.StatusInternalServerError,
98 | expectedBody: "Invalid response:",
99 | isJSON: false,
100 | },
101 | {
102 | name: "Valid request but invalid response (wrong content type)",
103 | url: "/test?param=value",
104 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
105 | w.Header().Set("Content-Type", "text/plain")
106 | w.WriteHeader(http.StatusOK)
107 | w.Write([]byte("Hello, World!"))
108 | }),
109 | expectedStatus: http.StatusInternalServerError,
110 | expectedBody: "Invalid response:",
111 | isJSON: false,
112 | },
113 | {
114 | name: "Valid request but response with extra field",
115 | url: "/test?param=value",
116 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
117 | w.Header().Set("Content-Type", "application/json")
118 | w.WriteHeader(http.StatusOK)
119 | json.NewEncoder(w).Encode(map[string]string{"message": "Hello, World!", "extra": "field"})
120 | }),
121 | expectedStatus: http.StatusOK,
122 | expectedBody: `{"message":"Hello, World!","extra":"field"}`,
123 | isJSON: true,
124 | },
125 | {
126 | name: "Invalid path",
127 | url: "/invalid",
128 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
129 | expectedStatus: http.StatusBadRequest,
130 | expectedBody: "Error finding route:",
131 | isJSON: false,
132 | },
133 | {
134 | name: "Missing required parameter",
135 | url: "/test",
136 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
137 | expectedStatus: http.StatusBadRequest,
138 | expectedBody: "Invalid request:",
139 | isJSON: false,
140 | },
141 | }
142 |
143 | for _, tt := range tests {
144 | t.Run(tt.name, func(t *testing.T) {
145 | req, err := http.NewRequest("GET", tt.url, nil)
146 | require.NoError(t, err)
147 |
148 | rr := httptest.NewRecorder()
149 | wrappedHandler := middleware(tt.handler)
150 | wrappedHandler.ServeHTTP(rr, req)
151 |
152 | assert.Equal(t, tt.expectedStatus, rr.Code)
153 |
154 | if tt.isJSON {
155 | assertJSONEqual(t, tt.expectedBody, rr.Body.String())
156 | } else {
157 | assert.Contains(t, rr.Body.String(), tt.expectedBody)
158 | }
159 | })
160 | }
161 | }
162 |
163 | func TestNewOpenAPIValidationMiddleware(t *testing.T) {
164 | tests := []struct {
165 | name string
166 | specContent string
167 | expectError bool
168 | }{
169 | {
170 | name: "Valid OpenAPI spec",
171 | specContent: `
172 | openapi: 3.0.0
173 | info:
174 | title: Test API
175 | version: 1.0.0
176 | paths:
177 | /test:
178 | get:
179 | responses:
180 | '200':
181 | description: OK
182 | `,
183 | expectError: false,
184 | },
185 | {
186 | name: "Invalid OpenAPI spec",
187 | specContent: "invalid: yaml: content",
188 | expectError: true,
189 | },
190 | }
191 |
192 | for _, tt := range tests {
193 | t.Run(tt.name, func(t *testing.T) {
194 | specFile, err := os.CreateTemp("", "openapi-spec-*.yaml")
195 | require.NoError(t, err)
196 | defer os.Remove(specFile.Name())
197 |
198 | _, err = specFile.Write([]byte(tt.specContent))
199 | require.NoError(t, err)
200 | specFile.Close()
201 |
202 | middleware, err := middlewares.NewOpenAPIValidationMiddleware(specFile.Name())
203 |
204 | if tt.expectError {
205 | assert.Error(t, err)
206 | assert.Nil(t, middleware)
207 | } else {
208 | assert.NoError(t, err)
209 | assert.NotNil(t, middleware)
210 | }
211 | })
212 | }
213 | }
214 |
--------------------------------------------------------------------------------
/internal/middlewares/cache_test.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "net/http"
5 | "net/http/httptest"
6 | "testing"
7 | "time"
8 | )
9 |
10 | func TestCacheMiddleware(t *testing.T) {
11 | t.Parallel()
12 | // Reset cache before each test
13 | responseCache.Flush()
14 |
15 | t.Run("Should not cache response with no cache headers", func(t *testing.T) {
16 | responseText := "test response"
17 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
18 | w.WriteHeader(200)
19 | w.Write([]byte(responseText))
20 | })
21 |
22 | middleware := NewCacheMiddleware()(handler)
23 | req := httptest.NewRequest("GET", "/test", nil)
24 |
25 | // First request
26 | w1 := httptest.NewRecorder()
27 | middleware.ServeHTTP(w1, req)
28 |
29 | if w1.Body.String() != "test response" {
30 | t.Errorf("Expected 'test response', got '%s'", w1.Body.String())
31 | }
32 |
33 | responseText = "new response"
34 |
35 | // Second request - should be served from cache
36 | w2 := httptest.NewRecorder()
37 | middleware.ServeHTTP(w2, req)
38 |
39 | if w2.Body.String() != "new response" {
40 | t.Errorf("Expected not cached 'new response', got '%s'", w2.Body.String())
41 | }
42 | })
43 |
44 | t.Run("Should respect max-age Cache-Control header", func(t *testing.T) {
45 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
46 | w.Header().Set("Cache-Control", "max-age=1")
47 | w.WriteHeader(200)
48 | w.Write([]byte("cache-control test"))
49 | })
50 |
51 | middleware := NewCacheMiddleware()(handler)
52 | req := httptest.NewRequest("GET", "/cache-control", nil)
53 |
54 | // First request
55 | w1 := httptest.NewRecorder()
56 | middleware.ServeHTTP(w1, req)
57 |
58 | // Wait for less than max-age
59 | time.Sleep(time.Millisecond * 500)
60 |
61 | // Should still be cached
62 | w2 := httptest.NewRecorder()
63 | middleware.ServeHTTP(w2, req)
64 |
65 | if w2.Body.String() != "cache-control test" {
66 | t.Errorf("Expected cached response before max-age expiration")
67 | }
68 |
69 | // Wait for cache to expire
70 | time.Sleep(time.Millisecond * 1500)
71 |
72 | if _, found := responseCache.Get("/cache-control"); found {
73 | t.Error("Cache should have expired")
74 | }
75 | })
76 |
77 | t.Run("Should respect Expires header", func(t *testing.T) {
78 | responseText := "expires test"
79 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80 | expiresTime := time.Now().Add(2 * time.Second)
81 | w.Header().Set("Expires", expiresTime.Format(time.RFC1123))
82 | w.WriteHeader(200)
83 | w.Write([]byte(responseText))
84 | })
85 |
86 | middleware := NewCacheMiddleware()(handler)
87 | req := httptest.NewRequest("GET", "/expires", nil)
88 |
89 | // First request
90 | w1 := httptest.NewRecorder()
91 | middleware.ServeHTTP(w1, req)
92 |
93 | // Wait for less than expiration
94 | time.Sleep(time.Second * 1)
95 |
96 | responseText = "something else"
97 |
98 | // Should still be cached
99 | w2 := httptest.NewRecorder()
100 | middleware.ServeHTTP(w2, req)
101 |
102 | if w2.Body.String() != "expires test" {
103 | t.Errorf("Expected cached response before expiration")
104 | }
105 |
106 | // Wait for cache to expire
107 | time.Sleep(time.Second * 2)
108 |
109 | if _, found := responseCache.Get("/expires"); found {
110 | t.Error("Cache should have expired")
111 | }
112 | })
113 |
114 | t.Run("Should preserve response headers", func(t *testing.T) {
115 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
116 | // Add expiration header
117 | expiresTime := time.Now().Add(50 * time.Second)
118 | w.Header().Set("Expires", expiresTime.Format(time.RFC1123))
119 |
120 | w.Header().Set("Content-Type", "application/json")
121 | w.Header().Set("X-Custom-Header", "test-value")
122 | w.WriteHeader(200)
123 | w.Write([]byte(`{"message":"test"}`))
124 | })
125 |
126 | middleware := NewCacheMiddleware()(handler)
127 | req := httptest.NewRequest("GET", "/headers", nil)
128 |
129 | // First request
130 | w1 := httptest.NewRecorder()
131 | middleware.ServeHTTP(w1, req)
132 |
133 | // Second request - should preserve headers
134 | w2 := httptest.NewRecorder()
135 | middleware.ServeHTTP(w2, req)
136 |
137 | expectedHeaders := map[string]string{
138 | "Content-Type": "application/json",
139 | "X-Custom-Header": "test-value",
140 | }
141 |
142 | for header, expectedValue := range expectedHeaders {
143 | if value := w2.Header().Get(header); value != expectedValue {
144 | t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
145 | }
146 | }
147 | })
148 |
149 | t.Run("Should handle invalid cache headers gracefully", func(t *testing.T) {
150 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
151 | w.Header().Set("Cache-Control", "max-age=invalid")
152 | w.Header().Set("Expires", "invalid-date")
153 | w.WriteHeader(200)
154 | w.Write([]byte("invalid headers test"))
155 | })
156 |
157 | middleware := NewCacheMiddleware()(handler)
158 | req := httptest.NewRequest("GET", "/invalid-headers", nil)
159 |
160 | w := httptest.NewRecorder()
161 | middleware.ServeHTTP(w, req)
162 |
163 | if w.Body.String() != "invalid headers test" {
164 | t.Errorf("Expected normal response despite invalid headers")
165 | }
166 | })
167 | }
168 |
169 | func TestGetCacheMaxAge(t *testing.T) {
170 | t.Parallel()
171 | tests := []struct {
172 | name string
173 | cacheControl string
174 | expected int
175 | }{
176 | {"Valid max-age", "max-age=60", 60},
177 | {"Multiple directives", "public, max-age=30", 30},
178 | {"Invalid max-age", "max-age=invalid", 0},
179 | {"No max-age", "public, private", 0},
180 | {"Empty string", "", 0},
181 | }
182 |
183 | for _, tt := range tests {
184 | t.Run(tt.name, func(t *testing.T) {
185 | result := getCacheMaxAge(tt.cacheControl)
186 | if result != tt.expected {
187 | t.Errorf("getCacheMaxAge(%s) = %d; want %d", tt.cacheControl, result, tt.expected)
188 | }
189 | })
190 | }
191 | }
192 |
193 | func TestGetCacheExpires(t *testing.T) {
194 | t.Parallel()
195 |
196 | now := time.Now()
197 | tests := []struct {
198 | name string
199 | expiresHeader string
200 | wantZero bool
201 | }{
202 | {"Valid date", now.Format(time.RFC1123), false},
203 | {"Invalid date", "invalid-date", true},
204 | {"Empty string", "", true},
205 | }
206 |
207 | for _, tt := range tests {
208 | t.Run(tt.name, func(t *testing.T) {
209 | result := getCacheExpires(tt.expiresHeader)
210 | if tt.wantZero && !result.IsZero() {
211 | t.Errorf("getCacheExpires(%s) expected zero time, got %v", tt.expiresHeader, result)
212 | }
213 | if !tt.wantZero && result.IsZero() {
214 | t.Errorf("getCacheExpires(%s) expected non-zero time, got zero time", tt.expiresHeader)
215 | }
216 | })
217 | }
218 | }
219 |
--------------------------------------------------------------------------------
/pkg/cron/cron_test.go:
--------------------------------------------------------------------------------
1 | package cron
2 |
3 | import (
4 | "encoding/json"
5 | "testing"
6 | "time"
7 | )
8 |
9 | func TestCronNew(t *testing.T) {
10 | t.Parallel()
11 |
12 | c := New()
13 |
14 | expectedInterval := 1 * time.Minute
15 | if c.interval != expectedInterval {
16 | t.Fatalf("Expected default interval %v, got %v", expectedInterval, c.interval)
17 | }
18 |
19 | expectedTimezone := time.UTC
20 | if c.timezone.String() != expectedTimezone.String() {
21 | t.Fatalf("Expected default timezone %v, got %v", expectedTimezone, c.timezone)
22 | }
23 |
24 | if len(c.jobs) != 0 {
25 | t.Fatalf("Expected no jobs by default, got \n%v", c.jobs)
26 | }
27 |
28 | if c.ticker != nil {
29 | t.Fatal("Expected the ticker NOT to be initialized")
30 | }
31 | }
32 |
33 | func TestCronSetInterval(t *testing.T) {
34 | t.Parallel()
35 |
36 | c := New()
37 |
38 | interval := 2 * time.Minute
39 |
40 | c.SetInterval(interval)
41 |
42 | if c.interval != interval {
43 | t.Fatalf("Expected interval %v, got %v", interval, c.interval)
44 | }
45 | }
46 |
47 | func TestCronSetTimezone(t *testing.T) {
48 | t.Parallel()
49 |
50 | c := New()
51 |
52 | timezone, _ := time.LoadLocation("Asia/Tokyo")
53 |
54 | c.SetTimezone(timezone)
55 |
56 | if c.timezone.String() != timezone.String() {
57 | t.Fatalf("Expected timezone %v, got %v", timezone, c.timezone)
58 | }
59 | }
60 |
61 | func TestCronAddAndRemove(t *testing.T) {
62 | t.Parallel()
63 |
64 | c := New()
65 |
66 | if err := c.Add("test0", "* * * * *", nil); err == nil {
67 | t.Fatal("Expected nil function error")
68 | }
69 |
70 | if err := c.Add("test1", "invalid", func() {}); err == nil {
71 | t.Fatal("Expected invalid cron expression error")
72 | }
73 |
74 | if err := c.Add("test2", "* * * * *", func() {}); err != nil {
75 | t.Fatal(err)
76 | }
77 |
78 | if err := c.Add("test3", "* * * * *", func() {}); err != nil {
79 | t.Fatal(err)
80 | }
81 |
82 | if err := c.Add("test4", "* * * * *", func() {}); err != nil {
83 | t.Fatal(err)
84 | }
85 |
86 | // overwrite test2
87 | if err := c.Add("test2", "1 2 3 4 5", func() {}); err != nil {
88 | t.Fatal(err)
89 | }
90 |
91 | if err := c.Add("test5", "1 2 3 4 5", func() {}); err != nil {
92 | t.Fatal(err)
93 | }
94 |
95 | // mock job deletion
96 | c.Remove("test4")
97 |
98 | // try to remove non-existing (should be no-op)
99 | c.Remove("missing")
100 |
101 | // check job keys
102 | {
103 | expectedKeys := []string{"test3", "test2", "test5"}
104 |
105 | if v := len(c.jobs); v != len(expectedKeys) {
106 | t.Fatalf("Expected %d jobs, got %d", len(expectedKeys), v)
107 | }
108 |
109 | for _, k := range expectedKeys {
110 | if c.jobs[k] == nil {
111 | t.Fatalf("Expected job with key %s, got nil", k)
112 | }
113 | }
114 | }
115 |
116 | // check the jobs schedule
117 | {
118 | expectedSchedules := map[string]string{
119 | "test2": `{"minutes":{"1":{}},"hours":{"2":{}},"days":{"3":{}},"months":{"4":{}},"daysOfWeek":{"5":{}}}`,
120 | "test3": `{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`,
121 | "test5": `{"minutes":{"1":{}},"hours":{"2":{}},"days":{"3":{}},"months":{"4":{}},"daysOfWeek":{"5":{}}}`,
122 | }
123 | for k, v := range expectedSchedules {
124 | raw, err := json.Marshal(c.jobs[k].schedule)
125 | if err != nil {
126 | t.Fatal(err)
127 | }
128 |
129 | if string(raw) != v {
130 | t.Fatalf("Expected %q schedule \n%s, \ngot \n%s", k, v, raw)
131 | }
132 | }
133 | }
134 | }
135 |
136 | func TestCronMustAdd(t *testing.T) {
137 | t.Parallel()
138 |
139 | c := New()
140 |
141 | defer func() {
142 | if r := recover(); r == nil {
143 | t.Errorf("test1 didn't panic")
144 | }
145 | }()
146 |
147 | c.MustAdd("test1", "* * * * *", nil)
148 |
149 | c.MustAdd("test2", "* * * * *", func() {})
150 |
151 | if _, ok := c.jobs["test2"]; !ok {
152 | t.Fatal("Couldn't find job test2")
153 | }
154 | }
155 |
156 | func TestCronRemoveAll(t *testing.T) {
157 | t.Parallel()
158 |
159 | c := New()
160 |
161 | if err := c.Add("test1", "* * * * *", func() {}); err != nil {
162 | t.Fatal(err)
163 | }
164 |
165 | if err := c.Add("test2", "* * * * *", func() {}); err != nil {
166 | t.Fatal(err)
167 | }
168 |
169 | if err := c.Add("test3", "* * * * *", func() {}); err != nil {
170 | t.Fatal(err)
171 | }
172 |
173 | if v := len(c.jobs); v != 3 {
174 | t.Fatalf("Expected %d jobs, got %d", 3, v)
175 | }
176 |
177 | c.RemoveAll()
178 |
179 | if v := len(c.jobs); v != 0 {
180 | t.Fatalf("Expected %d jobs, got %d", 0, v)
181 | }
182 | }
183 |
184 | func TestCronTotal(t *testing.T) {
185 | t.Parallel()
186 |
187 | c := New()
188 |
189 | if v := c.Total(); v != 0 {
190 | t.Fatalf("Expected 0 jobs, got %v", v)
191 | }
192 |
193 | if err := c.Add("test1", "* * * * *", func() {}); err != nil {
194 | t.Fatal(err)
195 | }
196 |
197 | if err := c.Add("test2", "* * * * *", func() {}); err != nil {
198 | t.Fatal(err)
199 | }
200 |
201 | // overwrite
202 | if err := c.Add("test1", "* * * * *", func() {}); err != nil {
203 | t.Fatal(err)
204 | }
205 |
206 | if v := c.Total(); v != 2 {
207 | t.Fatalf("Expected 2 jobs, got %v", v)
208 | }
209 | }
210 |
211 | func TestCronStartStop(t *testing.T) {
212 | t.Parallel()
213 |
214 | test1 := 0
215 | test2 := 0
216 |
217 | c := New()
218 |
219 | c.SetInterval(500 * time.Millisecond)
220 |
221 | c.Add("test1", "* * * * *", func() {
222 | test1++
223 | })
224 |
225 | c.Add("test2", "* * * * *", func() {
226 | test2++
227 | })
228 |
229 | expectedCalls := 2
230 |
231 | // call twice Start to check if the previous ticker will be reseted
232 | c.Start()
233 | c.Start()
234 |
235 | time.Sleep(1 * time.Second)
236 |
237 | // call twice Stop to ensure that the second stop is no-op
238 | c.Stop()
239 | c.Stop()
240 |
241 | if test1 != expectedCalls {
242 | t.Fatalf("Expected %d test1, got %d", expectedCalls, test1)
243 | }
244 | if test2 != expectedCalls {
245 | t.Fatalf("Expected %d test2, got %d", expectedCalls, test2)
246 | }
247 |
248 | // resume for 2 seconds
249 | c.Start()
250 |
251 | time.Sleep(2 * time.Second)
252 |
253 | c.Stop()
254 |
255 | expectedCalls += 4
256 |
257 | if test1 != expectedCalls {
258 | t.Fatalf("Expected %d test1, got %d", expectedCalls, test1)
259 | }
260 | if test2 != expectedCalls {
261 | t.Fatalf("Expected %d test2, got %d", expectedCalls, test2)
262 | }
263 | }
264 |
--------------------------------------------------------------------------------
/internal/middlewares/responsecapture.go:
--------------------------------------------------------------------------------
1 | package middlewares
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io"
7 | "net/http"
8 | "net/textproto"
9 | "strconv"
10 | "strings"
11 |
12 | "golang.org/x/net/http/httpguts"
13 | )
14 |
15 | // ResponseRecorder is an implementation of [http.ResponseWriter].
16 | type ResponseRecorder struct {
17 | // Code is the HTTP response code set by WriteHeader.
18 | //
19 | // Note that if a Handler never calls WriteHeader or Write,
20 | // this might end up being 0, rather than the implicit
21 | // http.StatusOK. To get the implicit value, use the Result
22 | // method.
23 | Code int
24 |
25 | // HeaderMap contains the headers explicitly set by the Handler.
26 | // It is an internal detail.
27 | //
28 | // Deprecated: HeaderMap exists for historical compatibility
29 | // and should not be used. To access the headers returned by a handler,
30 | // use the Response.Header map as returned by the Result method.
31 | HeaderMap http.Header
32 |
33 | // Body is the buffer to which the Handler's Write calls are sent.
34 | // If nil, the Writes are silently discarded.
35 | Body *bytes.Buffer
36 |
37 | // Flushed is whether the Handler called Flush.
38 | Flushed bool
39 |
40 | result *http.Response // cache of Result's return value
41 | snapHeader http.Header // snapshot of HeaderMap at first Write
42 | wroteHeader bool
43 | }
44 |
45 | // NewRecorder returns an initialized [ResponseRecorder].
46 | func NewRecorder() *ResponseRecorder {
47 | return &ResponseRecorder{
48 | HeaderMap: make(http.Header),
49 | Body: new(bytes.Buffer),
50 | Code: 200,
51 | }
52 | }
53 |
54 | // DefaultRemoteAddr is the default remote address to return in RemoteAddr if
55 | // an explicit DefaultRemoteAddr isn't set on [ResponseRecorder].
56 | const DefaultRemoteAddr = "1.2.3.4"
57 |
58 | // Header implements [http.ResponseWriter]. It returns the response
59 | // headers to mutate within a handler.
60 | func (rw *ResponseRecorder) Header() http.Header {
61 | m := rw.HeaderMap
62 | if m == nil {
63 | m = make(http.Header)
64 | rw.HeaderMap = m
65 | }
66 | return m
67 | }
68 |
69 | // writeHeader writes a header if it was not written yet and
70 | // detects Content-Type if needed.
71 | //
72 | // bytes or str are the beginning of the response body.
73 | // We pass both to avoid unnecessarily generate garbage
74 | // in rw.WriteString which was created for performance reasons.
75 | // Non-nil bytes win.
76 | func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
77 | if rw.wroteHeader {
78 | return
79 | }
80 | if len(str) > 512 {
81 | str = str[:512]
82 | }
83 |
84 | m := rw.Header()
85 |
86 | _, hasType := m["Content-Type"]
87 | hasTE := m.Get("Transfer-Encoding") != ""
88 | if !hasType && !hasTE {
89 | if b == nil {
90 | b = []byte(str)
91 | }
92 | m.Set("Content-Type", http.DetectContentType(b))
93 | }
94 |
95 | rw.WriteHeader(200)
96 | }
97 |
98 | // Write implements http.ResponseWriter. The data in buf is written to
99 | // rw.Body, if not nil.
100 | func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
101 | rw.writeHeader(buf, "")
102 | if rw.Body != nil {
103 | rw.Body.Write(buf)
104 | }
105 | return len(buf), nil
106 | }
107 |
108 | // WriteString implements [io.StringWriter]. The data in str is written
109 | // to rw.Body, if not nil.
110 | func (rw *ResponseRecorder) WriteString(str string) (int, error) {
111 | rw.writeHeader(nil, str)
112 | if rw.Body != nil {
113 | rw.Body.WriteString(str)
114 | }
115 | return len(str), nil
116 | }
117 |
118 | func checkWriteHeaderCode(code int) {
119 | // Issue 22880: require valid WriteHeader status codes.
120 | // For now we only enforce that it's three digits.
121 | // In the future we might block things over 599 (600 and above aren't defined
122 | // at https://httpwg.org/specs/rfc7231.html#status.codes)
123 | // and we might block under 200 (once we have more mature 1xx support).
124 | // But for now any three digits.
125 | //
126 | // We used to send "HTTP/1.1 000 0" on the wire in responses but there's
127 | // no equivalent bogus thing we can realistically send in HTTP/2,
128 | // so we'll consistently panic instead and help people find their bugs
129 | // early. (We can't return an error from WriteHeader even if we wanted to.)
130 | if code < 100 || code > 999 {
131 | panic(fmt.Sprintf("invalid WriteHeader code %v", code))
132 | }
133 | }
134 |
135 | // WriteHeader implements [http.ResponseWriter].
136 | func (rw *ResponseRecorder) WriteHeader(code int) {
137 | if rw.wroteHeader {
138 | return
139 | }
140 |
141 | checkWriteHeaderCode(code)
142 | rw.Code = code
143 | rw.wroteHeader = true
144 | if rw.HeaderMap == nil {
145 | rw.HeaderMap = make(http.Header)
146 | }
147 | rw.snapHeader = rw.HeaderMap.Clone()
148 | }
149 |
150 | // Result returns the response generated by the handler.
151 | //
152 | // The returned Response will have at least its StatusCode,
153 | // Header, Body, and optionally Trailer populated.
154 | // More fields may be populated in the future, so callers should
155 | // not DeepEqual the result in tests.
156 | //
157 | // The Response.Header is a snapshot of the headers at the time of the
158 | // first write call, or at the time of this call, if the handler never
159 | // did a write.
160 | //
161 | // The Response.Body is guaranteed to be non-nil and Body.Read call is
162 | // guaranteed to not return any error other than [io.EOF].
163 | //
164 | // Result must only be called after the handler has finished running.
165 | func (rw *ResponseRecorder) Result() *http.Response {
166 | if rw.result != nil {
167 | return rw.result
168 | }
169 | if rw.snapHeader == nil {
170 | rw.snapHeader = rw.HeaderMap.Clone()
171 | }
172 | res := &http.Response{
173 | Proto: "HTTP/1.1",
174 | ProtoMajor: 1,
175 | ProtoMinor: 1,
176 | StatusCode: rw.Code,
177 | Header: rw.snapHeader,
178 | }
179 | rw.result = res
180 | if res.StatusCode == 0 {
181 | res.StatusCode = 200
182 | }
183 | res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
184 | if rw.Body != nil {
185 | res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
186 | } else {
187 | res.Body = http.NoBody
188 | }
189 | res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
190 |
191 | if trailers, ok := rw.snapHeader["Trailer"]; ok {
192 | res.Trailer = make(http.Header, len(trailers))
193 | for _, k := range trailers {
194 | for _, k := range strings.Split(k, ",") {
195 | k = http.CanonicalHeaderKey(textproto.TrimString(k))
196 | if !httpguts.ValidTrailerHeader(k) {
197 | // Ignore since forbidden by RFC 7230, section 4.1.2.
198 | continue
199 | }
200 | vv, ok := rw.HeaderMap[k]
201 | if !ok {
202 | continue
203 | }
204 | vv2 := make([]string, len(vv))
205 | copy(vv2, vv)
206 | res.Trailer[k] = vv2
207 | }
208 | }
209 | }
210 | for k, vv := range rw.HeaderMap {
211 | if !strings.HasPrefix(k, http.TrailerPrefix) {
212 | continue
213 | }
214 | if res.Trailer == nil {
215 | res.Trailer = make(http.Header)
216 | }
217 | for _, v := range vv {
218 | res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
219 | }
220 | }
221 | return res
222 | }
223 |
224 | func (rr *ResponseRecorder) WriteTo(rw http.ResponseWriter) {
225 | rr.WriteHeadersTo(rw)
226 | rw.WriteHeader(rr.Result().StatusCode)
227 | rw.Write(rr.Body.Bytes())
228 | }
229 |
230 | func (rr *ResponseRecorder) WriteHeadersTo(rw http.ResponseWriter) {
231 | for header := range rr.Result().Header {
232 | rw.Header().Set(header, rr.Result().Header.Get(header))
233 | }
234 | }
235 |
236 | // parseContentLength trims whitespace from s and returns -1 if no value
237 | // is set, or the value if it's >= 0.
238 | //
239 | // This a modified version of same function found in net/http/transfer.go. This
240 | // one just ignores an invalid header.
241 | func parseContentLength(cl string) int64 {
242 | cl = textproto.TrimString(cl)
243 | if cl == "" {
244 | return -1
245 | }
246 | n, err := strconv.ParseUint(cl, 10, 63)
247 | if err != nil {
248 | return -1
249 | }
250 | return int64(n)
251 | }
252 |
--------------------------------------------------------------------------------
/config-schema.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "http://json-schema.org/draft-07/schema#",
3 | "type": "object",
4 | "properties": {
5 | "version": {
6 | "type": "string",
7 | "description": "Version of the configuration."
8 | },
9 | "host": {
10 | "type": "string",
11 | "description": "The host where the service will run."
12 | },
13 | "port": {
14 | "type": "integer",
15 | "description": "The port for the service."
16 | },
17 | "ssl": {
18 | "type": "object",
19 | "properties": {
20 | "keyfile": {
21 | "type": "string",
22 | "description": "Path to SSL key file."
23 | },
24 | "certfile": {
25 | "type": "string",
26 | "description": "Path to SSL certificate file."
27 | }
28 | },
29 | "required": [
30 | "keyfile",
31 | "certfile"
32 | ],
33 | "description": "SSL configuration for the server."
34 | },
35 | "open_telemetry": {
36 | "type": "object",
37 | "properties": {
38 | "endpoint": {
39 | "type": "string",
40 | "description": "GRPC connection string for open telemetry collection agent"
41 | },
42 | "sample_ratio": {
43 | "type":"number",
44 | "exclusiveMinimum": 0,
45 | "maximum": 1
46 | }
47 | },
48 | "required": ["sample_ratio", "endpoint"]
49 | },
50 | "services": {
51 | "type": "array",
52 | "items": {
53 | "type": "object",
54 | "properties": {
55 | "domain": {
56 | "type": "string",
57 | "description": "Domain name for the service."
58 | },
59 | "anomaly_detection": {
60 | "type": "object",
61 | "description": "Adds header to downstream request with routing anomaly score between 0 to 1",
62 | "properties": {
63 | "header_name": {
64 | "type":"string",
65 | "description": "The header name that will hold the anomaly score [Default X-Anomaly-Score]"
66 | },
67 | "min_score": {
68 | "type":"integer",
69 | "default": 100,
70 | "description": "Below that score the anomaly score is 0",
71 | "minimum": 0
72 | },
73 | "max_score": {
74 | "type":"integer",
75 | "default": 200,
76 | "description": "Above that score the anomaly score is 1",
77 | "minimum": 0
78 | },
79 | "treshold_for_rating": {
80 | "type": "integer",
81 | "default": 100,
82 | "description": "How many requests to collect data from before starting to calculate anomaly score",
83 | "minimum": 0
84 | },
85 | "active": {
86 | "type":"boolean",
87 | "description": "Activate the anomaly detector"
88 | }
89 | }
90 | },
91 | "endpoints": {
92 | "type": "array",
93 | "items": {
94 | "type": "object",
95 | "properties": {
96 | "path": {
97 | "type": "string",
98 | "description": "Endpoint path that will be served."
99 | },
100 | "directory": {
101 | "type": "string",
102 | "description": "Directory to serve files from."
103 | },
104 | "destination": {
105 | "type": "string",
106 | "description": "Server URL to proxy the requests there."
107 | },
108 | "backend": {
109 | "type": "object",
110 | "properties": {
111 | "balance_policy": {
112 | "type": "string",
113 | "enum": [
114 | "round-robin",
115 | "random",
116 | "least-latency"
117 | ],
118 | "description": "Load balancing policy for backend servers."
119 | },
120 | "servers": {
121 | "type": "array",
122 | "items": {
123 | "type": "object",
124 | "properties": {
125 | "url": {
126 | "type": "string",
127 | "description": "URL of the backend server."
128 | },
129 | "weight": {
130 | "type": "integer",
131 | "description": "Weight of the backend server for load balancing."
132 | }
133 | },
134 | "required": [
135 | "url",
136 | "weight"
137 | ]
138 | }
139 | }
140 | },
141 | "required": [
142 | "balance_policy",
143 | "servers"
144 | ]
145 | },
146 | "omit_headers": {
147 | "type": "array",
148 | "description": "List of headers to omit for secrets protection.",
149 | "items": {
150 | "type": "string"
151 | }
152 | },
153 | "headers": {
154 | "type": "array",
155 | "description": "List of headers to add to request.",
156 | "items": {
157 | "type": "string"
158 | }
159 | },
160 | "minify": {
161 | "type": "array",
162 | "items": {
163 | "type": "string"
164 | }
165 | },
166 | "gzip": {
167 | "type": "boolean",
168 | "description": "Enable GZIP compression."
169 | },
170 | "timeout": {
171 | "type": "string",
172 | "description": "Custom timeout for backend responses."
173 | },
174 | "max_size": {
175 | "type": "integer",
176 | "description": "Max request size in bytes."
177 | },
178 | "ratelimits": {
179 | "type": "array",
180 | "items": {
181 | "type": "string",
182 | "description": "Rate limits in the format of requests per time period (e.g., ip-10/m)."
183 | }
184 | },
185 | "openapi": {
186 | "type": "string",
187 | "description": "Path to the OpenAPI specification for request/response validation."
188 | },
189 | "checks": {
190 | "type": "array",
191 | "description": "List of health check configurations",
192 | "items": {
193 | "type": "object",
194 | "required": [
195 | "name",
196 | "cron",
197 | "method",
198 | "url",
199 | "timeout"
200 | ],
201 | "properties": {
202 | "name": {
203 | "type": "string",
204 | "description": "Descriptive name for the health check",
205 | "minLength": 1
206 | },
207 | "cron": {
208 | "type": "string",
209 | "description": "Cron expression or macro for check frequency",
210 | "pattern": "^(@yearly|@annually|@monthly|@weekly|@daily|@hourly|@minutely|([*\\d,-/]+\\s){4}[*\\d,-/]+)$",
211 | "examples": [
212 | "* * * * *",
213 | "@hourly",
214 | "@daily",
215 | "0 0 * * *"
216 | ]
217 | },
218 | "method": {
219 | "type": "string",
220 | "description": "HTTP method for the health check",
221 | "enum": [
222 | "GET",
223 | "POST",
224 | "PUT",
225 | "DELETE",
226 | "HEAD",
227 | "OPTIONS",
228 | "PATCH",
229 | "CONNECT",
230 | "TRACE"
231 | ]
232 | },
233 | "url": {
234 | "type": "string",
235 | "description": "Health check endpoint URL",
236 | "format": "uri",
237 | "pattern": "^https?://"
238 | },
239 | "timeout": {
240 | "type": "string",
241 | "description": "Timeout duration for health check requests",
242 | "pattern": "^\\d+[smh]$",
243 | "default": "5s",
244 | "examples": [
245 | "5s",
246 | "1m",
247 | "1h"
248 | ]
249 | },
250 | "headers": {
251 | "type": "object",
252 | "description": "Custom headers to be sent with the health check request",
253 | "additionalProperties": {
254 | "type": "string"
255 | },
256 | "examples": [
257 | {
258 | "Host": "domain.org",
259 | "Authorization": "Bearer abc123"
260 | }
261 | ]
262 | },
263 | "on_failure": {
264 | "type": "string",
265 | "description": "Shell command to execute if the health check fails. Supports variable expansion: $date, $error, and $check_name.",
266 | "examples": [
267 | "echo Health check '$check_name' failed at $date with error: $error"
268 | ]
269 | }
270 | }
271 | }
272 | },
273 | "cache": {
274 | "type": "boolean",
275 | "description": "Enable caching of response that has cache headers"
276 | }
277 | },
278 | "required": [
279 | "path"
280 | ],
281 | "oneOf": [
282 | {
283 | "required": [
284 | "directory"
285 | ]
286 | },
287 | {
288 | "required": [
289 | "destination"
290 | ]
291 | },
292 | {
293 | "required": [
294 | "backend"
295 | ]
296 | }
297 | ]
298 | }
299 | }
300 | },
301 | "required": [
302 | "domain",
303 | "endpoints"
304 | ]
305 | }
306 | }
307 | },
308 | "required": [
309 | "version",
310 | "host",
311 | "port",
312 | "services"
313 | ]
314 | }
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
2 | github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
5 | github.com/getkin/kin-openapi v0.128.0 h1:jqq3D9vC9pPq1dGcOCv7yOp1DaEe7c/T1vzcLbITSp4=
6 | github.com/getkin/kin-openapi v0.128.0/go.mod h1:OZrfXzUfGrNbsKj+xmFBx6E5c6yH3At/tAKSc2UszXM=
7 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
8 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
9 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
10 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
11 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
12 | github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
13 | github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
14 | github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
15 | github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
16 | github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
17 | github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
18 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
19 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
20 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
21 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
22 | github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
23 | github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
24 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjwqUPTYmYuemVOx+Ys=
25 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I=
26 | github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY=
27 | github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
28 | github.com/invopop/yaml v0.3.1 h1:f0+ZpmhfBSS4MhG+4HYseMdJhoeeopbSKbq5Rpeelso=
29 | github.com/invopop/yaml v0.3.1/go.mod h1:PMOp3nn4/12yEZUFfmOuNHJsZToEEOwoWsT+D81KkeA=
30 | github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
31 | github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
32 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
33 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
34 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
35 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
36 | github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
37 | github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
38 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw=
39 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
40 | github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
41 | github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
42 | github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
43 | github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
44 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
45 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
46 | github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
47 | github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
48 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
49 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
50 | github.com/tdewolff/minify/v2 v2.21.0 h1:nAPP1UVx0aK1xsQh/JiG3xyEnnqWw+agPstn+V6Pkto=
51 | github.com/tdewolff/minify/v2 v2.21.0/go.mod h1:hGcthJ6Vj51NG+9QRIfN/DpWj5loHnY3bfhThzWWq08=
52 | github.com/tdewolff/parse/v2 v2.7.17 h1:uC10p6DaQQORDy72eaIyD+AvAkaIUOouQ0nWp4uD0D0=
53 | github.com/tdewolff/parse/v2 v2.7.17/go.mod h1:3FbJWZp3XT9OWVN3Hmfp0p/a08v4h8J9W1aghka0soA=
54 | github.com/tdewolff/test v1.0.11-0.20231101010635-f1265d231d52/go.mod h1:6DAvZliBAAnD7rhVgwaM7DE5/d9NMOAJ09SqYqeK4QE=
55 | github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739 h1:IkjBCtQOOjIn03u/dMQK9g+Iw9ewps4mCl1nB8Sscbo=
56 | github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8=
57 | github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
58 | github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
59 | go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY=
60 | go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE=
61 | go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.7.0 h1:iNba3cIZTDPB2+IAbVY/3TUN+pCCLrNYo2GaGtsKBak=
62 | go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.7.0/go.mod h1:l5BDPiZ9FbeejzWTAX6BowMzQOM/GeaUQ6lr3sOcSkc=
63 | go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0 h1:FZ6ei8GFW7kyPYdxJaV2rgI6M+4tvZzhYsQ2wgyVC08=
64 | go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0/go.mod h1:MdEu/mC6j3D+tTEfvI15b5Ci2Fn7NneJ71YMoiS3tpI=
65 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 h1:K0XaT3DwHAcV4nKLzcQvwAgSyisUghWoY20I7huthMk=
66 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0/go.mod h1:B5Ki776z/MBnVha1Nzwp5arlzBbE3+1jk+pGmaP5HME=
67 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0 h1:FFeLy03iVTXP6ffeN2iXrxfGsZGCjVx0/4KlizjyBwU=
68 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0/go.mod h1:TMu73/k1CP8nBUpDLc71Wj/Kf7ZS9FK5b53VapRsP9o=
69 | go.opentelemetry.io/otel/log v0.7.0 h1:d1abJc0b1QQZADKvfe9JqqrfmPYQCz2tUSO+0XZmuV4=
70 | go.opentelemetry.io/otel/log v0.7.0/go.mod h1:2jf2z7uVfnzDNknKTO9G+ahcOAyWcp1fJmk/wJjULRo=
71 | go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE=
72 | go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY=
73 | go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk=
74 | go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0=
75 | go.opentelemetry.io/otel/sdk/log v0.7.0 h1:dXkeI2S0MLc5g0/AwxTZv6EUEjctiH8aG14Am56NTmQ=
76 | go.opentelemetry.io/otel/sdk/log v0.7.0/go.mod h1:oIRXpW+WD6M8BuGj5rtS0aRu/86cbDV/dAfNaZBIjYM=
77 | go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc=
78 | go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8=
79 | go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys=
80 | go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A=
81 | go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
82 | go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8=
83 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
84 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
85 | golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
86 | golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
87 | golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
88 | golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
89 | golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
90 | golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
91 | golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
92 | golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
93 | google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 h1:T6rh4haD3GVYsgEfWExoCZA2o2FmbNyKpTuAxbEFPTg=
94 | google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:wp2WsuBYj6j8wUdo3ToZsdxxixbvQNAHqVJrTgi5E5M=
95 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 h1:QCqS/PdaHTSWGvupk2F/ehwHtGc0/GYkT+3GAcR1CCc=
96 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI=
97 | google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E=
98 | google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
99 | google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
100 | google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
101 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
102 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
103 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
104 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
105 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
106 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Reverse Proxy Server
2 |
3 | [](https://github.com/hvuhsg/gatego/actions/workflows/go-tests.yml)
4 |
5 | ## Overview
6 |
7 | This reverse proxy server is designed to forward incoming requests to internal services, while offering advanced features such as SSL termination, rate limiting, content optimization, and OpenAPI-based request/response validation.
8 |
9 | ## Supported Features
10 |
11 | - 🔒 SSL Termination - HTTPS support with configurable SSL certificates
12 |
13 | - 🚀 Content Optimization
14 | - Minification for HTML, CSS, JS, XML, JSON, and SVG
15 | - GZIP compression support
16 |
17 |
18 | - ⚡ Performance Controls
19 | - Configurable request timeouts
20 | - Maximum request size limits
21 | - Response caching for cacheable content
22 |
23 |
24 | - 🛡️ Security & Protection
25 |
26 | - IP-based rate limiting (per minute/day)
27 | - Request/response validation via OpenAPI
28 | - Anomaly detection score (per session)
29 |
30 | - ⚖️ Load Balancing
31 |
32 | - Multiple backend server support
33 | - Round-robin, random, and least-latency policies
34 | - Weighted distribution options
35 |
36 |
37 | - 📁 File Serving - Static file serving with path stripping
38 |
39 | - 🏥 Health Monitoring
40 |
41 | - Automated health checks with cron scheduling
42 | Configurable failure notifications
43 |
44 |
45 | - 📊 Observability - OpenTelemetry integration for tracing and metrics
46 |
47 | ## More About The Features
48 | ### 1. SSL Termination
49 |
50 | The proxy supports secure connections through SSL, with configurable paths to the SSL key and certificate files. This allows for secure HTTPS communication between clients and the reverse proxy.
51 |
52 | ```yaml
53 | # Optional
54 | ssl:
55 | keyfile: /path/to/your/ssl/keyfile
56 | certfile: /path/to/your/ssl/certfile
57 | ```
58 |
59 | ### 2. Content Optimization
60 |
61 | - Minification: The server can minify content (e.g., HTML, CSS, JavaScript, XML, JSON, SVG) before forwarding it to the client, reducing response sizes and improving load times.
62 | - Compression: GZIP compression is supported to further reduce the size of responses, optimizing bandwidth usage.
63 |
64 | ```yaml
65 | - path: /
66 |
67 | # Optional
68 | minify: [js, html, css, json, xml, svg]
69 | # You can use 'all' instaed to enable all content-types
70 |
71 | # Optional
72 | gzip: true # Enable GZIP compression
73 | ```
74 |
75 |
76 | ### 3. Request Limits and Timeouts
77 |
78 | - Timeout: Custom timeouts can be set to avoid slow backend services from hanging client requests.
79 | - Maximum Request Size: Limits can be placed on the size of incoming requests to prevent excessively large payloads from overwhelming the server.
80 |
81 | ```yaml
82 | - path: /
83 | timeout: 5s # Custom timeout for backend responses (Default 30s)
84 | max_size: 2048 # Max request size in bytes (Default 10MB)
85 | ```
86 |
87 | ### 4. Rate Limiting
88 |
89 | Rate limiting can be applied to prevent abuse, restricting the number of requests an individual client (based on IP) can make within a specific time window. Multiple rate limit policies can be configured, such as:
90 | - Requests per minute from the same IP
91 | - Requests per day from the same IP
92 |
93 | ```yaml
94 | - path: /
95 |
96 | # Optional
97 | ratelimits:
98 | - ip-10/m # Limit to 10 requests per minute per IP
99 | - ip-500/d # Limit to 500 requests per day per IP
100 | ```
101 |
102 | ### 5. OpenAPI-based Request and Response Validation
103 |
104 | The server integrates OpenAPI for validating incoming requests and outgoing responses against an OpenAPI specification document. This ensures that:
105 |
106 | - Requests conform to the expected format, including parameters, headers, and body content.
107 | - Responses adhere to the defined API schema, ensuring consistent and reliable data exchange.
108 |
109 | You can specify the OpenAPI file path in the configuration, and the server will use it to validate the requests and responses automatically.
110 |
111 | ```yaml
112 | - path: /
113 |
114 | # Optional
115 | openapi: /path/to/openapi.yaml # OpenAPI file for request/response validation
116 | ```
117 |
118 |
119 | ### 6. Routing Anomaly Detection
120 |
121 | The Server will calculate an anomaly score for the request based on global avg routing and session avg routing.
122 | The score is added as a header to the request `X-Anomaly-Score`.
123 | The score ranging between 0 (normal request) to 1 (a-normal request)
124 |
125 | ```yaml
126 | services:
127 | - domain: your-domain.com
128 |
129 | # Will add to downstream request an header with routing anomaly score between 0 (normal) and 1 (suspicuse)
130 | anomaly_detection:
131 | active: true
132 | header_name: "X-Anomaly-Score" # (Optional) [Default: X-Anomaly-Score]
133 | min_score: 100 # (Optional) Every internal score below this number is 0 [Default: 100]
134 | max_score: 100 # (Optional) Every internal score above this number is 1 [Default: 200]
135 | treshold_for_rating: 100 # (Optional) The amount of requests to collect stats on before starting to rate anomaly [Default: 100]
136 | ```
137 |
138 |
139 | ### 7. Load Balancing and File Serving
140 |
141 | File serving is used when the `directory` field is set.
142 | > The endpoint path is removed from the request path before the file lookup. For example a path of /static and request path of /static/file.txt and a directory /var/www will search the file in /var/www/file.txt and not /var/www/static/file.txt
143 |
144 | ```yaml
145 | - path: /static
146 | directory: /var/www/
147 | ```
148 |
149 | The Server support load balancing between a number of backend servers and allow you to choose the balancing policy.
150 |
151 |
152 | ```yaml
153 | - path: /static
154 | backend:
155 | balance_policy: 'round-robin'
156 | servers:
157 | - url: http://backend-server-1/
158 | weight: 1
159 | - url: http://backend-server-2/
160 | weight: 2
161 | ```
162 |
163 | #### Supported Policies:
164 | - `round-robin` (affected by weights)
165 | - `random` (affected by weights)
166 | - `least-latency` (**not** affected by weights)
167 |
168 |
169 | ### 8. Health Checks
170 |
171 | The server supports automated health checks for backend services. You can configure periodic checks to monitor the health of your backend servers under each endpoint's configuration.
172 |
173 | ```yaml
174 | - path: /
175 | checks:
176 | - name: "Health Check" # Descriptive name for the check
177 | cron: "* * * * *" # Cron expression for check frequency
178 | # Supported cron macros:
179 | # - @yearly (or @annually) - Run once a year
180 | # - @monthly - Run once a month
181 | # - @weekly - Run once a week
182 | # - @daily - Run once a day
183 | # - @hourly - Run once an hour
184 | # - @minutely - Run once a minute
185 | method: GET # HTTP method for the health check
186 | url: "http://backend-server-1/up" # Health check endpoint
187 | timeout: 5s # Timeout for health check requests
188 | headers: # Optional custom headers
189 | Host: domain.org
190 | Authorization: "Bearer abc123"
191 | ```
192 |
193 | ### 9. OpenTelemetry Integration
194 | The server includes built-in support for OpenTelemetry, enabling comprehensive observability through distributed tracing, metrics, and logging. This integration helps monitor application performance, troubleshoot issues, and understand system behavior in distributed environments.
195 |
196 | ```yaml
197 | version: '...'
198 |
199 | open_telemetry:
200 | endpoint: "localhost:4317"
201 | sample_ratio: 0.01 # == 1%
202 | ```
203 |
204 | ## Configuration Example
205 |
206 | Here’s a generic example of how you can configure the reverse proxy:
207 |
208 | ```yaml
209 | version: '0.0.1'
210 | host: your-host
211 | port: your-port
212 |
213 | ssl:
214 | keyfile: /path/to/your/ssl/keyfile
215 | certfile: /path/to/your/ssl/certfile
216 |
217 | open_telemetry:
218 | endpoint: "localhost:4317"
219 | sample_ratio: 0.01 # == 1%
220 |
221 | services:
222 | - domain: your-domain.com
223 |
224 | # Will add to downstream request an header with routing anomaly score between 0 (normal) and 1 (suspicuse)
225 | anomaly_detection:
226 | active: true
227 | header_name: "X-Anomaly-Score" # (Optional) [Default: X-Anomaly-Score]
228 | min_score: 100 # (Optional) Every internal score below this number is 0 [Default: 100]
229 | max_score: 100 # (Optional) Every internal score above this number is 1 [Default: 200]
230 | treshold_for_rating: 100 # (Optional) The amount of requests to collect stats on before starting to rate anomaly [Default: 100]
231 |
232 | endpoints:
233 | - path: /your-endpoint # will be served for every request with path that start with /your-endpoint (Example: /your-endpoint/1)
234 |
235 | # directory: /home/yoyo/ # For static files serving
236 | # destination: http://your-backend-service/
237 | backend:
238 | balance_policy: 'round-robin' # Can be 'round-robin', 'random', or 'least-latency'
239 | servers:
240 | - url: http://backend-server-1/
241 | weight: 1
242 | - url: http://backend-server-2/
243 | weight: 2
244 |
245 | minify: [js, html, css, json, xml, svg]
246 | # You can use 'all' instaed to enable all content-types
247 |
248 | gzip: true # Enable GZIP compression
249 |
250 | timeout: 5s # Custom timeout for backend responses (Default 30s)
251 | max_size: 2048 # Max request size in bytes (Default 10MB)
252 |
253 | ratelimits:
254 | - ip-10/m # Limit to 10 requests per minute per IP
255 | - ip-500/d # Limit to 500 requests per day per IP
256 |
257 | openapi: /path/to/openapi.yaml # OpenAPI file for request/response validation
258 |
259 | omit_headers: [Server] # Omit response headers
260 |
261 | checks:
262 | - name: "Health Check"
263 |
264 | cron: "* * * * *" # == @minutely
265 | # Support cron format and macros.
266 | # Macros:
267 | # - @yearly
268 | # - @annually
269 | # - @monthly
270 | # - @weekly
271 | # - @daily
272 | # - @hourly
273 | # - @minutely
274 |
275 | method: GET # HTTP Method
276 | url: "http://backend-server-1/up"
277 | timeout: 5s
278 | headers:
279 | Host: domain.org
280 | Authorization: "Bearer abc123"
281 |
282 | # on_failure runs a shell command if the check fails. Expands $date, $error, $check_name.
283 | on_failure: |
284 | curl -d "Health check '$check_name' failed at $date due to: $error" ntfy.sh/gatego
285 | cache: true # Cache responses that has cache headers (Cache-Control and Expire)
286 |
287 | ```
288 |
289 | ### Breakdown
290 | The configuration is organized into three main sections:
291 |
292 | - Global Settings:
293 | - Server configuration (host, port)
294 | - SSL settings
295 | - OpenTelemetry configuration
296 |
297 |
298 | - Services
299 | - Domain-based routing
300 | - Multiple endpoints per domain
301 | - Path-based matching with longest-prefix wins
302 |
303 |
304 | - Endpoints
305 | - Backend service configuration
306 | - Performance optimizations
307 | - Security controls
308 | - Monitoring settings
309 |
310 | Each endpoint can be independently configured with its own set of features, allowing for flexible and granular control over different parts of your application.
311 |
312 | ## License
313 |
314 | This project is licensed under the MIT License.
315 |
--------------------------------------------------------------------------------
/internal/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "log"
7 | "net"
8 | "net/http"
9 | "net/url"
10 | "os"
11 | "regexp"
12 | "slices"
13 | "strconv"
14 | "strings"
15 | "time"
16 |
17 | "github.com/hashicorp/go-version"
18 | "github.com/hvuhsg/gatego/internal/middlewares"
19 | "github.com/hvuhsg/gatego/pkg/cron"
20 | "gopkg.in/yaml.v3"
21 | )
22 |
23 | const DefaultTimeout = time.Second * 30
24 | const DefaultMaxRequestSize = 1024 * 10 // 10 MB
25 | var SupportedBalancePolicies = []string{"round-robin", "random", "least-latency"}
26 |
27 | type Backend struct {
28 | BalancePolicy string `yaml:"balance_policy"`
29 | Servers []struct {
30 | URL string `yaml:"url"`
31 | Weight uint `yaml:"weight"`
32 | }
33 | }
34 |
35 | func (b Backend) validate() error {
36 | if !slices.Contains(SupportedBalancePolicies, b.BalancePolicy) {
37 | return fmt.Errorf("balance policy '%s' is not supported", b.BalancePolicy)
38 | }
39 |
40 | if len(b.Servers) == 0 {
41 | return errors.New("backend require at least one server")
42 | }
43 |
44 | for _, server := range b.Servers {
45 | if !isValidURL(server.URL) {
46 | return fmt.Errorf("invalid backend server url '%s'", server.URL)
47 | }
48 | }
49 |
50 | return nil
51 | }
52 |
53 | type Check struct {
54 | Name string `yaml:"name"`
55 | Cron string `yaml:"cron"`
56 | URL string `yaml:"url"`
57 | Method string `yaml:"method"`
58 | Timeout time.Duration `yaml:"timeout"`
59 | Headers map[string]string `yaml:"headers"`
60 | OnFailure string `yaml:"on_failure"`
61 | }
62 |
63 | func (c Check) validate() error {
64 | if len(c.Name) == 0 {
65 | return errors.New("check requires a name")
66 | }
67 |
68 | if _, err := cron.NewSchedule(c.Cron); err != nil {
69 | return errors.New("invalid check cron expression")
70 | }
71 |
72 | if !isValidURL(c.URL) {
73 | return errors.New("invalid check url")
74 | }
75 |
76 | if !isValidMethod(c.Method) {
77 | return errors.New("invalid check method")
78 | }
79 |
80 | return nil
81 | }
82 |
83 | type Path struct {
84 | Path string `yaml:"path"`
85 | Destination *string `yaml:"destination"` // The domain / url of the service server
86 | Directory *string `yaml:"directory"` // path to dir you want to serve
87 | Backend *Backend `yaml:"backend"` // List of servers to load balance between
88 | Headers *map[string]string `yaml:"headers"`
89 | OmitHeaders []string `yaml:"omit_headers"` // Omit specified headers
90 | Minify []string `yaml:"minify"`
91 | Gzip *bool `yaml:"gzip"`
92 | Timeout time.Duration `yaml:"timeout"`
93 | MaxSize uint64 `yaml:"max_size"`
94 | OpenAPI *string `yaml:"openapi"`
95 | RateLimits []string `yaml:"ratelimits"`
96 | Checks []Check `yaml:"checks"` // Automated checks
97 | Cache bool `yaml:"cache"` // Cache responses that has cache headers
98 | }
99 |
100 | func (p Path) validate() error {
101 | if p.Path[0] != '/' {
102 | return errors.New("path must start with '/'")
103 | }
104 |
105 | if p.Destination != nil {
106 | if !isValidURL(*p.Destination) {
107 | return errors.New("invalid destination url")
108 | }
109 |
110 | if p.Directory != nil {
111 | return errors.New("can't have destination and directory for the same path")
112 | }
113 | }
114 |
115 | if p.Directory != nil {
116 | if !isValidDir(*p.Directory) {
117 | return errors.New("invalid directory path")
118 | }
119 |
120 | if p.Cache {
121 | log.Println("[WARNING] Using cache while serving static files is not recommanded")
122 | }
123 | }
124 |
125 | if p.Backend != nil {
126 | if err := p.Backend.validate(); err != nil {
127 | return err
128 | }
129 | }
130 |
131 | if p.Destination == nil && p.Directory == nil && p.Backend == nil {
132 | return errors.New("path must have destination or directory or backend")
133 | }
134 |
135 | if p.OpenAPI != nil {
136 | if *p.OpenAPI == "" {
137 | return errors.New("openapi can't be empty (remove or fill)")
138 | }
139 |
140 | if !isValidFile(*p.OpenAPI) {
141 | return errors.New("invalid openapi spec path")
142 | }
143 | }
144 |
145 | for _, ratelimit := range p.RateLimits {
146 | _, err := middlewares.ParseLimitConfig(ratelimit)
147 | if err != nil {
148 | return fmt.Errorf("invalid ratelimit: %s", err.Error())
149 | }
150 | }
151 |
152 | for _, check := range p.Checks {
153 | if err := check.validate(); err != nil {
154 | return err
155 | }
156 | }
157 |
158 | return nil
159 | }
160 |
161 | type AnomalyDetection struct {
162 | HeaderName string `yaml:"header_name"`
163 | MinScore int `yaml:"min_score"`
164 | MaxScore int `yaml:"max_score"`
165 | TresholdForRating int `yaml:"treshold_for_rating"`
166 | Active bool `yaml:"active"`
167 | }
168 |
169 | func (a *AnomalyDetection) validate() error {
170 | if a.HeaderName == "" {
171 | a.HeaderName = "X-Anomaly-Score"
172 | }
173 |
174 | if a.MinScore == 0 {
175 | a.MinScore = 100
176 | }
177 |
178 | if a.MaxScore == 0 {
179 | a.MaxScore = 200
180 | }
181 |
182 | if a.TresholdForRating == 0 {
183 | a.TresholdForRating = 100
184 | }
185 |
186 | if a.MaxScore <= a.MinScore {
187 | return errors.New("anomaly detection maxScore MUST be grater the minScore")
188 | }
189 |
190 | return nil
191 | }
192 |
193 | type Service struct {
194 | Domain string `yaml:"domain"` // The domain / host the request was sent to
195 | Paths []Path `yaml:"endpoints"`
196 | AnomalyDetection *AnomalyDetection `yaml:"anomaly_detection"`
197 | }
198 |
199 | func (s Service) validate() error {
200 | if !isValidHostname(s.Domain) {
201 | return errors.New("invalid domain")
202 | }
203 |
204 | for _, path := range s.Paths {
205 | if err := path.validate(); err != nil {
206 | return err
207 | }
208 | }
209 |
210 | if s.AnomalyDetection != nil {
211 | if err := s.AnomalyDetection.validate(); err != nil {
212 | return err
213 | }
214 | }
215 |
216 | return nil
217 | }
218 |
219 | type TLS struct {
220 | Auto bool `yaml:"auto"`
221 | Domains []string `yaml:"domain"`
222 | Email *string `yaml:"email"`
223 | KeyFile *string `yaml:"keyfile"`
224 | CertFile *string `yaml:"certfile"`
225 | }
226 |
227 | func (tls TLS) validate() error {
228 | if tls.Auto {
229 | if len(tls.Domains) == 0 {
230 | return errors.New("when using the auto tls feature you MUST include a list of domains to issue certificates for")
231 | }
232 | if tls.Email == nil || len(*tls.Email) == 0 || !isValidEmail(*tls.Email) {
233 | return errors.New("when using the auto tls feature you MUST include a valid email for the lets-encrypt registration")
234 | }
235 | }
236 |
237 | if tls.CertFile != nil {
238 | if tls.KeyFile == nil {
239 | return errors.New("you MUST provide certfile AND keyfile")
240 | }
241 | }
242 |
243 | if tls.KeyFile != nil {
244 | if tls.CertFile == nil {
245 | return errors.New("you MUST provide certfile AND keyfile")
246 | }
247 |
248 | if !isValidFile(*tls.CertFile) {
249 | return errors.New("certfile path is invalid")
250 | }
251 |
252 | if !isValidFile(*tls.KeyFile) {
253 | return errors.New("keyfile path is invalid")
254 | }
255 | }
256 |
257 | return nil
258 | }
259 |
260 | type OTEL struct {
261 | Endpoint string `yaml:"endpoint"`
262 | SampleRatio float64 `yaml:"sample_ratio"`
263 | }
264 |
265 | func (otel OTEL) validate() error {
266 | if len(otel.Endpoint) > 0 {
267 | if err := isValidGRPCAddress(otel.Endpoint); err != nil {
268 | return err
269 | }
270 | }
271 |
272 | if otel.SampleRatio < 0 {
273 | return errors.New("OpenTelemetry sample ratio MUST be above 0")
274 | }
275 |
276 | if otel.SampleRatio == 0 {
277 | return errors.New("OpenTelemetry sample ratio is missing or equales to 0")
278 | }
279 |
280 | if otel.SampleRatio > 1 {
281 | return errors.New("OpenTelemetry sample ratio CAN NOT be above 1")
282 | }
283 |
284 | return nil
285 | }
286 |
287 | type Config struct {
288 | Version string `yaml:"version"`
289 | Host string `yaml:"host"` // listen host
290 | Port uint16 `yaml:"port"` // listen port
291 |
292 | OTEL *OTEL `yaml:"open_telemetry"`
293 |
294 | // TLS options
295 | TLS TLS `yaml:"ssl"`
296 |
297 | Services []Service `yaml:"services"`
298 | }
299 |
300 | func (c Config) Validate(currentVersion string) error {
301 | if c.Version == "" {
302 | return errors.New("version is required")
303 | }
304 |
305 | progVersion, _ := version.NewVersion(currentVersion)
306 | configVersion, err := version.NewVersion(c.Version)
307 | if err != nil {
308 | return errors.New("version is invalid")
309 | }
310 |
311 | if configVersion.Compare(progVersion) > 0 {
312 | return errors.New("config version is not supported (too advanced)")
313 | }
314 |
315 | if c.Host == "" {
316 | return errors.New("host is required")
317 | }
318 |
319 | if c.OTEL != nil {
320 | if err := (*c.OTEL).validate(); err != nil {
321 | return err
322 | }
323 | }
324 |
325 | if c.Port == 0 {
326 | return errors.New("port is required")
327 | }
328 |
329 | if err := c.TLS.validate(); err != nil {
330 | return err
331 | }
332 |
333 | if c.TLS.Auto && c.Port != 443 {
334 | return errors.New("the auto tls feature is only available if the server runs on port 443")
335 | }
336 |
337 | for _, service := range c.Services {
338 | if err := service.validate(); err != nil {
339 | return err
340 | }
341 | }
342 |
343 | return nil
344 | }
345 |
346 | func ParseConfig(filepath string, currentVersion string) (Config, error) {
347 | // Read the YAML file
348 | data, err := os.ReadFile(filepath)
349 | if err != nil {
350 | return Config{}, err
351 | }
352 |
353 | // Defaults
354 | c := Config{Port: 80}
355 |
356 | // Unmarshal the YAML data into the struct
357 | err = yaml.Unmarshal(data, &c)
358 | if err != nil {
359 | return Config{}, err
360 | }
361 |
362 | if err := c.Validate(currentVersion); err != nil {
363 | return Config{}, err
364 | }
365 |
366 | return c, nil
367 | }
368 |
369 | func isValidHostname(hostname string) bool {
370 | // Remove leading/trailing whitespace
371 | hostname = strings.TrimSpace(hostname)
372 |
373 | // Check if the hostname is empty
374 | if hostname == "" {
375 | return false
376 | }
377 |
378 | // Check if the hostname is too long (max 253 characters)
379 | if len(hostname) > 253 {
380 | return false
381 | }
382 |
383 | // Check for localhost
384 | if hostname == "localhost" {
385 | return true
386 | }
387 |
388 | // Check if it's an IP address (IPv4 or IPv6)
389 | if ip := net.ParseIP(hostname); ip != nil {
390 | return true
391 | }
392 |
393 | // Regular expression for domain validation
394 | // This regex allows for domains with multiple subdomains and supports IDNs
395 | domainRegex := regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$`)
396 |
397 | return domainRegex.MatchString(hostname)
398 | }
399 |
400 | func isValidURL(str string) bool {
401 | u, err := url.Parse(str)
402 | return err == nil && u.Scheme != "" && u.Host != ""
403 | }
404 |
405 | func isValidDir(path string) bool {
406 | if path == "" {
407 | return false
408 | }
409 |
410 | fileInfo, err := os.Stat(path)
411 | if err != nil {
412 | return false
413 | }
414 | return fileInfo.IsDir()
415 | }
416 |
417 | func isValidFile(path string) bool {
418 | if path == "" {
419 | return false
420 | }
421 |
422 | fileInfo, err := os.Stat(path)
423 | if err != nil {
424 | return false
425 | }
426 | return !fileInfo.IsDir()
427 | }
428 |
429 | func isValidMethod(method string) bool {
430 | methods := []string{
431 | http.MethodGet,
432 | http.MethodHead,
433 | http.MethodPost,
434 | http.MethodPut,
435 | http.MethodPatch,
436 | http.MethodDelete,
437 | http.MethodConnect,
438 | http.MethodOptions,
439 | http.MethodTrace,
440 | }
441 |
442 | return slices.Contains(methods, method)
443 | }
444 |
445 | func isValidGRPCAddress(address string) error {
446 | if address == "" {
447 | return fmt.Errorf("address cannot be empty")
448 | }
449 |
450 | // Split host and port
451 | host, portStr, err := net.SplitHostPort(address)
452 | if err != nil {
453 | return fmt.Errorf("invalid address format: %v", err)
454 | }
455 |
456 | // Validate port
457 | port, err := strconv.Atoi(portStr)
458 | if err != nil {
459 | return fmt.Errorf("invalid port number: %v", err)
460 | }
461 | if port < 1 || port > 65535 {
462 | return fmt.Errorf("port number must be between 1 and 65535")
463 | }
464 |
465 | // Empty host means localhost/0.0.0.0, which is valid
466 | if host == "" {
467 | return nil
468 | }
469 |
470 | // Check if host is IPv4 or IPv6
471 | if ip := net.ParseIP(host); ip != nil {
472 | return nil
473 | }
474 |
475 | // Validate hostname format
476 | hostnameRegex := regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9\-\.]*[a-zA-Z0-9])?$`)
477 | if !hostnameRegex.MatchString(host) {
478 | return fmt.Errorf("invalid hostname format")
479 | }
480 |
481 | // Check hostname length
482 | if len(host) > 253 {
483 | return fmt.Errorf("hostname too long")
484 | }
485 |
486 | // Validate hostname parts
487 | parts := strings.Split(host, ".")
488 | for _, part := range parts {
489 | if len(part) > 63 {
490 | return fmt.Errorf("hostname label too long")
491 | }
492 | }
493 |
494 | return nil
495 | }
496 |
497 | func isValidEmail(email string) bool {
498 | // Define a regular expression for valid email addresses
499 | var emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
500 |
501 | // Match the email string with the regular expression
502 | return emailRegex.MatchString(email)
503 | }
504 |
--------------------------------------------------------------------------------