├── .github └── workflows │ └── go-tests.yml ├── LICENCE ├── README.md ├── checks.go ├── cmd ├── config.yaml ├── main.go └── openapi.yaml ├── config-schema.json ├── dev ├── client │ └── main.go ├── otel │ ├── docker-compose.yaml │ └── otel-collector-config.yaml └── server │ └── main.go ├── gatego.go ├── go.mod ├── go.sum ├── handler.go ├── internal ├── config │ ├── config.go │ └── config_test.go ├── contextvalues │ ├── tracer.go │ └── version.go ├── handlers │ ├── balancer.go │ ├── balancer_test.go │ ├── files.go │ ├── files_test.go │ └── proxy.go └── middlewares │ ├── addheader.go │ ├── cache.go │ ├── cache_test.go │ ├── gzip.go │ ├── gzip_test.go │ ├── logging.go │ ├── logging_test.go │ ├── middleware.go │ ├── minify.go │ ├── minify_test.go │ ├── omit_headers.go │ ├── omit_headers_test.go │ ├── openapi.go │ ├── openapi_test.go │ ├── otel.go │ ├── ratelimit.go │ ├── ratelimit_test.go │ ├── responsecapture.go │ ├── security │ ├── routing_anomaly_score.go │ ├── routing_anomaly_score_test.go │ └── trackerhistory.go │ ├── sizelimiter.go │ ├── sizelimiter_test.go │ ├── timeout.go │ └── timeout_test.go ├── otel.go ├── pkg ├── cron │ ├── README.md │ ├── cron.go │ ├── cron_test.go │ ├── macros.go │ ├── schedule.go │ └── schedule_test.go ├── monitor │ ├── monitor.go │ └── monitor_test.go ├── multimux │ ├── multimux.go │ └── multimux_test.go ├── pathgraph │ └── pathgraph.go └── tracker │ ├── tracker.go │ └── tracker_test.go └── server.go /.github/workflows/go-tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Tests 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v4 21 | with: 22 | go-version: '1.22' 23 | 24 | - name: Test 25 | run: go test -v ./... 26 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 yehoyada 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reverse Proxy Server 2 | 3 | [![Tests](https://github.com/hvuhsg/gatego/actions/workflows/go-tests.yml/badge.svg?branch=main)](https://github.com/hvuhsg/gatego/actions/workflows/go-tests.yml) 4 | 5 | ## Overview 6 | 7 | This reverse proxy server is designed to forward incoming requests to internal services, while offering advanced features such as SSL termination, rate limiting, content optimization, and OpenAPI-based request/response validation. 8 | 9 | ## Supported Features 10 | 11 | - 🔒 SSL Termination - HTTPS support with configurable SSL certificates 12 | 13 | - 🚀 Content Optimization 14 | - Minification for HTML, CSS, JS, XML, JSON, and SVG 15 | - GZIP compression support 16 | 17 | 18 | - ⚡ Performance Controls 19 | - Configurable request timeouts 20 | - Maximum request size limits 21 | - Response caching for cacheable content 22 | 23 | 24 | - 🛡️ Security & Protection 25 | 26 | - IP-based rate limiting (per minute/day) 27 | - Request/response validation via OpenAPI 28 | - Anomaly detection score (per session) 29 | 30 | - ⚖️ Load Balancing 31 | 32 | - Multiple backend server support 33 | - Round-robin, random, and least-latency policies 34 | - Weighted distribution options 35 | 36 | 37 | - 📁 File Serving - Static file serving with path stripping 38 | 39 | - 🏥 Health Monitoring 40 | 41 | - Automated health checks with cron scheduling 42 | Configurable failure notifications 43 | 44 | 45 | - 📊 Observability - OpenTelemetry integration for tracing and metrics 46 | 47 | ## More About The Features 48 | ### 1. SSL Termination 49 | 50 | The proxy supports secure connections through SSL, with configurable paths to the SSL key and certificate files. This allows for secure HTTPS communication between clients and the reverse proxy. 51 | 52 | ```yaml 53 | # Optional 54 | ssl: 55 | keyfile: /path/to/your/ssl/keyfile 56 | certfile: /path/to/your/ssl/certfile 57 | ``` 58 | 59 | ### 2. Content Optimization 60 | 61 | - Minification: The server can minify content (e.g., HTML, CSS, JavaScript, XML, JSON, SVG) before forwarding it to the client, reducing response sizes and improving load times. 62 | - Compression: GZIP compression is supported to further reduce the size of responses, optimizing bandwidth usage. 63 | 64 | ```yaml 65 | - path: / 66 | 67 | # Optional 68 | minify: [js, html, css, json, xml, svg] 69 | # You can use 'all' instaed to enable all content-types 70 | 71 | # Optional 72 | gzip: true # Enable GZIP compression 73 | ``` 74 | 75 | 76 | ### 3. Request Limits and Timeouts 77 | 78 | - Timeout: Custom timeouts can be set to avoid slow backend services from hanging client requests. 79 | - Maximum Request Size: Limits can be placed on the size of incoming requests to prevent excessively large payloads from overwhelming the server. 80 | 81 | ```yaml 82 | - path: / 83 | timeout: 5s # Custom timeout for backend responses (Default 30s) 84 | max_size: 2048 # Max request size in bytes (Default 10MB) 85 | ``` 86 | 87 | ### 4. Rate Limiting 88 | 89 | Rate limiting can be applied to prevent abuse, restricting the number of requests an individual client (based on IP) can make within a specific time window. Multiple rate limit policies can be configured, such as: 90 | - Requests per minute from the same IP 91 | - Requests per day from the same IP 92 | 93 | ```yaml 94 | - path: / 95 | 96 | # Optional 97 | ratelimits: 98 | - ip-10/m # Limit to 10 requests per minute per IP 99 | - ip-500/d # Limit to 500 requests per day per IP 100 | ``` 101 | 102 | ### 5. OpenAPI-based Request and Response Validation 103 | 104 | The server integrates OpenAPI for validating incoming requests and outgoing responses against an OpenAPI specification document. This ensures that: 105 | 106 | - Requests conform to the expected format, including parameters, headers, and body content. 107 | - Responses adhere to the defined API schema, ensuring consistent and reliable data exchange. 108 | 109 | You can specify the OpenAPI file path in the configuration, and the server will use it to validate the requests and responses automatically. 110 | 111 | ```yaml 112 | - path: / 113 | 114 | # Optional 115 | openapi: /path/to/openapi.yaml # OpenAPI file for request/response validation 116 | ``` 117 | 118 | 119 | ### 6. Routing Anomaly Detection 120 | 121 | The Server will calculate an anomaly score for the request based on global avg routing and session avg routing. 122 | The score is added as a header to the request `X-Anomaly-Score`. 123 | The score ranging between 0 (normal request) to 1 (a-normal request) 124 | 125 | ```yaml 126 | services: 127 | - domain: your-domain.com 128 | 129 | # Will add to downstream request an header with routing anomaly score between 0 (normal) and 1 (suspicuse) 130 | anomaly_detection: 131 | active: true 132 | header_name: "X-Anomaly-Score" # (Optional) [Default: X-Anomaly-Score] 133 | min_score: 100 # (Optional) Every internal score below this number is 0 [Default: 100] 134 | max_score: 100 # (Optional) Every internal score above this number is 1 [Default: 200] 135 | treshold_for_rating: 100 # (Optional) The amount of requests to collect stats on before starting to rate anomaly [Default: 100] 136 | ``` 137 | 138 | 139 | ### 7. Load Balancing and File Serving 140 | 141 | File serving is used when the `directory` field is set. 142 | > The endpoint path is removed from the request path before the file lookup. For example a path of /static and request path of /static/file.txt and a directory /var/www will search the file in /var/www/file.txt and not /var/www/static/file.txt 143 | 144 | ```yaml 145 | - path: /static 146 | directory: /var/www/ 147 | ``` 148 | 149 | The Server support load balancing between a number of backend servers and allow you to choose the balancing policy. 150 | 151 | 152 | ```yaml 153 | - path: /static 154 | backend: 155 | balance_policy: 'round-robin' 156 | servers: 157 | - url: http://backend-server-1/ 158 | weight: 1 159 | - url: http://backend-server-2/ 160 | weight: 2 161 | ``` 162 | 163 | #### Supported Policies: 164 | - `round-robin` (affected by weights) 165 | - `random` (affected by weights) 166 | - `least-latency` (**not** affected by weights) 167 | 168 | 169 | ### 8. Health Checks 170 | 171 | The server supports automated health checks for backend services. You can configure periodic checks to monitor the health of your backend servers under each endpoint's configuration. 172 | 173 | ```yaml 174 | - path: / 175 | checks: 176 | - name: "Health Check" # Descriptive name for the check 177 | cron: "* * * * *" # Cron expression for check frequency 178 | # Supported cron macros: 179 | # - @yearly (or @annually) - Run once a year 180 | # - @monthly - Run once a month 181 | # - @weekly - Run once a week 182 | # - @daily - Run once a day 183 | # - @hourly - Run once an hour 184 | # - @minutely - Run once a minute 185 | method: GET # HTTP method for the health check 186 | url: "http://backend-server-1/up" # Health check endpoint 187 | timeout: 5s # Timeout for health check requests 188 | headers: # Optional custom headers 189 | Host: domain.org 190 | Authorization: "Bearer abc123" 191 | ``` 192 | 193 | ### 9. OpenTelemetry Integration 194 | The server includes built-in support for OpenTelemetry, enabling comprehensive observability through distributed tracing, metrics, and logging. This integration helps monitor application performance, troubleshoot issues, and understand system behavior in distributed environments. 195 | 196 | ```yaml 197 | version: '...' 198 | 199 | open_telemetry: 200 | endpoint: "localhost:4317" 201 | sample_ratio: 0.01 # == 1% 202 | ``` 203 | 204 | ## Configuration Example 205 | 206 | Here’s a generic example of how you can configure the reverse proxy: 207 | 208 | ```yaml 209 | version: '0.0.1' 210 | host: your-host 211 | port: your-port 212 | 213 | ssl: 214 | keyfile: /path/to/your/ssl/keyfile 215 | certfile: /path/to/your/ssl/certfile 216 | 217 | open_telemetry: 218 | endpoint: "localhost:4317" 219 | sample_ratio: 0.01 # == 1% 220 | 221 | services: 222 | - domain: your-domain.com 223 | 224 | # Will add to downstream request an header with routing anomaly score between 0 (normal) and 1 (suspicuse) 225 | anomaly_detection: 226 | active: true 227 | header_name: "X-Anomaly-Score" # (Optional) [Default: X-Anomaly-Score] 228 | min_score: 100 # (Optional) Every internal score below this number is 0 [Default: 100] 229 | max_score: 100 # (Optional) Every internal score above this number is 1 [Default: 200] 230 | treshold_for_rating: 100 # (Optional) The amount of requests to collect stats on before starting to rate anomaly [Default: 100] 231 | 232 | endpoints: 233 | - path: /your-endpoint # will be served for every request with path that start with /your-endpoint (Example: /your-endpoint/1) 234 | 235 | # directory: /home/yoyo/ # For static files serving 236 | # destination: http://your-backend-service/ 237 | backend: 238 | balance_policy: 'round-robin' # Can be 'round-robin', 'random', or 'least-latency' 239 | servers: 240 | - url: http://backend-server-1/ 241 | weight: 1 242 | - url: http://backend-server-2/ 243 | weight: 2 244 | 245 | minify: [js, html, css, json, xml, svg] 246 | # You can use 'all' instaed to enable all content-types 247 | 248 | gzip: true # Enable GZIP compression 249 | 250 | timeout: 5s # Custom timeout for backend responses (Default 30s) 251 | max_size: 2048 # Max request size in bytes (Default 10MB) 252 | 253 | ratelimits: 254 | - ip-10/m # Limit to 10 requests per minute per IP 255 | - ip-500/d # Limit to 500 requests per day per IP 256 | 257 | openapi: /path/to/openapi.yaml # OpenAPI file for request/response validation 258 | 259 | omit_headers: [Server] # Omit response headers 260 | 261 | checks: 262 | - name: "Health Check" 263 | 264 | cron: "* * * * *" # == @minutely 265 | # Support cron format and macros. 266 | # Macros: 267 | # - @yearly 268 | # - @annually 269 | # - @monthly 270 | # - @weekly 271 | # - @daily 272 | # - @hourly 273 | # - @minutely 274 | 275 | method: GET # HTTP Method 276 | url: "http://backend-server-1/up" 277 | timeout: 5s 278 | headers: 279 | Host: domain.org 280 | Authorization: "Bearer abc123" 281 | 282 | # on_failure runs a shell command if the check fails. Expands $date, $error, $check_name. 283 | on_failure: | 284 | curl -d "Health check '$check_name' failed at $date due to: $error" ntfy.sh/gatego 285 | cache: true # Cache responses that has cache headers (Cache-Control and Expire) 286 | 287 | ``` 288 | 289 | ### Breakdown 290 | The configuration is organized into three main sections: 291 | 292 | - Global Settings: 293 | - Server configuration (host, port) 294 | - SSL settings 295 | - OpenTelemetry configuration 296 | 297 | 298 | - Services 299 | - Domain-based routing 300 | - Multiple endpoints per domain 301 | - Path-based matching with longest-prefix wins 302 | 303 | 304 | - Endpoints 305 | - Backend service configuration 306 | - Performance optimizations 307 | - Security controls 308 | - Monitoring settings 309 | 310 | Each endpoint can be independently configured with its own set of features, allowing for flexible and granular control over different parts of your application. 311 | 312 | ## License 313 | 314 | This project is licensed under the MIT License. 315 | -------------------------------------------------------------------------------- /checks.go: -------------------------------------------------------------------------------- 1 | package gatego 2 | 3 | import ( 4 | "github.com/hvuhsg/gatego/internal/config" 5 | "github.com/hvuhsg/gatego/pkg/monitor" 6 | ) 7 | 8 | func createMonitorChecks(services []config.Service) []monitor.Check { 9 | checks := make([]monitor.Check, 0) 10 | for _, service := range services { 11 | for _, path := range service.Paths { 12 | for _, checkConfig := range path.Checks { 13 | check := monitor.Check{ 14 | Name: checkConfig.Name, 15 | Cron: checkConfig.Cron, 16 | URL: checkConfig.URL, 17 | Method: checkConfig.Method, 18 | Timeout: checkConfig.Timeout, 19 | Headers: checkConfig.Headers, 20 | OnFailure: checkConfig.OnFailure, 21 | } 22 | 23 | checks = append(checks, check) 24 | } 25 | } 26 | } 27 | 28 | return checks 29 | } 30 | -------------------------------------------------------------------------------- /cmd/config.yaml: -------------------------------------------------------------------------------- 1 | # yaml-language-server: $schema=https://raw.githubusercontent.com/hvuhsg/gatego/refs/heads/main/config-schema.json 2 | 3 | version: '0.0.1' 4 | host: localhost 5 | port: 8004 6 | 7 | # open_telemetry: 8 | # endpoint: "localhost:4317" 9 | # sample_ratio: 1 10 | 11 | services: 12 | - domain: localhost 13 | 14 | anomaly_detection: 15 | active: true 16 | 17 | endpoints: 18 | - path: / 19 | # directory: /home/yoyo/ # Instead of destination 20 | destination: http://127.0.0.1:4007/ 21 | # backend: 22 | # balance_policy: 'least-latency' # Can be 'round-robin', 'random', or 'least-latency' 23 | # servers: 24 | # - url: http://127.0.0.1:4007/ 25 | # weight: 1 26 | # - url: http://127.0.0.1:4008/ 27 | # weight: 2 28 | 29 | minify: [js, html, css, json, xml, svg] 30 | 31 | gzip: true 32 | 33 | timeout: 3s # Default (30s) 34 | max_size: 1024 # Default (10MB) 35 | 36 | ratelimits: 37 | - ip-60/m # Limit requests from the same IP to 6 requests per minute. 38 | - ip-100/d 39 | 40 | openapi: openapi.yaml 41 | 42 | checks: 43 | - name: "DB Health" 44 | cron: "* * * * *" 45 | method: GET 46 | url: "http://127.0.0.1:4007/check_db" 47 | timeout: 5s 48 | headers: 49 | Host: domain.org 50 | Authorization: "Bearer abc123" 51 | on_failure: | 52 | echo Health check '$check_name' failed at $date with error: $error 53 | 54 | omit_headers: [Authorization, X-API-Key, X-Secret-Token] 55 | 56 | cache: true 57 | -------------------------------------------------------------------------------- /cmd/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | "os/signal" 8 | 9 | "github.com/hvuhsg/gatego" 10 | "github.com/hvuhsg/gatego/internal/config" 11 | ) 12 | 13 | const version = "0.0.1" 14 | 15 | func main() { 16 | // Handle SIGINT (CTRL+C) gracefully. 17 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) 18 | defer stop() 19 | 20 | config, err := config.ParseConfig("config.yaml", version) 21 | if err != nil { 22 | log.Fatal(err) 23 | } 24 | 25 | log.Default().Println("Config loaded successfully") 26 | 27 | server := gatego.New(ctx, config, version) 28 | 29 | err = server.Run() 30 | if err != nil { 31 | log.Fatalln(err) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /cmd/openapi.yaml: -------------------------------------------------------------------------------- 1 | openapi: 3.1.0 2 | 3 | info: 4 | title: Simple API 5 | version: 1.0.0 6 | description: A simple API with one root path and one query parameter 7 | 8 | paths: 9 | /: 10 | post: 11 | summary: Root endpoint 12 | description: Returns a greeting message 13 | parameters: 14 | - in: query 15 | name: name 16 | schema: 17 | type: string 18 | maxLength: 10 19 | required: true 20 | description: Name of the person to greet 21 | responses: 22 | '200': 23 | description: Successful response 24 | content: 25 | application/json: 26 | schema: 27 | type: object 28 | properties: 29 | message: 30 | type: string 31 | example: "Hello, World!" 32 | '400': 33 | description: Bad request 34 | content: 35 | application/json: 36 | schema: 37 | type: object 38 | properties: 39 | error: 40 | type: string 41 | example: "Invalid query parameter" -------------------------------------------------------------------------------- /config-schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-07/schema#", 3 | "type": "object", 4 | "properties": { 5 | "version": { 6 | "type": "string", 7 | "description": "Version of the configuration." 8 | }, 9 | "host": { 10 | "type": "string", 11 | "description": "The host where the service will run." 12 | }, 13 | "port": { 14 | "type": "integer", 15 | "description": "The port for the service." 16 | }, 17 | "ssl": { 18 | "type": "object", 19 | "properties": { 20 | "keyfile": { 21 | "type": "string", 22 | "description": "Path to SSL key file." 23 | }, 24 | "certfile": { 25 | "type": "string", 26 | "description": "Path to SSL certificate file." 27 | } 28 | }, 29 | "required": [ 30 | "keyfile", 31 | "certfile" 32 | ], 33 | "description": "SSL configuration for the server." 34 | }, 35 | "open_telemetry": { 36 | "type": "object", 37 | "properties": { 38 | "endpoint": { 39 | "type": "string", 40 | "description": "GRPC connection string for open telemetry collection agent" 41 | }, 42 | "sample_ratio": { 43 | "type":"number", 44 | "exclusiveMinimum": 0, 45 | "maximum": 1 46 | } 47 | }, 48 | "required": ["sample_ratio", "endpoint"] 49 | }, 50 | "services": { 51 | "type": "array", 52 | "items": { 53 | "type": "object", 54 | "properties": { 55 | "domain": { 56 | "type": "string", 57 | "description": "Domain name for the service." 58 | }, 59 | "anomaly_detection": { 60 | "type": "object", 61 | "description": "Adds header to downstream request with routing anomaly score between 0 to 1", 62 | "properties": { 63 | "header_name": { 64 | "type":"string", 65 | "description": "The header name that will hold the anomaly score [Default X-Anomaly-Score]" 66 | }, 67 | "min_score": { 68 | "type":"integer", 69 | "default": 100, 70 | "description": "Below that score the anomaly score is 0", 71 | "minimum": 0 72 | }, 73 | "max_score": { 74 | "type":"integer", 75 | "default": 200, 76 | "description": "Above that score the anomaly score is 1", 77 | "minimum": 0 78 | }, 79 | "treshold_for_rating": { 80 | "type": "integer", 81 | "default": 100, 82 | "description": "How many requests to collect data from before starting to calculate anomaly score", 83 | "minimum": 0 84 | }, 85 | "active": { 86 | "type":"boolean", 87 | "description": "Activate the anomaly detector" 88 | } 89 | } 90 | }, 91 | "endpoints": { 92 | "type": "array", 93 | "items": { 94 | "type": "object", 95 | "properties": { 96 | "path": { 97 | "type": "string", 98 | "description": "Endpoint path that will be served." 99 | }, 100 | "directory": { 101 | "type": "string", 102 | "description": "Directory to serve files from." 103 | }, 104 | "destination": { 105 | "type": "string", 106 | "description": "Server URL to proxy the requests there." 107 | }, 108 | "backend": { 109 | "type": "object", 110 | "properties": { 111 | "balance_policy": { 112 | "type": "string", 113 | "enum": [ 114 | "round-robin", 115 | "random", 116 | "least-latency" 117 | ], 118 | "description": "Load balancing policy for backend servers." 119 | }, 120 | "servers": { 121 | "type": "array", 122 | "items": { 123 | "type": "object", 124 | "properties": { 125 | "url": { 126 | "type": "string", 127 | "description": "URL of the backend server." 128 | }, 129 | "weight": { 130 | "type": "integer", 131 | "description": "Weight of the backend server for load balancing." 132 | } 133 | }, 134 | "required": [ 135 | "url", 136 | "weight" 137 | ] 138 | } 139 | } 140 | }, 141 | "required": [ 142 | "balance_policy", 143 | "servers" 144 | ] 145 | }, 146 | "omit_headers": { 147 | "type": "array", 148 | "description": "List of headers to omit for secrets protection.", 149 | "items": { 150 | "type": "string" 151 | } 152 | }, 153 | "headers": { 154 | "type": "array", 155 | "description": "List of headers to add to request.", 156 | "items": { 157 | "type": "string" 158 | } 159 | }, 160 | "minify": { 161 | "type": "array", 162 | "items": { 163 | "type": "string" 164 | } 165 | }, 166 | "gzip": { 167 | "type": "boolean", 168 | "description": "Enable GZIP compression." 169 | }, 170 | "timeout": { 171 | "type": "string", 172 | "description": "Custom timeout for backend responses." 173 | }, 174 | "max_size": { 175 | "type": "integer", 176 | "description": "Max request size in bytes." 177 | }, 178 | "ratelimits": { 179 | "type": "array", 180 | "items": { 181 | "type": "string", 182 | "description": "Rate limits in the format of requests per time period (e.g., ip-10/m)." 183 | } 184 | }, 185 | "openapi": { 186 | "type": "string", 187 | "description": "Path to the OpenAPI specification for request/response validation." 188 | }, 189 | "checks": { 190 | "type": "array", 191 | "description": "List of health check configurations", 192 | "items": { 193 | "type": "object", 194 | "required": [ 195 | "name", 196 | "cron", 197 | "method", 198 | "url", 199 | "timeout" 200 | ], 201 | "properties": { 202 | "name": { 203 | "type": "string", 204 | "description": "Descriptive name for the health check", 205 | "minLength": 1 206 | }, 207 | "cron": { 208 | "type": "string", 209 | "description": "Cron expression or macro for check frequency", 210 | "pattern": "^(@yearly|@annually|@monthly|@weekly|@daily|@hourly|@minutely|([*\\d,-/]+\\s){4}[*\\d,-/]+)$", 211 | "examples": [ 212 | "* * * * *", 213 | "@hourly", 214 | "@daily", 215 | "0 0 * * *" 216 | ] 217 | }, 218 | "method": { 219 | "type": "string", 220 | "description": "HTTP method for the health check", 221 | "enum": [ 222 | "GET", 223 | "POST", 224 | "PUT", 225 | "DELETE", 226 | "HEAD", 227 | "OPTIONS", 228 | "PATCH", 229 | "CONNECT", 230 | "TRACE" 231 | ] 232 | }, 233 | "url": { 234 | "type": "string", 235 | "description": "Health check endpoint URL", 236 | "format": "uri", 237 | "pattern": "^https?://" 238 | }, 239 | "timeout": { 240 | "type": "string", 241 | "description": "Timeout duration for health check requests", 242 | "pattern": "^\\d+[smh]$", 243 | "default": "5s", 244 | "examples": [ 245 | "5s", 246 | "1m", 247 | "1h" 248 | ] 249 | }, 250 | "headers": { 251 | "type": "object", 252 | "description": "Custom headers to be sent with the health check request", 253 | "additionalProperties": { 254 | "type": "string" 255 | }, 256 | "examples": [ 257 | { 258 | "Host": "domain.org", 259 | "Authorization": "Bearer abc123" 260 | } 261 | ] 262 | }, 263 | "on_failure": { 264 | "type": "string", 265 | "description": "Shell command to execute if the health check fails. Supports variable expansion: $date, $error, and $check_name.", 266 | "examples": [ 267 | "echo Health check '$check_name' failed at $date with error: $error" 268 | ] 269 | } 270 | } 271 | } 272 | }, 273 | "cache": { 274 | "type": "boolean", 275 | "description": "Enable caching of response that has cache headers" 276 | } 277 | }, 278 | "required": [ 279 | "path" 280 | ], 281 | "oneOf": [ 282 | { 283 | "required": [ 284 | "directory" 285 | ] 286 | }, 287 | { 288 | "required": [ 289 | "destination" 290 | ] 291 | }, 292 | { 293 | "required": [ 294 | "backend" 295 | ] 296 | } 297 | ] 298 | } 299 | } 300 | }, 301 | "required": [ 302 | "domain", 303 | "endpoints" 304 | ] 305 | } 306 | } 307 | }, 308 | "required": [ 309 | "version", 310 | "host", 311 | "port", 312 | "services" 313 | ] 314 | } -------------------------------------------------------------------------------- /dev/client/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net/http" 10 | ) 11 | 12 | func sendRequest() *http.Response { 13 | // Sample data to send in JSON format 14 | data := map[string]interface{}{ 15 | "key1": "value1", 16 | "key2": "value2", 17 | "key3": 123, 18 | } 19 | 20 | // Convert the data to JSON 21 | jsonData, err := json.Marshal(data) 22 | if err != nil { 23 | log.Fatal("Error marshaling JSON:", err) 24 | } 25 | 26 | // Create a new POST request with the JSON payload 27 | req, err := http.NewRequest(http.MethodPost, "http://localhost:8004/?name=yoyo", bytes.NewBuffer(jsonData)) 28 | if err != nil { 29 | log.Fatal("Error creating request:", err) 30 | } 31 | 32 | // Set the appropriate Content-Type header for JSON 33 | req.Header.Set("Content-Type", "application/json") 34 | 35 | // Send the POST request 36 | client := http.DefaultClient 37 | response, err := client.Do(req) 38 | if err != nil { 39 | log.Fatal("Error sending request:", err) 40 | } 41 | 42 | return response 43 | } 44 | func main() { 45 | resp := sendRequest() 46 | defer resp.Body.Close() // Always defer closing the response body 47 | 48 | // Check the response status code 49 | if resp.StatusCode > 299 { 50 | log.Printf("Error: received status code %d", resp.StatusCode) 51 | } 52 | 53 | fmt.Println(resp) 54 | 55 | // Read the response body 56 | data, err := io.ReadAll(resp.Body) 57 | if err != nil { 58 | log.Fatal(err) 59 | } 60 | 61 | // Print the response body 62 | fmt.Println(string(data)) 63 | } 64 | -------------------------------------------------------------------------------- /dev/otel/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | # Jaeger 3 | jaeger: 4 | image: jaegertracing/all-in-one:latest 5 | ports: 6 | - "16686:16686" # Jaeger UI 7 | - "14250:14250" # Model used by collector 8 | environment: 9 | - COLLECTOR_OTLP_ENABLED=true 10 | 11 | # OpenTelemetry Collector 12 | otel-collector: 13 | image: otel/opentelemetry-collector-contrib:latest 14 | command: ["--config=/etc/otel-collector-config.yaml"] 15 | volumes: 16 | - ./otel-collector-config.yaml:/etc/otel-collector-config.yaml 17 | ports: 18 | - "4317:4317" # OTLP gRPC receiver 19 | - "4318:4318" # OTLP http receiver 20 | - "8888:8888" # Prometheus metrics exposed by the collector 21 | - "8889:8889" # Prometheus exporter metrics 22 | - "13133:13133" # Health check extension 23 | depends_on: 24 | - jaeger -------------------------------------------------------------------------------- /dev/otel/otel-collector-config.yaml: -------------------------------------------------------------------------------- 1 | receivers: 2 | otlp: 3 | protocols: 4 | grpc: 5 | endpoint: 0.0.0.0:4317 6 | http: 7 | endpoint: 0.0.0.0:4318 8 | 9 | processors: 10 | batch: 11 | timeout: 1s 12 | send_batch_size: 1024 13 | 14 | memory_limiter: 15 | check_interval: 1s 16 | limit_mib: 1000 17 | spike_limit_mib: 200 18 | 19 | exporters: 20 | otlp: 21 | endpoint: "jaeger:4317" 22 | tls: 23 | insecure: true 24 | 25 | debug: 26 | verbosity: detailed 27 | 28 | extensions: 29 | health_check: 30 | endpoint: 0.0.0.0:13133 31 | 32 | service: 33 | extensions: [health_check] 34 | pipelines: 35 | traces: 36 | receivers: [otlp] 37 | processors: [memory_limiter, batch] 38 | exporters: [otlp, debug] -------------------------------------------------------------------------------- /dev/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/http" 7 | ) 8 | 9 | func main() { 10 | server := http.NewServeMux() 11 | 12 | server.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 13 | fmt.Printf("%s, %s, %s, %v\n", r.Proto, r.Host, r.URL, r.Header) 14 | w.Header().Set("Content-Type", "application/json") 15 | w.WriteHeader(200) 16 | w.Write([]byte(`{ "hello" : 1.5 , "good" : true }`)) 17 | }) 18 | 19 | fmt.Println("Running server at '127.0.0.1:4007'") 20 | 21 | err := http.ListenAndServe("127.0.0.1:4007", server) 22 | if err != nil { 23 | log.Fatal(err) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /gatego.go: -------------------------------------------------------------------------------- 1 | package gatego 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/hvuhsg/gatego/internal/config" 9 | "github.com/hvuhsg/gatego/internal/contextvalues" 10 | "github.com/hvuhsg/gatego/pkg/monitor" 11 | ) 12 | 13 | const serviceName = "gatego" 14 | 15 | type GateGo struct { 16 | config config.Config 17 | monitor *monitor.Monitor 18 | ctx context.Context 19 | } 20 | 21 | func New(ctx context.Context, config config.Config, version string) *GateGo { 22 | ctx = contextvalues.AddVersionToContext(ctx, version) 23 | return &GateGo{config: config, ctx: ctx} 24 | } 25 | 26 | func (gg GateGo) Run() error { 27 | useOtel := gg.config.OTEL != nil 28 | if useOtel { 29 | otelConfig := otelConfig{ 30 | ServiceName: serviceName, 31 | SampleRatio: gg.config.OTEL.SampleRatio, 32 | CollectorTimeout: time.Second * 5, // TODO: Add to config 33 | TraceCollectorEndpoint: gg.config.OTEL.Endpoint, 34 | MetricCollectorEndpoint: gg.config.OTEL.Endpoint, 35 | LogsCollectorEndpoint: gg.config.OTEL.Endpoint, 36 | } 37 | shutdown, err := setupOTelSDK(gg.ctx, otelConfig) 38 | if err != nil { 39 | return err 40 | } 41 | defer shutdown(context.Background()) 42 | } 43 | 44 | // Create checks start monitoring 45 | healthChecks := createMonitorChecks(gg.config.Services) 46 | gg.monitor = monitor.New(time.Second*5, healthChecks...) 47 | gg.monitor.Start() 48 | 49 | server, err := newServer(gg.ctx, gg.config, useOtel) 50 | if err != nil { 51 | return err 52 | } 53 | defer server.Shutdown(gg.ctx) 54 | 55 | serveErrChan, err := server.serve(gg.config.TLS.CertFile, gg.config.TLS.KeyFile) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | // Wait for interruption. 61 | select { 62 | case err = <-serveErrChan: 63 | return err 64 | case <-gg.ctx.Done(): 65 | fmt.Println("\nShutting down...") 66 | return server.Shutdown(context.Background()) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hvuhsg/gatego 2 | 3 | go 1.22.0 4 | 5 | require gopkg.in/yaml.v3 v3.0.1 6 | 7 | require ( 8 | github.com/hashicorp/go-version v1.7.0 9 | github.com/tdewolff/minify/v2 v2.21.0 10 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 11 | go.opentelemetry.io/otel/log v0.7.0 12 | go.opentelemetry.io/otel/sdk/log v0.7.0 13 | go.opentelemetry.io/otel/trace v1.31.0 14 | ) 15 | 16 | require ( 17 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect 18 | github.com/go-logr/logr v1.4.2 // indirect 19 | github.com/go-logr/stdr v1.2.2 // indirect 20 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect 21 | go.opentelemetry.io/otel/metric v1.31.0 // indirect 22 | go.opentelemetry.io/proto/otlp v1.3.1 // indirect 23 | golang.org/x/sys v0.26.0 // indirect 24 | golang.org/x/text v0.19.0 // indirect 25 | google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect 26 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 // indirect 27 | google.golang.org/grpc v1.67.1 // indirect 28 | google.golang.org/protobuf v1.35.1 // indirect 29 | ) 30 | 31 | require ( 32 | github.com/davecgh/go-spew v1.1.1 // indirect 33 | github.com/getkin/kin-openapi v0.128.0 34 | github.com/go-openapi/jsonpointer v0.21.0 // indirect 35 | github.com/go-openapi/swag v0.23.0 // indirect 36 | github.com/google/uuid v1.6.0 37 | github.com/gorilla/mux v1.8.0 // indirect 38 | github.com/invopop/yaml v0.3.1 // indirect 39 | github.com/josharian/intern v1.0.0 // indirect 40 | github.com/mailru/easyjson v0.7.7 // indirect 41 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect 42 | github.com/patrickmn/go-cache v2.1.0+incompatible 43 | github.com/perimeterx/marshmallow v1.1.5 // indirect 44 | github.com/pmezard/go-difflib v1.0.0 // indirect 45 | github.com/stretchr/testify v1.9.0 46 | github.com/tdewolff/parse/v2 v2.7.17 // indirect 47 | go.opentelemetry.io/otel v1.31.0 48 | go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.7.0 49 | go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0 50 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0 51 | go.opentelemetry.io/otel/sdk v1.31.0 52 | go.opentelemetry.io/otel/sdk/metric v1.31.0 53 | golang.org/x/net v0.30.0 54 | golang.org/x/time v0.7.0 55 | ) 56 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= 2 | github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/getkin/kin-openapi v0.128.0 h1:jqq3D9vC9pPq1dGcOCv7yOp1DaEe7c/T1vzcLbITSp4= 6 | github.com/getkin/kin-openapi v0.128.0/go.mod h1:OZrfXzUfGrNbsKj+xmFBx6E5c6yH3At/tAKSc2UszXM= 7 | github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= 8 | github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= 9 | github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 10 | github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= 11 | github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= 12 | github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= 13 | github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= 14 | github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= 15 | github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= 16 | github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= 17 | github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= 18 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 19 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 20 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 21 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 22 | github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= 23 | github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= 24 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjwqUPTYmYuemVOx+Ys= 25 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I= 26 | github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= 27 | github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= 28 | github.com/invopop/yaml v0.3.1 h1:f0+ZpmhfBSS4MhG+4HYseMdJhoeeopbSKbq5Rpeelso= 29 | github.com/invopop/yaml v0.3.1/go.mod h1:PMOp3nn4/12yEZUFfmOuNHJsZToEEOwoWsT+D81KkeA= 30 | github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= 31 | github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= 32 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 33 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 34 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 35 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 36 | github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= 37 | github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= 38 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= 39 | github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= 40 | github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= 41 | github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= 42 | github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= 43 | github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= 44 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 45 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 46 | github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= 47 | github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= 48 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 49 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 50 | github.com/tdewolff/minify/v2 v2.21.0 h1:nAPP1UVx0aK1xsQh/JiG3xyEnnqWw+agPstn+V6Pkto= 51 | github.com/tdewolff/minify/v2 v2.21.0/go.mod h1:hGcthJ6Vj51NG+9QRIfN/DpWj5loHnY3bfhThzWWq08= 52 | github.com/tdewolff/parse/v2 v2.7.17 h1:uC10p6DaQQORDy72eaIyD+AvAkaIUOouQ0nWp4uD0D0= 53 | github.com/tdewolff/parse/v2 v2.7.17/go.mod h1:3FbJWZp3XT9OWVN3Hmfp0p/a08v4h8J9W1aghka0soA= 54 | github.com/tdewolff/test v1.0.11-0.20231101010635-f1265d231d52/go.mod h1:6DAvZliBAAnD7rhVgwaM7DE5/d9NMOAJ09SqYqeK4QE= 55 | github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739 h1:IkjBCtQOOjIn03u/dMQK9g+Iw9ewps4mCl1nB8Sscbo= 56 | github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8= 57 | github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= 58 | github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= 59 | go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY= 60 | go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE= 61 | go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.7.0 h1:iNba3cIZTDPB2+IAbVY/3TUN+pCCLrNYo2GaGtsKBak= 62 | go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.7.0/go.mod h1:l5BDPiZ9FbeejzWTAX6BowMzQOM/GeaUQ6lr3sOcSkc= 63 | go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0 h1:FZ6ei8GFW7kyPYdxJaV2rgI6M+4tvZzhYsQ2wgyVC08= 64 | go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.31.0/go.mod h1:MdEu/mC6j3D+tTEfvI15b5Ci2Fn7NneJ71YMoiS3tpI= 65 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 h1:K0XaT3DwHAcV4nKLzcQvwAgSyisUghWoY20I7huthMk= 66 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0/go.mod h1:B5Ki776z/MBnVha1Nzwp5arlzBbE3+1jk+pGmaP5HME= 67 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0 h1:FFeLy03iVTXP6ffeN2iXrxfGsZGCjVx0/4KlizjyBwU= 68 | go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0/go.mod h1:TMu73/k1CP8nBUpDLc71Wj/Kf7ZS9FK5b53VapRsP9o= 69 | go.opentelemetry.io/otel/log v0.7.0 h1:d1abJc0b1QQZADKvfe9JqqrfmPYQCz2tUSO+0XZmuV4= 70 | go.opentelemetry.io/otel/log v0.7.0/go.mod h1:2jf2z7uVfnzDNknKTO9G+ahcOAyWcp1fJmk/wJjULRo= 71 | go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE= 72 | go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= 73 | go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk= 74 | go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0= 75 | go.opentelemetry.io/otel/sdk/log v0.7.0 h1:dXkeI2S0MLc5g0/AwxTZv6EUEjctiH8aG14Am56NTmQ= 76 | go.opentelemetry.io/otel/sdk/log v0.7.0/go.mod h1:oIRXpW+WD6M8BuGj5rtS0aRu/86cbDV/dAfNaZBIjYM= 77 | go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc= 78 | go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8= 79 | go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= 80 | go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= 81 | go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= 82 | go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= 83 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 84 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 85 | golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= 86 | golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= 87 | golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= 88 | golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 89 | golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= 90 | golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= 91 | golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= 92 | golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 93 | google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 h1:T6rh4haD3GVYsgEfWExoCZA2o2FmbNyKpTuAxbEFPTg= 94 | google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:wp2WsuBYj6j8wUdo3ToZsdxxixbvQNAHqVJrTgi5E5M= 95 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 h1:QCqS/PdaHTSWGvupk2F/ehwHtGc0/GYkT+3GAcR1CCc= 96 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= 97 | google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= 98 | google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= 99 | google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= 100 | google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= 101 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 102 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 103 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 104 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 105 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 106 | -------------------------------------------------------------------------------- /handler.go: -------------------------------------------------------------------------------- 1 | package gatego 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "os" 8 | "slices" 9 | 10 | "github.com/hvuhsg/gatego/internal/config" 11 | "github.com/hvuhsg/gatego/internal/handlers" 12 | "github.com/hvuhsg/gatego/internal/middlewares" 13 | "github.com/hvuhsg/gatego/internal/middlewares/security" 14 | ) 15 | 16 | var ErrUnsupportedBaseHandler = errors.New("base handler unsupported") 17 | 18 | func GetBaseHandler(service config.Service, path config.Path) (http.Handler, error) { 19 | if path.Destination != nil && *path.Destination != "" { 20 | return handlers.NewProxy(service, path) 21 | } else if path.Directory != nil && *path.Directory != "" { 22 | handler := handlers.NewFiles(*path.Directory, path.Path) 23 | return handler, nil 24 | } else if path.Backend != nil { 25 | return handlers.NewBalancer(service, path) 26 | } else { 27 | // Should not be reached (early validation should prevent it) 28 | return nil, ErrUnsupportedBaseHandler 29 | } 30 | } 31 | 32 | func NewHandler(ctx context.Context, useOtel bool, service config.Service, path config.Path) (http.Handler, error) { 33 | handler, err := GetBaseHandler(service, path) 34 | if err != nil { 35 | return nil, err 36 | } 37 | 38 | handlerWithMiddlewares := middlewares.NewHandlerWithMiddleware(handler) 39 | 40 | handlerWithMiddlewares.Add(middlewares.NewLoggingMiddleware(os.Stdout)) 41 | 42 | // Open Telemetry 43 | if useOtel { 44 | otelMiddleware, err := middlewares.NewOpenTelemetryMiddleware( 45 | ctx, 46 | middlewares.OTELConfig{ 47 | ServiceDomain: service.Domain, 48 | BasePath: path.Path, 49 | }, 50 | ) 51 | if err != nil { 52 | return nil, err 53 | } 54 | handlerWithMiddlewares.Add(otelMiddleware) 55 | } 56 | 57 | // Timeout 58 | if path.Timeout == 0 { 59 | path.Timeout = config.DefaultTimeout 60 | } 61 | handlerWithMiddlewares.Add(middlewares.NewTimeoutMiddleware(path.Timeout)) 62 | 63 | // Max request size 64 | if path.MaxSize == 0 { 65 | path.MaxSize = config.DefaultMaxRequestSize 66 | } 67 | handlerWithMiddlewares.Add(middlewares.NewRequestSizeLimitMiddleware(path.MaxSize)) 68 | 69 | // Rate limits 70 | if len(path.RateLimits) > 0 { 71 | ratelimiter, err := middlewares.NewRateLimitMiddleware(path.RateLimits) 72 | if err != nil { 73 | return nil, err 74 | } 75 | handlerWithMiddlewares.Add(ratelimiter) 76 | } 77 | 78 | // Add anomaly detector 79 | if service.AnomalyDetection != nil { 80 | handlerWithMiddlewares.Add( 81 | security.NewRoutingAnomalyDetector( 82 | service.AnomalyDetection.HeaderName, 83 | service.AnomalyDetection.TresholdForRating, 84 | service.AnomalyDetection.MinScore, 85 | service.AnomalyDetection.MaxScore).AddAnomalyScore, 86 | ) 87 | } 88 | 89 | // Add headers 90 | if path.Headers != nil { 91 | handlerWithMiddlewares.Add(middlewares.NewAddHeadersMiddleware(*path.Headers)) 92 | } 93 | 94 | // GZIP compression 95 | if path.Gzip != nil && *path.Gzip { 96 | handlerWithMiddlewares.Add(middlewares.GzipMiddleware) 97 | } 98 | 99 | // Remove response headers 100 | if len(path.OmitHeaders) > 0 { 101 | handlerWithMiddlewares.Add(middlewares.NewOmitHeadersMiddleware(path.OmitHeaders)) 102 | } 103 | 104 | // Minify files 105 | minifyConfig := middlewares.MinifyConfig{ 106 | ALL: slices.Contains(path.Minify, "all"), 107 | JS: slices.Contains(path.Minify, "js"), 108 | HTML: slices.Contains(path.Minify, "html"), 109 | CSS: slices.Contains(path.Minify, "css"), 110 | JSON: slices.Contains(path.Minify, "json"), 111 | SVG: slices.Contains(path.Minify, "svg"), 112 | XML: slices.Contains(path.Minify, "xml"), 113 | } 114 | handlerWithMiddlewares.Add(middlewares.NewMinifyMiddleware(minifyConfig)) 115 | 116 | // OpenAPI validation 117 | if path.OpenAPI != nil { 118 | openapiMiddleware, err := middlewares.NewOpenAPIValidationMiddleware(*path.OpenAPI) 119 | if err != nil { 120 | return nil, err 121 | } 122 | handlerWithMiddlewares.Add(openapiMiddleware) 123 | } 124 | 125 | // Response cache 126 | if path.Cache { 127 | handlerWithMiddlewares.Add(middlewares.NewCacheMiddleware()) 128 | } 129 | 130 | return handlerWithMiddlewares, nil 131 | } 132 | -------------------------------------------------------------------------------- /internal/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | "os" 11 | "regexp" 12 | "slices" 13 | "strconv" 14 | "strings" 15 | "time" 16 | 17 | "github.com/hashicorp/go-version" 18 | "github.com/hvuhsg/gatego/internal/middlewares" 19 | "github.com/hvuhsg/gatego/pkg/cron" 20 | "gopkg.in/yaml.v3" 21 | ) 22 | 23 | const DefaultTimeout = time.Second * 30 24 | const DefaultMaxRequestSize = 1024 * 10 // 10 MB 25 | var SupportedBalancePolicies = []string{"round-robin", "random", "least-latency"} 26 | 27 | type Backend struct { 28 | BalancePolicy string `yaml:"balance_policy"` 29 | Servers []struct { 30 | URL string `yaml:"url"` 31 | Weight uint `yaml:"weight"` 32 | } 33 | } 34 | 35 | func (b Backend) validate() error { 36 | if !slices.Contains(SupportedBalancePolicies, b.BalancePolicy) { 37 | return fmt.Errorf("balance policy '%s' is not supported", b.BalancePolicy) 38 | } 39 | 40 | if len(b.Servers) == 0 { 41 | return errors.New("backend require at least one server") 42 | } 43 | 44 | for _, server := range b.Servers { 45 | if !isValidURL(server.URL) { 46 | return fmt.Errorf("invalid backend server url '%s'", server.URL) 47 | } 48 | } 49 | 50 | return nil 51 | } 52 | 53 | type Check struct { 54 | Name string `yaml:"name"` 55 | Cron string `yaml:"cron"` 56 | URL string `yaml:"url"` 57 | Method string `yaml:"method"` 58 | Timeout time.Duration `yaml:"timeout"` 59 | Headers map[string]string `yaml:"headers"` 60 | OnFailure string `yaml:"on_failure"` 61 | } 62 | 63 | func (c Check) validate() error { 64 | if len(c.Name) == 0 { 65 | return errors.New("check requires a name") 66 | } 67 | 68 | if _, err := cron.NewSchedule(c.Cron); err != nil { 69 | return errors.New("invalid check cron expression") 70 | } 71 | 72 | if !isValidURL(c.URL) { 73 | return errors.New("invalid check url") 74 | } 75 | 76 | if !isValidMethod(c.Method) { 77 | return errors.New("invalid check method") 78 | } 79 | 80 | return nil 81 | } 82 | 83 | type Path struct { 84 | Path string `yaml:"path"` 85 | Destination *string `yaml:"destination"` // The domain / url of the service server 86 | Directory *string `yaml:"directory"` // path to dir you want to serve 87 | Backend *Backend `yaml:"backend"` // List of servers to load balance between 88 | Headers *map[string]string `yaml:"headers"` 89 | OmitHeaders []string `yaml:"omit_headers"` // Omit specified headers 90 | Minify []string `yaml:"minify"` 91 | Gzip *bool `yaml:"gzip"` 92 | Timeout time.Duration `yaml:"timeout"` 93 | MaxSize uint64 `yaml:"max_size"` 94 | OpenAPI *string `yaml:"openapi"` 95 | RateLimits []string `yaml:"ratelimits"` 96 | Checks []Check `yaml:"checks"` // Automated checks 97 | Cache bool `yaml:"cache"` // Cache responses that has cache headers 98 | } 99 | 100 | func (p Path) validate() error { 101 | if p.Path[0] != '/' { 102 | return errors.New("path must start with '/'") 103 | } 104 | 105 | if p.Destination != nil { 106 | if !isValidURL(*p.Destination) { 107 | return errors.New("invalid destination url") 108 | } 109 | 110 | if p.Directory != nil { 111 | return errors.New("can't have destination and directory for the same path") 112 | } 113 | } 114 | 115 | if p.Directory != nil { 116 | if !isValidDir(*p.Directory) { 117 | return errors.New("invalid directory path") 118 | } 119 | 120 | if p.Cache { 121 | log.Println("[WARNING] Using cache while serving static files is not recommanded") 122 | } 123 | } 124 | 125 | if p.Backend != nil { 126 | if err := p.Backend.validate(); err != nil { 127 | return err 128 | } 129 | } 130 | 131 | if p.Destination == nil && p.Directory == nil && p.Backend == nil { 132 | return errors.New("path must have destination or directory or backend") 133 | } 134 | 135 | if p.OpenAPI != nil { 136 | if *p.OpenAPI == "" { 137 | return errors.New("openapi can't be empty (remove or fill)") 138 | } 139 | 140 | if !isValidFile(*p.OpenAPI) { 141 | return errors.New("invalid openapi spec path") 142 | } 143 | } 144 | 145 | for _, ratelimit := range p.RateLimits { 146 | _, err := middlewares.ParseLimitConfig(ratelimit) 147 | if err != nil { 148 | return fmt.Errorf("invalid ratelimit: %s", err.Error()) 149 | } 150 | } 151 | 152 | for _, check := range p.Checks { 153 | if err := check.validate(); err != nil { 154 | return err 155 | } 156 | } 157 | 158 | return nil 159 | } 160 | 161 | type AnomalyDetection struct { 162 | HeaderName string `yaml:"header_name"` 163 | MinScore int `yaml:"min_score"` 164 | MaxScore int `yaml:"max_score"` 165 | TresholdForRating int `yaml:"treshold_for_rating"` 166 | Active bool `yaml:"active"` 167 | } 168 | 169 | func (a *AnomalyDetection) validate() error { 170 | if a.HeaderName == "" { 171 | a.HeaderName = "X-Anomaly-Score" 172 | } 173 | 174 | if a.MinScore == 0 { 175 | a.MinScore = 100 176 | } 177 | 178 | if a.MaxScore == 0 { 179 | a.MaxScore = 200 180 | } 181 | 182 | if a.TresholdForRating == 0 { 183 | a.TresholdForRating = 100 184 | } 185 | 186 | if a.MaxScore <= a.MinScore { 187 | return errors.New("anomaly detection maxScore MUST be grater the minScore") 188 | } 189 | 190 | return nil 191 | } 192 | 193 | type Service struct { 194 | Domain string `yaml:"domain"` // The domain / host the request was sent to 195 | Paths []Path `yaml:"endpoints"` 196 | AnomalyDetection *AnomalyDetection `yaml:"anomaly_detection"` 197 | } 198 | 199 | func (s Service) validate() error { 200 | if !isValidHostname(s.Domain) { 201 | return errors.New("invalid domain") 202 | } 203 | 204 | for _, path := range s.Paths { 205 | if err := path.validate(); err != nil { 206 | return err 207 | } 208 | } 209 | 210 | if s.AnomalyDetection != nil { 211 | if err := s.AnomalyDetection.validate(); err != nil { 212 | return err 213 | } 214 | } 215 | 216 | return nil 217 | } 218 | 219 | type TLS struct { 220 | Auto bool `yaml:"auto"` 221 | Domains []string `yaml:"domain"` 222 | Email *string `yaml:"email"` 223 | KeyFile *string `yaml:"keyfile"` 224 | CertFile *string `yaml:"certfile"` 225 | } 226 | 227 | func (tls TLS) validate() error { 228 | if tls.Auto { 229 | if len(tls.Domains) == 0 { 230 | return errors.New("when using the auto tls feature you MUST include a list of domains to issue certificates for") 231 | } 232 | if tls.Email == nil || len(*tls.Email) == 0 || !isValidEmail(*tls.Email) { 233 | return errors.New("when using the auto tls feature you MUST include a valid email for the lets-encrypt registration") 234 | } 235 | } 236 | 237 | if tls.CertFile != nil { 238 | if tls.KeyFile == nil { 239 | return errors.New("you MUST provide certfile AND keyfile") 240 | } 241 | } 242 | 243 | if tls.KeyFile != nil { 244 | if tls.CertFile == nil { 245 | return errors.New("you MUST provide certfile AND keyfile") 246 | } 247 | 248 | if !isValidFile(*tls.CertFile) { 249 | return errors.New("certfile path is invalid") 250 | } 251 | 252 | if !isValidFile(*tls.KeyFile) { 253 | return errors.New("keyfile path is invalid") 254 | } 255 | } 256 | 257 | return nil 258 | } 259 | 260 | type OTEL struct { 261 | Endpoint string `yaml:"endpoint"` 262 | SampleRatio float64 `yaml:"sample_ratio"` 263 | } 264 | 265 | func (otel OTEL) validate() error { 266 | if len(otel.Endpoint) > 0 { 267 | if err := isValidGRPCAddress(otel.Endpoint); err != nil { 268 | return err 269 | } 270 | } 271 | 272 | if otel.SampleRatio < 0 { 273 | return errors.New("OpenTelemetry sample ratio MUST be above 0") 274 | } 275 | 276 | if otel.SampleRatio == 0 { 277 | return errors.New("OpenTelemetry sample ratio is missing or equales to 0") 278 | } 279 | 280 | if otel.SampleRatio > 1 { 281 | return errors.New("OpenTelemetry sample ratio CAN NOT be above 1") 282 | } 283 | 284 | return nil 285 | } 286 | 287 | type Config struct { 288 | Version string `yaml:"version"` 289 | Host string `yaml:"host"` // listen host 290 | Port uint16 `yaml:"port"` // listen port 291 | 292 | OTEL *OTEL `yaml:"open_telemetry"` 293 | 294 | // TLS options 295 | TLS TLS `yaml:"ssl"` 296 | 297 | Services []Service `yaml:"services"` 298 | } 299 | 300 | func (c Config) Validate(currentVersion string) error { 301 | if c.Version == "" { 302 | return errors.New("version is required") 303 | } 304 | 305 | progVersion, _ := version.NewVersion(currentVersion) 306 | configVersion, err := version.NewVersion(c.Version) 307 | if err != nil { 308 | return errors.New("version is invalid") 309 | } 310 | 311 | if configVersion.Compare(progVersion) > 0 { 312 | return errors.New("config version is not supported (too advanced)") 313 | } 314 | 315 | if c.Host == "" { 316 | return errors.New("host is required") 317 | } 318 | 319 | if c.OTEL != nil { 320 | if err := (*c.OTEL).validate(); err != nil { 321 | return err 322 | } 323 | } 324 | 325 | if c.Port == 0 { 326 | return errors.New("port is required") 327 | } 328 | 329 | if err := c.TLS.validate(); err != nil { 330 | return err 331 | } 332 | 333 | if c.TLS.Auto && c.Port != 443 { 334 | return errors.New("the auto tls feature is only available if the server runs on port 443") 335 | } 336 | 337 | for _, service := range c.Services { 338 | if err := service.validate(); err != nil { 339 | return err 340 | } 341 | } 342 | 343 | return nil 344 | } 345 | 346 | func ParseConfig(filepath string, currentVersion string) (Config, error) { 347 | // Read the YAML file 348 | data, err := os.ReadFile(filepath) 349 | if err != nil { 350 | return Config{}, err 351 | } 352 | 353 | // Defaults 354 | c := Config{Port: 80} 355 | 356 | // Unmarshal the YAML data into the struct 357 | err = yaml.Unmarshal(data, &c) 358 | if err != nil { 359 | return Config{}, err 360 | } 361 | 362 | if err := c.Validate(currentVersion); err != nil { 363 | return Config{}, err 364 | } 365 | 366 | return c, nil 367 | } 368 | 369 | func isValidHostname(hostname string) bool { 370 | // Remove leading/trailing whitespace 371 | hostname = strings.TrimSpace(hostname) 372 | 373 | // Check if the hostname is empty 374 | if hostname == "" { 375 | return false 376 | } 377 | 378 | // Check if the hostname is too long (max 253 characters) 379 | if len(hostname) > 253 { 380 | return false 381 | } 382 | 383 | // Check for localhost 384 | if hostname == "localhost" { 385 | return true 386 | } 387 | 388 | // Check if it's an IP address (IPv4 or IPv6) 389 | if ip := net.ParseIP(hostname); ip != nil { 390 | return true 391 | } 392 | 393 | // Regular expression for domain validation 394 | // This regex allows for domains with multiple subdomains and supports IDNs 395 | domainRegex := regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,63}$`) 396 | 397 | return domainRegex.MatchString(hostname) 398 | } 399 | 400 | func isValidURL(str string) bool { 401 | u, err := url.Parse(str) 402 | return err == nil && u.Scheme != "" && u.Host != "" 403 | } 404 | 405 | func isValidDir(path string) bool { 406 | if path == "" { 407 | return false 408 | } 409 | 410 | fileInfo, err := os.Stat(path) 411 | if err != nil { 412 | return false 413 | } 414 | return fileInfo.IsDir() 415 | } 416 | 417 | func isValidFile(path string) bool { 418 | if path == "" { 419 | return false 420 | } 421 | 422 | fileInfo, err := os.Stat(path) 423 | if err != nil { 424 | return false 425 | } 426 | return !fileInfo.IsDir() 427 | } 428 | 429 | func isValidMethod(method string) bool { 430 | methods := []string{ 431 | http.MethodGet, 432 | http.MethodHead, 433 | http.MethodPost, 434 | http.MethodPut, 435 | http.MethodPatch, 436 | http.MethodDelete, 437 | http.MethodConnect, 438 | http.MethodOptions, 439 | http.MethodTrace, 440 | } 441 | 442 | return slices.Contains(methods, method) 443 | } 444 | 445 | func isValidGRPCAddress(address string) error { 446 | if address == "" { 447 | return fmt.Errorf("address cannot be empty") 448 | } 449 | 450 | // Split host and port 451 | host, portStr, err := net.SplitHostPort(address) 452 | if err != nil { 453 | return fmt.Errorf("invalid address format: %v", err) 454 | } 455 | 456 | // Validate port 457 | port, err := strconv.Atoi(portStr) 458 | if err != nil { 459 | return fmt.Errorf("invalid port number: %v", err) 460 | } 461 | if port < 1 || port > 65535 { 462 | return fmt.Errorf("port number must be between 1 and 65535") 463 | } 464 | 465 | // Empty host means localhost/0.0.0.0, which is valid 466 | if host == "" { 467 | return nil 468 | } 469 | 470 | // Check if host is IPv4 or IPv6 471 | if ip := net.ParseIP(host); ip != nil { 472 | return nil 473 | } 474 | 475 | // Validate hostname format 476 | hostnameRegex := regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9\-\.]*[a-zA-Z0-9])?$`) 477 | if !hostnameRegex.MatchString(host) { 478 | return fmt.Errorf("invalid hostname format") 479 | } 480 | 481 | // Check hostname length 482 | if len(host) > 253 { 483 | return fmt.Errorf("hostname too long") 484 | } 485 | 486 | // Validate hostname parts 487 | parts := strings.Split(host, ".") 488 | for _, part := range parts { 489 | if len(part) > 63 { 490 | return fmt.Errorf("hostname label too long") 491 | } 492 | } 493 | 494 | return nil 495 | } 496 | 497 | func isValidEmail(email string) bool { 498 | // Define a regular expression for valid email addresses 499 | var emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) 500 | 501 | // Match the email string with the regular expression 502 | return emailRegex.MatchString(email) 503 | } 504 | -------------------------------------------------------------------------------- /internal/config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | ) 8 | 9 | func TestPathValidate(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | path Path 13 | wantErr bool 14 | }{ 15 | {"Valid path with destination", Path{Path: "/api", Destination: ptr("http://example.com")}, false}, 16 | {"Valid path with directory", Path{Path: "/static", Directory: ptr("/var")}, false}, 17 | {"Invalid path without leading slash", Path{Path: "api", Destination: ptr("http://example.com")}, true}, 18 | {"Invalid destination URL", Path{Path: "/api", Destination: ptr("not-a-url")}, true}, 19 | {"Invalid with both destination and directory", Path{Path: "/both", Destination: ptr("http://example.com"), Directory: ptr("/var/www")}, true}, 20 | {"Invalid with neither destination nor directory", Path{Path: "/empty"}, true}, 21 | } 22 | 23 | for _, tt := range tests { 24 | t.Run(tt.name, func(t *testing.T) { 25 | err := tt.path.validate() 26 | if (err != nil) != tt.wantErr { 27 | t.Errorf("Path.validate() error = %v, wantErr %v", err, tt.wantErr) 28 | } 29 | }) 30 | } 31 | } 32 | 33 | func TestServiceValidate(t *testing.T) { 34 | tests := []struct { 35 | name string 36 | service Service 37 | wantErr bool 38 | }{ 39 | {"Valid service", Service{Domain: "example.com", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}, false}, 40 | {"Invalid domain", Service{Domain: "not a domain", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}, true}, 41 | {"Invalid path", Service{Domain: "example.com", Paths: []Path{{Path: "invalid", Destination: ptr("http://api.example.com")}}}, true}, 42 | } 43 | 44 | for _, tt := range tests { 45 | t.Run(tt.name, func(t *testing.T) { 46 | err := tt.service.validate() 47 | if (err != nil) != tt.wantErr { 48 | t.Errorf("Service.validate() service = %v, error = %v, wantErr %v", err, tt.service, tt.wantErr) 49 | } 50 | }) 51 | } 52 | } 53 | 54 | func TestConfigValidate(t *testing.T) { 55 | tests := []struct { 56 | name string 57 | config Config 58 | currentVersion string 59 | wantErr bool 60 | }{ 61 | {"Valid config", Config{Version: "1.0.0", Host: "localhost", Port: 80, Services: []Service{{Domain: "example.com", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}}}, "1.0.0", false}, 62 | {"AutoTLS with port != 443", Config{Version: "1.0.0", Host: "localhost", Port: 80, TLS: TLS{Auto: true, Domains: []string{"example.com"}}, Services: []Service{{Domain: "example.com", Paths: []Path{{Path: "/api", Destination: ptr("http://api.example.com")}}}}}, "1.0.0", true}, 63 | {"Missing version", Config{Host: "localhost"}, "1.0.0", true}, 64 | {"Invalid version", Config{Version: "invalid", Host: "localhost"}, "1.0.0", true}, 65 | {"Future version", Config{Version: "2.0.0", Host: "localhost"}, "1.0.0", true}, 66 | {"Missing host", Config{Version: "1.0.0"}, "1.0.0", true}, 67 | } 68 | 69 | for _, tt := range tests { 70 | t.Run(tt.name, func(t *testing.T) { 71 | err := tt.config.Validate(tt.currentVersion) 72 | if (err != nil) != tt.wantErr { 73 | t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr) 74 | } 75 | }) 76 | } 77 | } 78 | 79 | func TestParseConfig(t *testing.T) { 80 | // Create a temporary directory for test files 81 | tempDir, err := os.MkdirTemp("", "config_test") 82 | if err != nil { 83 | t.Fatalf("Failed to create temp dir: %v", err) 84 | } 85 | defer os.RemoveAll(tempDir) 86 | 87 | // Create a valid config file 88 | validConfig := ` 89 | version: "1.0.0" 90 | host: "localhost" 91 | port: 8080 92 | services: 93 | - domain: "example.com" 94 | endpoints: 95 | - path: "/api" 96 | destination: "http://api.example.com" 97 | ` 98 | validConfigPath := filepath.Join(tempDir, "valid_config.yaml") 99 | err = os.WriteFile(validConfigPath, []byte(validConfig), 0644) 100 | if err != nil { 101 | t.Fatalf("Failed to write valid config file: %v", err) 102 | } 103 | 104 | // Create an invalid config file 105 | invalidConfig := ` 106 | version: "invalid" 107 | host: "localhost" 108 | ` 109 | invalidConfigPath := filepath.Join(tempDir, "invalid_config.yaml") 110 | err = os.WriteFile(invalidConfigPath, []byte(invalidConfig), 0644) 111 | if err != nil { 112 | t.Fatalf("Failed to write invalid config file: %v", err) 113 | } 114 | 115 | tests := []struct { 116 | name string 117 | filepath string 118 | currentVersion string 119 | wantErr bool 120 | }{ 121 | {"Valid config", validConfigPath, "1.0.0", false}, 122 | {"Invalid config", invalidConfigPath, "1.0.0", true}, 123 | {"Non-existent file", filepath.Join(tempDir, "non_existent.yaml"), "1.0.0", true}, 124 | } 125 | 126 | for _, tt := range tests { 127 | t.Run(tt.name, func(t *testing.T) { 128 | _, err := ParseConfig(tt.filepath, tt.currentVersion) 129 | if (err != nil) != tt.wantErr { 130 | t.Errorf("ParseConfig() error = %v, wantErr %v", err, tt.wantErr) 131 | } 132 | }) 133 | } 134 | } 135 | 136 | func TestIsValidURL(t *testing.T) { 137 | tests := []struct { 138 | name string 139 | url string 140 | want bool 141 | }{ 142 | {"Valid URL", "http://example.com", true}, 143 | {"Valid URL with path", "https://example.com/path", true}, 144 | {"Invalid URL", "not-a-url", false}, 145 | {"Invalid URL", "not a domain", false}, 146 | {"Missing scheme", "example.com", false}, 147 | } 148 | 149 | for _, tt := range tests { 150 | t.Run(tt.name, func(t *testing.T) { 151 | if got := isValidURL(tt.url); got != tt.want { 152 | t.Errorf("isValidURL() = %v, want %v", got, tt.want) 153 | } 154 | }) 155 | } 156 | } 157 | 158 | func TestIsValidDir(t *testing.T) { 159 | // Create a temporary directory for the test 160 | tempDir, err := os.MkdirTemp("", "dir_test") 161 | if err != nil { 162 | t.Fatalf("Failed to create temp dir: %v", err) 163 | } 164 | defer os.RemoveAll(tempDir) 165 | 166 | tests := []struct { 167 | name string 168 | path string 169 | want bool 170 | }{ 171 | {"Valid directory", tempDir, true}, 172 | {"Non-existent directory", filepath.Join(tempDir, "non_existent"), false}, 173 | {"Empty path", "", false}, 174 | } 175 | 176 | for _, tt := range tests { 177 | t.Run(tt.name, func(t *testing.T) { 178 | if got := isValidDir(tt.path); got != tt.want { 179 | t.Errorf("isValidDir() = %v, want %v", got, tt.want) 180 | } 181 | }) 182 | } 183 | } 184 | 185 | // Helper function to create string pointers 186 | func ptr(s string) *string { 187 | return &s 188 | } 189 | -------------------------------------------------------------------------------- /internal/contextvalues/tracer.go: -------------------------------------------------------------------------------- 1 | package contextvalues 2 | 3 | import ( 4 | "context" 5 | 6 | "go.opentelemetry.io/otel/trace" 7 | ) 8 | 9 | // Define a custom type for context keys to avoid collisions 10 | type tracerKeyType string 11 | 12 | var tracerKey = tracerKeyType("tracer") 13 | 14 | // Add tracer to context 15 | func AddTracerToContext(ctx context.Context, tracer trace.Tracer) context.Context { 16 | return context.WithValue(ctx, tracerKey, tracer) 17 | } 18 | 19 | // Retrieve tracer from context 20 | func TracerFromContext(ctx context.Context) trace.Tracer { 21 | var tracer trace.Tracer = nil 22 | if t, ok := ctx.Value(tracerKey).(trace.Tracer); ok { 23 | tracer = t 24 | } 25 | return tracer 26 | } 27 | -------------------------------------------------------------------------------- /internal/contextvalues/version.go: -------------------------------------------------------------------------------- 1 | package contextvalues 2 | 3 | import "context" 4 | 5 | // Define a custom type for context keys to avoid collisions 6 | type versionKeyType string 7 | 8 | var versionKey = versionKeyType("version") 9 | 10 | // Add version to context 11 | func AddVersionToContext(ctx context.Context, version string) context.Context { 12 | return context.WithValue(ctx, versionKey, version) 13 | } 14 | 15 | // Retrieve version from context 16 | func VersionFromContext(ctx context.Context) string { 17 | version := "" 18 | if v, ok := ctx.Value(versionKey).(string); ok { 19 | version = v 20 | } 21 | return version 22 | } 23 | -------------------------------------------------------------------------------- /internal/handlers/balancer.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | "net/http" 7 | "net/http/httputil" 8 | "net/url" 9 | "time" 10 | 11 | "github.com/hvuhsg/gatego/internal/config" 12 | "github.com/hvuhsg/gatego/internal/contextvalues" 13 | semconv "go.opentelemetry.io/otel/semconv/v1.4.0" 14 | ) 15 | 16 | type ServerAndWeight struct { 17 | server *httputil.ReverseProxy 18 | weight int 19 | url string 20 | } 21 | 22 | type BalancePolicy interface { 23 | GetNext() *httputil.ReverseProxy 24 | } 25 | 26 | type Balancer struct { 27 | policy BalancePolicy 28 | } 29 | 30 | func NewBalancer(service config.Service, path config.Path) (*Balancer, error) { 31 | serversConfig := path.Backend.Servers 32 | 33 | serversAndWeights := make([]ServerAndWeight, 0, len(serversConfig)) 34 | for _, serverConfig := range serversConfig { 35 | serverURL, err := url.Parse(serverConfig.URL) 36 | if err != nil { 37 | return &Balancer{}, err 38 | } 39 | 40 | server := httputil.NewSingleHostReverseProxy(serverURL) 41 | 42 | serverWeight := int(serverConfig.Weight) 43 | if serverWeight < 1 { 44 | serverWeight = 1 45 | } 46 | serversAndWeights = append(serversAndWeights, ServerAndWeight{server: server, weight: serverWeight, url: serverConfig.URL}) 47 | } 48 | 49 | var policy BalancePolicy 50 | switch path.Backend.BalancePolicy { 51 | case "round-robin": 52 | policy = NewRoundRobinPolicy(serversAndWeights) 53 | case "random": 54 | policy = NewRandomPolicy(serversAndWeights) 55 | case "least-latency": 56 | policy = NewLeastLatencyPolicy(serversAndWeights) 57 | } 58 | 59 | balancer := Balancer{policy: policy} 60 | 61 | return &balancer, nil 62 | } 63 | 64 | func (b *Balancer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 65 | proxy := b.policy.GetNext() 66 | 67 | tracer := contextvalues.TracerFromContext(r.Context()) 68 | if tracer != nil { 69 | ctx, span := tracer.Start(r.Context(), "request.upstream") 70 | span.SetAttributes(semconv.HTTPServerAttributesFromHTTPRequest(r.Host, r.URL.Path, r)...) 71 | r = r.WithContext(ctx) 72 | defer span.End() 73 | } 74 | 75 | proxy.ServeHTTP(w, r) 76 | } 77 | 78 | type RoundRobinPolicy struct { 79 | current int 80 | weightsSum int 81 | servers []ServerAndWeight 82 | } 83 | 84 | func NewRoundRobinPolicy(servers []ServerAndWeight) *RoundRobinPolicy { 85 | weightsSum := 0 86 | for _, server := range servers { 87 | weightsSum += server.weight 88 | } 89 | 90 | policy := &RoundRobinPolicy{current: 0, weightsSum: weightsSum, servers: servers} 91 | return policy 92 | } 93 | 94 | // The servers provided must be provided in the same order for accurate results 95 | func (rrp *RoundRobinPolicy) GetNext() *httputil.ReverseProxy { 96 | serverIndex := rrp.current 97 | 98 | for _, server := range rrp.servers { 99 | serverIndex -= server.weight 100 | if serverIndex < 0 { 101 | rrp.current += 1 102 | return server.server 103 | } 104 | } 105 | 106 | rrp.current = (rrp.current % rrp.weightsSum) + 1 107 | return rrp.servers[0].server 108 | } 109 | 110 | type RandomPolicy struct { 111 | weightsSum int 112 | servers []ServerAndWeight 113 | } 114 | 115 | func NewRandomPolicy(servers []ServerAndWeight) *RandomPolicy { 116 | weightsSum := 0 117 | for _, server := range servers { 118 | weightsSum += server.weight 119 | } 120 | 121 | return &RandomPolicy{weightsSum: weightsSum, servers: servers} 122 | } 123 | 124 | func (rp *RandomPolicy) GetNext() *httputil.ReverseProxy { 125 | randomServerIndex := rand.Intn(rp.weightsSum) 126 | 127 | for _, server := range rp.servers { 128 | randomServerIndex -= server.weight 129 | if randomServerIndex <= 0 { 130 | return server.server 131 | } 132 | } 133 | 134 | return rp.servers[0].server 135 | } 136 | 137 | type LeastLatencyPolicy struct { 138 | serversLatency map[string]int64 139 | servers []ServerAndWeight 140 | } 141 | 142 | func NewLeastLatencyPolicy(serversAndURLs []ServerAndWeight) *LeastLatencyPolicy { 143 | serversLatency := make(map[string]int64, len(serversAndURLs)) 144 | 145 | for _, serverAndWeight := range serversAndURLs { 146 | serversLatency[serverAndWeight.url] = 0 147 | } 148 | 149 | return &LeastLatencyPolicy{servers: serversAndURLs, serversLatency: serversLatency} 150 | } 151 | 152 | func (llp *LeastLatencyPolicy) GetNext() *httputil.ReverseProxy { 153 | 154 | bestServerURL := llp.servers[0].url 155 | var bestLatency int64 = math.MaxInt64 156 | 157 | for url, latency := range llp.serversLatency { 158 | if latency < bestLatency { 159 | bestServerURL = url 160 | bestLatency = latency 161 | } 162 | } 163 | 164 | var chosenServer ServerAndWeight 165 | for _, server := range llp.servers { 166 | if server.url == bestServerURL { 167 | chosenServer = server 168 | break 169 | } 170 | } 171 | 172 | // TODO: use decaing latency for extream latency conditions 173 | 174 | startTime := time.Now().UnixMicro() 175 | chosenServer.server.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { 176 | llp.serversLatency[chosenServer.url] = time.Now().UnixMicro() - startTime 177 | } 178 | chosenServer.server.ModifyResponse = func(r *http.Response) error { 179 | llp.serversLatency[chosenServer.url] = time.Now().UnixMicro() - startTime 180 | return nil 181 | } 182 | 183 | return chosenServer.server 184 | } 185 | -------------------------------------------------------------------------------- /internal/handlers/balancer_test.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "net/http/httputil" 7 | "net/url" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "github.com/hvuhsg/gatego/internal/config" 13 | ) 14 | 15 | func TestNewBalancer(t *testing.T) { 16 | service := config.Service{} 17 | path := config.Path{ 18 | Backend: &config.Backend{ 19 | BalancePolicy: "round-robin", 20 | Servers: []struct { 21 | URL string "yaml:\"url\"" 22 | Weight uint "yaml:\"weight\"" 23 | }{ 24 | {URL: "http://localhost:8001", Weight: 1}, 25 | {URL: "http://localhost:8002", Weight: 2}, 26 | }, 27 | }, 28 | } 29 | 30 | balancer, err := NewBalancer(service, path) 31 | if err != nil { 32 | t.Fatalf("Failed to create balancer: %v", err) 33 | } 34 | 35 | if balancer == nil { 36 | t.Fatal("Balancer is nil") 37 | } 38 | } 39 | 40 | func TestRoundRobinPolicy(t *testing.T) { 41 | servers := []ServerAndWeight{ 42 | {server: createDummyProxy("http://localhost:8001/"), weight: 1, url: "http://localhost:8001/"}, 43 | {server: createDummyProxy("http://localhost:8002/"), weight: 1, url: "http://localhost:8002/"}, 44 | } 45 | 46 | policy := NewRoundRobinPolicy(servers) 47 | 48 | // Test the round-robin behavior 49 | expectedOrder := []string{"http://localhost:8001/", "http://localhost:8002/", "http://localhost:8001/", "http://localhost:8002/"} 50 | for i, expected := range expectedOrder { 51 | server := policy.GetNext() 52 | if server.Director == nil { 53 | t.Fatalf("Server %d is nil", i) 54 | } 55 | serverURL := getProxyURL(server) 56 | if serverURL != expected { 57 | t.Errorf("index = %d Expected server %s, got %s", i, expected, serverURL) 58 | } 59 | } 60 | } 61 | 62 | func TestRandomPolicy(t *testing.T) { 63 | servers := []ServerAndWeight{ 64 | {server: createDummyProxy("http://localhost:8001"), weight: 1, url: "http://localhost:8001"}, 65 | {server: createDummyProxy("http://localhost:8002"), weight: 1, url: "http://localhost:8002"}, 66 | } 67 | 68 | policy := NewRandomPolicy(servers) 69 | 70 | // Test that we get a valid server (we can't test randomness easily) 71 | for i := 0; i < 10; i++ { 72 | server := policy.GetNext() 73 | if server == nil { 74 | t.Fatal("Got nil server from RandomPolicy") 75 | } 76 | } 77 | } 78 | 79 | func TestLeastLatencyPolicy(t *testing.T) { 80 | // Create mock servers 81 | server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 82 | time.Sleep(20 * time.Millisecond) 83 | w.WriteHeader(http.StatusOK) 84 | w.Write([]byte("Slow response from server 1")) 85 | })) 86 | defer server1.Close() 87 | 88 | server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 89 | w.WriteHeader(http.StatusOK) 90 | w.Write([]byte("Fast response from server 2")) 91 | })) 92 | defer server2.Close() 93 | 94 | servers := []ServerAndWeight{ 95 | {server: httputil.NewSingleHostReverseProxy(mustParseURL(server1.URL)), weight: 1, url: server1.URL}, 96 | {server: httputil.NewSingleHostReverseProxy(mustParseURL(server2.URL)), weight: 1, url: server2.URL}, 97 | } 98 | 99 | policy := NewLeastLatencyPolicy(servers) 100 | 101 | // Initially, all servers should have 0 latency 102 | server := policy.GetNext() 103 | if server == nil { 104 | t.Fatal("Got nil server from LeastLatencyPolicy") 105 | } 106 | 107 | // Simulate a request and update latency 108 | w := httptest.NewRecorder() 109 | r, _ := http.NewRequest("GET", server1.URL, nil) 110 | server.ServeHTTP(w, r) 111 | 112 | // The policy should now prefer the fast second server 113 | server = policy.GetNext() 114 | serverURL := strings.TrimSuffix(getProxyURL(server), "/") 115 | if serverURL != strings.TrimSuffix(server2.URL, "/") { 116 | t.Errorf("LeastLatencyPolicy did not choose the server with least latency Got %s Want %s", serverURL, server2.URL) 117 | } 118 | } 119 | 120 | func TestBalancerServeHTTP(t *testing.T) { 121 | // Create mock servers 122 | server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 123 | w.WriteHeader(http.StatusOK) 124 | w.Write([]byte("Response from server 1")) 125 | })) 126 | defer server1.Close() 127 | 128 | server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 129 | w.WriteHeader(http.StatusOK) 130 | w.Write([]byte("Response from server 2")) 131 | })) 132 | defer server2.Close() 133 | 134 | // Create ServerAndWeight structs using the mock servers 135 | servers := []ServerAndWeight{ 136 | {server: httputil.NewSingleHostReverseProxy(mustParseURL(server1.URL)), weight: 1, url: server1.URL}, 137 | {server: httputil.NewSingleHostReverseProxy(mustParseURL(server2.URL)), weight: 1, url: server2.URL}, 138 | } 139 | 140 | policy := NewRoundRobinPolicy(servers) 141 | balancer := &Balancer{policy: policy} 142 | 143 | w := httptest.NewRecorder() 144 | r, _ := http.NewRequest("GET", "http://example.com", nil) 145 | 146 | balancer.ServeHTTP(w, r) 147 | 148 | if w.Code != http.StatusOK { 149 | t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) 150 | } 151 | 152 | w = httptest.NewRecorder() 153 | r, _ = http.NewRequest("GET", "http://example.com", nil) 154 | 155 | balancer.ServeHTTP(w, r) 156 | 157 | if w.Code != http.StatusOK { 158 | t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) 159 | } 160 | } 161 | 162 | // Helper function to create a dummy reverse proxy 163 | func createDummyProxy(targetURL string) *httputil.ReverseProxy { 164 | url, _ := url.Parse(targetURL) 165 | return httputil.NewSingleHostReverseProxy(url) 166 | } 167 | 168 | // Helper function to get the target URL of a reverse proxy 169 | func getProxyURL(proxy *httputil.ReverseProxy) string { 170 | req, _ := http.NewRequest("GET", "http://example.com", nil) 171 | proxy.Director(req) 172 | return req.URL.String() 173 | } 174 | 175 | // Helper function to parse URL and panic on error 176 | func mustParseURL(rawURL string) *url.URL { 177 | u, err := url.Parse(rawURL) 178 | if err != nil { 179 | panic(err) 180 | } 181 | return u 182 | } 183 | -------------------------------------------------------------------------------- /internal/handlers/files.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "path" 7 | "strings" 8 | ) 9 | 10 | type Files struct { 11 | basePath string 12 | handler http.Handler 13 | } 14 | 15 | func NewFiles(dirPath string, basePath string) Files { 16 | return Files{handler: http.FileServer(http.Dir(dirPath)), basePath: basePath} 17 | } 18 | 19 | func (f Files) ServeHTTP(w http.ResponseWriter, r *http.Request) { 20 | cleanedPath, err := removeBaseURLPath(f.basePath, r.URL.Path) 21 | if err == nil { 22 | r.URL.Path = cleanedPath 23 | } 24 | 25 | f.handler.ServeHTTP(w, r) 26 | } 27 | 28 | func removeBaseURLPath(basePath, fullPath string) (string, error) { 29 | // Ensure paths start with "/" 30 | basePath = "/" + strings.Trim(basePath, "/") 31 | fullPath = "/" + strings.Trim(fullPath, "/") 32 | 33 | // Normalize paths 34 | basePath = path.Clean(basePath) 35 | fullPath = path.Clean(fullPath) 36 | 37 | // Check if the full path starts with the base path 38 | if !strings.HasPrefix(fullPath, basePath) { 39 | return "", fmt.Errorf("full path %s is not in base path %s", fullPath, basePath) 40 | } 41 | 42 | // Remove the base path 43 | relPath := strings.TrimPrefix(fullPath, basePath) 44 | 45 | // Ensure the relative path starts with "/" 46 | relPath = "/" + strings.TrimPrefix(relPath, "/") 47 | 48 | return relPath, nil 49 | } 50 | -------------------------------------------------------------------------------- /internal/handlers/files_test.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "os" 7 | "path/filepath" 8 | "testing" 9 | ) 10 | 11 | func TestRemoveBaseURLPath(t *testing.T) { 12 | tests := []struct { 13 | name string 14 | basePath string 15 | fullPath string 16 | want string 17 | wantErr bool 18 | }{ 19 | { 20 | name: "simple path", 21 | basePath: "/api", 22 | fullPath: "/api/file.txt", 23 | want: "/file.txt", 24 | wantErr: false, 25 | }, 26 | { 27 | name: "path with multiple segments", 28 | basePath: "/api/v1", 29 | fullPath: "/api/v1/docs/file.txt", 30 | want: "/docs/file.txt", 31 | wantErr: false, 32 | }, 33 | { 34 | name: "paths with trailing slashes", 35 | basePath: "/api/", 36 | fullPath: "/api/file.txt/", 37 | want: "/file.txt", 38 | wantErr: false, 39 | }, 40 | { 41 | name: "paths without leading slashes", 42 | basePath: "api", 43 | fullPath: "api/file.txt", 44 | want: "/file.txt", 45 | wantErr: false, 46 | }, 47 | { 48 | name: "path not in base path", 49 | basePath: "/api", 50 | fullPath: "/other/file.txt", 51 | want: "", 52 | wantErr: true, 53 | }, 54 | { 55 | name: "empty paths", 56 | basePath: "", 57 | fullPath: "/file.txt", 58 | want: "/file.txt", 59 | wantErr: false, 60 | }, 61 | { 62 | name: "identical paths", 63 | basePath: "/api", 64 | fullPath: "/api", 65 | want: "/", 66 | wantErr: false, 67 | }, 68 | } 69 | 70 | for _, tt := range tests { 71 | t.Run(tt.name, func(t *testing.T) { 72 | got, err := removeBaseURLPath(tt.basePath, tt.fullPath) 73 | if (err != nil) != tt.wantErr { 74 | t.Errorf("removeBaseURLPath() error = %v, wantErr %v", err, tt.wantErr) 75 | return 76 | } 77 | if got != tt.want { 78 | t.Errorf("removeBaseURLPath() = %v, want %v", got, tt.want) 79 | } 80 | }) 81 | } 82 | } 83 | 84 | func TestFiles_ServeHTTP(t *testing.T) { 85 | // Create a temporary directory for test files 86 | tmpDir, err := os.MkdirTemp("", "files_test") 87 | if err != nil { 88 | t.Fatal(err) 89 | } 90 | defer os.RemoveAll(tmpDir) 91 | 92 | // Create a test file 93 | testContent := []byte("test file content") 94 | testFilePath := filepath.Join(tmpDir, "test.txt") 95 | if err := os.WriteFile(testFilePath, testContent, 0644); err != nil { 96 | t.Fatal(err) 97 | } 98 | 99 | tests := []struct { 100 | name string 101 | basePath string 102 | requestPath string 103 | expectedStatus int 104 | expectedBody string 105 | }{ 106 | { 107 | name: "valid file request", 108 | basePath: "/files", 109 | requestPath: "/files/test.txt", 110 | expectedStatus: http.StatusOK, 111 | expectedBody: "test file content", 112 | }, 113 | { 114 | name: "file not found", 115 | basePath: "/files", 116 | requestPath: "/files/nonexistent.txt", 117 | expectedStatus: http.StatusNotFound, 118 | expectedBody: "404 page not found\n", 119 | }, 120 | { 121 | name: "path outside base path", 122 | basePath: "/files", 123 | requestPath: "/other/test.txt", 124 | expectedStatus: http.StatusNotFound, 125 | expectedBody: "404 page not found\n", 126 | }, 127 | } 128 | 129 | for _, tt := range tests { 130 | t.Run(tt.name, func(t *testing.T) { 131 | // Create a new Files handler 132 | files := NewFiles(tmpDir, tt.basePath) 133 | 134 | // Create a test request 135 | req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil) 136 | w := httptest.NewRecorder() 137 | 138 | // Serve the request 139 | files.ServeHTTP(w, req) 140 | 141 | // Check status code 142 | if w.Code != tt.expectedStatus { 143 | t.Errorf("ServeHTTP() status = %v, want %v", w.Code, tt.expectedStatus) 144 | } 145 | 146 | // Check response body 147 | if w.Body.String() != tt.expectedBody { 148 | t.Errorf("ServeHTTP() body = %v, want %v", w.Body.String(), tt.expectedBody) 149 | } 150 | }) 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /internal/handlers/proxy.go: -------------------------------------------------------------------------------- 1 | package handlers 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httputil" 6 | "net/url" 7 | 8 | "github.com/hvuhsg/gatego/internal/config" 9 | "github.com/hvuhsg/gatego/internal/contextvalues" 10 | semconv "go.opentelemetry.io/otel/semconv/v1.4.0" 11 | ) 12 | 13 | type Proxy struct { 14 | proxy *httputil.ReverseProxy 15 | } 16 | 17 | func NewProxy(service config.Service, path config.Path) (Proxy, error) { 18 | serviceURL, err := url.Parse(*path.Destination) 19 | if err != nil { 20 | return Proxy{}, err 21 | } 22 | 23 | proxy := httputil.NewSingleHostReverseProxy(serviceURL) 24 | 25 | server := Proxy{proxy: proxy} 26 | return server, nil 27 | } 28 | 29 | func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 30 | tracer := contextvalues.TracerFromContext(r.Context()) 31 | if tracer != nil { 32 | ctx, span := tracer.Start(r.Context(), "request.upstream") 33 | span.SetAttributes(semconv.HTTPServerAttributesFromHTTPRequest(r.Host, r.URL.Path, r)...) 34 | r = r.WithContext(ctx) 35 | defer span.End() 36 | } 37 | p.proxy.ServeHTTP(w, r) 38 | } 39 | -------------------------------------------------------------------------------- /internal/middlewares/addheader.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "go.opentelemetry.io/otel/trace" 8 | ) 9 | 10 | func NewAddHeadersMiddleware(headers map[string]string) Middleware { 11 | return func(next http.Handler) http.Handler { 12 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 13 | span := trace.SpanFromContext(r.Context()) 14 | 15 | for header, value := range headers { 16 | r.Header.Set(header, value) 17 | span.AddEvent(fmt.Sprintf("Added header %s to request", header)) 18 | } 19 | next.ServeHTTP(w, r) 20 | }) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /internal/middlewares/cache.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | "strings" 7 | "time" 8 | 9 | "github.com/patrickmn/go-cache" 10 | "go.opentelemetry.io/otel/trace" 11 | ) 12 | 13 | const DEFAULT_CACHE_TTL = time.Minute * 1 14 | const CLEANUP_CACHE_INTERVAL = time.Minute * 10 15 | 16 | var responseCache = cache.New(DEFAULT_CACHE_TTL, CLEANUP_CACHE_INTERVAL) // Default cache with a placeholder TTL 17 | 18 | type CachedResponse struct { 19 | statusCode int 20 | body []byte 21 | headers http.Header 22 | } 23 | 24 | func NewCacheMiddleware() Middleware { 25 | return func(next http.Handler) http.Handler { 26 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 27 | span := trace.SpanFromContext(r.Context()) 28 | 29 | // Check if response response is already cached 30 | cachedResponse, found := responseCache.Get(r.URL.String()) 31 | if found { 32 | span.AddEvent("Cache hit") 33 | response := cachedResponse.(CachedResponse) 34 | for header := range response.headers { 35 | w.Header().Set(header, response.headers.Get(header)) 36 | } 37 | w.WriteHeader(response.statusCode) 38 | w.Write(response.body) 39 | return 40 | } 41 | 42 | // Serve the next handler and capture the response 43 | rc := NewRecorder() 44 | next.ServeHTTP(rc, r) 45 | 46 | // Get cache control headers 47 | cacheControl := rc.Header().Get("Cache-Control") 48 | maxAge := getCacheMaxAge(cacheControl) 49 | expires := getCacheExpires(rc.Header().Get("Expires")) 50 | 51 | // Determine TTL based on cache headers 52 | ttl := time.Second * 0 53 | if maxAge > 0 { 54 | ttl = time.Duration(maxAge) * time.Second 55 | } else if !expires.IsZero() { 56 | ttl = time.Until(expires) 57 | } 58 | 59 | // Cache the response if it's cacheable 60 | if ttl > 0 { 61 | cachedResponse := CachedResponse{statusCode: rc.Result().StatusCode, body: rc.Body.Bytes(), headers: rc.Result().Header} 62 | responseCache.Set(r.URL.String(), cachedResponse, ttl) 63 | span.AddEvent("Response stored in cache") 64 | } 65 | 66 | // Write the captured response (original or cached) 67 | rc.WriteTo(w) 68 | }) 69 | } 70 | } 71 | 72 | func getCacheMaxAge(cacheControl string) int { 73 | for _, directive := range strings.Split(cacheControl, ",") { 74 | directive = strings.TrimSpace(directive) 75 | if strings.HasPrefix(directive, "max-age=") { 76 | maxAge, err := strconv.Atoi(strings.TrimPrefix(directive, "max-age=")) 77 | if err == nil { 78 | return maxAge 79 | } 80 | } 81 | } 82 | return 0 83 | } 84 | 85 | func getCacheExpires(expiresHeader string) time.Time { 86 | expires, err := time.Parse(time.RFC1123, expiresHeader) 87 | if err != nil { 88 | return time.Time{} 89 | } 90 | return expires 91 | } 92 | -------------------------------------------------------------------------------- /internal/middlewares/cache_test.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestCacheMiddleware(t *testing.T) { 11 | t.Parallel() 12 | // Reset cache before each test 13 | responseCache.Flush() 14 | 15 | t.Run("Should not cache response with no cache headers", func(t *testing.T) { 16 | responseText := "test response" 17 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 18 | w.WriteHeader(200) 19 | w.Write([]byte(responseText)) 20 | }) 21 | 22 | middleware := NewCacheMiddleware()(handler) 23 | req := httptest.NewRequest("GET", "/test", nil) 24 | 25 | // First request 26 | w1 := httptest.NewRecorder() 27 | middleware.ServeHTTP(w1, req) 28 | 29 | if w1.Body.String() != "test response" { 30 | t.Errorf("Expected 'test response', got '%s'", w1.Body.String()) 31 | } 32 | 33 | responseText = "new response" 34 | 35 | // Second request - should be served from cache 36 | w2 := httptest.NewRecorder() 37 | middleware.ServeHTTP(w2, req) 38 | 39 | if w2.Body.String() != "new response" { 40 | t.Errorf("Expected not cached 'new response', got '%s'", w2.Body.String()) 41 | } 42 | }) 43 | 44 | t.Run("Should respect max-age Cache-Control header", func(t *testing.T) { 45 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 46 | w.Header().Set("Cache-Control", "max-age=1") 47 | w.WriteHeader(200) 48 | w.Write([]byte("cache-control test")) 49 | }) 50 | 51 | middleware := NewCacheMiddleware()(handler) 52 | req := httptest.NewRequest("GET", "/cache-control", nil) 53 | 54 | // First request 55 | w1 := httptest.NewRecorder() 56 | middleware.ServeHTTP(w1, req) 57 | 58 | // Wait for less than max-age 59 | time.Sleep(time.Millisecond * 500) 60 | 61 | // Should still be cached 62 | w2 := httptest.NewRecorder() 63 | middleware.ServeHTTP(w2, req) 64 | 65 | if w2.Body.String() != "cache-control test" { 66 | t.Errorf("Expected cached response before max-age expiration") 67 | } 68 | 69 | // Wait for cache to expire 70 | time.Sleep(time.Millisecond * 1500) 71 | 72 | if _, found := responseCache.Get("/cache-control"); found { 73 | t.Error("Cache should have expired") 74 | } 75 | }) 76 | 77 | t.Run("Should respect Expires header", func(t *testing.T) { 78 | responseText := "expires test" 79 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 80 | expiresTime := time.Now().Add(2 * time.Second) 81 | w.Header().Set("Expires", expiresTime.Format(time.RFC1123)) 82 | w.WriteHeader(200) 83 | w.Write([]byte(responseText)) 84 | }) 85 | 86 | middleware := NewCacheMiddleware()(handler) 87 | req := httptest.NewRequest("GET", "/expires", nil) 88 | 89 | // First request 90 | w1 := httptest.NewRecorder() 91 | middleware.ServeHTTP(w1, req) 92 | 93 | // Wait for less than expiration 94 | time.Sleep(time.Second * 1) 95 | 96 | responseText = "something else" 97 | 98 | // Should still be cached 99 | w2 := httptest.NewRecorder() 100 | middleware.ServeHTTP(w2, req) 101 | 102 | if w2.Body.String() != "expires test" { 103 | t.Errorf("Expected cached response before expiration") 104 | } 105 | 106 | // Wait for cache to expire 107 | time.Sleep(time.Second * 2) 108 | 109 | if _, found := responseCache.Get("/expires"); found { 110 | t.Error("Cache should have expired") 111 | } 112 | }) 113 | 114 | t.Run("Should preserve response headers", func(t *testing.T) { 115 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 116 | // Add expiration header 117 | expiresTime := time.Now().Add(50 * time.Second) 118 | w.Header().Set("Expires", expiresTime.Format(time.RFC1123)) 119 | 120 | w.Header().Set("Content-Type", "application/json") 121 | w.Header().Set("X-Custom-Header", "test-value") 122 | w.WriteHeader(200) 123 | w.Write([]byte(`{"message":"test"}`)) 124 | }) 125 | 126 | middleware := NewCacheMiddleware()(handler) 127 | req := httptest.NewRequest("GET", "/headers", nil) 128 | 129 | // First request 130 | w1 := httptest.NewRecorder() 131 | middleware.ServeHTTP(w1, req) 132 | 133 | // Second request - should preserve headers 134 | w2 := httptest.NewRecorder() 135 | middleware.ServeHTTP(w2, req) 136 | 137 | expectedHeaders := map[string]string{ 138 | "Content-Type": "application/json", 139 | "X-Custom-Header": "test-value", 140 | } 141 | 142 | for header, expectedValue := range expectedHeaders { 143 | if value := w2.Header().Get(header); value != expectedValue { 144 | t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value) 145 | } 146 | } 147 | }) 148 | 149 | t.Run("Should handle invalid cache headers gracefully", func(t *testing.T) { 150 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 151 | w.Header().Set("Cache-Control", "max-age=invalid") 152 | w.Header().Set("Expires", "invalid-date") 153 | w.WriteHeader(200) 154 | w.Write([]byte("invalid headers test")) 155 | }) 156 | 157 | middleware := NewCacheMiddleware()(handler) 158 | req := httptest.NewRequest("GET", "/invalid-headers", nil) 159 | 160 | w := httptest.NewRecorder() 161 | middleware.ServeHTTP(w, req) 162 | 163 | if w.Body.String() != "invalid headers test" { 164 | t.Errorf("Expected normal response despite invalid headers") 165 | } 166 | }) 167 | } 168 | 169 | func TestGetCacheMaxAge(t *testing.T) { 170 | t.Parallel() 171 | tests := []struct { 172 | name string 173 | cacheControl string 174 | expected int 175 | }{ 176 | {"Valid max-age", "max-age=60", 60}, 177 | {"Multiple directives", "public, max-age=30", 30}, 178 | {"Invalid max-age", "max-age=invalid", 0}, 179 | {"No max-age", "public, private", 0}, 180 | {"Empty string", "", 0}, 181 | } 182 | 183 | for _, tt := range tests { 184 | t.Run(tt.name, func(t *testing.T) { 185 | result := getCacheMaxAge(tt.cacheControl) 186 | if result != tt.expected { 187 | t.Errorf("getCacheMaxAge(%s) = %d; want %d", tt.cacheControl, result, tt.expected) 188 | } 189 | }) 190 | } 191 | } 192 | 193 | func TestGetCacheExpires(t *testing.T) { 194 | t.Parallel() 195 | 196 | now := time.Now() 197 | tests := []struct { 198 | name string 199 | expiresHeader string 200 | wantZero bool 201 | }{ 202 | {"Valid date", now.Format(time.RFC1123), false}, 203 | {"Invalid date", "invalid-date", true}, 204 | {"Empty string", "", true}, 205 | } 206 | 207 | for _, tt := range tests { 208 | t.Run(tt.name, func(t *testing.T) { 209 | result := getCacheExpires(tt.expiresHeader) 210 | if tt.wantZero && !result.IsZero() { 211 | t.Errorf("getCacheExpires(%s) expected zero time, got %v", tt.expiresHeader, result) 212 | } 213 | if !tt.wantZero && result.IsZero() { 214 | t.Errorf("getCacheExpires(%s) expected non-zero time, got zero time", tt.expiresHeader) 215 | } 216 | }) 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /internal/middlewares/gzip.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "compress/gzip" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | // GzipMiddleware compresses the response using gzip if the client supports it 10 | func GzipMiddleware(next http.Handler) http.Handler { 11 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 12 | // Check if the client accepts gzip encoding 13 | if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { 14 | // Client doesn't support gzip, serve the next handler 15 | next.ServeHTTP(w, r) 16 | return 17 | } 18 | 19 | // Create a gzip.Writer 20 | gzipWriter := gzip.NewWriter(w) 21 | defer gzipWriter.Close() 22 | 23 | // Serve the next handler, writing the response into the ResponseCapture 24 | rc := NewRecorder() 25 | next.ServeHTTP(rc, r) 26 | 27 | rc.WriteHeadersTo(w) 28 | 29 | w.Header().Del("Content-Length") 30 | w.Header().Set("Content-Encoding", "gzip") // Set Content-Encoding header 31 | 32 | w.WriteHeader(rc.Result().StatusCode) 33 | 34 | gzipWriter.Write(rc.Body.Bytes()) 35 | }) 36 | } 37 | -------------------------------------------------------------------------------- /internal/middlewares/gzip_test.go: -------------------------------------------------------------------------------- 1 | package middlewares_test 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "io" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | 11 | "github.com/hvuhsg/gatego/internal/middlewares" 12 | ) 13 | 14 | // Helper to decode gzip data 15 | func decodeGzip(t *testing.T, gzippedBody []byte) string { 16 | gzipReader, err := gzip.NewReader(bytes.NewReader(gzippedBody)) 17 | if err != nil { 18 | t.Fatalf("failed to create gzip reader: %v", err) 19 | } 20 | defer gzipReader.Close() 21 | 22 | var decodedBody bytes.Buffer 23 | if _, err := io.Copy(&decodedBody, gzipReader); err != nil { 24 | t.Fatalf("failed to decode gzip body: %v", err) 25 | } 26 | 27 | return decodedBody.String() 28 | } 29 | 30 | // TestGzipMiddleware_NoGzipSupport tests the middleware when the client does not support gzip 31 | func TestGzipMiddleware_NoGzipSupport(t *testing.T) { 32 | // Create a test handler to be wrapped by the middleware 33 | nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 34 | w.Write([]byte("Hello, World")) 35 | }) 36 | 37 | // Wrap the handler with GzipMiddleware 38 | handler := middlewares.GzipMiddleware(nextHandler) 39 | 40 | // Create a new HTTP request without gzip support 41 | req := httptest.NewRequest(http.MethodGet, "/", nil) 42 | req.Header.Set("Accept-Encoding", "identity") 43 | 44 | // Record the response 45 | rr := httptest.NewRecorder() 46 | handler.ServeHTTP(rr, req) 47 | 48 | // Check that the response is not gzipped 49 | if encoding := rr.Header().Get("Content-Encoding"); encoding != "" { 50 | t.Errorf("expected no gzip encoding, got %s", encoding) 51 | } 52 | 53 | // Check the body content 54 | if rr.Body.String() != "Hello, World" { 55 | t.Errorf("expected 'Hello, World', got %s", rr.Body.String()) 56 | } 57 | } 58 | 59 | // TestGzipMiddleware_WithGzipSupport tests the middleware when the client supports gzip 60 | func TestGzipMiddleware_WithGzipSupport(t *testing.T) { 61 | // Create a test handler to be wrapped by the middleware 62 | nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 63 | w.WriteHeader(http.StatusOK) 64 | w.Write([]byte("Hello, World")) 65 | }) 66 | 67 | // Wrap the handler with GzipMiddleware 68 | handler := middlewares.GzipMiddleware(nextHandler) 69 | 70 | // Create a new HTTP request with gzip support 71 | req := httptest.NewRequest(http.MethodGet, "/", nil) 72 | req.Header.Set("Accept-Encoding", "gzip") 73 | 74 | // Record the response 75 | rr := httptest.NewRecorder() 76 | handler.ServeHTTP(rr, req) 77 | 78 | // Check that the response is gzipped 79 | if encoding := rr.Header().Get("Content-Encoding"); encoding != "gzip" { 80 | t.Errorf("expected gzip encoding, got %s", encoding) 81 | } 82 | 83 | // Decode the gzipped response body 84 | gzippedBody := rr.Body.Bytes() 85 | decodedBody := decodeGzip(t, gzippedBody) 86 | 87 | // Check the body content 88 | if decodedBody != "Hello, World" { 89 | t.Errorf("expected 'Hello, World', got %s", decodedBody) 90 | } 91 | } 92 | 93 | // TestGzipMiddleware_StatusCode tests that the middleware preserves status codes 94 | func TestGzipMiddleware_StatusCode(t *testing.T) { 95 | // Create a test handler to be wrapped by the middleware 96 | nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 97 | w.WriteHeader(http.StatusCreated) 98 | w.Write([]byte("Created")) 99 | }) 100 | 101 | // Wrap the handler with GzipMiddleware 102 | handler := middlewares.GzipMiddleware(nextHandler) 103 | 104 | // Create a new HTTP request with gzip support 105 | req := httptest.NewRequest(http.MethodGet, "/", nil) 106 | req.Header.Set("Accept-Encoding", "gzip") 107 | 108 | // Record the response 109 | rr := httptest.NewRecorder() 110 | handler.ServeHTTP(rr, req) 111 | 112 | // Check that the response is gzipped 113 | if encoding := rr.Header().Get("Content-Encoding"); encoding != "gzip" { 114 | t.Errorf("expected gzip encoding, got %s", encoding) 115 | } 116 | 117 | // Check that the status code is preserved 118 | if status := rr.Result().StatusCode; status != http.StatusCreated { 119 | t.Errorf("expected status code %d, got %d", http.StatusCreated, status) 120 | } 121 | 122 | // Decode the gzipped response body 123 | gzippedBody := rr.Body.Bytes() 124 | decodedBody := decodeGzip(t, gzippedBody) 125 | 126 | // Check the body content 127 | if decodedBody != "Created" { 128 | t.Errorf("expected 'Created', got %s", decodedBody) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /internal/middlewares/logging.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "time" 8 | ) 9 | 10 | func formatDuration(ms int64) string { 11 | if ms < 1000 { 12 | return fmt.Sprintf("%dms", ms) 13 | } 14 | return fmt.Sprintf("%.1fs", float64(ms)/1000) 15 | } 16 | 17 | // Logging middleware log the request / response with the log style of nginx 18 | func NewLoggingMiddleware(out io.Writer) func(http.Handler) http.Handler { 19 | return func(next http.Handler) http.Handler { 20 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 21 | start := time.Now().UnixMilli() 22 | 23 | rh := &responseHook{ResponseWriter: w, respSize: 0} 24 | next.ServeHTTP(rh, r) 25 | 26 | end := time.Now().UnixMilli() 27 | 28 | scheme := "http" 29 | if r.TLS != nil { 30 | scheme = "https" 31 | } 32 | fullURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.String()) 33 | 34 | method := r.Method 35 | path := r.URL.Path 36 | responseSize := rh.respSize 37 | remoteAddr := r.RemoteAddr 38 | date := time.Now().Format("2006-01-02 15:04:05") 39 | userAgent := r.UserAgent() 40 | statusCode := rh.statusCode 41 | duration := formatDuration(end - start) 42 | 43 | fmt.Fprintf(out, "%s - - [%s] \"%s %s %s\" %d %d %s \"%s\" \"%s\"\n", remoteAddr, date, method, path, r.Proto, statusCode, responseSize, duration, fullURL, userAgent) 44 | }) 45 | } 46 | } 47 | 48 | type responseHook struct { 49 | http.ResponseWriter 50 | respSize int 51 | statusCode int 52 | } 53 | 54 | func (rh *responseHook) Write(b []byte) (int, error) { 55 | // Save the length of the response 56 | rh.respSize += len(b) 57 | 58 | return rh.ResponseWriter.Write(b) 59 | } 60 | 61 | func (rh *responseHook) WriteHeader(statusCode int) { 62 | // Save status code 63 | rh.statusCode = statusCode 64 | 65 | rh.ResponseWriter.WriteHeader(statusCode) 66 | } 67 | -------------------------------------------------------------------------------- /internal/middlewares/logging_test.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "regexp" 8 | "strings" 9 | "testing" 10 | ) 11 | 12 | func TestLoggingMiddleware(t *testing.T) { 13 | tests := []struct { 14 | name string 15 | method string 16 | path string 17 | statusCode int 18 | responseBody string 19 | expectedLogParts []string 20 | }{ 21 | { 22 | name: "GET request", 23 | method: "GET", 24 | path: "/test", 25 | statusCode: 200, 26 | responseBody: "OK", 27 | expectedLogParts: []string{ 28 | "GET", 29 | "/test", 30 | "HTTP/1.1", 31 | "200", 32 | "2", 33 | "http://example.com/test", 34 | }, 35 | }, 36 | { 37 | name: "POST request with 404", 38 | method: "POST", 39 | path: "/notfound", 40 | statusCode: 404, 41 | responseBody: "Not Found", 42 | expectedLogParts: []string{ 43 | "POST", 44 | "/notfound", 45 | "HTTP/1.1", 46 | "404", 47 | "9", 48 | "http://example.com/notfound", 49 | }, 50 | }, 51 | } 52 | 53 | for _, tt := range tests { 54 | t.Run(tt.name, func(t *testing.T) { 55 | // Create a buffer to capture the log output 56 | buf := &bytes.Buffer{} 57 | 58 | // Create a test handler that returns the specified status code and body 59 | testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 60 | w.WriteHeader(tt.statusCode) 61 | w.Write([]byte(tt.responseBody)) 62 | }) 63 | 64 | // Create the logging middleware 65 | loggingMiddleware := NewLoggingMiddleware(buf) 66 | 67 | // Create a test server with the logging middleware 68 | ts := httptest.NewServer(loggingMiddleware(testHandler)) 69 | defer ts.Close() 70 | 71 | // Create and send the request 72 | req, _ := http.NewRequest(tt.method, ts.URL+tt.path, nil) 73 | req.Host = "example.com" // Set a consistent host for testing 74 | resp, err := http.DefaultClient.Do(req) 75 | if err != nil { 76 | t.Fatalf("Error making request: %v", err) 77 | } 78 | defer resp.Body.Close() 79 | 80 | // Check the response 81 | if resp.StatusCode != tt.statusCode { 82 | t.Errorf("Expected status code %d, got %d", tt.statusCode, resp.StatusCode) 83 | } 84 | 85 | // Check the log output 86 | logOutput := buf.String() 87 | for _, expectedPart := range tt.expectedLogParts { 88 | if !strings.Contains(logOutput, expectedPart) { 89 | t.Errorf("Expected log to contain '%s', but it didn't. Log: %s", expectedPart, logOutput) 90 | } 91 | } 92 | 93 | // Check for the presence of a timestamp in the expected format 94 | timeStampFormat := "[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}" 95 | if matched, _ := regexp.MatchString(timeStampFormat, logOutput); !matched { 96 | t.Errorf("Expected log to contain a timestamp in format YYYY-MM-DD HH:MM:SS, but it didn't. Log: %s", logOutput) 97 | } 98 | 99 | // Check for the presence of a duration 100 | if !strings.Contains(logOutput, "ms") && !strings.Contains(logOutput, "s") { 101 | t.Errorf("Expected log to contain a duration, but it didn't. Log: %s", logOutput) 102 | } 103 | }) 104 | } 105 | } 106 | 107 | func TestFormatDuration(t *testing.T) { 108 | tests := []struct { 109 | name string 110 | duration int64 111 | expected string 112 | }{ 113 | {"Less than a second", 500, "500ms"}, 114 | {"Exactly one second", 1000, "1.0s"}, 115 | {"More than a second", 1500, "1.5s"}, 116 | {"Multiple seconds", 3750, "3.8s"}, 117 | } 118 | 119 | for _, tt := range tests { 120 | t.Run(tt.name, func(t *testing.T) { 121 | result := formatDuration(tt.duration) 122 | if result != tt.expected { 123 | t.Errorf("formatDuration(%d) = %s; want %s", tt.duration, result, tt.expected) 124 | } 125 | }) 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /internal/middlewares/middleware.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import "net/http" 4 | 5 | type Middleware func(http.Handler) http.Handler 6 | 7 | type HandlerWithMiddleware struct { 8 | finalHandler http.Handler 9 | middlewares []Middleware 10 | } 11 | 12 | func NewHandlerWithMiddleware(handler http.Handler) *HandlerWithMiddleware { 13 | return &HandlerWithMiddleware{ 14 | finalHandler: handler, 15 | middlewares: []Middleware{}, 16 | } 17 | } 18 | 19 | func (h *HandlerWithMiddleware) Add(middleware Middleware) { 20 | h.middlewares = append(h.middlewares, middleware) 21 | } 22 | 23 | func (h *HandlerWithMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { 24 | // Chain the middlewares around the final handler 25 | handler := h.finalHandler 26 | for i := len(h.middlewares) - 1; i >= 0; i-- { 27 | handler = h.middlewares[i](handler) 28 | } 29 | handler.ServeHTTP(w, r) 30 | } 31 | -------------------------------------------------------------------------------- /internal/middlewares/minify.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strconv" 7 | 8 | "github.com/tdewolff/minify/v2" 9 | "github.com/tdewolff/minify/v2/css" 10 | "github.com/tdewolff/minify/v2/html" 11 | "github.com/tdewolff/minify/v2/js" 12 | "github.com/tdewolff/minify/v2/json" 13 | "github.com/tdewolff/minify/v2/svg" 14 | "github.com/tdewolff/minify/v2/xml" 15 | "go.opentelemetry.io/otel/trace" 16 | ) 17 | 18 | type MinifyConfig struct { 19 | ALL bool 20 | JS bool 21 | CSS bool 22 | HTML bool 23 | JSON bool 24 | SVG bool 25 | XML bool 26 | } 27 | 28 | func NewMinifyMiddleware(config MinifyConfig) Middleware { 29 | m := minify.New() 30 | 31 | // Add minifiers for the different content types 32 | if config.HTML || config.ALL { 33 | m.AddFunc("text/html", html.Minify) 34 | } 35 | if config.CSS || config.ALL { 36 | m.AddFunc("text/css", css.Minify) 37 | } 38 | if config.JS || config.ALL { 39 | m.AddFunc("application/javascript", js.Minify) 40 | } 41 | if config.JSON || config.ALL { 42 | m.AddFunc("application/json", json.Minify) 43 | } 44 | if config.SVG || config.ALL { 45 | m.AddFunc("image/svg+xml", svg.Minify) 46 | } 47 | if config.XML || config.ALL { 48 | m.AddFunc("application/xml", xml.Minify) 49 | } 50 | 51 | return func(next http.Handler) http.Handler { 52 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 53 | span := trace.SpanFromContext(r.Context()) 54 | 55 | // Create a custom ResponseWriter to capture the response 56 | rc := NewRecorder() 57 | 58 | // Serve the next handler 59 | next.ServeHTTP(rc, r) 60 | 61 | // Get the content type of the response 62 | contentType := rc.Header().Get("Content-Type") 63 | 64 | minifiedContent, err := m.Bytes(contentType, rc.Body.Bytes()) 65 | if err != nil { 66 | rc.WriteTo(w) // Return the original response 67 | return 68 | } 69 | 70 | span.AddEvent(fmt.Sprintf("Minified response content, content-type = %s", contentType)) 71 | 72 | // Write the minified content to the response 73 | w.Header().Set("Content-Length", strconv.Itoa(len(minifiedContent))) 74 | rc.WriteHeadersTo(w) 75 | w.WriteHeader(rc.Result().StatusCode) 76 | 77 | w.Write(minifiedContent) 78 | }) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /internal/middlewares/minify_test.go: -------------------------------------------------------------------------------- 1 | package middlewares_test 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/hvuhsg/gatego/internal/middlewares" 9 | ) 10 | 11 | // Helper function to create a basic next handler that returns content with a specific content type 12 | func createHandler(contentType, content string) http.Handler { 13 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 | w.Header().Set("Content-Type", contentType) 15 | w.WriteHeader(http.StatusOK) 16 | w.Write([]byte(content)) 17 | }) 18 | } 19 | 20 | // TestMinifyMiddleware_HTML tests HTML minification 21 | func TestMinifyMiddleware_HTML(t *testing.T) { 22 | handler := createHandler("text/html", "

Hello World

") 23 | 24 | config := middlewares.MinifyConfig{HTML: true} 25 | middleware := middlewares.NewMinifyMiddleware(config) 26 | minifiedHandler := middleware(handler) 27 | 28 | req := httptest.NewRequest(http.MethodGet, "/", nil) 29 | rr := httptest.NewRecorder() 30 | 31 | minifiedHandler.ServeHTTP(rr, req) 32 | 33 | expected := "

Hello World

" 34 | if rr.Body.String() != expected { 35 | t.Errorf("expected '%s', got '%s'", expected, rr.Body.String()) 36 | } 37 | } 38 | 39 | // TestMinifyMiddleware_CSS tests CSS minification 40 | func TestMinifyMiddleware_CSS(t *testing.T) { 41 | handler := createHandler("text/css", "body { color: red; }") 42 | 43 | config := middlewares.MinifyConfig{CSS: true} 44 | middleware := middlewares.NewMinifyMiddleware(config) 45 | minifiedHandler := middleware(handler) 46 | 47 | req := httptest.NewRequest(http.MethodGet, "/", nil) 48 | rr := httptest.NewRecorder() 49 | 50 | minifiedHandler.ServeHTTP(rr, req) 51 | 52 | expected := "body{color:red}" 53 | if rr.Body.String() != expected { 54 | t.Errorf("expected '%s', got '%s'", expected, rr.Body.String()) 55 | } 56 | } 57 | 58 | // TestMinifyMiddleware_JS tests JS minification 59 | func TestMinifyMiddleware_JS(t *testing.T) { 60 | handler := createHandler("application/javascript", "function test() { return 1; }") 61 | 62 | config := middlewares.MinifyConfig{JS: true} 63 | middleware := middlewares.NewMinifyMiddleware(config) 64 | minifiedHandler := middleware(handler) 65 | 66 | req := httptest.NewRequest(http.MethodGet, "/", nil) 67 | rr := httptest.NewRecorder() 68 | 69 | minifiedHandler.ServeHTTP(rr, req) 70 | 71 | expected := "function test(){return 1}" 72 | if rr.Body.String() != expected { 73 | t.Errorf("expected '%s', got '%s'", expected, rr.Body.String()) 74 | } 75 | } 76 | 77 | // TestMinifyMiddleware_JSON tests JSON minification 78 | func TestMinifyMiddleware_JSON(t *testing.T) { 79 | handler := createHandler("application/json", `{ 80 | "name": "John", 81 | "age": 30 82 | }`) 83 | 84 | config := middlewares.MinifyConfig{JSON: true} 85 | middleware := middlewares.NewMinifyMiddleware(config) 86 | minifiedHandler := middleware(handler) 87 | 88 | req := httptest.NewRequest(http.MethodGet, "/", nil) 89 | rr := httptest.NewRecorder() 90 | 91 | minifiedHandler.ServeHTTP(rr, req) 92 | 93 | expected := `{"name":"John","age":30}` 94 | if rr.Body.String() != expected { 95 | t.Errorf("expected '%s', got '%s'", expected, rr.Body.String()) 96 | } 97 | } 98 | 99 | // TestMinifyMiddleware_SkipUnsupported tests that unsupported content types are not minified 100 | func TestMinifyMiddleware_SkipUnsupported(t *testing.T) { 101 | handler := createHandler("text/plain", "This is a plain text file.") 102 | 103 | config := middlewares.MinifyConfig{HTML: true, CSS: true, JS: true} 104 | middleware := middlewares.NewMinifyMiddleware(config) 105 | minifiedHandler := middleware(handler) 106 | 107 | req := httptest.NewRequest(http.MethodGet, "/", nil) 108 | rr := httptest.NewRecorder() 109 | 110 | minifiedHandler.ServeHTTP(rr, req) 111 | 112 | expected := "This is a plain text file." 113 | if rr.Body.String() != expected { 114 | t.Errorf("expected '%s', got '%s'", expected, rr.Body.String()) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /internal/middlewares/omit_headers.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "go.opentelemetry.io/otel/trace" 8 | ) 9 | 10 | // OmitHeaders middleware removes specified headers from the response to enhance security. 11 | func NewOmitHeadersMiddleware(headersToOmit []string) func(http.Handler) http.Handler { 12 | return func(next http.Handler) http.Handler { 13 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 | span := trace.SpanFromContext(r.Context()) 15 | 16 | rc := NewRecorder() 17 | next.ServeHTTP(rc, r) 18 | 19 | // Omit headers from response 20 | for _, header := range headersToOmit { 21 | if rc.Result().Header.Get(header) != "" { 22 | rc.Result().Header.Del(header) 23 | span.AddEvent(fmt.Sprintf("Removed response header %s", header)) 24 | } 25 | } 26 | 27 | rc.WriteHeadersTo(w) 28 | w.WriteHeader(rc.Result().StatusCode) 29 | w.Write(rc.Body.Bytes()) 30 | }) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /internal/middlewares/omit_headers_test.go: -------------------------------------------------------------------------------- 1 | package middlewares_test 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/hvuhsg/gatego/internal/middlewares" 9 | ) 10 | 11 | // TestOmitHeadersMiddleware_OmitResponseHeaders tests that headers are omitted from the response 12 | func TestOmitHeadersMiddleware_OmitResponseHeaders(t *testing.T) { 13 | nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 | w.Header().Set("Authorization", "Bearer some-secret-token") 15 | w.Header().Set("X-API-Key", "secret-api-key") 16 | w.WriteHeader(http.StatusOK) 17 | w.Write([]byte("OK")) 18 | }) 19 | 20 | headers := []string{"Authorization", "X-API-Key"} 21 | handler := middlewares.NewOmitHeadersMiddleware(headers)(nextHandler) 22 | 23 | req := httptest.NewRequest(http.MethodGet, "/", nil) 24 | 25 | rr := httptest.NewRecorder() 26 | handler.ServeHTTP(rr, req) 27 | 28 | if rr.Header().Get("Authorization") != "" { 29 | t.Errorf("expected 'Authorization' header to be omitted, got %s", rr.Header().Get("Authorization")) 30 | } 31 | if rr.Header().Get("X-API-Key") != "" { 32 | t.Errorf("expected 'X-API-Key' header to be omitted, got %s", rr.Header().Get("X-API-Key")) 33 | } 34 | 35 | if rr.Body.String() != "OK" { 36 | t.Errorf("expected 'OK', got %s", rr.Body.String()) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /internal/middlewares/openapi.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/getkin/kin-openapi/openapi3" 8 | "github.com/getkin/kin-openapi/openapi3filter" 9 | "github.com/getkin/kin-openapi/routers/gorillamux" 10 | "go.opentelemetry.io/otel/trace" 11 | ) 12 | 13 | func NewOpenAPIValidationMiddleware(specPath string) (Middleware, error) { 14 | loader := &openapi3.Loader{IsExternalRefsAllowed: true} 15 | doc, err := loader.LoadFromFile(specPath) 16 | if err != nil { 17 | return nil, fmt.Errorf("error loading OpenAPI spec: %w", err) 18 | } 19 | 20 | if err := doc.Validate(loader.Context); err != nil { 21 | return nil, fmt.Errorf("error validating OpenAPI spec: %w", err) 22 | } 23 | 24 | router, err := gorillamux.NewRouter(doc) 25 | if err != nil { 26 | return nil, fmt.Errorf("error creating router: %w", err) 27 | } 28 | 29 | return func(next http.Handler) http.Handler { 30 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 31 | span := trace.SpanFromContext(r.Context()) 32 | 33 | route, pathParams, err := router.FindRoute(r) 34 | if err != nil { 35 | span.AddEvent("Request path not found in openapi spec") 36 | http.Error(w, fmt.Sprintf("Error finding route: %v", err), http.StatusBadRequest) 37 | return 38 | } 39 | 40 | requestValidationInput := &openapi3filter.RequestValidationInput{ 41 | Request: r, 42 | PathParams: pathParams, 43 | Route: route, 44 | } 45 | 46 | if err := openapi3filter.ValidateRequest(r.Context(), requestValidationInput); err != nil { 47 | span.AddEvent(fmt.Sprintf("Error while validating request with openapi spec. err = %v", err)) 48 | http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) 49 | return 50 | } 51 | 52 | span.AddEvent("Request validated by openapi spec") 53 | 54 | rc := NewRecorder() 55 | next.ServeHTTP(rc, r) 56 | 57 | responseValidationInput := &openapi3filter.ResponseValidationInput{ 58 | RequestValidationInput: requestValidationInput, 59 | Status: rc.Result().StatusCode, 60 | Header: rc.Header(), 61 | } 62 | 63 | if rc.Body.Bytes() != nil { 64 | responseValidationInput.SetBodyBytes(rc.Body.Bytes()) 65 | } 66 | 67 | if err := openapi3filter.ValidateResponse(r.Context(), responseValidationInput); err != nil { 68 | span.AddEvent(fmt.Sprintf("Error while validating response with openapi spec. err = %v", err)) 69 | http.Error(w, fmt.Sprintf("Invalid response: %v", err), http.StatusInternalServerError) 70 | return 71 | } 72 | 73 | span.AddEvent("Response validated by openapi spec") 74 | 75 | rc.WriteHeadersTo(w) 76 | w.WriteHeader(rc.Result().StatusCode) 77 | if rc.Body.Bytes() != nil { 78 | w.Write(rc.Body.Bytes()) 79 | } 80 | }) 81 | }, nil 82 | } 83 | -------------------------------------------------------------------------------- /internal/middlewares/openapi_test.go: -------------------------------------------------------------------------------- 1 | package middlewares_test 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "net/http/httptest" 7 | "os" 8 | "testing" 9 | 10 | "github.com/hvuhsg/gatego/internal/middlewares" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | // Helper function to compare JSON 16 | func assertJSONEqual(t *testing.T, expected, actual string) { 17 | var expectedMap, actualMap map[string]interface{} 18 | err := json.Unmarshal([]byte(expected), &expectedMap) 19 | require.NoError(t, err, "Error unmarshaling expected JSON") 20 | err = json.Unmarshal([]byte(actual), &actualMap) 21 | require.NoError(t, err, "Error unmarshaling actual JSON") 22 | assert.Equal(t, expectedMap, actualMap) 23 | } 24 | 25 | func TestOpenAPIValidationMiddleware(t *testing.T) { 26 | // Create a temporary OpenAPI spec file for testing 27 | specFile, err := os.CreateTemp("", "openapi-spec-*.yaml") 28 | require.NoError(t, err) 29 | defer os.Remove(specFile.Name()) 30 | 31 | // Write a simple OpenAPI spec to the file 32 | specContent := ` 33 | openapi: 3.0.0 34 | info: 35 | title: Test API 36 | version: 1.0.0 37 | paths: 38 | /test: 39 | get: 40 | parameters: 41 | - name: param 42 | in: query 43 | required: true 44 | schema: 45 | type: string 46 | responses: 47 | '200': 48 | description: OK 49 | content: 50 | application/json: 51 | schema: 52 | type: object 53 | required: 54 | - message 55 | properties: 56 | message: 57 | type: string 58 | status: 59 | type: string 60 | ` 61 | _, err = specFile.Write([]byte(specContent)) 62 | require.NoError(t, err) 63 | specFile.Close() 64 | 65 | // Create the middleware 66 | middleware, err := middlewares.NewOpenAPIValidationMiddleware(specFile.Name()) 67 | require.NoError(t, err) 68 | 69 | tests := []struct { 70 | name string 71 | url string 72 | handler http.HandlerFunc 73 | expectedStatus int 74 | expectedBody string 75 | isJSON bool 76 | }{ 77 | { 78 | name: "Valid request and response", 79 | url: "/test?param=value", 80 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 81 | w.Header().Set("Content-Type", "application/json") 82 | w.WriteHeader(http.StatusOK) 83 | json.NewEncoder(w).Encode(map[string]string{"message": "Hello, World!", "status": "ok"}) 84 | }), 85 | expectedStatus: http.StatusOK, 86 | expectedBody: `{"message":"Hello, World!","status":"ok"}`, 87 | isJSON: true, 88 | }, 89 | { 90 | name: "Valid request but invalid response (missing required field)", 91 | url: "/test?param=value", 92 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 93 | w.Header().Set("Content-Type", "application/json") 94 | w.WriteHeader(http.StatusOK) 95 | json.NewEncoder(w).Encode(map[string]string{"status": "error"}) 96 | }), 97 | expectedStatus: http.StatusInternalServerError, 98 | expectedBody: "Invalid response:", 99 | isJSON: false, 100 | }, 101 | { 102 | name: "Valid request but invalid response (wrong content type)", 103 | url: "/test?param=value", 104 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 105 | w.Header().Set("Content-Type", "text/plain") 106 | w.WriteHeader(http.StatusOK) 107 | w.Write([]byte("Hello, World!")) 108 | }), 109 | expectedStatus: http.StatusInternalServerError, 110 | expectedBody: "Invalid response:", 111 | isJSON: false, 112 | }, 113 | { 114 | name: "Valid request but response with extra field", 115 | url: "/test?param=value", 116 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 117 | w.Header().Set("Content-Type", "application/json") 118 | w.WriteHeader(http.StatusOK) 119 | json.NewEncoder(w).Encode(map[string]string{"message": "Hello, World!", "extra": "field"}) 120 | }), 121 | expectedStatus: http.StatusOK, 122 | expectedBody: `{"message":"Hello, World!","extra":"field"}`, 123 | isJSON: true, 124 | }, 125 | { 126 | name: "Invalid path", 127 | url: "/invalid", 128 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), 129 | expectedStatus: http.StatusBadRequest, 130 | expectedBody: "Error finding route:", 131 | isJSON: false, 132 | }, 133 | { 134 | name: "Missing required parameter", 135 | url: "/test", 136 | handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), 137 | expectedStatus: http.StatusBadRequest, 138 | expectedBody: "Invalid request:", 139 | isJSON: false, 140 | }, 141 | } 142 | 143 | for _, tt := range tests { 144 | t.Run(tt.name, func(t *testing.T) { 145 | req, err := http.NewRequest("GET", tt.url, nil) 146 | require.NoError(t, err) 147 | 148 | rr := httptest.NewRecorder() 149 | wrappedHandler := middleware(tt.handler) 150 | wrappedHandler.ServeHTTP(rr, req) 151 | 152 | assert.Equal(t, tt.expectedStatus, rr.Code) 153 | 154 | if tt.isJSON { 155 | assertJSONEqual(t, tt.expectedBody, rr.Body.String()) 156 | } else { 157 | assert.Contains(t, rr.Body.String(), tt.expectedBody) 158 | } 159 | }) 160 | } 161 | } 162 | 163 | func TestNewOpenAPIValidationMiddleware(t *testing.T) { 164 | tests := []struct { 165 | name string 166 | specContent string 167 | expectError bool 168 | }{ 169 | { 170 | name: "Valid OpenAPI spec", 171 | specContent: ` 172 | openapi: 3.0.0 173 | info: 174 | title: Test API 175 | version: 1.0.0 176 | paths: 177 | /test: 178 | get: 179 | responses: 180 | '200': 181 | description: OK 182 | `, 183 | expectError: false, 184 | }, 185 | { 186 | name: "Invalid OpenAPI spec", 187 | specContent: "invalid: yaml: content", 188 | expectError: true, 189 | }, 190 | } 191 | 192 | for _, tt := range tests { 193 | t.Run(tt.name, func(t *testing.T) { 194 | specFile, err := os.CreateTemp("", "openapi-spec-*.yaml") 195 | require.NoError(t, err) 196 | defer os.Remove(specFile.Name()) 197 | 198 | _, err = specFile.Write([]byte(tt.specContent)) 199 | require.NoError(t, err) 200 | specFile.Close() 201 | 202 | middleware, err := middlewares.NewOpenAPIValidationMiddleware(specFile.Name()) 203 | 204 | if tt.expectError { 205 | assert.Error(t, err) 206 | assert.Nil(t, middleware) 207 | } else { 208 | assert.NoError(t, err) 209 | assert.NotNil(t, middleware) 210 | } 211 | }) 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /internal/middlewares/otel.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/hvuhsg/gatego/internal/contextvalues" 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | "go.opentelemetry.io/otel/codes" 12 | "go.opentelemetry.io/otel/propagation" 13 | semconv "go.opentelemetry.io/otel/semconv/v1.4.0" 14 | "go.opentelemetry.io/otel/trace" 15 | ) 16 | 17 | const tracerName = "request" 18 | const spanName = "middlewares" 19 | 20 | type OTELConfig struct { 21 | ServiceDomain string 22 | BasePath string 23 | } 24 | 25 | func NewOpenTelemetryMiddleware(ctx context.Context, config OTELConfig) (Middleware, error) { 26 | tp := otel.GetTracerProvider() 27 | tracer := tp.Tracer(tracerName) 28 | 29 | return func(next http.Handler) http.Handler { 30 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 31 | // Add tracer to request context 32 | r = r.WithContext(contextvalues.AddTracerToContext(r.Context(), tracer)) 33 | 34 | // Create span for request 35 | ctx, span := tracer.Start( 36 | r.Context(), 37 | spanName, 38 | trace.WithAttributes(semconv.NetAttributesFromHTTPRequest("", r)...), 39 | trace.WithSpanKind(trace.SpanKindServer), 40 | ) 41 | defer span.End() 42 | 43 | // Add request-specific attributes 44 | attrs := make([]attribute.KeyValue, 0) 45 | attrs = append(attrs, semconv.HTTPUserAgentKey.String(r.UserAgent())) 46 | attrs = append(attrs, semconv.HTTPServerAttributesFromHTTPRequest(config.ServiceDomain, config.BasePath, r)...) 47 | span.SetAttributes( 48 | attrs..., 49 | ) 50 | 51 | // Handle panic recovery 52 | defer func() { 53 | if err := recover(); err != nil { 54 | span.SetStatus(codes.Error, fmt.Sprintf("panic: %v", err)) 55 | span.RecordError(fmt.Errorf("%v", err)) 56 | panic(err) // Re-panic after recording error 57 | } 58 | }() 59 | 60 | // Propegate open telemetry context via the request to the upstream service 61 | otel.GetTextMapPropagator().Inject(r.Context(), propagation.HeaderCarrier(r.Header)) 62 | 63 | // Add span to request context 64 | rc := NewRecorder() 65 | next.ServeHTTP(rc, r.WithContext(ctx)) 66 | 67 | // Set status and attributes based on response code 68 | statusCode := rc.Result().StatusCode 69 | span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(statusCode)...) 70 | if statusCode >= 400 { 71 | span.SetStatus(codes.Error, http.StatusText(statusCode)) 72 | if statusCode >= 500 { 73 | span.RecordError(fmt.Errorf("server error: %d", statusCode)) 74 | } 75 | } else { 76 | span.SetStatus(codes.Ok, "") 77 | } 78 | 79 | // Add response information 80 | span.SetAttributes( 81 | attribute.Int64("http.response_size", rc.Result().ContentLength), 82 | attribute.String("http.response_content_type", rc.Result().Header.Get("Content-Type")), 83 | ) 84 | 85 | // Return response 86 | rc.WriteTo(w) 87 | }) 88 | }, nil 89 | } 90 | -------------------------------------------------------------------------------- /internal/middlewares/ratelimit.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "net/http" 7 | "slices" 8 | "strconv" 9 | "strings" 10 | "sync" 11 | "time" 12 | 13 | "golang.org/x/time/rate" 14 | ) 15 | 16 | var SupportedZones = []string{"ip"} 17 | var ErrZoneNotSupported = errors.New("rate limit zone is not supported") 18 | 19 | type RateLimiter struct { 20 | limiters sync.Map 21 | } 22 | 23 | type LimitConfig struct { 24 | Zone string 25 | Requests int 26 | Per time.Duration 27 | } 28 | 29 | func (lc LimitConfig) GetKey(r *http.Request) (key string, err error) { 30 | err = nil 31 | switch lc.Zone { 32 | case "ip": 33 | parts := strings.Split(r.RemoteAddr, ":") 34 | ip := parts[0] 35 | key = "ip:" + ip 36 | default: 37 | err = errors.New("rate limit zone is not supported") 38 | } 39 | key = strconv.Itoa(int(lc.Per.Seconds())) + "|" + strconv.Itoa(lc.Requests) + "!" + key 40 | return 41 | } 42 | 43 | func NewRateLimiter() *RateLimiter { 44 | return &RateLimiter{} 45 | } 46 | 47 | func (rl *RateLimiter) addLimiter(key string, limit rate.Limit, burst int) { 48 | limiter := rate.NewLimiter(limit, burst) 49 | rl.limiters.Store(key, limiter) 50 | } 51 | 52 | func (rl *RateLimiter) getLimiter(key string) *rate.Limiter { 53 | if limiter, ok := rl.limiters.Load(key); ok { 54 | return limiter.(*rate.Limiter) 55 | } 56 | return nil 57 | } 58 | 59 | func ParseLimitConfig(config string) (LimitConfig, error) { 60 | parts := strings.Split(config, "-") 61 | if len(parts) != 2 { 62 | return LimitConfig{}, fmt.Errorf("invalid limit config: %s", config) 63 | } 64 | zone := parts[0] 65 | if !slices.Contains(SupportedZones, strings.ToLower(zone)) { 66 | return LimitConfig{}, ErrZoneNotSupported 67 | } 68 | 69 | limitParts := strings.Split(parts[1], "/") 70 | if len(limitParts) != 2 { 71 | return LimitConfig{}, fmt.Errorf("invalid limit config: %s", config) 72 | } 73 | 74 | requests, err := strconv.Atoi(limitParts[0]) 75 | if err != nil { 76 | return LimitConfig{}, fmt.Errorf("invalid requests number: %s", limitParts[0]) 77 | } 78 | 79 | var duration time.Duration 80 | switch limitParts[1] { 81 | case "s": 82 | duration = time.Second 83 | case "m": 84 | duration = time.Minute 85 | case "h": 86 | duration = time.Hour 87 | case "d": 88 | duration = time.Hour * 24 89 | default: 90 | return LimitConfig{}, fmt.Errorf("invalid time unit: %s", limitParts[1]) 91 | } 92 | 93 | return LimitConfig{ 94 | Zone: zone, 95 | Requests: requests, 96 | Per: duration, 97 | }, nil 98 | } 99 | 100 | func NewRateLimitMiddleware(limits []string) (func(http.Handler) http.Handler, error) { 101 | rateLimiter := NewRateLimiter() 102 | 103 | // Pre-process ratelimit configs 104 | parsedLimits := make([]LimitConfig, 0, len(limits)) 105 | for _, limit := range limits { 106 | parsedLimit, err := ParseLimitConfig(limit) 107 | if err != nil { 108 | return nil, err 109 | } 110 | parsedLimits = append(parsedLimits, parsedLimit) 111 | } 112 | 113 | return func(next http.Handler) http.Handler { 114 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 115 | for _, config := range parsedLimits { 116 | key, err := config.GetKey(r) 117 | if err != nil { 118 | // Should never reach here (validation should prevent it) 119 | http.Error(w, err.Error(), http.StatusInternalServerError) 120 | } 121 | 122 | limiter := rateLimiter.getLimiter(key) 123 | if limiter == nil { 124 | rateLimiter.addLimiter(key, rate.Every(config.Per), config.Requests) 125 | limiter = rateLimiter.getLimiter(key) 126 | } 127 | 128 | if !limiter.Allow() { 129 | setRateLimitHeaders(w, limiter, config) 130 | http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) 131 | return 132 | } 133 | } 134 | next.ServeHTTP(w, r) 135 | }) 136 | }, nil 137 | } 138 | 139 | func setRateLimitHeaders(w http.ResponseWriter, limiter *rate.Limiter, config LimitConfig) { 140 | now := time.Now() 141 | limit := config.Requests 142 | remaining := int(limiter.Tokens()) 143 | reset := now.Add(config.Per).Unix() 144 | 145 | w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", limit)) 146 | w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) 147 | w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", reset)) 148 | } 149 | -------------------------------------------------------------------------------- /internal/middlewares/ratelimit_test.go: -------------------------------------------------------------------------------- 1 | package middlewares_test 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/http/httptest" 7 | "sync" 8 | "testing" 9 | 10 | "github.com/hvuhsg/gatego/internal/middlewares" 11 | ) 12 | 13 | func TestRateLimitExceeded(t *testing.T) { 14 | limits := []string{"ip-1/s"} 15 | rateLimitMiddleware, err := middlewares.NewRateLimitMiddleware(limits) 16 | if err != nil { 17 | t.Fatalf("Error creating middleware: %v", err) 18 | } 19 | 20 | handler := rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 21 | w.WriteHeader(http.StatusOK) 22 | })) 23 | 24 | // Create a test server 25 | req := httptest.NewRequest("GET", "http://example.com", nil) 26 | req.RemoteAddr = "192.168.1.1:12345" 27 | rr := httptest.NewRecorder() 28 | 29 | // First request should pass 30 | handler.ServeHTTP(rr, req) 31 | if rr.Code != http.StatusOK { 32 | t.Errorf("expected status OK, got %v", rr.Code) 33 | } 34 | 35 | // Second request should fail (rate limit exceeded) 36 | rr = httptest.NewRecorder() 37 | handler.ServeHTTP(rr, req) 38 | if rr.Code != http.StatusTooManyRequests { 39 | t.Errorf("expected status TooManyRequests, got %v", rr.Code) 40 | } 41 | } 42 | 43 | func TestRateLimitHeaders(t *testing.T) { 44 | limits := []string{"ip-1/s"} 45 | rateLimitMiddleware, err := middlewares.NewRateLimitMiddleware(limits) 46 | if err != nil { 47 | t.Fatalf("Error creating middleware: %v", err) 48 | } 49 | 50 | handler := rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 51 | w.WriteHeader(http.StatusOK) 52 | })) 53 | 54 | req := httptest.NewRequest("GET", "http://example.com", nil) 55 | req.RemoteAddr = "192.168.1.1:12345" 56 | rr := httptest.NewRecorder() 57 | 58 | // First request should pass 59 | handler.ServeHTTP(rr, req) 60 | 61 | // Second request should fail 62 | handler.ServeHTTP(rr, req) 63 | 64 | if rr.Header().Get("X-RateLimit-Limit") != "1" { 65 | t.Errorf("expected X-RateLimit-Limit 1, got %s", rr.Header().Get("X-RateLimit-Limit")) 66 | } 67 | if rr.Header().Get("X-RateLimit-Remaining") != "0" { 68 | t.Errorf("expected X-RateLimit-Remaining 0, got %s", rr.Header().Get("X-RateLimit-Remaining")) 69 | } 70 | } 71 | 72 | func TestInvalidRateLimitConfig(t *testing.T) { 73 | limits := []string{"ip-invalid/s"} 74 | _, err := middlewares.NewRateLimitMiddleware(limits) 75 | if err == nil { 76 | t.Errorf("expected error for invalid config, got none") 77 | } 78 | } 79 | 80 | func TestUnsupportedZone(t *testing.T) { 81 | limits := []string{"unsupported-10/s"} 82 | _, err := middlewares.NewRateLimitMiddleware(limits) 83 | if !errors.Is(err, middlewares.ErrZoneNotSupported) { 84 | t.Errorf("expected ErrZoneNotSupported, got %v", err) 85 | } 86 | } 87 | 88 | func TestRateLimitConcurrentRequests(t *testing.T) { 89 | limits := []string{"ip-3/s"} 90 | rateLimitMiddleware, err := middlewares.NewRateLimitMiddleware(limits) 91 | if err != nil { 92 | t.Fatalf("Error creating middleware: %v", err) 93 | } 94 | 95 | handler := rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 96 | w.WriteHeader(http.StatusOK) 97 | })) 98 | 99 | req := httptest.NewRequest("GET", "http://example.com", nil) 100 | req.RemoteAddr = "192.168.1.1:12345" 101 | 102 | var wg sync.WaitGroup 103 | var rateLimitedCount int 104 | var mu sync.Mutex 105 | 106 | for i := 0; i < 5; i++ { 107 | wg.Add(1) 108 | go func() { 109 | defer wg.Done() 110 | rr := httptest.NewRecorder() 111 | handler.ServeHTTP(rr, req) 112 | 113 | mu.Lock() 114 | if rr.Code == http.StatusTooManyRequests { 115 | rateLimitedCount++ 116 | } 117 | mu.Unlock() 118 | }() 119 | } 120 | 121 | wg.Wait() 122 | 123 | if rateLimitedCount != 2 { 124 | t.Errorf("expected 2 rate limited requests, got %d", rateLimitedCount) 125 | } 126 | } 127 | 128 | func TestRateLimitDifferentTimeWindows(t *testing.T) { 129 | limits := []string{"ip-2/s", "ip-5/m"} 130 | rateLimitMiddleware, err := middlewares.NewRateLimitMiddleware(limits) 131 | if err != nil { 132 | t.Fatalf("Error creating middleware: %v", err) 133 | } 134 | 135 | handler := rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 136 | w.WriteHeader(http.StatusOK) 137 | })) 138 | 139 | req := httptest.NewRequest("GET", "http://example.com", nil) 140 | req.RemoteAddr = "192.168.1.1:12345" 141 | rr := httptest.NewRecorder() 142 | 143 | // First and second request should pass 144 | handler.ServeHTTP(rr, req) 145 | if rr.Code != http.StatusOK { 146 | t.Errorf("expected status OK, got %v", rr.Code) 147 | } 148 | 149 | rr = httptest.NewRecorder() 150 | handler.ServeHTTP(rr, req) 151 | if rr.Code != http.StatusOK { 152 | t.Errorf("expected status OK, got %v", rr.Code) 153 | } 154 | 155 | // Third request should fail due to 1-second window limit 156 | rr = httptest.NewRecorder() 157 | handler.ServeHTTP(rr, req) 158 | if rr.Code != http.StatusTooManyRequests { 159 | t.Errorf("expected status TooManyRequests, got %v", rr.Code) 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /internal/middlewares/responsecapture.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "net/textproto" 9 | "strconv" 10 | "strings" 11 | 12 | "golang.org/x/net/http/httpguts" 13 | ) 14 | 15 | // ResponseRecorder is an implementation of [http.ResponseWriter]. 16 | type ResponseRecorder struct { 17 | // Code is the HTTP response code set by WriteHeader. 18 | // 19 | // Note that if a Handler never calls WriteHeader or Write, 20 | // this might end up being 0, rather than the implicit 21 | // http.StatusOK. To get the implicit value, use the Result 22 | // method. 23 | Code int 24 | 25 | // HeaderMap contains the headers explicitly set by the Handler. 26 | // It is an internal detail. 27 | // 28 | // Deprecated: HeaderMap exists for historical compatibility 29 | // and should not be used. To access the headers returned by a handler, 30 | // use the Response.Header map as returned by the Result method. 31 | HeaderMap http.Header 32 | 33 | // Body is the buffer to which the Handler's Write calls are sent. 34 | // If nil, the Writes are silently discarded. 35 | Body *bytes.Buffer 36 | 37 | // Flushed is whether the Handler called Flush. 38 | Flushed bool 39 | 40 | result *http.Response // cache of Result's return value 41 | snapHeader http.Header // snapshot of HeaderMap at first Write 42 | wroteHeader bool 43 | } 44 | 45 | // NewRecorder returns an initialized [ResponseRecorder]. 46 | func NewRecorder() *ResponseRecorder { 47 | return &ResponseRecorder{ 48 | HeaderMap: make(http.Header), 49 | Body: new(bytes.Buffer), 50 | Code: 200, 51 | } 52 | } 53 | 54 | // DefaultRemoteAddr is the default remote address to return in RemoteAddr if 55 | // an explicit DefaultRemoteAddr isn't set on [ResponseRecorder]. 56 | const DefaultRemoteAddr = "1.2.3.4" 57 | 58 | // Header implements [http.ResponseWriter]. It returns the response 59 | // headers to mutate within a handler. 60 | func (rw *ResponseRecorder) Header() http.Header { 61 | m := rw.HeaderMap 62 | if m == nil { 63 | m = make(http.Header) 64 | rw.HeaderMap = m 65 | } 66 | return m 67 | } 68 | 69 | // writeHeader writes a header if it was not written yet and 70 | // detects Content-Type if needed. 71 | // 72 | // bytes or str are the beginning of the response body. 73 | // We pass both to avoid unnecessarily generate garbage 74 | // in rw.WriteString which was created for performance reasons. 75 | // Non-nil bytes win. 76 | func (rw *ResponseRecorder) writeHeader(b []byte, str string) { 77 | if rw.wroteHeader { 78 | return 79 | } 80 | if len(str) > 512 { 81 | str = str[:512] 82 | } 83 | 84 | m := rw.Header() 85 | 86 | _, hasType := m["Content-Type"] 87 | hasTE := m.Get("Transfer-Encoding") != "" 88 | if !hasType && !hasTE { 89 | if b == nil { 90 | b = []byte(str) 91 | } 92 | m.Set("Content-Type", http.DetectContentType(b)) 93 | } 94 | 95 | rw.WriteHeader(200) 96 | } 97 | 98 | // Write implements http.ResponseWriter. The data in buf is written to 99 | // rw.Body, if not nil. 100 | func (rw *ResponseRecorder) Write(buf []byte) (int, error) { 101 | rw.writeHeader(buf, "") 102 | if rw.Body != nil { 103 | rw.Body.Write(buf) 104 | } 105 | return len(buf), nil 106 | } 107 | 108 | // WriteString implements [io.StringWriter]. The data in str is written 109 | // to rw.Body, if not nil. 110 | func (rw *ResponseRecorder) WriteString(str string) (int, error) { 111 | rw.writeHeader(nil, str) 112 | if rw.Body != nil { 113 | rw.Body.WriteString(str) 114 | } 115 | return len(str), nil 116 | } 117 | 118 | func checkWriteHeaderCode(code int) { 119 | // Issue 22880: require valid WriteHeader status codes. 120 | // For now we only enforce that it's three digits. 121 | // In the future we might block things over 599 (600 and above aren't defined 122 | // at https://httpwg.org/specs/rfc7231.html#status.codes) 123 | // and we might block under 200 (once we have more mature 1xx support). 124 | // But for now any three digits. 125 | // 126 | // We used to send "HTTP/1.1 000 0" on the wire in responses but there's 127 | // no equivalent bogus thing we can realistically send in HTTP/2, 128 | // so we'll consistently panic instead and help people find their bugs 129 | // early. (We can't return an error from WriteHeader even if we wanted to.) 130 | if code < 100 || code > 999 { 131 | panic(fmt.Sprintf("invalid WriteHeader code %v", code)) 132 | } 133 | } 134 | 135 | // WriteHeader implements [http.ResponseWriter]. 136 | func (rw *ResponseRecorder) WriteHeader(code int) { 137 | if rw.wroteHeader { 138 | return 139 | } 140 | 141 | checkWriteHeaderCode(code) 142 | rw.Code = code 143 | rw.wroteHeader = true 144 | if rw.HeaderMap == nil { 145 | rw.HeaderMap = make(http.Header) 146 | } 147 | rw.snapHeader = rw.HeaderMap.Clone() 148 | } 149 | 150 | // Result returns the response generated by the handler. 151 | // 152 | // The returned Response will have at least its StatusCode, 153 | // Header, Body, and optionally Trailer populated. 154 | // More fields may be populated in the future, so callers should 155 | // not DeepEqual the result in tests. 156 | // 157 | // The Response.Header is a snapshot of the headers at the time of the 158 | // first write call, or at the time of this call, if the handler never 159 | // did a write. 160 | // 161 | // The Response.Body is guaranteed to be non-nil and Body.Read call is 162 | // guaranteed to not return any error other than [io.EOF]. 163 | // 164 | // Result must only be called after the handler has finished running. 165 | func (rw *ResponseRecorder) Result() *http.Response { 166 | if rw.result != nil { 167 | return rw.result 168 | } 169 | if rw.snapHeader == nil { 170 | rw.snapHeader = rw.HeaderMap.Clone() 171 | } 172 | res := &http.Response{ 173 | Proto: "HTTP/1.1", 174 | ProtoMajor: 1, 175 | ProtoMinor: 1, 176 | StatusCode: rw.Code, 177 | Header: rw.snapHeader, 178 | } 179 | rw.result = res 180 | if res.StatusCode == 0 { 181 | res.StatusCode = 200 182 | } 183 | res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) 184 | if rw.Body != nil { 185 | res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) 186 | } else { 187 | res.Body = http.NoBody 188 | } 189 | res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) 190 | 191 | if trailers, ok := rw.snapHeader["Trailer"]; ok { 192 | res.Trailer = make(http.Header, len(trailers)) 193 | for _, k := range trailers { 194 | for _, k := range strings.Split(k, ",") { 195 | k = http.CanonicalHeaderKey(textproto.TrimString(k)) 196 | if !httpguts.ValidTrailerHeader(k) { 197 | // Ignore since forbidden by RFC 7230, section 4.1.2. 198 | continue 199 | } 200 | vv, ok := rw.HeaderMap[k] 201 | if !ok { 202 | continue 203 | } 204 | vv2 := make([]string, len(vv)) 205 | copy(vv2, vv) 206 | res.Trailer[k] = vv2 207 | } 208 | } 209 | } 210 | for k, vv := range rw.HeaderMap { 211 | if !strings.HasPrefix(k, http.TrailerPrefix) { 212 | continue 213 | } 214 | if res.Trailer == nil { 215 | res.Trailer = make(http.Header) 216 | } 217 | for _, v := range vv { 218 | res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) 219 | } 220 | } 221 | return res 222 | } 223 | 224 | func (rr *ResponseRecorder) WriteTo(rw http.ResponseWriter) { 225 | rr.WriteHeadersTo(rw) 226 | rw.WriteHeader(rr.Result().StatusCode) 227 | rw.Write(rr.Body.Bytes()) 228 | } 229 | 230 | func (rr *ResponseRecorder) WriteHeadersTo(rw http.ResponseWriter) { 231 | for header := range rr.Result().Header { 232 | rw.Header().Set(header, rr.Result().Header.Get(header)) 233 | } 234 | } 235 | 236 | // parseContentLength trims whitespace from s and returns -1 if no value 237 | // is set, or the value if it's >= 0. 238 | // 239 | // This a modified version of same function found in net/http/transfer.go. This 240 | // one just ignores an invalid header. 241 | func parseContentLength(cl string) int64 { 242 | cl = textproto.TrimString(cl) 243 | if cl == "" { 244 | return -1 245 | } 246 | n, err := strconv.ParseUint(cl, 10, 63) 247 | if err != nil { 248 | return -1 249 | } 250 | return int64(n) 251 | } 252 | -------------------------------------------------------------------------------- /internal/middlewares/security/routing_anomaly_score.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "math" 5 | "net/http" 6 | "net/url" 7 | "strconv" 8 | "sync" 9 | 10 | "github.com/hvuhsg/gatego/pkg/pathgraph" 11 | "github.com/hvuhsg/gatego/pkg/tracker" 12 | "go.opentelemetry.io/otel/attribute" 13 | "go.opentelemetry.io/otel/trace" 14 | ) 15 | 16 | const ( 17 | tracingCookieName = "sad-trc" 18 | cookieMaxAge = 24 * 60 * 60 // 24 hours in seconds 19 | refererHeaderName = "Referer" 20 | ) 21 | 22 | // RoutingAnomalyDetector handles path tracking logic and manages user sessions 23 | type RoutingAnomalyDetector struct { 24 | graph *pathgraph.PathGraph 25 | numberOfJumps int 26 | scoreSum float64 27 | avgDiviation float64 28 | lastPaths sync.Map // Maps trace_id to last path 29 | trackerRoutingHistory sync.Map 30 | tracker tracker.Tracker 31 | 32 | tresholdForRating int // The number of requests before starting to calculate anomaly score 33 | minScore int // If the diviation form the avg diviation is lower then this then the session is not suspicuse 34 | maxScore int // If the diviation form the avg diviation is larger then this then the session is fully suspicuse 35 | anomalyHeaderName string 36 | } 37 | 38 | func NewRoutingAnomalyDetector(headerName string, tresholdForRating, minScore, maxScore int) *RoutingAnomalyDetector { 39 | return &RoutingAnomalyDetector{ 40 | graph: pathgraph.NewPathGraph(), 41 | tracker: tracker.NewCookieTracker(tracingCookieName, cookieMaxAge, false), 42 | anomalyHeaderName: headerName, 43 | minScore: minScore, 44 | maxScore: maxScore, 45 | tresholdForRating: tresholdForRating, 46 | } 47 | } 48 | 49 | // NewPathTracker creates a new PathTracker instance 50 | func NewPathTracker(graph *pathgraph.PathGraph) *RoutingAnomalyDetector { 51 | return &RoutingAnomalyDetector{ 52 | graph: graph, 53 | lastPaths: sync.Map{}, 54 | } 55 | } 56 | 57 | // Claculate anomaly score based on global avg routing and tracker routing 58 | // This middleware uses a graph to represent every path called by users 59 | // Eeach source, destination path has a vertex with the score of how many requests jumpt it, 60 | // We save tracker (session) jumps history and calculate an anomaly score, and add it as header to the request. 61 | func (pt *RoutingAnomalyDetector) AddAnomalyScore(next http.Handler) http.Handler { 62 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 63 | span := trace.SpanFromContext(r.Context()) 64 | 65 | // Get or create trace ID 66 | traceID := pt.tracker.GetTrackerID(r) 67 | 68 | if traceID != "" { 69 | // We do not want the tracker to be sent to the downstream server 70 | pt.tracker.RemoveTracker(r) 71 | } else { // Create new tracker if not found 72 | var err error 73 | traceID, err = pt.tracker.SetTracker(w) 74 | if err != nil { 75 | // Log error but continue serving 76 | next.ServeHTTP(w, r) 77 | return 78 | } 79 | 80 | // Create tracker history 81 | trackerH := &trackerHistory{jumpsCount: 0, jumpsScoreSum: 0} 82 | pt.trackerRoutingHistory.Store(traceID, trackerH) 83 | } 84 | 85 | currentPath := r.URL.Path 86 | 87 | // Get last path for this trace ID 88 | lastPath, exists := pt.getLastPath(traceID, r) 89 | if !exists { 90 | lastPath = "" // empty path means the user has entered the site for the first time 91 | } 92 | 93 | jumpScore := pt.graph.AddJump(lastPath, currentPath) 94 | value, ok := pt.trackerRoutingHistory.Load(traceID) 95 | 96 | var trackerH *trackerHistory 97 | if ok { 98 | trackerH = value.(*trackerHistory) 99 | } 100 | 101 | // update tracker history with jump score 102 | trackerH.jumpsCount++ 103 | trackerH.jumpsScoreSum += jumpScore 104 | 105 | // update global stats 106 | pt.numberOfJumps++ 107 | pt.scoreSum += jumpScore 108 | 109 | pt.lastPaths.Store(traceID, currentPath) 110 | 111 | anomalyScore := pt.calcAnomalyRating(trackerH) 112 | span.SetAttributes(attribute.Float64("RoutingAnomalyScore", anomalyScore)) 113 | 114 | r.Header.Set(pt.anomalyHeaderName, strconv.FormatFloat(anomalyScore, 'f', 2, 64)) 115 | 116 | // Call the next handler 117 | next.ServeHTTP(w, r) 118 | }) 119 | } 120 | 121 | // GetLastPath retrieves the last path for a given trace ID from storage or referer header (in this order) 122 | func (pt *RoutingAnomalyDetector) getLastPath(traceID string, r *http.Request) (string, bool) { 123 | path, exists := pt.lastPaths.Load(traceID) 124 | 125 | if !exists { 126 | u := r.Header.Get(refererHeaderName) 127 | url, err := url.Parse(u) 128 | if err == nil { 129 | path = url.Path 130 | } 131 | } 132 | 133 | return path.(string), exists 134 | } 135 | 136 | // 0 - is fully normal, 1 - fully suspicuse 137 | func (pt *RoutingAnomalyDetector) calcAnomalyRating(trackerH *trackerHistory) float64 { 138 | avgGlobalScore := (pt.scoreSum / float64(pt.numberOfJumps)) * 2 139 | avgTrackerScore := trackerH.Avg() 140 | 141 | diviation := math.Abs(avgGlobalScore - avgTrackerScore) 142 | 143 | // If avg diviation is 0 it will return +Inf and get the correct result 144 | anomalyScore := (diviation / (pt.avgDiviation / 100)) 145 | 146 | // Update avgDiviation with new diviation 147 | pt.avgDiviation = ((pt.avgDiviation * float64(pt.numberOfJumps)) + diviation) / float64(pt.numberOfJumps) 148 | 149 | // Only return 0 until useage data is collected 150 | if pt.numberOfJumps < pt.tresholdForRating { 151 | return 0 152 | } 153 | 154 | if anomalyScore < float64(pt.minScore) { 155 | return 0 156 | } 157 | 158 | if anomalyScore > float64(pt.maxScore) { 159 | return 1 160 | } 161 | 162 | return (anomalyScore - float64(pt.minScore)) / 100 163 | } 164 | -------------------------------------------------------------------------------- /internal/middlewares/security/routing_anomaly_score_test.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | ) 7 | 8 | func TestCalcAnomalyRating(t *testing.T) { 9 | tests := []struct { 10 | name string 11 | detector *RoutingAnomalyDetector 12 | trackerHistory *trackerHistory 13 | want float64 14 | }{ 15 | { 16 | name: "Below threshold returns 0", 17 | detector: &RoutingAnomalyDetector{ 18 | numberOfJumps: 99, 19 | scoreSum: 100, 20 | avgDiviation: 10, 21 | tresholdForRating: 100, 22 | minScore: 100, 23 | maxScore: 200, 24 | anomalyHeaderName: "test", 25 | }, 26 | trackerHistory: &trackerHistory{jumpsScoreSum: 50, jumpsCount: 1}, 27 | want: 0, 28 | }, 29 | { 30 | name: "Score below minScore returns 0", 31 | detector: &RoutingAnomalyDetector{ 32 | numberOfJumps: 101, 33 | scoreSum: 500, 34 | avgDiviation: 100, 35 | tresholdForRating: 100, 36 | minScore: 100, 37 | maxScore: 200, 38 | anomalyHeaderName: "test", 39 | }, 40 | trackerHistory: &trackerHistory{jumpsScoreSum: 95, jumpsCount: 1}, 41 | want: 0, 42 | }, 43 | { 44 | name: "Score above maxScore returns 1", 45 | detector: &RoutingAnomalyDetector{ 46 | numberOfJumps: 101, 47 | scoreSum: 1000, 48 | avgDiviation: 1, 49 | tresholdForRating: 100, 50 | minScore: 100, 51 | maxScore: 200, 52 | anomalyHeaderName: "test", 53 | }, 54 | trackerHistory: &trackerHistory{jumpsScoreSum: 50, jumpsCount: 1}, 55 | want: 1, 56 | }, 57 | { 58 | name: "Normal score calculation", 59 | detector: &RoutingAnomalyDetector{ 60 | numberOfJumps: 100, 61 | scoreSum: 500, 62 | avgDiviation: 50, 63 | tresholdForRating: 100, 64 | minScore: 100, 65 | maxScore: 200, 66 | anomalyHeaderName: "test", 67 | }, 68 | trackerHistory: &trackerHistory{jumpsScoreSum: 90, jumpsCount: 1}, 69 | want: 0.6, // assuming avg score of 10 units 70 | }, 71 | { 72 | name: "Zero avgDiviation handling", 73 | detector: &RoutingAnomalyDetector{ 74 | numberOfJumps: 101, 75 | scoreSum: 100, 76 | avgDiviation: 0, 77 | tresholdForRating: 100, 78 | minScore: 100, 79 | maxScore: 200, 80 | anomalyHeaderName: "test", 81 | }, 82 | trackerHistory: &trackerHistory{jumpsScoreSum: 100, jumpsCount: 1}, 83 | want: 1, // should return max score due to division by zero protection 84 | }, 85 | } 86 | 87 | for _, tt := range tests { 88 | t.Run(tt.name, func(t *testing.T) { 89 | got := tt.detector.calcAnomalyRating(tt.trackerHistory) 90 | 91 | if math.Abs(got-tt.want) > 0.0001 { // Using small epsilon for float comparison 92 | t.Errorf("calcAnomalyRating() = %v, want %v", got, tt.want) 93 | } 94 | }) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /internal/middlewares/security/trackerhistory.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | type trackerHistory struct { 4 | jumpsCount int 5 | jumpsScoreSum float64 6 | } 7 | 8 | func (th trackerHistory) Avg() float64 { 9 | return th.jumpsScoreSum / float64(th.jumpsCount) 10 | } 11 | -------------------------------------------------------------------------------- /internal/middlewares/sizelimiter.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | ) 9 | 10 | // NewRequestSizeLimitMiddleware limits the size of the request body to the specified limit in bytes. 11 | func NewRequestSizeLimitMiddleware(maxSize uint64) func(http.Handler) http.Handler { 12 | return func(next http.Handler) http.Handler { 13 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 | // Create a buffer to read the body 15 | buf := new(bytes.Buffer) 16 | 17 | // Use io.LimitReader to limit the size of the request body 18 | limitedReader := io.LimitReader(r.Body, int64(maxSize+1)) // Allow one extra byte for overflow detection 19 | _, err := io.Copy(buf, limitedReader) // Copy the limited input into the buffer 20 | 21 | // Check for errors 22 | if err != nil { 23 | http.Error(w, "Error reading request body", http.StatusInternalServerError) 24 | return 25 | } 26 | 27 | // Check if we exceeded the maximum size 28 | if buf.Len() > int(maxSize) { 29 | http.Error(w, fmt.Sprintf("Request body too large. Maximum allowed size is %d bytes.", maxSize), http.StatusRequestEntityTooLarge) 30 | return 31 | } 32 | 33 | // Restore the request body for further processing 34 | r.Body = io.NopCloser(bytes.NewReader(buf.Bytes())) 35 | 36 | // Proceed to the next handler 37 | next.ServeHTTP(w, r) 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /internal/middlewares/sizelimiter_test.go: -------------------------------------------------------------------------------- 1 | package middlewares_test 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/hvuhsg/gatego/internal/middlewares" 10 | ) 11 | 12 | // TestRequestSizeLimitMiddleware tests the RequestSizeLimitMiddleware 13 | func TestRequestSizeLimitMiddleware(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | body []byte 17 | maxSize uint64 18 | expectedCode int 19 | expectedBody string 20 | }{ 21 | { 22 | name: "Within limit", 23 | body: []byte("This is within the limit."), 24 | maxSize: 30, 25 | expectedCode: http.StatusOK, 26 | }, 27 | { 28 | name: "Exactly at limit", 29 | body: bytes.Repeat([]byte("A"), 30), // 30 bytes 30 | maxSize: 30, 31 | expectedCode: http.StatusOK, 32 | }, 33 | { 34 | name: "Exceeds limit", 35 | body: bytes.Repeat([]byte("A"), 31), // 31 bytes 36 | maxSize: 30, 37 | expectedCode: http.StatusRequestEntityTooLarge, 38 | expectedBody: "Request body too large. Maximum allowed size is 30 bytes.\n", 39 | }, 40 | } 41 | 42 | for _, tt := range tests { 43 | t.Run(tt.name, func(t *testing.T) { 44 | // Create a request with the test body 45 | req := httptest.NewRequest("POST", "http://example.com", bytes.NewReader(tt.body)) 46 | rr := httptest.NewRecorder() 47 | 48 | // Use a simple handler that just returns OK 49 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 50 | w.WriteHeader(http.StatusOK) 51 | }) 52 | 53 | // Create the middleware with the specified maxSize 54 | middleware := middlewares.NewRequestSizeLimitMiddleware(tt.maxSize) 55 | 56 | // Serve the request through the middleware 57 | middleware(handler).ServeHTTP(rr, req) 58 | 59 | // Check the response code 60 | if rr.Code != tt.expectedCode { 61 | t.Errorf("expected status code %d, got %d", tt.expectedCode, rr.Code) 62 | } 63 | 64 | // Check the response body if applicable 65 | if tt.expectedBody != "" { 66 | if rr.Body.String() != tt.expectedBody { 67 | t.Errorf("expected response body %q, got %q", tt.expectedBody, rr.Body.String()) 68 | } 69 | } 70 | }) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /internal/middlewares/timeout.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "time" 7 | 8 | "go.opentelemetry.io/otel/trace" 9 | ) 10 | 11 | // NewTimeoutMiddleware returns an HTTP handler that wraps the provided handler with a timeout. 12 | // If the processing takes longer than the specified timeout, it returns a 503 Service Unavailable error. 13 | func NewTimeoutMiddleware(timeout time.Duration) func(next http.Handler) http.Handler { 14 | return func(next http.Handler) http.Handler { 15 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 16 | span := trace.SpanFromContext(r.Context()) 17 | 18 | // Create a context with the specified timeout 19 | ctx, cancel := context.WithTimeout(r.Context(), timeout) 20 | defer cancel() // Make sure to cancel the context when done 21 | 22 | // Create a new request with the timeout context 23 | r = r.WithContext(ctx) 24 | 25 | // Channel to capture when the request processing finishes 26 | done := make(chan struct{}) 27 | 28 | go func() { 29 | // Serve the request 30 | next.ServeHTTP(w, r) 31 | // Signal that the request processing is done 32 | close(done) 33 | }() 34 | 35 | select { 36 | case <-ctx.Done(): 37 | // If the context is canceled (due to timeout), return an error response 38 | if ctx.Err() == context.DeadlineExceeded { 39 | span.AddEvent("Request timed out") 40 | http.Error(w, "Request timed out", http.StatusGatewayTimeout) 41 | } 42 | case <-done: 43 | // If the request finished within the timeout, return the result 44 | return 45 | } 46 | }) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /internal/middlewares/timeout_test.go: -------------------------------------------------------------------------------- 1 | package middlewares_test 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | "time" 9 | 10 | "github.com/hvuhsg/gatego/internal/middlewares" 11 | ) 12 | 13 | func TestTimeoutMiddleware(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | timeout time.Duration 17 | handlerSleep time.Duration 18 | expectedStatus int 19 | }{ 20 | { 21 | name: "Request completes before timeout", 22 | timeout: 100 * time.Millisecond, 23 | handlerSleep: 50 * time.Millisecond, 24 | expectedStatus: http.StatusOK, 25 | }, 26 | { 27 | name: "Request times out", 28 | timeout: 50 * time.Millisecond, 29 | handlerSleep: 100 * time.Millisecond, 30 | expectedStatus: http.StatusGatewayTimeout, 31 | }, 32 | } 33 | 34 | for _, tt := range tests { 35 | t.Run(tt.name, func(t *testing.T) { 36 | // Create a test handler that sleeps for the specified duration 37 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 38 | time.Sleep(tt.handlerSleep) 39 | w.WriteHeader(http.StatusOK) 40 | }) 41 | 42 | // Wrap the handler with our middleware 43 | wrappedHandler := middlewares.NewTimeoutMiddleware(tt.timeout)(handler) 44 | 45 | // Create a test request 46 | req, err := http.NewRequest("GET", "/test", nil) 47 | if err != nil { 48 | t.Fatal(err) 49 | } 50 | 51 | // Create a ResponseRecorder to record the response 52 | rr := httptest.NewRecorder() 53 | 54 | // Serve the request using our wrapped handler 55 | wrappedHandler.ServeHTTP(rr, req) 56 | 57 | // Check the status code 58 | if status := rr.Code; status != tt.expectedStatus { 59 | t.Errorf("handler returned wrong status code: got %v want %v", 60 | status, tt.expectedStatus) 61 | } 62 | }) 63 | } 64 | } 65 | 66 | func TestTimeoutMiddlewareCancelContext(t *testing.T) { 67 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 68 | // Sleep until the context is canceled 69 | <-r.Context().Done() 70 | // Check if the context was canceled due to a timeout 71 | if r.Context().Err() == context.DeadlineExceeded { 72 | w.WriteHeader(http.StatusGatewayTimeout) 73 | } 74 | 75 | w.WriteHeader(http.StatusOK) 76 | }) 77 | 78 | wrappedHandler := middlewares.NewTimeoutMiddleware(50 * time.Millisecond)(handler) 79 | 80 | req, err := http.NewRequest("GET", "/test", nil) 81 | if err != nil { 82 | t.Fatal(err) 83 | } 84 | 85 | rr := httptest.NewRecorder() 86 | 87 | wrappedHandler.ServeHTTP(rr, req) 88 | 89 | if status := rr.Code; status != http.StatusGatewayTimeout { 90 | t.Errorf("handler returned wrong status code: got %v want %v", 91 | status, http.StatusGatewayTimeout) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /otel.go: -------------------------------------------------------------------------------- 1 | package gatego 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "time" 7 | 8 | "github.com/hvuhsg/gatego/internal/contextvalues" 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/attribute" 11 | "go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc" 12 | "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" 13 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace" 14 | "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" 15 | "go.opentelemetry.io/otel/log/global" 16 | "go.opentelemetry.io/otel/propagation" 17 | "go.opentelemetry.io/otel/sdk/log" 18 | "go.opentelemetry.io/otel/sdk/metric" 19 | "go.opentelemetry.io/otel/sdk/resource" 20 | "go.opentelemetry.io/otel/sdk/trace" 21 | semconv "go.opentelemetry.io/otel/semconv/v1.4.0" 22 | ) 23 | 24 | type otelConfig struct { 25 | TraceCollectorEndpoint string 26 | MetricCollectorEndpoint string 27 | LogsCollectorEndpoint string 28 | CollectorTimeout time.Duration 29 | ServiceName string 30 | SampleRatio float64 31 | } 32 | 33 | // setupOTelSDK bootstraps the OpenTelemetry pipeline. 34 | // If it does not return an error, make sure to call shutdown for proper cleanup. 35 | func setupOTelSDK(ctx context.Context, conf otelConfig) (func(context.Context) error, error) { 36 | var shutdownFuncs []func(context.Context) error 37 | 38 | // shutdown calls cleanup functions registered via shutdownFuncs. 39 | // The errors from the calls are joined. 40 | // Each registered cleanup will be invoked once. 41 | shutdown := func(ctx context.Context) error { 42 | var err error 43 | for _, fn := range shutdownFuncs { 44 | err = errors.Join(err, fn(ctx)) 45 | } 46 | shutdownFuncs = nil 47 | return err 48 | } 49 | 50 | // handleErr calls shutdown for cleanup and makes sure that all errors are returned. 51 | handleErr := func(inErr error) (func(context.Context) error, error) { 52 | err := errors.Join(inErr, shutdown(ctx)) 53 | return nil, err 54 | } 55 | 56 | // Set up propagator. 57 | prop := newPropagator() 58 | otel.SetTextMapPropagator(prop) 59 | 60 | // Set up resource 61 | resource := resource.NewWithAttributes( 62 | semconv.SchemaURL, 63 | semconv.ServiceNameKey.String(conf.ServiceName), 64 | semconv.TelemetrySDKLanguageGo, 65 | attribute.String("version", contextvalues.VersionFromContext(ctx)), 66 | ) 67 | 68 | // Set up trace provider. 69 | tracerProvider, err := newTraceProvider(ctx, resource, conf.TraceCollectorEndpoint, conf.CollectorTimeout, conf.SampleRatio) 70 | if err != nil { 71 | return handleErr(err) 72 | } 73 | shutdownFuncs = append(shutdownFuncs, tracerProvider.Shutdown) 74 | otel.SetTracerProvider(tracerProvider) 75 | 76 | // Set up meter provider. 77 | meterProvider, err := newMeterProvider(ctx, resource, conf.TraceCollectorEndpoint, conf.CollectorTimeout) 78 | if err != nil { 79 | return handleErr(err) 80 | } 81 | shutdownFuncs = append(shutdownFuncs, meterProvider.Shutdown) 82 | otel.SetMeterProvider(meterProvider) 83 | 84 | // Set up logger provider. 85 | loggerProvider, err := newLoggerProvider(ctx, resource, conf.TraceCollectorEndpoint, conf.CollectorTimeout) 86 | if err != nil { 87 | return handleErr(err) 88 | } 89 | shutdownFuncs = append(shutdownFuncs, loggerProvider.Shutdown) 90 | global.SetLoggerProvider(loggerProvider) 91 | 92 | return shutdown, err 93 | } 94 | 95 | func newPropagator() propagation.TextMapPropagator { 96 | return propagation.NewCompositeTextMapPropagator( 97 | propagation.TraceContext{}, 98 | propagation.Baggage{}, 99 | ) 100 | } 101 | 102 | func newTraceProvider(ctx context.Context, resource *resource.Resource, endpoint string, timeout time.Duration, sampleRatio float64) (*trace.TracerProvider, error) { 103 | exporter, err := otlptrace.New( 104 | ctx, 105 | otlptracegrpc.NewClient( 106 | otlptracegrpc.WithEndpoint(endpoint), // OTLP gRPC endpoint 107 | otlptracegrpc.WithTimeout(timeout), 108 | otlptracegrpc.WithInsecure(), 109 | ), 110 | ) 111 | if err != nil { 112 | return nil, err 113 | } 114 | 115 | traceProvider := trace.NewTracerProvider( 116 | trace.WithResource(resource), 117 | trace.WithBatcher(exporter), 118 | trace.WithSampler(trace.TraceIDRatioBased(sampleRatio)), 119 | ) 120 | return traceProvider, nil 121 | } 122 | 123 | func newMeterProvider(ctx context.Context, resource *resource.Resource, endpoint string, timeout time.Duration) (*metric.MeterProvider, error) { 124 | exporter, err := otlpmetricgrpc.New( 125 | ctx, 126 | otlpmetricgrpc.WithEndpoint(endpoint), // OTLP gRPC endpoint 127 | otlpmetricgrpc.WithTimeout(timeout), 128 | otlpmetricgrpc.WithInsecure(), 129 | ) 130 | if err != nil { 131 | return nil, err 132 | } 133 | 134 | meterProvider := metric.NewMeterProvider( 135 | metric.WithResource(resource), 136 | metric.WithReader( 137 | metric.NewPeriodicReader( 138 | exporter, 139 | ), 140 | ), 141 | ) 142 | 143 | return meterProvider, nil 144 | } 145 | 146 | func newLoggerProvider(ctx context.Context, resource *resource.Resource, endpoint string, timeout time.Duration) (*log.LoggerProvider, error) { 147 | exporter, err := otlploggrpc.New( 148 | ctx, 149 | otlploggrpc.WithEndpoint(endpoint), // OTLP gRPC endpoint 150 | otlploggrpc.WithTimeout(timeout), 151 | otlploggrpc.WithInsecure(), 152 | ) 153 | if err != nil { 154 | return nil, err 155 | } 156 | 157 | loggerProvider := log.NewLoggerProvider( 158 | log.WithResource(resource), 159 | log.WithProcessor(log.NewBatchProcessor(exporter)), 160 | ) 161 | return loggerProvider, nil 162 | } 163 | -------------------------------------------------------------------------------- /pkg/cron/README.md: -------------------------------------------------------------------------------- 1 | # Cron 2 | 3 | A Go package that implements a crontab-like service to execute and schedule repetitive tasks/jobs. 4 | 5 | ## Features 6 | 7 | - Supports cron expressions for flexible scheduling 8 | - Allows registering and managing multiple jobs 9 | - Provides macros for common schedule patterns 10 | - Supports custom timezones 11 | - Allows setting custom tick intervals 12 | - Supports starting and stopping the cron service 13 | 14 | ## Installation 15 | 16 | ```sh 17 | go get github.com/hvuhsg/gatego/pkg/cron 18 | ``` 19 | 20 | ## Usage 21 | 22 | ```go 23 | package main 24 | 25 | import ( 26 | "fmt" 27 | "time" 28 | 29 | "github.com/hvuhsg/gatego/pkg/cron" 30 | ) 31 | 32 | func main() { 33 | c := cron.New() 34 | 35 | // Register a job 36 | c.MustAdd("job1", "*/5 * * * *", func() { 37 | fmt.Println("Running job1...") 38 | }) 39 | 40 | // Set a custom timezone 41 | loc, _ := time.LoadLocation("Asia/Tokyo") 42 | c.SetTimezone(loc) 43 | 44 | // Set a custom tick interval 45 | c.SetInterval(5 * time.Second) 46 | 47 | // Start the cron service 48 | c.Start() 49 | 50 | // Stop the cron service after 30 seconds 51 | time.Sleep(30 * time.Second) 52 | c.Stop() 53 | } 54 | ``` 55 | 56 | ## Cron Expression Format 57 | 58 | The package supports the following cron expression format: 59 | 60 | ``` 61 | * * * * * 62 | │ │ │ │ │ 63 | │ │ │ │ └── Day of Week (0-6) 64 | │ │ │ └──── Month (1-12) 65 | │ │ └────── Day of Month (1-31) 66 | │ └──────── Hour (0-23) 67 | └────────── Minute (0-59) 68 | ``` 69 | 70 | It also supports the following macros: 71 | 72 | - `@yearly` or `@annually`: Run once a year at midnight on the first day of the year 73 | - `@monthly`: Run once a month at midnight on the first day of the month 74 | - `@weekly`: Run once a week at midnight on Sunday 75 | - `@daily` or `@midnight`: Run once a day at midnight 76 | - `@hourly`: Run once an hour at the beginning of the hour 77 | - `@minutely`: Run once a minute at the beginning of the minute 78 | -------------------------------------------------------------------------------- /pkg/cron/cron.go: -------------------------------------------------------------------------------- 1 | // Package cron implements a crontab-like service to execute and schedule 2 | // repeative tasks/jobs. 3 | // 4 | // Example: 5 | // 6 | // c := cron.New() 7 | // c.MustAdd("dailyReport", "0 0 * * *", func() { ... }) 8 | // c.Start() 9 | package cron 10 | 11 | import ( 12 | "errors" 13 | "fmt" 14 | "sync" 15 | "time" 16 | ) 17 | 18 | type job struct { 19 | schedule *Schedule 20 | run func() 21 | } 22 | 23 | // Cron is a crontab-like struct for tasks/jobs scheduling. 24 | type Cron struct { 25 | timezone *time.Location 26 | ticker *time.Ticker 27 | startTimer *time.Timer 28 | jobs map[string]*job 29 | interval time.Duration 30 | 31 | sync.RWMutex 32 | } 33 | 34 | // New create a new Cron struct with default tick interval of 1 minute 35 | // and timezone in UTC. 36 | // 37 | // You can change the default tick interval with Cron.SetInterval(). 38 | // You can change the default timezone with Cron.SetTimezone(). 39 | func New() *Cron { 40 | return &Cron{ 41 | interval: 1 * time.Minute, 42 | timezone: time.UTC, 43 | jobs: map[string]*job{}, 44 | } 45 | } 46 | 47 | // SetInterval changes the current cron tick interval 48 | // (it usually should be >= 1 minute). 49 | func (c *Cron) SetInterval(d time.Duration) { 50 | // update interval 51 | c.Lock() 52 | wasStarted := c.ticker != nil 53 | c.interval = d 54 | c.Unlock() 55 | 56 | // restart the ticker 57 | if wasStarted { 58 | c.Start() 59 | } 60 | } 61 | 62 | // SetTimezone changes the current cron tick timezone. 63 | func (c *Cron) SetTimezone(l *time.Location) { 64 | c.Lock() 65 | defer c.Unlock() 66 | 67 | c.timezone = l 68 | } 69 | 70 | // MustAdd is similar to Add() but panic on failure. 71 | func (c *Cron) MustAdd(jobId string, cronExpr string, run func()) { 72 | if err := c.Add(jobId, cronExpr, run); err != nil { 73 | panic(err) 74 | } 75 | } 76 | 77 | // Add registers a single cron job. 78 | // 79 | // If there is already a job with the provided id, then the old job 80 | // will be replaced with the new one. 81 | // 82 | // cronExpr is a regular cron expression, eg. "0 */3 * * *" (aka. at minute 0 past every 3rd hour). 83 | // Check cron.NewSchedule() for the supported tokens. 84 | func (c *Cron) Add(jobId string, cronExpr string, run func()) error { 85 | if run == nil { 86 | return errors.New("failed to add new cron job: run must be non-nil function") 87 | } 88 | 89 | c.Lock() 90 | defer c.Unlock() 91 | 92 | schedule, err := NewSchedule(cronExpr) 93 | if err != nil { 94 | return fmt.Errorf("failed to add new cron job: %w", err) 95 | } 96 | 97 | c.jobs[jobId] = &job{ 98 | schedule: schedule, 99 | run: run, 100 | } 101 | 102 | return nil 103 | } 104 | 105 | // Remove removes a single cron job by its id. 106 | func (c *Cron) Remove(jobId string) { 107 | c.Lock() 108 | defer c.Unlock() 109 | 110 | delete(c.jobs, jobId) 111 | } 112 | 113 | // RemoveAll removes all registered cron jobs. 114 | func (c *Cron) RemoveAll() { 115 | c.Lock() 116 | defer c.Unlock() 117 | 118 | c.jobs = map[string]*job{} 119 | } 120 | 121 | // Total returns the current total number of registered cron jobs. 122 | func (c *Cron) Total() int { 123 | c.RLock() 124 | defer c.RUnlock() 125 | 126 | return len(c.jobs) 127 | } 128 | 129 | // Stop stops the current cron ticker (if not already). 130 | // 131 | // You can resume the ticker by calling Start(). 132 | func (c *Cron) Stop() { 133 | c.Lock() 134 | defer c.Unlock() 135 | 136 | if c.startTimer != nil { 137 | c.startTimer.Stop() 138 | c.startTimer = nil 139 | } 140 | 141 | if c.ticker == nil { 142 | return // already stopped 143 | } 144 | 145 | c.ticker.Stop() 146 | c.ticker = nil 147 | } 148 | 149 | // Start starts the cron ticker. 150 | // 151 | // Calling Start() on already started cron will restart the ticker. 152 | func (c *Cron) Start() { 153 | c.Stop() 154 | 155 | // delay the ticker to start at 00 of 1 c.interval duration 156 | now := time.Now() 157 | next := now.Add(c.interval).Truncate(c.interval) 158 | delay := next.Sub(now) 159 | 160 | c.Lock() 161 | c.startTimer = time.AfterFunc(delay, func() { 162 | c.Lock() 163 | c.ticker = time.NewTicker(c.interval) 164 | c.Unlock() 165 | 166 | // run immediately at 00 167 | c.runDue(time.Now()) 168 | 169 | // run after each tick 170 | go func() { 171 | for t := range c.ticker.C { 172 | c.runDue(t) 173 | } 174 | }() 175 | }) 176 | c.Unlock() 177 | } 178 | 179 | // HasStarted checks whether the current Cron ticker has been started. 180 | func (c *Cron) HasStarted() bool { 181 | c.RLock() 182 | defer c.RUnlock() 183 | 184 | return c.ticker != nil 185 | } 186 | 187 | // runDue runs all registered jobs that are scheduled for the provided time. 188 | func (c *Cron) runDue(t time.Time) { 189 | c.RLock() 190 | defer c.RUnlock() 191 | 192 | moment := NewMoment(t.In(c.timezone)) 193 | 194 | for _, j := range c.jobs { 195 | if j.schedule.IsDue(moment) { 196 | go j.run() 197 | } 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /pkg/cron/cron_test.go: -------------------------------------------------------------------------------- 1 | package cron 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestCronNew(t *testing.T) { 10 | t.Parallel() 11 | 12 | c := New() 13 | 14 | expectedInterval := 1 * time.Minute 15 | if c.interval != expectedInterval { 16 | t.Fatalf("Expected default interval %v, got %v", expectedInterval, c.interval) 17 | } 18 | 19 | expectedTimezone := time.UTC 20 | if c.timezone.String() != expectedTimezone.String() { 21 | t.Fatalf("Expected default timezone %v, got %v", expectedTimezone, c.timezone) 22 | } 23 | 24 | if len(c.jobs) != 0 { 25 | t.Fatalf("Expected no jobs by default, got \n%v", c.jobs) 26 | } 27 | 28 | if c.ticker != nil { 29 | t.Fatal("Expected the ticker NOT to be initialized") 30 | } 31 | } 32 | 33 | func TestCronSetInterval(t *testing.T) { 34 | t.Parallel() 35 | 36 | c := New() 37 | 38 | interval := 2 * time.Minute 39 | 40 | c.SetInterval(interval) 41 | 42 | if c.interval != interval { 43 | t.Fatalf("Expected interval %v, got %v", interval, c.interval) 44 | } 45 | } 46 | 47 | func TestCronSetTimezone(t *testing.T) { 48 | t.Parallel() 49 | 50 | c := New() 51 | 52 | timezone, _ := time.LoadLocation("Asia/Tokyo") 53 | 54 | c.SetTimezone(timezone) 55 | 56 | if c.timezone.String() != timezone.String() { 57 | t.Fatalf("Expected timezone %v, got %v", timezone, c.timezone) 58 | } 59 | } 60 | 61 | func TestCronAddAndRemove(t *testing.T) { 62 | t.Parallel() 63 | 64 | c := New() 65 | 66 | if err := c.Add("test0", "* * * * *", nil); err == nil { 67 | t.Fatal("Expected nil function error") 68 | } 69 | 70 | if err := c.Add("test1", "invalid", func() {}); err == nil { 71 | t.Fatal("Expected invalid cron expression error") 72 | } 73 | 74 | if err := c.Add("test2", "* * * * *", func() {}); err != nil { 75 | t.Fatal(err) 76 | } 77 | 78 | if err := c.Add("test3", "* * * * *", func() {}); err != nil { 79 | t.Fatal(err) 80 | } 81 | 82 | if err := c.Add("test4", "* * * * *", func() {}); err != nil { 83 | t.Fatal(err) 84 | } 85 | 86 | // overwrite test2 87 | if err := c.Add("test2", "1 2 3 4 5", func() {}); err != nil { 88 | t.Fatal(err) 89 | } 90 | 91 | if err := c.Add("test5", "1 2 3 4 5", func() {}); err != nil { 92 | t.Fatal(err) 93 | } 94 | 95 | // mock job deletion 96 | c.Remove("test4") 97 | 98 | // try to remove non-existing (should be no-op) 99 | c.Remove("missing") 100 | 101 | // check job keys 102 | { 103 | expectedKeys := []string{"test3", "test2", "test5"} 104 | 105 | if v := len(c.jobs); v != len(expectedKeys) { 106 | t.Fatalf("Expected %d jobs, got %d", len(expectedKeys), v) 107 | } 108 | 109 | for _, k := range expectedKeys { 110 | if c.jobs[k] == nil { 111 | t.Fatalf("Expected job with key %s, got nil", k) 112 | } 113 | } 114 | } 115 | 116 | // check the jobs schedule 117 | { 118 | expectedSchedules := map[string]string{ 119 | "test2": `{"minutes":{"1":{}},"hours":{"2":{}},"days":{"3":{}},"months":{"4":{}},"daysOfWeek":{"5":{}}}`, 120 | "test3": `{"minutes":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"32":{},"33":{},"34":{},"35":{},"36":{},"37":{},"38":{},"39":{},"4":{},"40":{},"41":{},"42":{},"43":{},"44":{},"45":{},"46":{},"47":{},"48":{},"49":{},"5":{},"50":{},"51":{},"52":{},"53":{},"54":{},"55":{},"56":{},"57":{},"58":{},"59":{},"6":{},"7":{},"8":{},"9":{}},"hours":{"0":{},"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"days":{"1":{},"10":{},"11":{},"12":{},"13":{},"14":{},"15":{},"16":{},"17":{},"18":{},"19":{},"2":{},"20":{},"21":{},"22":{},"23":{},"24":{},"25":{},"26":{},"27":{},"28":{},"29":{},"3":{},"30":{},"31":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"months":{"1":{},"10":{},"11":{},"12":{},"2":{},"3":{},"4":{},"5":{},"6":{},"7":{},"8":{},"9":{}},"daysOfWeek":{"0":{},"1":{},"2":{},"3":{},"4":{},"5":{},"6":{}}}`, 121 | "test5": `{"minutes":{"1":{}},"hours":{"2":{}},"days":{"3":{}},"months":{"4":{}},"daysOfWeek":{"5":{}}}`, 122 | } 123 | for k, v := range expectedSchedules { 124 | raw, err := json.Marshal(c.jobs[k].schedule) 125 | if err != nil { 126 | t.Fatal(err) 127 | } 128 | 129 | if string(raw) != v { 130 | t.Fatalf("Expected %q schedule \n%s, \ngot \n%s", k, v, raw) 131 | } 132 | } 133 | } 134 | } 135 | 136 | func TestCronMustAdd(t *testing.T) { 137 | t.Parallel() 138 | 139 | c := New() 140 | 141 | defer func() { 142 | if r := recover(); r == nil { 143 | t.Errorf("test1 didn't panic") 144 | } 145 | }() 146 | 147 | c.MustAdd("test1", "* * * * *", nil) 148 | 149 | c.MustAdd("test2", "* * * * *", func() {}) 150 | 151 | if _, ok := c.jobs["test2"]; !ok { 152 | t.Fatal("Couldn't find job test2") 153 | } 154 | } 155 | 156 | func TestCronRemoveAll(t *testing.T) { 157 | t.Parallel() 158 | 159 | c := New() 160 | 161 | if err := c.Add("test1", "* * * * *", func() {}); err != nil { 162 | t.Fatal(err) 163 | } 164 | 165 | if err := c.Add("test2", "* * * * *", func() {}); err != nil { 166 | t.Fatal(err) 167 | } 168 | 169 | if err := c.Add("test3", "* * * * *", func() {}); err != nil { 170 | t.Fatal(err) 171 | } 172 | 173 | if v := len(c.jobs); v != 3 { 174 | t.Fatalf("Expected %d jobs, got %d", 3, v) 175 | } 176 | 177 | c.RemoveAll() 178 | 179 | if v := len(c.jobs); v != 0 { 180 | t.Fatalf("Expected %d jobs, got %d", 0, v) 181 | } 182 | } 183 | 184 | func TestCronTotal(t *testing.T) { 185 | t.Parallel() 186 | 187 | c := New() 188 | 189 | if v := c.Total(); v != 0 { 190 | t.Fatalf("Expected 0 jobs, got %v", v) 191 | } 192 | 193 | if err := c.Add("test1", "* * * * *", func() {}); err != nil { 194 | t.Fatal(err) 195 | } 196 | 197 | if err := c.Add("test2", "* * * * *", func() {}); err != nil { 198 | t.Fatal(err) 199 | } 200 | 201 | // overwrite 202 | if err := c.Add("test1", "* * * * *", func() {}); err != nil { 203 | t.Fatal(err) 204 | } 205 | 206 | if v := c.Total(); v != 2 { 207 | t.Fatalf("Expected 2 jobs, got %v", v) 208 | } 209 | } 210 | 211 | func TestCronStartStop(t *testing.T) { 212 | t.Parallel() 213 | 214 | test1 := 0 215 | test2 := 0 216 | 217 | c := New() 218 | 219 | c.SetInterval(500 * time.Millisecond) 220 | 221 | c.Add("test1", "* * * * *", func() { 222 | test1++ 223 | }) 224 | 225 | c.Add("test2", "* * * * *", func() { 226 | test2++ 227 | }) 228 | 229 | expectedCalls := 2 230 | 231 | // call twice Start to check if the previous ticker will be reseted 232 | c.Start() 233 | c.Start() 234 | 235 | time.Sleep(1 * time.Second) 236 | 237 | // call twice Stop to ensure that the second stop is no-op 238 | c.Stop() 239 | c.Stop() 240 | 241 | if test1 != expectedCalls { 242 | t.Fatalf("Expected %d test1, got %d", expectedCalls, test1) 243 | } 244 | if test2 != expectedCalls { 245 | t.Fatalf("Expected %d test2, got %d", expectedCalls, test2) 246 | } 247 | 248 | // resume for 2 seconds 249 | c.Start() 250 | 251 | time.Sleep(2 * time.Second) 252 | 253 | c.Stop() 254 | 255 | expectedCalls += 4 256 | 257 | if test1 != expectedCalls { 258 | t.Fatalf("Expected %d test1, got %d", expectedCalls, test1) 259 | } 260 | if test2 != expectedCalls { 261 | t.Fatalf("Expected %d test2, got %d", expectedCalls, test2) 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /pkg/cron/macros.go: -------------------------------------------------------------------------------- 1 | package cron 2 | 3 | const ( 4 | Yearly = "@yearly" 5 | Annually = "@annually" 6 | Monthly = "@monthly" 7 | Weekly = "@weekly" 8 | Daily = "@daily" 9 | Midnight = "@midnight" 10 | Hourly = "@hourly" 11 | Minutely = "@minutely" 12 | ) 13 | 14 | var macros = map[string]string{ 15 | Yearly: "0 0 1 1 *", 16 | Annually: "0 0 1 1 *", 17 | Monthly: "0 0 1 * *", 18 | Weekly: "0 0 * * 0", 19 | Daily: "0 0 * * *", 20 | Midnight: "0 0 * * *", 21 | Hourly: "0 * * * *", 22 | Minutely: "* * * * *", 23 | } 24 | -------------------------------------------------------------------------------- /pkg/cron/schedule.go: -------------------------------------------------------------------------------- 1 | package cron 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | "time" 9 | ) 10 | 11 | // Moment represents a parsed single time moment. 12 | type Moment struct { 13 | Minute int `json:"minute"` 14 | Hour int `json:"hour"` 15 | Day int `json:"day"` 16 | Month int `json:"month"` 17 | DayOfWeek int `json:"dayOfWeek"` 18 | } 19 | 20 | // NewMoment creates a new Moment from the specified time. 21 | func NewMoment(t time.Time) *Moment { 22 | return &Moment{ 23 | Minute: t.Minute(), 24 | Hour: t.Hour(), 25 | Day: t.Day(), 26 | Month: int(t.Month()), 27 | DayOfWeek: int(t.Weekday()), 28 | } 29 | } 30 | 31 | // Schedule stores parsed information for each time component when a cron job should run. 32 | type Schedule struct { 33 | Minutes map[int]struct{} `json:"minutes"` 34 | Hours map[int]struct{} `json:"hours"` 35 | Days map[int]struct{} `json:"days"` 36 | Months map[int]struct{} `json:"months"` 37 | DaysOfWeek map[int]struct{} `json:"daysOfWeek"` 38 | } 39 | 40 | // IsDue checks whether the provided Moment satisfies the current Schedule. 41 | func (s *Schedule) IsDue(m *Moment) bool { 42 | if _, ok := s.Minutes[m.Minute]; !ok { 43 | return false 44 | } 45 | 46 | if _, ok := s.Hours[m.Hour]; !ok { 47 | return false 48 | } 49 | 50 | if _, ok := s.Days[m.Day]; !ok { 51 | return false 52 | } 53 | 54 | if _, ok := s.DaysOfWeek[m.DayOfWeek]; !ok { 55 | return false 56 | } 57 | 58 | if _, ok := s.Months[m.Month]; !ok { 59 | return false 60 | } 61 | 62 | return true 63 | } 64 | 65 | // NewSchedule creates a new Schedule from a cron expression. 66 | // 67 | // A cron expression could be a macro OR 5 segments separated by space, 68 | // representing: minute, hour, day of the month, month and day of the week. 69 | // 70 | // The following segment formats are supported: 71 | // - wildcard: * 72 | // - range: 1-30 73 | // - step: */n or 1-30/n 74 | // - list: 1,2,3,10-20/n 75 | // 76 | // The following macros are supported: 77 | // - @yearly (or @annually) 78 | // - @monthly 79 | // - @weekly 80 | // - @daily (or @midnight) 81 | // - @hourly 82 | // - @minutely 83 | func NewSchedule(cronExpr string) (*Schedule, error) { 84 | if v, ok := macros[cronExpr]; ok { 85 | cronExpr = v 86 | } 87 | 88 | segments := strings.Split(cronExpr, " ") 89 | if len(segments) != 5 { 90 | return nil, errors.New("invalid cron expression - must be a valid macro or to have exactly 5 space separated segments") 91 | } 92 | 93 | minutes, err := parseCronSegment(segments[0], 0, 59) 94 | if err != nil { 95 | return nil, err 96 | } 97 | 98 | hours, err := parseCronSegment(segments[1], 0, 23) 99 | if err != nil { 100 | return nil, err 101 | } 102 | 103 | days, err := parseCronSegment(segments[2], 1, 31) 104 | if err != nil { 105 | return nil, err 106 | } 107 | 108 | months, err := parseCronSegment(segments[3], 1, 12) 109 | if err != nil { 110 | return nil, err 111 | } 112 | 113 | daysOfWeek, err := parseCronSegment(segments[4], 0, 6) 114 | if err != nil { 115 | return nil, err 116 | } 117 | 118 | return &Schedule{ 119 | Minutes: minutes, 120 | Hours: hours, 121 | Days: days, 122 | Months: months, 123 | DaysOfWeek: daysOfWeek, 124 | }, nil 125 | } 126 | 127 | // parseCronSegment parses a single cron expression segment and 128 | // returns its time schedule slots. 129 | func parseCronSegment(segment string, min int, max int) (map[int]struct{}, error) { 130 | slots := map[int]struct{}{} 131 | 132 | list := strings.Split(segment, ",") 133 | for _, p := range list { 134 | stepParts := strings.Split(p, "/") 135 | 136 | // step (*/n, 1-30/n) 137 | var step int 138 | switch len(stepParts) { 139 | case 1: 140 | step = 1 141 | case 2: 142 | parsedStep, err := strconv.Atoi(stepParts[1]) 143 | if err != nil { 144 | return nil, err 145 | } 146 | if parsedStep < 1 || parsedStep > max { 147 | return nil, fmt.Errorf("invalid segment step boundary - the step must be between 1 and the %d", max) 148 | } 149 | step = parsedStep 150 | default: 151 | return nil, errors.New("invalid segment step format - must be in the format */n or 1-30/n") 152 | } 153 | 154 | // find the min and max range of the segment part 155 | var rangeMin, rangeMax int 156 | if stepParts[0] == "*" { 157 | rangeMin = min 158 | rangeMax = max 159 | } else { 160 | // single digit (1) or range (1-30) 161 | rangeParts := strings.Split(stepParts[0], "-") 162 | switch len(rangeParts) { 163 | case 1: 164 | if step != 1 { 165 | return nil, errors.New("invalid segement step - step > 1 could be used only with the wildcard or range format") 166 | } 167 | parsed, err := strconv.Atoi(rangeParts[0]) 168 | if err != nil { 169 | return nil, err 170 | } 171 | if parsed < min || parsed > max { 172 | return nil, errors.New("invalid segment value - must be between the min and max of the segment") 173 | } 174 | rangeMin = parsed 175 | rangeMax = rangeMin 176 | case 2: 177 | parsedMin, err := strconv.Atoi(rangeParts[0]) 178 | if err != nil { 179 | return nil, err 180 | } 181 | if parsedMin < min || parsedMin > max { 182 | return nil, fmt.Errorf("invalid segment range minimum - must be between %d and %d", min, max) 183 | } 184 | rangeMin = parsedMin 185 | 186 | parsedMax, err := strconv.Atoi(rangeParts[1]) 187 | if err != nil { 188 | return nil, err 189 | } 190 | if parsedMax < parsedMin || parsedMax > max { 191 | return nil, fmt.Errorf("invalid segment range maximum - must be between %d and %d", rangeMin, max) 192 | } 193 | rangeMax = parsedMax 194 | default: 195 | return nil, errors.New("invalid segment range format - the range must have 1 or 2 parts") 196 | } 197 | } 198 | 199 | // fill the slots 200 | for i := rangeMin; i <= rangeMax; i += step { 201 | slots[i] = struct{}{} 202 | } 203 | } 204 | 205 | return slots, nil 206 | } 207 | -------------------------------------------------------------------------------- /pkg/monitor/monitor.go: -------------------------------------------------------------------------------- 1 | package monitor 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/http" 7 | "os/exec" 8 | "strings" 9 | "time" 10 | 11 | "github.com/google/uuid" 12 | "github.com/hvuhsg/gatego/pkg/cron" 13 | ) 14 | 15 | type Check struct { 16 | Name string 17 | Cron string 18 | URL string 19 | Method string 20 | Timeout time.Duration 21 | Headers map[string]string 22 | OnFailure string 23 | } 24 | 25 | func (c Check) run(onFailure func(error)) func() { 26 | return func() { 27 | // Create a client with timeout 28 | client := &http.Client{ 29 | Timeout: c.Timeout, 30 | } 31 | 32 | // Create new request 33 | req, err := http.NewRequest(c.Method, c.URL, nil) 34 | if err != nil { 35 | log.Default().Printf("Check <%s> error creating check request URL=%s Method=%s\n", c.Name, c.URL, c.Method) 36 | onFailure(err) 37 | return 38 | } 39 | 40 | // Add headers 41 | for key, value := range c.Headers { 42 | req.Header.Add(key, value) 43 | } 44 | 45 | // Send request 46 | resp, err := client.Do(req) 47 | if err != nil { 48 | log.Default().Printf("Check <%s> error sending request Error=%s\n", c.Name, err.Error()) 49 | onFailure(err) 50 | return 51 | } 52 | defer resp.Body.Close() 53 | 54 | // Check status code 55 | if resp.StatusCode != http.StatusOK { 56 | log.Default().Printf("Check <%s> failed. Expected status code 200 got %d\n", c.Name, resp.StatusCode) 57 | onFailure(fmt.Errorf("expected status code 200 got %d", resp.StatusCode)) 58 | return 59 | } 60 | } 61 | } 62 | 63 | func handleFailure(check Check, err error) error { 64 | // Expand command 65 | command := check.OnFailure 66 | date := time.Now().UTC().Format("2006-01-02 15:04:05") 67 | command = strings.ReplaceAll(command, "$date", date) 68 | command = strings.ReplaceAll(command, "$error", err.Error()) 69 | command = strings.ReplaceAll(command, "$check_name", check.Name) 70 | 71 | // Run it 72 | args := strings.Split(command, " ") 73 | cmd := exec.Command(args[0], args[1:]...) 74 | if err := cmd.Start(); err != nil { 75 | return err 76 | } 77 | return nil 78 | } 79 | 80 | type Monitor struct { 81 | Delay time.Duration 82 | Checks []Check 83 | scheduler *cron.Cron 84 | } 85 | 86 | func New(delay time.Duration, checks ...Check) *Monitor { 87 | return &Monitor{Delay: delay, Checks: checks, scheduler: cron.New()} 88 | } 89 | 90 | func (m Monitor) Start() error { 91 | m.scheduler = cron.New() 92 | 93 | for _, check := range m.Checks { 94 | err := m.scheduler.Add(uuid.NewString(), check.Cron, check.run(func(err error) { 95 | if check.OnFailure != "" { 96 | if err := handleFailure(check, err); err != nil { 97 | log.Default().Printf("Failed to spawn on_failure command: %s\n", err) 98 | } 99 | } 100 | })) 101 | if err != nil { 102 | return err 103 | } 104 | } 105 | 106 | go func() { 107 | time.Sleep(m.Delay) 108 | m.scheduler.Start() 109 | log.Default().Println("Started running automated checks.") 110 | }() 111 | 112 | return nil 113 | } 114 | -------------------------------------------------------------------------------- /pkg/monitor/monitor_test.go: -------------------------------------------------------------------------------- 1 | package monitor 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestCheck_run(t *testing.T) { 12 | tests := []struct { 13 | name string 14 | server *httptest.Server 15 | check Check 16 | expectedError bool 17 | serverResponse int 18 | }{ 19 | { 20 | name: "successful check", 21 | server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 22 | w.WriteHeader(http.StatusOK) 23 | })), 24 | check: Check{ 25 | Name: "test-check", 26 | Method: "GET", 27 | Timeout: 5 * time.Second, 28 | Headers: map[string]string{"X-Test": "test-value"}, 29 | }, 30 | expectedError: false, 31 | serverResponse: http.StatusOK, 32 | }, 33 | { 34 | name: "failed check - wrong status code", 35 | server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 36 | w.WriteHeader(http.StatusInternalServerError) 37 | })), 38 | check: Check{ 39 | Name: "test-check-fail", 40 | Method: "GET", 41 | Timeout: 5 * time.Second, 42 | }, 43 | expectedError: true, 44 | serverResponse: http.StatusInternalServerError, 45 | }, 46 | { 47 | name: "check with timeout", 48 | server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 49 | time.Sleep(2 * time.Second) 50 | w.WriteHeader(http.StatusOK) 51 | })), 52 | check: Check{ 53 | Name: "test-check-timeout", 54 | Method: "GET", 55 | Timeout: 1 * time.Second, 56 | }, 57 | expectedError: true, 58 | }, 59 | { 60 | name: "check with custom headers", 61 | server: httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 62 | if r.Header.Get("X-Custom") != "custom-value" { 63 | w.WriteHeader(http.StatusBadRequest) 64 | return 65 | } 66 | w.WriteHeader(http.StatusOK) 67 | })), 68 | check: Check{ 69 | Name: "test-check-headers", 70 | Method: "GET", 71 | Timeout: 5 * time.Second, 72 | Headers: map[string]string{"X-Custom": "custom-value"}, 73 | }, 74 | expectedError: false, 75 | serverResponse: http.StatusOK, 76 | }, 77 | } 78 | 79 | for _, tt := range tests { 80 | t.Run(tt.name, func(t *testing.T) { 81 | defer tt.server.Close() 82 | 83 | tt.check.URL = tt.server.URL 84 | tt.check.run(func(error) {}) 85 | }) 86 | } 87 | } 88 | 89 | func TestChecker_Start(t *testing.T) { 90 | tests := []struct { 91 | name string 92 | checker Monitor 93 | expectedError bool 94 | }{ 95 | { 96 | name: "successful start", 97 | checker: Monitor{ 98 | Delay: 1 * time.Second, 99 | Checks: []Check{ 100 | { 101 | Name: "test-check", 102 | Cron: "* * * * *", 103 | Method: "GET", 104 | URL: "http://example.com", 105 | Timeout: 5 * time.Second, 106 | }, 107 | }, 108 | }, 109 | expectedError: false, 110 | }, 111 | { 112 | name: "invalid cron expression", 113 | checker: Monitor{ 114 | Delay: 1 * time.Second, 115 | Checks: []Check{ 116 | { 117 | Name: "test-check-invalid-cron", 118 | Cron: "invalid", 119 | Method: "GET", 120 | URL: "http://example.com", 121 | Timeout: 5 * time.Second, 122 | }, 123 | }, 124 | }, 125 | expectedError: true, 126 | }, 127 | } 128 | 129 | for _, tt := range tests { 130 | t.Run(tt.name, func(t *testing.T) { 131 | err := tt.checker.Start() 132 | if (err != nil) != tt.expectedError { 133 | t.Errorf("Checker.Start() error = %v, expectedError %v", err, tt.expectedError) 134 | } 135 | 136 | // Clean up scheduler if it was created 137 | if tt.checker.scheduler != nil { 138 | tt.checker.scheduler.Stop() 139 | } 140 | }) 141 | } 142 | } 143 | 144 | func TestChecker_OnFailure(t *testing.T) { 145 | tests := []struct { 146 | name string 147 | checker Monitor 148 | expectedError bool 149 | }{ 150 | { 151 | name: "on failure command with valid command", 152 | checker: Monitor{ 153 | Delay: 1 * time.Second, 154 | Checks: []Check{ 155 | { 156 | Name: "test-check-failure", 157 | Cron: "* * * * *", 158 | Method: "GET", 159 | URL: "http://example.com", 160 | Timeout: 5 * time.Second, 161 | OnFailure: "echo check '$check_name' failed at $date: $error", 162 | }, 163 | }, 164 | }, 165 | expectedError: false, 166 | }, 167 | { 168 | name: "on failure command with invalid command", 169 | checker: Monitor{ 170 | Delay: 1 * time.Second, 171 | Checks: []Check{ 172 | { 173 | Name: "test-check-failure", 174 | Cron: "* * * * *", 175 | Method: "GET", 176 | URL: "http://example.com", 177 | Timeout: 5 * time.Second, 178 | OnFailure: "invalidCommand $error", 179 | }, 180 | }, 181 | }, 182 | expectedError: true, 183 | }, 184 | } 185 | 186 | for _, tt := range tests { 187 | t.Run(tt.name, func(t *testing.T) { 188 | // Simulate a failure scenario by injecting an error 189 | err := errors.New("Connection timeout") 190 | err = handleFailure(tt.checker.Checks[0], err) 191 | 192 | // Check if an error was returned and if it matches the expected result 193 | if (err != nil) != tt.expectedError { 194 | t.Errorf("handleFailure() error = %v, expectedError %v", err, tt.expectedError) 195 | } 196 | 197 | // Clean up scheduler if it was created 198 | if tt.checker.scheduler != nil { 199 | tt.checker.scheduler.Stop() 200 | } 201 | }) 202 | } 203 | } 204 | 205 | // TestCheckWithMockServer tests the Check struct with a mock HTTP server 206 | func TestCheckWithMockServer(t *testing.T) { 207 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 208 | // Verify method 209 | if r.Method != http.MethodGet { 210 | t.Errorf("Expected method %s, got %s", http.MethodGet, r.Method) 211 | } 212 | 213 | // Verify headers 214 | if r.Header.Get("X-Test") != "test-value" { 215 | t.Errorf("Expected header X-Test: test-value, got %s", r.Header.Get("X-Test")) 216 | } 217 | 218 | w.WriteHeader(http.StatusOK) 219 | }) 220 | 221 | server := httptest.NewServer(handler) 222 | defer server.Close() 223 | 224 | check := Check{ 225 | Name: "test-check", 226 | Method: http.MethodGet, 227 | URL: server.URL, 228 | Timeout: 5 * time.Second, 229 | Headers: map[string]string{"X-Test": "test-value"}, 230 | } 231 | 232 | check.run(func(error) {}) 233 | } 234 | -------------------------------------------------------------------------------- /pkg/multimux/multimux.go: -------------------------------------------------------------------------------- 1 | // This package implement a mutil-mux an http handler 2 | // that acts as seprate http.ServeMux for each registred host 3 | 4 | package multimux 5 | 6 | import ( 7 | "net/http" 8 | "strings" 9 | "sync" 10 | ) 11 | 12 | type MultiMux struct { 13 | Hosts sync.Map 14 | } 15 | 16 | func NewMultiMux() *MultiMux { 17 | return &MultiMux{Hosts: sync.Map{}} 18 | } 19 | 20 | func (mm *MultiMux) RegisterHandler(host string, pattern string, handler http.Handler) { 21 | cleanedHost := cleanHost(host) 22 | muxAny, _ := mm.Hosts.LoadOrStore(cleanedHost, http.NewServeMux()) 23 | mux := muxAny.(*http.ServeMux) 24 | 25 | cleanedPattern := strings.ToLower(pattern) 26 | 27 | mux.Handle(cleanedPattern, handler) 28 | } 29 | 30 | func (mm *MultiMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { 31 | host := r.Host 32 | cleanedHost := cleanHost(host) 33 | muxAny, exists := mm.Hosts.Load(cleanedHost) 34 | 35 | if !exists { 36 | w.WriteHeader(http.StatusNotFound) 37 | return 38 | } 39 | 40 | mux := muxAny.(*http.ServeMux) 41 | mux.ServeHTTP(w, r) 42 | } 43 | 44 | func cleanHost(domain string) string { 45 | return removePort(strings.ToLower(domain)) 46 | } 47 | 48 | func removePort(addr string) string { 49 | if i := strings.LastIndex(addr, ":"); i != -1 { 50 | return addr[:i] 51 | } 52 | return addr 53 | } 54 | -------------------------------------------------------------------------------- /pkg/multimux/multimux_test.go: -------------------------------------------------------------------------------- 1 | package multimux 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | ) 9 | 10 | func TestRegisterHandler(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | host string 14 | pattern string 15 | }{ 16 | {"basic registration", "example.com", "/path"}, 17 | {"with port", "example.com:8080", "/path"}, 18 | {"uppercase host", "EXAMPLE.COM", "/path"}, 19 | {"uppercase pattern", "/PATH", "/path"}, 20 | {"with subdomain", "sub.example.com", "/path"}, 21 | } 22 | 23 | for _, tt := range tests { 24 | t.Run(tt.name, func(t *testing.T) { 25 | mm := NewMultiMux() 26 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 27 | mm.RegisterHandler(tt.host, tt.pattern, handler) 28 | 29 | cleanedHost := cleanHost(tt.host) 30 | mux, exists := mm.Hosts.Load(cleanedHost) 31 | if !exists { 32 | t.Errorf("Host %s was not registered", cleanedHost) 33 | } 34 | if mux == nil { 35 | t.Errorf("ServeMux for host %s is nil", cleanedHost) 36 | } 37 | }) 38 | } 39 | } 40 | 41 | func TestServeHTTP(t *testing.T) { 42 | tests := []struct { 43 | name string 44 | host string 45 | path string 46 | expectedStatus int 47 | expectedBody string 48 | }{ 49 | { 50 | name: "existing host and path", 51 | host: "example.com", 52 | path: "/test", 53 | expectedStatus: http.StatusOK, 54 | expectedBody: "handler1", 55 | }, 56 | { 57 | name: "existing host with port", 58 | host: "example.com:8080", 59 | path: "/test", 60 | expectedStatus: http.StatusOK, 61 | expectedBody: "handler1", 62 | }, 63 | { 64 | name: "non-existing host", 65 | host: "unknown.com", 66 | path: "/test", 67 | expectedStatus: http.StatusNotFound, 68 | expectedBody: "", 69 | }, 70 | } 71 | 72 | for _, tt := range tests { 73 | t.Run(tt.name, func(t *testing.T) { 74 | mm := NewMultiMux() 75 | 76 | // Register a test handler 77 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 | fmt.Fprint(w, "handler1") 79 | }) 80 | mm.RegisterHandler("example.com", "/test", handler) 81 | 82 | // Create test request 83 | req := httptest.NewRequest("GET", "http://"+tt.host+tt.path, nil) 84 | req.Host = tt.host 85 | w := httptest.NewRecorder() 86 | 87 | // Serve the request 88 | mm.ServeHTTP(w, req) 89 | 90 | // Check status code 91 | if w.Code != tt.expectedStatus { 92 | t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code) 93 | } 94 | 95 | // Check response body if expected 96 | if tt.expectedBody != "" && w.Body.String() != tt.expectedBody { 97 | t.Errorf("expected body %q, got %q", tt.expectedBody, w.Body.String()) 98 | } 99 | }) 100 | } 101 | } 102 | 103 | func TestCleanHost(t *testing.T) { 104 | tests := []struct { 105 | input string 106 | expected string 107 | }{ 108 | {"example.com", "example.com"}, 109 | {"EXAMPLE.COM", "example.com"}, 110 | {"example.com:8080", "example.com"}, 111 | {"EXAMPLE.COM:8080", "example.com"}, 112 | {"sub.example.com:8080", "sub.example.com"}, 113 | {"localhost", "localhost"}, 114 | {"localhost:8080", "localhost"}, 115 | } 116 | 117 | for _, tt := range tests { 118 | t.Run(tt.input, func(t *testing.T) { 119 | result := cleanHost(tt.input) 120 | if result != tt.expected { 121 | t.Errorf("cleanHost(%q) = %q; want %q", tt.input, result, tt.expected) 122 | } 123 | }) 124 | } 125 | } 126 | 127 | func TestRemovePort(t *testing.T) { 128 | tests := []struct { 129 | input string 130 | expected string 131 | }{ 132 | {"example.com", "example.com"}, 133 | {"example.com:8080", "example.com"}, 134 | {"example.com:80", "example.com"}, 135 | {"localhost:8080", "localhost"}, 136 | {"127.0.0.1:8080", "127.0.0.1"}, 137 | {"[::1]:8080", "[::1]"}, 138 | } 139 | 140 | for _, tt := range tests { 141 | t.Run(tt.input, func(t *testing.T) { 142 | result := removePort(tt.input) 143 | if result != tt.expected { 144 | t.Errorf("removePort(%q) = %q; want %q", tt.input, result, tt.expected) 145 | } 146 | }) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /pkg/pathgraph/pathgraph.go: -------------------------------------------------------------------------------- 1 | package pathgraph 2 | 3 | import "strings" 4 | 5 | const incRate = 1 6 | const baseWeight = 0 7 | 8 | // PathVertex represents a vertex in the graph 9 | type PathVertex struct { 10 | Path string 11 | Weight float64 12 | } 13 | 14 | // PathGraph represents a weighted directed graph of navigation paths 15 | type PathGraph struct { 16 | // Map of source path to map of destination paths and their weights 17 | adjacencyList map[string]map[string]*PathVertex 18 | } 19 | 20 | // NewPathGraph creates a new instance of PathGraph 21 | func NewPathGraph() *PathGraph { 22 | return &PathGraph{ 23 | adjacencyList: make(map[string]map[string]*PathVertex), 24 | } 25 | } 26 | 27 | // AddJump adds or updates a path transition in the graph 28 | func (g *PathGraph) AddJump(sourcePath, destPath string) float64 { 29 | sourcePath = normalizePath(sourcePath) 30 | destPath = normalizePath(destPath) 31 | 32 | // Initialize source path if it doesn't exist 33 | if _, exists := g.adjacencyList[sourcePath]; !exists { 34 | g.adjacencyList[sourcePath] = make(map[string]*PathVertex) 35 | } 36 | 37 | // Get or create destination node 38 | vertex, exists := g.adjacencyList[sourcePath][destPath] 39 | if !exists { 40 | vertex = &PathVertex{ 41 | Path: destPath, 42 | Weight: baseWeight, 43 | } 44 | g.adjacencyList[sourcePath][destPath] = vertex 45 | } 46 | 47 | // Increment weight 48 | vertex.Weight += incRate 49 | 50 | return vertex.Weight - 1 // The original weight (before the jump) 51 | } 52 | 53 | // GetDestinations returns all destinations and their weights for a given source path 54 | func (g *PathGraph) GetDestinations(sourcePath string) map[string]float64 { 55 | sourcePath = normalizePath(sourcePath) 56 | 57 | result := make(map[string]float64) 58 | 59 | if vertexs, exists := g.adjacencyList[sourcePath]; exists { 60 | for path, vertex := range vertexs { 61 | result[path] = vertex.Weight 62 | } 63 | } 64 | 65 | return result 66 | } 67 | 68 | // GetAllPaths returns all unique paths in the graph 69 | func (g *PathGraph) GetAllPaths() []string { 70 | pathSet := make(map[string]bool) 71 | 72 | // Add all source paths 73 | for sourcePath := range g.adjacencyList { 74 | pathSet[sourcePath] = true 75 | 76 | // Add all destination paths 77 | for destPath := range g.adjacencyList[sourcePath] { 78 | pathSet[destPath] = true 79 | } 80 | } 81 | 82 | // Convert set to slice 83 | paths := make([]string, 0, len(pathSet)) 84 | for path := range pathSet { 85 | paths = append(paths, path) 86 | } 87 | 88 | return paths 89 | } 90 | 91 | func normalizePath(path string) string { 92 | if len(path) == 0 || path[0] != '/' { 93 | path = "/" + path 94 | } 95 | 96 | path = strings.ToLower(path) 97 | 98 | return path 99 | } 100 | -------------------------------------------------------------------------------- /pkg/tracker/tracker.go: -------------------------------------------------------------------------------- 1 | package tracker 2 | 3 | import ( 4 | "crypto/rand" 5 | "encoding/hex" 6 | "net/http" 7 | ) 8 | 9 | type Tracker interface { 10 | GetTrackerID(*http.Request) string 11 | SetTracker(http.ResponseWriter) (string, error) 12 | RemoveTracker(*http.Request) 13 | } 14 | 15 | type cookieTracker struct { 16 | cookieName string 17 | trackerMaxAge int 18 | secureCookie bool 19 | } 20 | 21 | func NewCookieTracker(cookieName string, maxAge int, isSecure bool) cookieTracker { 22 | return cookieTracker{cookieName: cookieName, trackerMaxAge: maxAge, secureCookie: isSecure} 23 | } 24 | 25 | // Get the tracker id from request or return empty string if not found 26 | func (ct cookieTracker) GetTrackerID(r *http.Request) string { 27 | cookie, err := r.Cookie(ct.cookieName) 28 | 29 | if err != nil { 30 | return "" 31 | } 32 | 33 | return cookie.Value 34 | } 35 | 36 | // Set tracer into response and return the tracker id 37 | func (ct cookieTracker) SetTracker(w http.ResponseWriter) (string, error) { 38 | traceID, err := generateTraceID() 39 | if err != nil { 40 | return "", err 41 | } 42 | 43 | http.SetCookie(w, &http.Cookie{ 44 | Name: ct.cookieName, 45 | Value: traceID, 46 | Path: "/", 47 | MaxAge: ct.trackerMaxAge, 48 | HttpOnly: true, 49 | Secure: ct.secureCookie, 50 | SameSite: http.SameSiteLaxMode, 51 | }) 52 | 53 | return traceID, nil 54 | } 55 | 56 | func (ct cookieTracker) RemoveTracker(r *http.Request) { 57 | // Get existing cookies 58 | oldCookies := r.Cookies() 59 | 60 | // Create new headers without the cookie we want to remove 61 | r.Header.Del("Cookie") 62 | 63 | // Add back all cookies except the one we want to remove 64 | for _, cookie := range oldCookies { 65 | if cookie.Name != ct.cookieName { 66 | r.AddCookie(cookie) 67 | } 68 | } 69 | } 70 | 71 | func generateTraceID() (string, error) { 72 | bytes := make([]byte, 16) 73 | if _, err := rand.Read(bytes); err != nil { 74 | return "", err 75 | } 76 | return hex.EncodeToString(bytes), nil 77 | } 78 | -------------------------------------------------------------------------------- /pkg/tracker/tracker_test.go: -------------------------------------------------------------------------------- 1 | package tracker 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestNewCookieTracker(t *testing.T) { 11 | tracker := NewCookieTracker("testTracker", 3600, true) 12 | 13 | if tracker.cookieName != "testTracker" { 14 | t.Errorf("Expected cookieName to be 'testTracker', got %s", tracker.cookieName) 15 | } 16 | if tracker.trackerMaxAge != 3600 { 17 | t.Errorf("Expected trackerMaxAge to be 3600, got %d", tracker.trackerMaxAge) 18 | } 19 | if !tracker.secureCookie { 20 | t.Errorf("Expected secureCookie to be true") 21 | } 22 | } 23 | 24 | func TestGenerateTraceID(t *testing.T) { 25 | traceID1, err1 := generateTraceID() 26 | if err1 != nil { 27 | t.Fatalf("Unexpected error generating trace ID: %v", err1) 28 | } 29 | 30 | traceID2, err2 := generateTraceID() 31 | if err2 != nil { 32 | t.Fatalf("Unexpected error generating trace ID: %v", err2) 33 | } 34 | 35 | if len(traceID1) != 32 { 36 | t.Errorf("Expected trace ID length to be 32, got %d", len(traceID1)) 37 | } 38 | 39 | if traceID1 == traceID2 { 40 | t.Error("Generated trace IDs should be unique") 41 | } 42 | } 43 | 44 | func TestSetTracker(t *testing.T) { 45 | tracker := NewCookieTracker("testTracker", 3600, true) 46 | 47 | // Create a test response writer 48 | w := httptest.NewRecorder() 49 | 50 | // Set tracker 51 | traceID, err := tracker.SetTracker(w) 52 | if err != nil { 53 | t.Fatalf("Unexpected error setting tracker: %v", err) 54 | } 55 | 56 | // Check response headers 57 | cookies := w.Result().Cookies() 58 | if len(cookies) != 1 { 59 | t.Fatalf("Expected 1 cookie, got %d", len(cookies)) 60 | } 61 | 62 | cookie := cookies[0] 63 | if cookie.Name != "testTracker" { 64 | t.Errorf("Expected cookie name 'testTracker', got %s", cookie.Name) 65 | } 66 | if cookie.Value != traceID { 67 | t.Errorf("Cookie value does not match returned trace ID") 68 | } 69 | if cookie.Path != "/" { 70 | t.Errorf("Expected cookie path '/', got %s", cookie.Path) 71 | } 72 | if cookie.MaxAge != 3600 { 73 | t.Errorf("Expected MaxAge 3600, got %d", cookie.MaxAge) 74 | } 75 | if !cookie.HttpOnly { 76 | t.Errorf("Expected HttpOnly to be true") 77 | } 78 | } 79 | 80 | func TestGetTrackerID(t *testing.T) { 81 | tracker := NewCookieTracker("testTracker", 3600, true) 82 | 83 | // Test request without cookie 84 | req1 := httptest.NewRequest(http.MethodGet, "/", nil) 85 | trackerID1 := tracker.GetTrackerID(req1) 86 | if trackerID1 != "" { 87 | t.Errorf("Expected empty string when no cookie exists, got %s", trackerID1) 88 | } 89 | 90 | // Test request with cookie 91 | req2 := httptest.NewRequest(http.MethodGet, "/", nil) 92 | req2.AddCookie(&http.Cookie{ 93 | Name: "testTracker", 94 | Value: "test-trace-id", 95 | }) 96 | 97 | trackerID2 := tracker.GetTrackerID(req2) 98 | if trackerID2 != "test-trace-id" { 99 | t.Errorf("Expected 'test-trace-id', got %s", trackerID2) 100 | } 101 | } 102 | 103 | func TestRemoveTracker(t *testing.T) { 104 | tracker := NewCookieTracker("testTracker", 3600, true) 105 | 106 | // Create a request with multiple cookies 107 | req := httptest.NewRequest(http.MethodGet, "/", nil) 108 | req.AddCookie(&http.Cookie{Name: "testTracker", Value: "remove-me"}) 109 | req.AddCookie(&http.Cookie{Name: "otherCookie", Value: "keep-me"}) 110 | 111 | // Remove the specific tracker cookie 112 | tracker.RemoveTracker(req) 113 | 114 | // Check that the cookie header has been modified 115 | cookieHeader := req.Header.Get("Cookie") 116 | if strings.Contains(cookieHeader, "testTracker") { 117 | t.Errorf("testTracker cookie should have been removed") 118 | } 119 | if !strings.Contains(cookieHeader, "otherCookie=keep-me") { 120 | t.Errorf("Other cookies should be preserved") 121 | } 122 | } 123 | 124 | // Benchmark trace ID generation 125 | func BenchmarkGenerateTraceID(b *testing.B) { 126 | for i := 0; i < b.N; i++ { 127 | generateTraceID() 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package gatego 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net" 8 | "net/http" 9 | "os" 10 | "time" 11 | 12 | "github.com/hvuhsg/gatego/internal/config" 13 | "github.com/hvuhsg/gatego/pkg/multimux" 14 | ) 15 | 16 | type gategoServer struct { 17 | *http.Server 18 | } 19 | 20 | func newServer(ctx context.Context, config config.Config, useOtel bool) (*gategoServer, error) { 21 | multimuxer, err := createMultiMuxer(ctx, config.Services, useOtel) 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | addr := fmt.Sprintf("%s:%d", config.Host, config.Port) 27 | 28 | // Start HTTP server. 29 | server := &http.Server{ 30 | Addr: addr, 31 | BaseContext: func(_ net.Listener) context.Context { return ctx }, 32 | ReadTimeout: time.Second, 33 | WriteTimeout: 10 * time.Second, 34 | Handler: multimuxer, 35 | } 36 | 37 | return &gategoServer{Server: server}, nil 38 | } 39 | 40 | func createMultiMuxer(ctx context.Context, services []config.Service, useOtel bool) (*multimux.MultiMux, error) { 41 | mm := multimux.NewMultiMux() 42 | 43 | for _, service := range services { 44 | for _, path := range service.Paths { 45 | handler, err := NewHandler(ctx, useOtel, service, path) 46 | if err != nil { 47 | return nil, err 48 | } 49 | 50 | mm.RegisterHandler(service.Domain, path.Path, handler) 51 | } 52 | } 53 | 54 | return mm, nil 55 | } 56 | 57 | func (gs *gategoServer) serve(certfile *string, keyfile *string) (chan error, error) { 58 | supportTLS, err := checkTLSConfig(certfile, keyfile) 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | serveErr := make(chan error, 1) 64 | 65 | go func() { 66 | if supportTLS { 67 | log.Default().Printf("Serving proxy with TLS %s\n", gs.Addr) 68 | serveErr <- gs.ListenAndServeTLS(*certfile, *keyfile) 69 | } else { 70 | log.Default().Printf("Serving proxy %s\n", gs.Addr) 71 | serveErr <- gs.ListenAndServe() 72 | } 73 | }() 74 | 75 | return serveErr, nil 76 | } 77 | 78 | func checkTLSConfig(certfile *string, keyfile *string) (bool, error) { 79 | if keyfile == nil || certfile == nil || *keyfile == "" || *certfile == "" { 80 | return false, nil 81 | } 82 | 83 | if !fileExists(*keyfile) { 84 | return false, fmt.Errorf("can't find keyfile at '%s'", *keyfile) 85 | } 86 | 87 | if !fileExists(*certfile) { 88 | return false, fmt.Errorf("can't find certfile at '%s'", *certfile) 89 | } 90 | 91 | return true, nil 92 | } 93 | 94 | func fileExists(filepath string) bool { 95 | _, err := os.Stat(filepath) 96 | 97 | if os.IsNotExist(err) { 98 | return false 99 | } 100 | 101 | // If we cant check the file info we probably can't open the file 102 | if err != nil { 103 | return false 104 | } 105 | 106 | return true 107 | } 108 | --------------------------------------------------------------------------------