├── .gitattributes ├── .github ├── codecov.yml ├── dependabot.yml └── workflows │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── authorization.go ├── authorization_test.go ├── cache.go ├── client.go ├── client_test.go ├── clientprovider.go ├── clientprovider_test.go ├── gate.go ├── gate_bench_test.go ├── gate_test.go ├── go.mod ├── go.sum ├── ratelimiter.go └── ratelimiter_test.go /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=lf -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | coverage: 3 | status: 4 | patch: off 5 | project: 6 | default: 7 | target: 75% 8 | threshold: null -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | labels: ["dependencies"] 6 | schedule: 7 | interval: "weekly" 8 | day: "saturday" 9 | - package-ecosystem: "gomod" 10 | directory: "/" 11 | labels: ["dependencies"] 12 | schedule: 13 | interval: "weekly" 14 | day: "saturday" 15 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | pull_request: 4 | paths-ignore: 5 | - '*.md' 6 | push: 7 | branches: 8 | - master 9 | paths-ignore: 10 | - '*.md' 11 | jobs: 12 | test: 13 | name: test 14 | runs-on: ubuntu-latest 15 | timeout-minutes: 3 16 | steps: 17 | - uses: actions/setup-go@v5 18 | with: 19 | go-version: 1.23.3 20 | - uses: actions/checkout@v4 21 | - name: Test (race) 22 | run: go test ./... -race 23 | - name: Test (coverage) 24 | run: go test ./... -coverprofile=coverage.txt -covermode=atomic 25 | - name: Codecov 26 | uses: codecov/codecov-action@v5.4.3 27 | with: 28 | files: ./coverage.txt 29 | token: ${{ secrets.CODECOV_TOKEN }} 30 | 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | /vendor -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021-2025 TwiN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # g8 2 | 3 | ![test](https://github.com/TwiN/g8/actions/workflows/test.yml/badge.svg?branch=master) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/TwiN/g8)](https://goreportcard.com/report/github.com/TwiN/g8/v3) 5 | [![codecov](https://codecov.io/gh/TwiN/g8/branch/master/graph/badge.svg)](https://codecov.io/gh/TwiN/g8) 6 | [![Go version](https://img.shields.io/github/go-mod/go-version/TwiN/g8.svg)](https://github.com/TwiN/g8) 7 | [![Go Reference](https://pkg.go.dev/badge/github.com/TwiN/g8.svg)](https://pkg.go.dev/github.com/TwiN/g8/v3) 8 | [![Follow TwiN](https://img.shields.io/github/followers/TwiN?label=Follow&style=social)](https://github.com/TwiN) 9 | 10 | g8, pronounced gate, is a simple Go library for protecting HTTP handlers. 11 | 12 | Tired of constantly re-implementing a security layer for each application? Me too, that's why I made g8. 13 | 14 | 15 | ## Installation 16 | ```console 17 | go get -u github.com/TwiN/g8/v3 18 | ``` 19 | 20 | 21 | ## Usage 22 | Because the entire purpose of g8 is to NOT waste time configuring the layer of security, the primary emphasis is to 23 | keep it as simple as possible. 24 | 25 | 26 | ### Simple 27 | Just want a simple layer of security without the need for advanced permissions? This configuration is what you're 28 | looking for. 29 | 30 | ```go 31 | authorizationService := g8.NewAuthorizationService().WithToken("mytoken") 32 | gate := g8.New().WithAuthorizationService(authorizationService) 33 | 34 | router := http.NewServeMux() 35 | router.Handle("/unprotected", yourHandler) 36 | router.Handle("/protected", gate.Protect(yourHandler)) 37 | 38 | http.ListenAndServe(":8080", router) 39 | ``` 40 | 41 | The endpoint `/protected` is now only accessible if you pass the header `Authorization: Bearer mytoken`. 42 | 43 | If you use `http.HandleFunc` instead of `http.Handle`, you may use `gate.ProtectFunc(yourHandler)` instead. 44 | 45 | If you're not using the `Authorization` header, you can specify a custom token extractor. 46 | This enables use cases like [Protecting a handler using session cookie](#protecting-a-handler-using-session-cookie) 47 | 48 | 49 | ### Advanced permissions 50 | If you have tokens with more permissions than others, g8's permission system will make managing authorization a breeze. 51 | 52 | Rather than registering tokens, think of it as registering clients, the only difference being that clients may be 53 | configured with permissions while tokens cannot. 54 | 55 | ```go 56 | authorizationService := g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken").WithPermission("admin")) 57 | gate := g8.New().WithAuthorizationService(authorizationService) 58 | 59 | router := http.NewServeMux() 60 | router.Handle("/unprotected", yourHandler) 61 | router.Handle("/protected-with-admin", gate.ProtectWithPermissions(yourHandler, []string{"admin"})) 62 | 63 | http.ListenAndServe(":8080", router) 64 | ``` 65 | 66 | The endpoint `/protected-with-admin` is now only accessible if you pass the header `Authorization: Bearer mytoken`, 67 | because the client with the token `mytoken` has the permission `admin`. Note that the following handler would also be 68 | accessible with that token: 69 | ```go 70 | router.Handle("/protected", gate.Protect(yourHandler)) 71 | ``` 72 | 73 | To clarify, both clients and tokens have access to handlers that aren't protected with extra permissions, and 74 | essentially, tokens are registered as clients with no extra permissions in the background. 75 | 76 | Creating a token like so: 77 | ```go 78 | authorizationService := g8.NewAuthorizationService().WithToken("mytoken") 79 | ``` 80 | is the equivalent of creating the following client: 81 | ```go 82 | authorizationService := g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken")) 83 | ``` 84 | 85 | 86 | ### With client provider 87 | A client provider's task is to retrieve a Client from an external source (e.g. a database) when provided with a token. 88 | You should use a client provider when you have a lot of tokens and it wouldn't make sense to register all of them using 89 | `AuthorizationService`'s `WithToken`/`WithTokens`/`WithClient`/`WithClients`. 90 | 91 | Note that the provider is used as a fallback source. As such, if a token is explicitly registered using one of the 4 92 | aforementioned functions, the client provider will not be used. 93 | 94 | ```go 95 | clientProvider := g8.NewClientProvider(func(token string) *g8.Client { 96 | // We'll assume that the following function calls your database and returns a struct "User" that 97 | // has the user's token as well as the permissions granted to said user 98 | user := database.GetUserByToken(token) 99 | if user != nil { 100 | return g8.NewClient(user.Token).WithPermissions(user.Permissions) 101 | } 102 | return nil 103 | }) 104 | authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider) 105 | gate := g8.New().WithAuthorizationService(authorizationService) 106 | ``` 107 | 108 | You can also configure the client provider to cache the output of the function you provide to retrieve clients by token: 109 | ```go 110 | clientProvider := g8.NewClientProvider(...).WithCache(ttl, maxSize) 111 | ``` 112 | 113 | Since g8 leverages [TwiN/gocache](https://github.com/TwiN/gocache) (unless you're using `WithCustomCache`), 114 | you can also use gocache's constants for configuring the TTL and the maximum size: 115 | - Setting the TTL to `gocache.NoExpiration` (-1) will disable the TTL. 116 | - Setting the maximum size to `gocache.NoMaxSize` (0) will disable the maximum cache size 117 | 118 | To avoid any misunderstandings, using a client provider is not mandatory. If you only have a few tokens and you can load 119 | them on application start, you can just leverage `AuthorizationService`'s `WithToken`/`WithTokens`/`WithClient`/`WithClients`. 120 | 121 | 122 | ## AuthorizationService 123 | As the previous examples may have hinted, there are several ways to create clients. The one thing they have 124 | in common is that they all go through AuthorizationService, which is in charge of both managing clients and determining 125 | whether a request should be blocked or allowed through. 126 | 127 | | Function | Description | 128 | |:-------------------|:---------------------------------------------------------------------------------------------------------------------------------| 129 | | WithToken | Creates a single static client with no extra permissions | 130 | | WithTokens | Creates a slice of static clients with no extra permissions | 131 | | WithClient | Creates a single static client | 132 | | WithClients | Creates a slice of static clients | 133 | | WithClientProvider | Creates a client provider which will allow a fallback to a dynamic source (e.g. to a database) when a static client is not found | 134 | 135 | Except for `WithClientProvider`, every functions listed above can be called more than once. 136 | As a result, you may safely perform actions like this: 137 | ```go 138 | authorizationService := g8.NewAuthorizationService(). 139 | WithToken("123"). 140 | WithToken("456"). 141 | WithClient(g8.NewClient("789").WithPermission("admin")) 142 | gate := g8.New().WithAuthorizationService(authorizationService) 143 | ``` 144 | 145 | Be aware that g8.Client supports a list of permissions as well. You may call `WithPermission` several times, or call 146 | `WithPermissions` with a slice of permissions instead. 147 | 148 | 149 | ### Permissions 150 | Unlike client permissions, handler permissions are requirements. 151 | 152 | A client may have as many permissions as you want, but for said client to have access to a handler protected by 153 | permissions, the client must have all permissions defined by said handler in order to have access to it. 154 | 155 | In other words, a client with the permissions `create`, `read`, `update` and `delete` would have access to all of these handlers: 156 | ```go 157 | gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("mytoken").WithPermissions([]string{"create", "read", "update", "delete"}))) 158 | router := http.NewServeMux() 159 | router.Handle("/", gate.Protect(homeHandler)) // equivalent of gate.ProtectWithPermissions(homeHandler, []string{}) 160 | router.Handle("/create", gate.ProtectWithPermissions(createHandler, []string{"create"})) 161 | router.Handle("/read", gate.ProtectWithPermissions(readHandler, []string{"read"})) 162 | router.Handle("/update", gate.ProtectWithPermissions(updateHandler, []string{"update"})) 163 | router.Handle("/delete", gate.ProtectWithPermissions(deleteHandler, []string{"delete"})) 164 | router.Handle("/crud", gate.ProtectWithPermissions(crudHandler, []string{"create", "read", "update", "delete"})) 165 | ``` 166 | But it would not have access to the following handler, because while `mytoken` has the `read` permission, it does not 167 | have the `backup` permission: 168 | ```go 169 | router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"})) 170 | ``` 171 | 172 | If you're using an HTTP library that supports middlewares like [mux](https://github.com/gorilla/mux), you can protect 173 | an entire group of handlers instead using `gate.Protect` or `gate.PermissionMiddleware()`: 174 | ```go 175 | router := mux.NewRouter() 176 | 177 | userRouter := router.PathPrefix("/").Subrouter() 178 | userRouter.Use(gate.Protect) 179 | userRouter.HandleFunc("/api/v1/users/me", getUserProfile).Methods("GET") 180 | userRouter.HandleFunc("/api/v1/users/me/friends", getUserFriends).Methods("GET") 181 | userRouter.HandleFunc("/api/v1/users/me/email", updateUserEmail).Methods("PATCH") 182 | 183 | adminRouter := router.PathPrefix("/").Subrouter() 184 | adminRouter.Use(gate.PermissionMiddleware("admin")) 185 | adminRouter.HandleFunc("/api/v1/users/{id}/ban", banUserByID).Methods("POST") 186 | adminRouter.HandleFunc("/api/v1/users/{id}/delete", deleteUserByID).Methods("DELETE") 187 | ``` 188 | 189 | 190 | ## Rate limiting 191 | To add a rate limit of 100 requests per second: 192 | ```go 193 | gate := g8.New().WithRateLimit(100) 194 | ``` 195 | 196 | 197 | ## Accessing the token from the protected handlers 198 | If you need to access the token from the handlers you are protecting with g8, you can retrieve it from the 199 | request context by using the key `g8.TokenContextKey`: 200 | ```go 201 | http.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) { 202 | token, _ := r.Context().Value(g8.TokenContextKey).(string) 203 | // ... 204 | })) 205 | ``` 206 | 207 | ## Examples 208 | ### Protecting a handler using session cookie 209 | If you want to only allow authenticated users to access a handler, you can use a custom token extractor function 210 | combined with a client provider. 211 | 212 | First, we'll create a function to extract the session ID from the session cookie. While a session ID does not 213 | theoretically refer to a token, g8 uses the term `token` as a blanket term to refer to any string that can be used to 214 | identify a client. 215 | ```go 216 | customTokenExtractorFunc := func(request *http.Request) string { 217 | sessionCookie, err := request.Cookie("session") 218 | if err != nil { 219 | return "" 220 | } 221 | return sessionCookie.Value 222 | } 223 | ``` 224 | 225 | Next, we need to create a client provider that will validate our token, which refers to the session ID in this case. 226 | ```go 227 | clientProvider := g8.NewClientProvider(func(token string) *g8.Client { 228 | // We'll assume that the following function calls your database and validates whether the session is valid. 229 | isSessionValid := database.CheckIfSessionIsValid(token) 230 | if !isSessionValid { 231 | return nil // Returning nil will cause the gate to return a 401 Unauthorized. 232 | } 233 | // You could also retrieve the user and their permissions if you wanted instead, but for this example, 234 | // all we care about is confirming whether the session is valid or not. 235 | return g8.NewClient(token) 236 | }) 237 | ``` 238 | 239 | Keep in mind that you can get really creative with the client provider above. 240 | For instance, you could refresh the session's expiration time, which will allow the user to stay logged in for 241 | as long as they're active. 242 | 243 | You're also not limited to using something stateful like the example above. You could use a JWT and have your client 244 | provider validate said JWT. 245 | 246 | Finally, we can create the authorization service and the gate: 247 | ```go 248 | authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider) 249 | gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) 250 | ``` 251 | 252 | If you need to access the token (session ID in this case) from the protected handlers, you can retrieve it from the 253 | request context by using the key `g8.TokenContextKey`: 254 | ```go 255 | http.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) { 256 | sessionID, _ := r.Context().Value(g8.TokenContextKey).(string) 257 | // ... 258 | })) 259 | ``` 260 | 261 | ### Using a custom header 262 | The logic is the same as the example above: 263 | ```go 264 | customTokenExtractorFunc := func(request *http.Request) string { 265 | return request.Header.Get("X-API-Token") 266 | } 267 | 268 | clientProvider := g8.NewClientProvider(func(token string) *g8.Client { 269 | // We'll assume that the following function calls your database and returns a struct "User" that 270 | // has the user's token as well as the permissions granted to said user 271 | user := database.GetUserByToken(token) 272 | if user != nil { 273 | return g8.NewClient(user.Token).WithPermissions(user.Permissions) 274 | } 275 | return nil 276 | }) 277 | authorizationService := g8.NewAuthorizationService().WithClientProvider(clientProvider) 278 | gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) 279 | ``` 280 | 281 | ### Using a custom cache 282 | 283 | ```go 284 | package main 285 | 286 | import ( 287 | g8 "github.com/TwiN/g8/v3" 288 | ) 289 | 290 | type customCache struct { 291 | entries map[string]any 292 | sync.Mutex 293 | } 294 | 295 | func (c *customCache) Get(key string) (value any, exists bool) { 296 | return nil, false 297 | } 298 | 299 | func (c *customCache) Set(key string, value any) { 300 | // ... 301 | } 302 | 303 | // To verify the implementation 304 | var _ g8.Cache = (*customCache)(nil) 305 | 306 | func main() { 307 | getClientByTokenFunc := func(token string) *g8.Client { 308 | // We'll assume that the following function calls your database and returns a struct "User" that 309 | // has the user's token as well as the permissions granted to said user 310 | user := database.GetUserByToken(token) 311 | if user != nil { 312 | return g8.NewClient(user.Token).WithPermissions(user.Permissions).WithData(user.Data) 313 | } 314 | return nil 315 | } 316 | // Create the provider with the custom cache 317 | provider := g8.NewClientProvider(getClientByTokenFunc).WithCustomCache(&customCache{}) 318 | } 319 | ``` 320 | -------------------------------------------------------------------------------- /authorization.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // AuthorizationService is the service that manages client/token registry and client fallback as well as the service 8 | // that determines whether a token meets the specific requirements to be authorized by a Gate or not. 9 | type AuthorizationService struct { 10 | clients map[string]*Client 11 | clientProvider *ClientProvider 12 | 13 | mutex sync.RWMutex 14 | } 15 | 16 | // NewAuthorizationService creates a new AuthorizationService 17 | func NewAuthorizationService() *AuthorizationService { 18 | return &AuthorizationService{ 19 | clients: make(map[string]*Client), 20 | } 21 | } 22 | 23 | // WithToken is used to specify a single token for which authorization will be granted 24 | // 25 | // The client that will be created from this token will have access to all handlers that are not protected with a 26 | // specific permission. 27 | // 28 | // In other words, if you were to do the following: 29 | // 30 | // gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("12345")) 31 | // 32 | // The following handler would be accessible with the token 12345: 33 | // 34 | // router.Handle("/1st-handler", gate.Protect(yourHandler)) 35 | // 36 | // But not this one would not be accessible with the token 12345: 37 | // 38 | // router.Handle("/2nd-handler", gate.ProtectWithPermissions(yourOtherHandler, []string{"admin"})) 39 | // 40 | // Calling this function multiple times will add multiple clients, though you may want to use WithTokens instead 41 | // if you plan to add multiple clients 42 | // 43 | // If you wish to configure advanced permissions, consider using WithClient instead. 44 | func (authorizationService *AuthorizationService) WithToken(token string) *AuthorizationService { 45 | authorizationService.mutex.Lock() 46 | authorizationService.clients[token] = NewClient(token) 47 | authorizationService.mutex.Unlock() 48 | return authorizationService 49 | } 50 | 51 | // WithTokens is used to specify a slice of tokens for which authorization will be granted 52 | func (authorizationService *AuthorizationService) WithTokens(tokens []string) *AuthorizationService { 53 | authorizationService.mutex.Lock() 54 | for _, token := range tokens { 55 | authorizationService.clients[token] = NewClient(token) 56 | } 57 | authorizationService.mutex.Unlock() 58 | return authorizationService 59 | } 60 | 61 | // WithClient is used to specify a single client for which authorization will be granted 62 | // 63 | // When compared to WithToken, the advantage of using this function is that you may specify the client's 64 | // permissions and thus, be a lot more granular with what endpoint a token has access to. 65 | // 66 | // In other words, if you were to do the following: 67 | // 68 | // gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("12345").WithPermission("mod"))) 69 | // 70 | // The following handlers would be accessible with the token 12345: 71 | // 72 | // router.Handle("/1st-handler", gate.ProtectWithPermissions(yourHandler, []string{"mod"})) 73 | // router.Handle("/2nd-handler", gate.Protect(yourOtherHandler)) 74 | // 75 | // But not this one, because the user does not have the permission "admin": 76 | // 77 | // router.Handle("/3rd-handler", gate.ProtectWithPermissions(yetAnotherHandler, []string{"admin"})) 78 | // 79 | // Calling this function multiple times will add multiple clients, though you may want to use WithClients instead 80 | // if you plan to add multiple clients 81 | func (authorizationService *AuthorizationService) WithClient(client *Client) *AuthorizationService { 82 | authorizationService.mutex.Lock() 83 | authorizationService.clients[client.Token] = client 84 | authorizationService.mutex.Unlock() 85 | return authorizationService 86 | } 87 | 88 | // WithClients is used to specify a slice of clients for which authorization will be granted 89 | func (authorizationService *AuthorizationService) WithClients(clients []*Client) *AuthorizationService { 90 | authorizationService.mutex.Lock() 91 | for _, client := range clients { 92 | authorizationService.clients[client.Token] = client 93 | } 94 | authorizationService.mutex.Unlock() 95 | return authorizationService 96 | } 97 | 98 | // WithClientProvider allows specifying a custom provider to fetch clients by token. 99 | // 100 | // For example, you can use it to fallback to making a call in your database when a request is made with a token that 101 | // hasn't been specified via WithToken, WithTokens, WithClient or WithClients. 102 | func (authorizationService *AuthorizationService) WithClientProvider(provider *ClientProvider) *AuthorizationService { 103 | authorizationService.clientProvider = provider 104 | return authorizationService 105 | } 106 | 107 | // Authorize checks whether a client with a given token exists and has the permissions required. 108 | // 109 | // If permissionsRequired is nil or empty and a client with the given token exists, said client will have access to all 110 | // handlers that are not protected by a given permission. 111 | // 112 | // Returns the client is authorized (or nil if no client was authorized), as well as whether the token is authorized 113 | func (authorizationService *AuthorizationService) Authorize(token string, permissionsRequired []string) (client *Client, authorized bool) { 114 | if len(token) == 0 { 115 | return nil, false 116 | } 117 | authorizationService.mutex.RLock() 118 | client, _ = authorizationService.clients[token] 119 | authorizationService.mutex.RUnlock() 120 | // If there's no clients with the given token directly stored in the AuthorizationService, fall back to the 121 | // client provider, if there's one configured. 122 | if client == nil && authorizationService.clientProvider != nil { 123 | client = authorizationService.clientProvider.GetClientByToken(token) 124 | } 125 | if client != nil && client.HasPermissions(permissionsRequired) { 126 | // If the client has the required permissions, return true and the client 127 | return client, true 128 | } 129 | return nil, false 130 | } 131 | -------------------------------------------------------------------------------- /authorization_test.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import "testing" 4 | 5 | func TestAuthorizationService_Authorize(t *testing.T) { 6 | authorizationService := NewAuthorizationService().WithToken("token") 7 | if _, authorized := authorizationService.Authorize("token", nil); !authorized { 8 | t.Error("should've returned true") 9 | } 10 | if _, authorized := authorizationService.Authorize("bad-token", nil); authorized { 11 | t.Error("should've returned false") 12 | } 13 | if _, authorized := authorizationService.Authorize("token", []string{"admin"}); authorized { 14 | t.Error("should've returned false") 15 | } 16 | if _, authorized := authorizationService.Authorize("", nil); authorized { 17 | t.Error("should've returned false") 18 | } 19 | } 20 | 21 | func TestAuthorizationService_AuthorizeWithPermissions(t *testing.T) { 22 | authorizationService := NewAuthorizationService().WithClient(NewClient("token").WithPermissions([]string{"a", "b"})) 23 | if _, authorized := authorizationService.Authorize("token", nil); !authorized { 24 | t.Error("should've returned true") 25 | } 26 | if _, authorized := authorizationService.Authorize("token", []string{"a"}); !authorized { 27 | t.Error("should've returned true") 28 | } 29 | if _, authorized := authorizationService.Authorize("token", []string{"b"}); !authorized { 30 | t.Error("should've returned true") 31 | } 32 | if _, authorized := authorizationService.Authorize("token", []string{"a", "b"}); !authorized { 33 | t.Error("should've returned true") 34 | } 35 | if _, authorized := authorizationService.Authorize("token", []string{"c"}); authorized { 36 | t.Error("should've returned false") 37 | } 38 | if _, authorized := authorizationService.Authorize("token", []string{"a", "c"}); authorized { 39 | t.Error("should've returned false") 40 | } 41 | if _, authorized := authorizationService.Authorize("bad-token", nil); authorized { 42 | t.Error("should've returned false") 43 | } 44 | if _, authorized := authorizationService.Authorize("bad-token", []string{"a"}); authorized { 45 | t.Error("should've returned false") 46 | } 47 | if _, authorized := authorizationService.Authorize("", []string{"a"}); authorized { 48 | t.Error("should've returned false") 49 | } 50 | } 51 | 52 | func TestAuthorizationService_WithToken(t *testing.T) { 53 | authorizationService := NewAuthorizationService().WithToken("token") 54 | if _, authorized := authorizationService.Authorize("token", nil); !authorized { 55 | t.Error("should've returned true") 56 | } 57 | if _, authorized := authorizationService.Authorize("bad-token", nil); authorized { 58 | t.Error("should've returned false") 59 | } 60 | if _, authorized := authorizationService.Authorize("token", []string{"admin"}); authorized { 61 | t.Error("should've returned false") 62 | } 63 | } 64 | 65 | func TestAuthorizationService_WithTokens(t *testing.T) { 66 | authorizationService := NewAuthorizationService().WithTokens([]string{"1", "2"}) 67 | if _, authorized := authorizationService.Authorize("1", nil); !authorized { 68 | t.Error("should've returned true") 69 | } 70 | if _, authorized := authorizationService.Authorize("2", nil); !authorized { 71 | t.Error("should've returned true") 72 | } 73 | if _, authorized := authorizationService.Authorize("3", nil); authorized { 74 | t.Error("should've returned false") 75 | } 76 | } 77 | 78 | func TestAuthorizationService_WithClient(t *testing.T) { 79 | authorizationService := NewAuthorizationService().WithClient(NewClient("token").WithPermissions([]string{"a", "b"})) 80 | if _, authorized := authorizationService.Authorize("token", []string{"a", "b"}); !authorized { 81 | t.Error("should've returned true") 82 | } 83 | if _, authorized := authorizationService.Authorize("token", []string{"a"}); !authorized { 84 | t.Error("should've returned true") 85 | } 86 | if _, authorized := authorizationService.Authorize("token", []string{"b"}); !authorized { 87 | t.Error("should've returned true") 88 | } 89 | if _, authorized := authorizationService.Authorize("token", []string{"c"}); authorized { 90 | t.Error("should've returned false") 91 | } 92 | } 93 | 94 | func TestAuthorizationService_WithClients(t *testing.T) { 95 | authorizationService := NewAuthorizationService().WithClients([]*Client{NewClient("1").WithPermission("a"), NewClient("2").WithPermission("b")}) 96 | if _, authorized := authorizationService.Authorize("1", []string{"a"}); !authorized { 97 | t.Error("should've returned true") 98 | } 99 | if _, authorized := authorizationService.Authorize("2", []string{"b"}); !authorized { 100 | t.Error("should've returned true") 101 | } 102 | if _, authorized := authorizationService.Authorize("1", []string{"b"}); authorized { 103 | t.Error("should've returned false") 104 | } 105 | if _, authorized := authorizationService.Authorize("2", []string{"a"}); authorized { 106 | t.Error("should've returned false") 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /cache.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import ( 4 | "github.com/TwiN/gocache/v2" 5 | ) 6 | 7 | type Cache interface { 8 | Get(key string) (value any, exists bool) 9 | Set(key string, value any) 10 | } 11 | 12 | // Make sure that gocache.Cache is compatible with the interface 13 | var _ Cache = (*gocache.Cache)(nil) 14 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | // Client is a struct containing both a Token and a slice of extra Permissions that said token has. 4 | type Client struct { 5 | // Token is the value used to authenticate with the API. 6 | Token string 7 | 8 | // Permissions is a slice of extra permissions that may be used for more granular access control. 9 | // 10 | // If you only wish to use Gate.Protect and Gate.ProtectFunc, you do not have to worry about this, 11 | // since they're only used by Gate.ProtectWithPermissions and Gate.ProtectFuncWithPermissions 12 | Permissions []string 13 | 14 | // Data is a field that can be used to store any data you want to associate with the client. 15 | Data any 16 | } 17 | 18 | // NewClient creates a Client with a given token 19 | func NewClient(token string) *Client { 20 | return &Client{ 21 | Token: token, 22 | } 23 | } 24 | 25 | // NewClientWithPermissions creates a Client with a slice of permissions 26 | // Equivalent to using NewClient and WithPermissions 27 | func NewClientWithPermissions(token string, permissions []string) *Client { 28 | return NewClient(token).WithPermissions(permissions) 29 | } 30 | 31 | // NewClientWithData creates a Client with some data 32 | // Equivalent to using NewClient and WithData 33 | func NewClientWithData(token string, data any) *Client { 34 | return NewClient(token).WithData(data) 35 | } 36 | 37 | // NewClientWithPermissionsAndData creates a Client with a slice of permissions and some data 38 | // Equivalent to using NewClient, WithPermissions and WithData 39 | func NewClientWithPermissionsAndData(token string, permissions []string, data any) *Client { 40 | return NewClient(token).WithPermissions(permissions).WithData(data) 41 | } 42 | 43 | // WithPermissions adds a slice of permissions to a client 44 | func (client *Client) WithPermissions(permissions []string) *Client { 45 | client.Permissions = append(client.Permissions, permissions...) 46 | return client 47 | } 48 | 49 | // WithPermission adds a permission to a client 50 | func (client *Client) WithPermission(permission string) *Client { 51 | client.Permissions = append(client.Permissions, permission) 52 | return client 53 | } 54 | 55 | // WithData attaches data to a client 56 | func (client *Client) WithData(data any) *Client { 57 | client.Data = data 58 | return client 59 | } 60 | 61 | // HasPermission checks whether a client has a given permission 62 | func (client *Client) HasPermission(permissionRequired string) bool { 63 | for _, permission := range client.Permissions { 64 | if permissionRequired == permission { 65 | return true 66 | } 67 | } 68 | return false 69 | } 70 | 71 | // HasPermissions checks whether a client has the all permissions passed 72 | func (client *Client) HasPermissions(permissionsRequired []string) bool { 73 | for _, permissionRequired := range permissionsRequired { 74 | if !client.HasPermission(permissionRequired) { 75 | return false 76 | } 77 | } 78 | return true 79 | } 80 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import "testing" 4 | 5 | func TestClient_HasPermission(t *testing.T) { 6 | client := NewClientWithPermissions("token", []string{"a", "b"}) 7 | if !client.HasPermission("a") { 8 | t.Errorf("client has permissions %s, therefore HasPermission(a) should've been true", client.Permissions) 9 | } 10 | if !client.HasPermission("b") { 11 | t.Errorf("client has permissions %s, therefore HasPermission(b) should've been true", client.Permissions) 12 | } 13 | if client.HasPermission("c") { 14 | t.Errorf("client has permissions %s, therefore HasPermission(c) should've been false", client.Permissions) 15 | } 16 | if client.HasPermission("ab") { 17 | t.Errorf("client has permissions %s, therefore HasPermission(ab) should've been false", client.Permissions) 18 | } 19 | } 20 | 21 | func TestClient_HasPermissions(t *testing.T) { 22 | client := NewClientWithPermissions("token", []string{"a", "b"}) 23 | if !client.HasPermissions(nil) { 24 | t.Errorf("client has permissions %s, therefore HasPermissions(nil) should've been true", client.Permissions) 25 | } 26 | if !client.HasPermissions([]string{"a"}) { 27 | t.Errorf("client has permissions %s, therefore HasPermissions([a]) should've been true", client.Permissions) 28 | } 29 | if !client.HasPermissions([]string{"b"}) { 30 | t.Errorf("client has permissions %s, therefore HasPermissions([b]) should've been true", client.Permissions) 31 | } 32 | if !client.HasPermissions([]string{"a", "b"}) { 33 | t.Errorf("client has permissions %s, therefore HasPermissions([a, b]) should've been true", client.Permissions) 34 | } 35 | if client.HasPermissions([]string{"a", "b", "c"}) { 36 | t.Errorf("client has permissions %s, therefore HasPermissions([a, b, c]) should've been false", client.Permissions) 37 | } 38 | } 39 | 40 | func TestClient_WithData(t *testing.T) { 41 | client := NewClient("token") 42 | if client.Data != nil { 43 | t.Error("expected client data to be nil") 44 | } 45 | client.WithData(5) 46 | if client.Data != 5 { 47 | t.Errorf("expected client data to be 5, got %d", client.Data) 48 | } 49 | client.WithData(map[string]string{"key": "value"}) 50 | if data, ok := client.Data.(map[string]string); !ok || data["key"] != "value" { 51 | t.Errorf("expected client data to be map[string]string{key: value}, got %v", client.Data) 52 | } 53 | } 54 | 55 | func TestNewClientWithData(t *testing.T) { 56 | client := NewClientWithData("token", 5) 57 | if client.Data != 5 { 58 | t.Errorf("expected client data to be 5, got %d", client.Data) 59 | } 60 | } 61 | 62 | func TestNewClientWithPermissionsAndData(t *testing.T) { 63 | client := NewClientWithPermissionsAndData("token", []string{"a", "b"}, 5) 64 | if client.Data != 5 { 65 | t.Errorf("expected client data to be 5, got %d", client.Data) 66 | } 67 | if !client.HasPermission("a") { 68 | t.Errorf("client has permissions %s, therefore HasPermission(a) should've been true", client.Permissions) 69 | } 70 | if !client.HasPermission("b") { 71 | t.Errorf("client has permissions %s, therefore HasPermission(b) should've been true", client.Permissions) 72 | } 73 | if client.HasPermission("c") { 74 | t.Errorf("client has permissions %s, therefore HasPermission(c) should've been false", client.Permissions) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /clientprovider.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/TwiN/gocache/v2" 7 | ) 8 | 9 | // ClientProvider has the task of retrieving a Client from an external source (e.g. a database) when provided with a 10 | // token. It should be used when you have a lot of tokens, and it wouldn't make sense to register all of them using 11 | // AuthorizationService's WithToken, WithTokens, WithClient or WithClients. 12 | // 13 | // Note that the provider is used as a fallback source. As such, if a token is explicitly registered using one of the 4 14 | // aforementioned functions, the client provider will not be used by the AuthorizationService when a request is made 15 | // with said token. It will, however, be called upon if a token that is not explicitly registered in 16 | // AuthorizationService is sent alongside a request going through the Gate. 17 | // 18 | // clientProvider := g8.NewClientProvider(func(token string) *g8.Client { 19 | // // We'll assume that the following function calls your database and returns a struct "User" that 20 | // // has the user's token as well as the permissions granted to said user 21 | // user := database.GetUserByToken(token) 22 | // if user != nil { 23 | // return g8.NewClient(user.Token).WithPermissions(user.Permissions) 24 | // } 25 | // return nil 26 | // }) 27 | // gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider)) 28 | type ClientProvider struct { 29 | getClientByTokenFunc func(token string) *Client 30 | 31 | cache Cache 32 | } 33 | 34 | // NewClientProvider creates a ClientProvider 35 | // The parameter that must be passed is a function that the provider will use to retrieve a client by a given token 36 | // 37 | // Example: 38 | // 39 | // clientProvider := g8.NewClientProvider(func(token string) *g8.Client { 40 | // // We'll assume that the following function calls your database and returns a struct "User" that 41 | // // has the user's token as well as the permissions granted to said user 42 | // user := database.GetUserByToken(token) 43 | // if user == nil { 44 | // return nil 45 | // } 46 | // return g8.NewClient(user.Token).WithPermissions(user.Permissions) 47 | // }) 48 | // gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider)) 49 | func NewClientProvider(getClientByTokenFunc func(token string) *Client) *ClientProvider { 50 | return &ClientProvider{ 51 | getClientByTokenFunc: getClientByTokenFunc, 52 | } 53 | } 54 | 55 | // WithCache enables an in-memory cache for the ClientProvider. 56 | // 57 | // Example: 58 | // 59 | // clientProvider := g8.NewClientProvider(func(token string) *g8.Client { 60 | // // We'll assume that the following function calls your database and returns a struct "User" that 61 | // // has the user's token as well as the permissions granted to said user 62 | // user := database.GetUserByToken(token) 63 | // if user != nil { 64 | // return g8.NewClient(user.Token).WithPermissions(user.Permissions) 65 | // } 66 | // return nil 67 | // }) 68 | // gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClientProvider(clientProvider.WithCache(time.Hour, 70000))) 69 | func (provider *ClientProvider) WithCache(ttl time.Duration, maxSize int) *ClientProvider { 70 | return provider.WithCustomCache( 71 | gocache.NewCache().WithEvictionPolicy(gocache.LeastRecentlyUsed).WithMaxSize(maxSize).WithDefaultTTL(ttl), 72 | ) 73 | } 74 | 75 | // WithCustomCache allows you to use a custom cache implementation instead of the default one. 76 | // By default, using WithCache will leverage gocache. 77 | // 78 | // Note that the custom cache must implement the Cache interface 79 | func (provider *ClientProvider) WithCustomCache(cache Cache) *ClientProvider { 80 | provider.cache = cache 81 | return provider 82 | } 83 | 84 | // GetClientByToken retrieves a client by its token through the provided getClientByTokenFunc. 85 | func (provider *ClientProvider) GetClientByToken(token string) *Client { 86 | if provider.cache == nil { 87 | return provider.getClientByTokenFunc(token) 88 | } 89 | if cachedClient, exists := provider.cache.Get(token); exists { 90 | if cachedClient == nil { 91 | return nil 92 | } 93 | // Safely typecast the client. 94 | // Regardless of whether the typecast is successful or not, we return client since it'll be either client or 95 | // nil. Technically, it should never be nil, but it's better to be safe than sorry. 96 | client, _ := cachedClient.(*Client) 97 | return client 98 | } 99 | client := provider.getClientByTokenFunc(token) 100 | provider.cache.Set(token, client) 101 | return client 102 | } 103 | -------------------------------------------------------------------------------- /clientprovider_test.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | "github.com/TwiN/gocache/v2" 9 | ) 10 | 11 | var ( 12 | getClientByTokenFunc = func(token string) *Client { 13 | if token == "valid-token" { 14 | return NewClient("valid-token").WithData("client-data") 15 | } 16 | return nil 17 | } 18 | ) 19 | 20 | func TestClientProvider_GetClientByToken(t *testing.T) { 21 | provider := NewClientProvider(getClientByTokenFunc) 22 | if client := provider.GetClientByToken("valid-token"); client == nil { 23 | t.Error("should've returned a client") 24 | } else if client.Data != "client-data" { 25 | t.Error("expected client data to be 'client-data', got", client.Data) 26 | } 27 | if client := provider.GetClientByToken("invalid-token"); client != nil { 28 | t.Error("should've returned nil") 29 | } 30 | } 31 | 32 | func TestClientProvider_WithCache(t *testing.T) { 33 | provider := NewClientProvider(getClientByTokenFunc).WithCache(gocache.NoExpiration, 10000) 34 | if provider.cache.(*gocache.Cache).Count() != 0 { 35 | t.Error("expected cache to be empty") 36 | } 37 | if client := provider.GetClientByToken("valid-token"); client == nil { 38 | t.Error("expected client, got nil") 39 | } 40 | if provider.cache.(*gocache.Cache).Count() != 1 { 41 | t.Error("expected cache size to be 1") 42 | } 43 | if client := provider.GetClientByToken("valid-token"); client == nil { 44 | t.Error("expected client, got nil") 45 | } 46 | if provider.cache.(*gocache.Cache).Count() != 1 { 47 | t.Error("expected cache size to be 1") 48 | } 49 | if client := provider.GetClientByToken("invalid-token"); client != nil { 50 | t.Error("expected nil, got", client) 51 | } 52 | if provider.cache.(*gocache.Cache).Count() != 2 { 53 | t.Error("expected cache size to be 2") 54 | } 55 | if client := provider.GetClientByToken("invalid-token"); client != nil { 56 | t.Error("expected nil, got", client) 57 | } 58 | if client := provider.GetClientByToken("invalid-token"); client != nil { 59 | t.Error("should've returned nil (cached)") 60 | } 61 | } 62 | 63 | func TestClientProvider_WithCacheAndExpiration(t *testing.T) { 64 | provider := NewClientProvider(getClientByTokenFunc).WithCache(10*time.Millisecond, 10) 65 | provider.GetClientByToken("token") 66 | if provider.cache.(*gocache.Cache).Count() != 1 { 67 | t.Error("expected cache size to be 1") 68 | } 69 | if provider.cache.(*gocache.Cache).Stats().ExpiredKeys != 0 { 70 | t.Error("expected cache statistics to report 0 expired key") 71 | } 72 | time.Sleep(15 * time.Millisecond) 73 | provider.GetClientByToken("token") 74 | if provider.cache.(*gocache.Cache).Stats().ExpiredKeys != 1 { 75 | t.Error("expected cache statistics to report 1 expired key") 76 | } 77 | } 78 | 79 | type customCache struct { 80 | entries map[string]any 81 | sync.Mutex 82 | } 83 | 84 | func (c *customCache) Get(key string) (value any, exists bool) { 85 | c.Lock() 86 | v, exists := c.entries[key] 87 | c.Unlock() 88 | return v, exists 89 | } 90 | 91 | func (c *customCache) Set(key string, value any) { 92 | c.Lock() 93 | if c.entries == nil { 94 | c.entries = make(map[string]any) 95 | } 96 | c.entries[key] = value 97 | c.Unlock() 98 | } 99 | 100 | var _ Cache = (*customCache)(nil) 101 | 102 | func TestClientProvider_WithCustomCache(t *testing.T) { 103 | provider := NewClientProvider(getClientByTokenFunc).WithCustomCache(&customCache{}) 104 | if len(provider.cache.(*customCache).entries) != 0 { 105 | t.Error("expected cache to be empty") 106 | } 107 | if client := provider.GetClientByToken("valid-token"); client == nil { 108 | t.Error("expected client, got nil") 109 | } 110 | if len(provider.cache.(*customCache).entries) != 1 { 111 | t.Error("expected cache size to be 1") 112 | } 113 | if client := provider.GetClientByToken("valid-token"); client == nil { 114 | t.Error("expected client, got nil") 115 | } 116 | if len(provider.cache.(*customCache).entries) != 1 { 117 | t.Error("expected cache size to be 1") 118 | } 119 | if client := provider.GetClientByToken("invalid-token"); client != nil { 120 | t.Error("expected nil, got", client) 121 | } 122 | if len(provider.cache.(*customCache).entries) != 2 { 123 | t.Error("expected cache size to be 2") 124 | } 125 | if client := provider.GetClientByToken("invalid-token"); client != nil { 126 | t.Error("expected nil, got", client) 127 | } 128 | if client := provider.GetClientByToken("invalid-token"); client != nil { 129 | t.Error("should've returned nil (cached)") 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /gate.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | const ( 10 | // AuthorizationHeader is the header in which g8 looks for the authorization bearer token 11 | AuthorizationHeader = "Authorization" 12 | 13 | // DefaultUnauthorizedResponseBody is the default response body returned if a request was sent with a missing or invalid token 14 | DefaultUnauthorizedResponseBody = "token is missing or invalid" 15 | 16 | // DefaultTooManyRequestsResponseBody is the default response body returned if a request exceeded the allowed rate limit 17 | DefaultTooManyRequestsResponseBody = "too many requests" 18 | 19 | // TokenContextKey is the key used to store the client's token in the context. 20 | TokenContextKey = "g8.token" 21 | 22 | // DataContextKey is the key used to store the client's data in the context. 23 | DataContextKey = "g8.data" 24 | ) 25 | 26 | // Gate is lock to the front door of your API, letting only those you allow through. 27 | type Gate struct { 28 | authorizationService *AuthorizationService 29 | unauthorizedResponseBody []byte 30 | 31 | customTokenExtractorFunc func(request *http.Request) string 32 | 33 | rateLimiter *RateLimiter 34 | tooManyRequestsResponseBody []byte 35 | } 36 | 37 | // Deprecated: use New instead. 38 | func NewGate(authorizationService *AuthorizationService) *Gate { 39 | return &Gate{ 40 | authorizationService: authorizationService, 41 | unauthorizedResponseBody: []byte(DefaultUnauthorizedResponseBody), 42 | tooManyRequestsResponseBody: []byte(DefaultTooManyRequestsResponseBody), 43 | } 44 | } 45 | 46 | // New creates a new Gate. 47 | func New() *Gate { 48 | return &Gate{ 49 | unauthorizedResponseBody: []byte(DefaultUnauthorizedResponseBody), 50 | tooManyRequestsResponseBody: []byte(DefaultTooManyRequestsResponseBody), 51 | } 52 | } 53 | 54 | // WithAuthorizationService sets the authorization service to use. 55 | // 56 | // If there is no authorization service, Gate will not enforce authorization. 57 | func (gate *Gate) WithAuthorizationService(authorizationService *AuthorizationService) *Gate { 58 | gate.authorizationService = authorizationService 59 | return gate 60 | } 61 | 62 | // WithCustomUnauthorizedResponseBody sets a custom response body when Gate determines that a request must be blocked 63 | func (gate *Gate) WithCustomUnauthorizedResponseBody(unauthorizedResponseBody []byte) *Gate { 64 | gate.unauthorizedResponseBody = unauthorizedResponseBody 65 | return gate 66 | } 67 | 68 | // WithCustomTokenExtractor allows the specification of a custom function to extract a token from a request. 69 | // If a custom token extractor is not specified, the token will be extracted from the Authorization header. 70 | // 71 | // For instance, if you're using a session cookie, you can extract the token from the cookie like so: 72 | // 73 | // authorizationService := g8.NewAuthorizationService() 74 | // customTokenExtractorFunc := func(request *http.Request) string { 75 | // sessionCookie, err := request.Cookie("session") 76 | // if err != nil { 77 | // return "" 78 | // } 79 | // return sessionCookie.Value 80 | // } 81 | // gate := g8.New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) 82 | // 83 | // You would normally use this with a client provider that matches whatever need you have. 84 | // For example, if you're using a session cookie, your client provider would retrieve the user from the session ID 85 | // extracted by this custom token extractor. 86 | // 87 | // Note that for the sake of convenience, the token extracted from the request is passed the protected handlers request 88 | // context under the key TokenContextKey. This is especially useful if the token is in fact a session ID. 89 | func (gate *Gate) WithCustomTokenExtractor(customTokenExtractorFunc func(request *http.Request) string) *Gate { 90 | gate.customTokenExtractorFunc = customTokenExtractorFunc 91 | return gate 92 | } 93 | 94 | // WithRateLimit adds rate limiting to the Gate 95 | // 96 | // If you just want to use a gate for rate limiting purposes: 97 | // 98 | // gate := g8.New().WithRateLimit(50) 99 | func (gate *Gate) WithRateLimit(maximumRequestsPerSecond int) *Gate { 100 | gate.rateLimiter = NewRateLimiter(maximumRequestsPerSecond) 101 | return gate 102 | } 103 | 104 | // Protect secures a handler, requiring requests going through to have a valid Authorization Bearer token. 105 | // Unlike ProtectWithPermissions, Protect will allow access to any registered tokens, regardless of their permissions 106 | // or lack thereof. 107 | // 108 | // Example: 109 | // 110 | // gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token")) 111 | // router := http.NewServeMux() 112 | // // Without protection 113 | // router.Handle("/handle", yourHandler) 114 | // // With protection 115 | // router.Handle("/handle", gate.Protect(yourHandler)) 116 | // 117 | // The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey 118 | func (gate *Gate) Protect(handler http.Handler) http.Handler { 119 | return gate.ProtectWithPermissions(handler, nil) 120 | } 121 | 122 | // ProtectWithPermissions secures a handler, requiring requests going through to have a valid Authorization Bearer token 123 | // as well as a slice of permissions that must be met. 124 | // 125 | // Example: 126 | // 127 | // gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("ADMIN"))) 128 | // router := http.NewServeMux() 129 | // // Without protection 130 | // router.Handle("/handle", yourHandler) 131 | // // With protection 132 | // router.Handle("/handle", gate.ProtectWithPermissions(yourHandler, []string{"admin"})) 133 | // 134 | // The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey 135 | func (gate *Gate) ProtectWithPermissions(handler http.Handler, permissions []string) http.Handler { 136 | return gate.ProtectFuncWithPermissions(func(writer http.ResponseWriter, request *http.Request) { 137 | handler.ServeHTTP(writer, request) 138 | }, permissions) 139 | } 140 | 141 | // ProtectWithPermission does the same thing as ProtectWithPermissions, but for a single permission instead of a 142 | // slice of permissions 143 | // 144 | // See ProtectWithPermissions for further documentation 145 | func (gate *Gate) ProtectWithPermission(handler http.Handler, permission string) http.Handler { 146 | return gate.ProtectFuncWithPermissions(func(writer http.ResponseWriter, request *http.Request) { 147 | handler.ServeHTTP(writer, request) 148 | }, []string{permission}) 149 | } 150 | 151 | // ProtectFunc secures a handlerFunc, requiring requests going through to have a valid Authorization Bearer token. 152 | // Unlike ProtectFuncWithPermissions, ProtectFunc will allow access to any registered tokens, regardless of their 153 | // permissions or lack thereof. 154 | // 155 | // Example: 156 | // 157 | // gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithToken("token")) 158 | // router := http.NewServeMux() 159 | // // Without protection 160 | // router.HandleFunc("/handle", yourHandlerFunc) 161 | // // With protection 162 | // router.HandleFunc("/handle", gate.ProtectFunc(yourHandlerFunc)) 163 | // 164 | // The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey 165 | func (gate *Gate) ProtectFunc(handlerFunc http.HandlerFunc) http.HandlerFunc { 166 | return gate.ProtectFuncWithPermissions(handlerFunc, nil) 167 | } 168 | 169 | // ProtectFuncWithPermissions secures a handler, requiring requests going through to have a valid Authorization Bearer 170 | // token as well as a slice of permissions that must be met. 171 | // 172 | // Example: 173 | // 174 | // gate := g8.New().WithAuthorizationService(g8.NewAuthorizationService().WithClient(g8.NewClient("token").WithPermission("admin"))) 175 | // router := http.NewServeMux() 176 | // // Without protection 177 | // router.HandleFunc("/handle", yourHandlerFunc) 178 | // // With protection 179 | // router.HandleFunc("/handle", gate.ProtectFuncWithPermissions(yourHandlerFunc, []string{"admin"})) 180 | // 181 | // The token extracted from the request is passed to the handlerFunc request context under the key TokenContextKey 182 | func (gate *Gate) ProtectFuncWithPermissions(handlerFunc http.HandlerFunc, permissions []string) http.HandlerFunc { 183 | return func(writer http.ResponseWriter, request *http.Request) { 184 | if gate.rateLimiter != nil { 185 | if !gate.rateLimiter.Try() { 186 | writer.WriteHeader(http.StatusTooManyRequests) 187 | _, _ = writer.Write(gate.tooManyRequestsResponseBody) 188 | return 189 | } 190 | } 191 | if gate.authorizationService != nil { 192 | token := gate.ExtractTokenFromRequest(request) 193 | if client, authorized := gate.authorizationService.Authorize(token, permissions); !authorized { 194 | writer.WriteHeader(http.StatusUnauthorized) 195 | _, _ = writer.Write(gate.unauthorizedResponseBody) 196 | return 197 | } else { 198 | request = request.WithContext(context.WithValue(request.Context(), TokenContextKey, token)) 199 | if client != nil && client.Data != nil { 200 | request = request.WithContext(context.WithValue(request.Context(), DataContextKey, client.Data)) 201 | } 202 | } 203 | } 204 | handlerFunc(writer, request) 205 | } 206 | } 207 | 208 | // ProtectFuncWithPermission does the same thing as ProtectFuncWithPermissions, but for a single permission instead of a 209 | // slice of permissions 210 | // 211 | // See ProtectFuncWithPermissions for further documentation 212 | func (gate *Gate) ProtectFuncWithPermission(handlerFunc http.HandlerFunc, permission string) http.HandlerFunc { 213 | return gate.ProtectFuncWithPermissions(handlerFunc, []string{permission}) 214 | } 215 | 216 | // ExtractTokenFromRequest extracts a token from a request. 217 | // 218 | // By default, it extracts the bearer token from the AuthorizationHeader, but if a customTokenExtractorFunc is defined, 219 | // it will use that instead. 220 | // 221 | // Note that this method is internally used by Protect, ProtectWithPermission, ProtectFunc and 222 | // ProtectFuncWithPermissions, but it is exposed in case you need to use it directly. 223 | func (gate *Gate) ExtractTokenFromRequest(request *http.Request) string { 224 | if gate.customTokenExtractorFunc != nil { 225 | // A custom token extractor function is defined, so we'll use it instead of the default token extraction logic 226 | return gate.customTokenExtractorFunc(request) 227 | } 228 | return strings.TrimPrefix(request.Header.Get(AuthorizationHeader), "Bearer ") 229 | } 230 | 231 | // PermissionMiddleware is a middleware that behaves like ProtectWithPermission, but it is meant to be used 232 | // as a middleware for libraries that support such a feature. 233 | // 234 | // For instance, if you are using github.com/gorilla/mux, you can use PermissionMiddleware like so: 235 | // 236 | // router := mux.NewRouter() 237 | // router.Use(gate.PermissionMiddleware("admin")) 238 | // router.Handle("/admin/handle", adminHandler) 239 | // 240 | // If you do not want to protect a router with a specific permission, you can use Gate.Protect instead. 241 | func (gate *Gate) PermissionMiddleware(permissions ...string) func(http.Handler) http.Handler { 242 | return func(next http.Handler) http.Handler { 243 | return gate.ProtectWithPermissions(next, permissions) 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /gate_bench_test.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | ) 9 | 10 | var handler http.Handler = &testHandler{} 11 | 12 | func BenchmarkTestHandler(b *testing.B) { 13 | request, _ := http.NewRequest("GET", "/handle", nil) 14 | 15 | router := http.NewServeMux() 16 | router.Handle("/handle", handler) 17 | 18 | for n := 0; n < b.N; n++ { 19 | responseRecorder := httptest.NewRecorder() 20 | router.ServeHTTP(responseRecorder, request) 21 | if responseRecorder.Code != http.StatusOK { 22 | b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 23 | } 24 | } 25 | b.ReportAllocs() 26 | } 27 | 28 | func BenchmarkGate_ProtectWhenNoAuthorizationHeader(b *testing.B) { 29 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) 30 | request, _ := http.NewRequest("GET", "/handle", nil) 31 | 32 | router := http.NewServeMux() 33 | router.Handle("/handle", gate.Protect(handler)) 34 | 35 | for n := 0; n < b.N; n++ { 36 | responseRecorder := httptest.NewRecorder() 37 | router.ServeHTTP(responseRecorder, request) 38 | if responseRecorder.Code != http.StatusUnauthorized { 39 | b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 40 | } 41 | } 42 | b.ReportAllocs() 43 | } 44 | 45 | func BenchmarkGate_ProtectWithInvalidToken(b *testing.B) { 46 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) 47 | request, _ := http.NewRequest("GET", "/handle", nil) 48 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) 49 | 50 | router := http.NewServeMux() 51 | router.Handle("/handle", gate.Protect(handler)) 52 | 53 | for n := 0; n < b.N; n++ { 54 | responseRecorder := httptest.NewRecorder() 55 | router.ServeHTTP(responseRecorder, request) 56 | if responseRecorder.Code != http.StatusUnauthorized { 57 | b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 58 | } 59 | } 60 | b.ReportAllocs() 61 | } 62 | 63 | func BenchmarkGate_ProtectWithValidToken(b *testing.B) { 64 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) 65 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 66 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token")) 67 | 68 | router := http.NewServeMux() 69 | router.Handle("/handle", gate.Protect(handler)) 70 | 71 | for n := 0; n < b.N; n++ { 72 | responseRecorder := httptest.NewRecorder() 73 | router.ServeHTTP(responseRecorder, request) 74 | if responseRecorder.Code != http.StatusOK { 75 | b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 76 | } 77 | } 78 | b.ReportAllocs() 79 | } 80 | 81 | func BenchmarkGate_ProtectWithPermissionsAndValidToken(b *testing.B) { 82 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) 83 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 84 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) 85 | 86 | router := http.NewServeMux() 87 | router.Handle("/handle", gate.ProtectWithPermissions(handler, []string{"admin"})) 88 | 89 | for n := 0; n < b.N; n++ { 90 | responseRecorder := httptest.NewRecorder() 91 | router.ServeHTTP(responseRecorder, request) 92 | if responseRecorder.Code != http.StatusOK { 93 | b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 94 | } 95 | } 96 | b.ReportAllocs() 97 | } 98 | 99 | func BenchmarkGate_ProtectWithPermissionsAndValidTokenButInsufficientPermissions(b *testing.B) { 100 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("mod"))) 101 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 102 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) 103 | 104 | router := http.NewServeMux() 105 | router.Handle("/handle", gate.ProtectWithPermissions(handler, []string{"admin"})) 106 | 107 | for n := 0; n < b.N; n++ { 108 | responseRecorder := httptest.NewRecorder() 109 | router.ServeHTTP(responseRecorder, request) 110 | if responseRecorder.Code != http.StatusUnauthorized { 111 | b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 112 | } 113 | } 114 | b.ReportAllocs() 115 | } 116 | 117 | func BenchmarkGate_ProtectConcurrently(b *testing.B) { 118 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) 119 | 120 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 121 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token")) 122 | 123 | badRequest, _ := http.NewRequest("GET", "/handle", http.NoBody) 124 | badRequest.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) 125 | 126 | router := http.NewServeMux() 127 | router.Handle("/handle", gate.Protect(handler)) 128 | 129 | b.RunParallel(func(pb *testing.PB) { 130 | for pb.Next() { 131 | responseRecorder := httptest.NewRecorder() 132 | router.ServeHTTP(responseRecorder, request) 133 | if responseRecorder.Code != http.StatusOK { 134 | b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 135 | } 136 | responseRecorder = httptest.NewRecorder() 137 | router.ServeHTTP(responseRecorder, badRequest) 138 | if responseRecorder.Code != http.StatusUnauthorized { 139 | b.Fatalf("%s %s should have returned %d, but returned %d instead", badRequest.Method, badRequest.URL, http.StatusUnauthorized, responseRecorder.Code) 140 | } 141 | } 142 | }) 143 | b.ReportAllocs() 144 | } 145 | 146 | func BenchmarkGate_ProtectWithClientProviderConcurrently(b *testing.B) { 147 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider)) 148 | 149 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 150 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken)) 151 | 152 | firstBadRequest, _ := http.NewRequest("GET", "/handle", http.NoBody) 153 | firstBadRequest.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "bad-token-1")) 154 | 155 | secondBadRequest, _ := http.NewRequest("GET", "/handle", http.NoBody) 156 | secondBadRequest.Header.Set("Authorization", fmt.Sprintf("Bearer %s", "bad-token-2")) 157 | 158 | router := http.NewServeMux() 159 | router.Handle("/handle", gate.Protect(handler)) 160 | 161 | b.RunParallel(func(pb *testing.PB) { 162 | for pb.Next() { 163 | responseRecorder := httptest.NewRecorder() 164 | router.ServeHTTP(responseRecorder, request) 165 | if responseRecorder.Code != http.StatusOK { 166 | b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 167 | } 168 | responseRecorder = httptest.NewRecorder() 169 | router.ServeHTTP(responseRecorder, firstBadRequest) 170 | if responseRecorder.Code != http.StatusUnauthorized { 171 | b.Fatalf("%s %s should have returned %d, but returned %d instead", firstBadRequest.Method, firstBadRequest.URL, http.StatusUnauthorized, responseRecorder.Code) 172 | } 173 | responseRecorder = httptest.NewRecorder() 174 | router.ServeHTTP(responseRecorder, secondBadRequest) 175 | if responseRecorder.Code != http.StatusUnauthorized { 176 | b.Fatalf("%s %s should have returned %d, but returned %d instead", secondBadRequest.Method, secondBadRequest.URL, http.StatusUnauthorized, responseRecorder.Code) 177 | } 178 | } 179 | }) 180 | b.ReportAllocs() 181 | } 182 | 183 | func BenchmarkGate_ProtectWithValidTokenAndCustomTokenExtractorFuncConcurrently(b *testing.B) { 184 | customTokenExtractorFunc := func(request *http.Request) string { 185 | sessionCookie, err := request.Cookie("session") 186 | if err != nil { 187 | return "" 188 | } 189 | return sessionCookie.Value 190 | } 191 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")).WithCustomTokenExtractor(customTokenExtractorFunc) 192 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 193 | request.AddCookie(&http.Cookie{Name: "session", Value: "good-token"}) 194 | 195 | router := http.NewServeMux() 196 | router.Handle("/handle", gate.Protect(handler)) 197 | 198 | b.RunParallel(func(pb *testing.PB) { 199 | for pb.Next() { 200 | responseRecorder := httptest.NewRecorder() 201 | router.ServeHTTP(responseRecorder, request) 202 | if responseRecorder.Code != http.StatusOK { 203 | b.Fatalf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 204 | } 205 | } 206 | }) 207 | b.ReportAllocs() 208 | } 209 | -------------------------------------------------------------------------------- /gate_test.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "strings" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | const ( 14 | FirstTestProviderClientPermission = "permission-1" 15 | SecondTestProviderClientPermission = "permission-2" 16 | TestProviderClientToken = "client-token-from-provider" 17 | TestProviderClientData = "client-data-from-provider" 18 | ) 19 | 20 | var ( 21 | mockClientProvider = NewClientProvider(func(token string) *Client { 22 | // We'll pretend that there's only one token that's valid in the client provider, every other token 23 | // returns nil 24 | if token == TestProviderClientToken { 25 | return &Client{ 26 | Token: TestProviderClientToken, 27 | Data: TestProviderClientData, 28 | Permissions: []string{FirstTestProviderClientPermission, SecondTestProviderClientPermission}, 29 | } 30 | } 31 | return nil 32 | }) 33 | ) 34 | 35 | type testHandler struct { 36 | } 37 | 38 | func (handler *testHandler) ServeHTTP(writer http.ResponseWriter, _ *http.Request) { 39 | writer.WriteHeader(http.StatusOK) 40 | } 41 | 42 | func testHandlerFunc(writer http.ResponseWriter, _ *http.Request) { 43 | writer.WriteHeader(http.StatusOK) 44 | } 45 | 46 | func TestUsability(t *testing.T) { 47 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) 48 | 49 | var handler http.Handler = &testHandler{} 50 | handlerFunc := func(writer http.ResponseWriter, request *http.Request) { 51 | writer.WriteHeader(http.StatusOK) 52 | } 53 | 54 | router := http.NewServeMux() 55 | router.Handle("/handle", handler) 56 | router.Handle("/handle-protected", gate.Protect(handler)) 57 | router.HandleFunc("/handlefunc", handlerFunc) 58 | router.HandleFunc("/handlefunc-protected", gate.ProtectFunc(handlerFunc)) 59 | } 60 | 61 | func TestNewGate(t *testing.T) { 62 | gate := NewGate(nil) 63 | if gate == nil { 64 | t.Error("gate should not be nil") 65 | } 66 | } 67 | 68 | func TestUnprotectedHandler(t *testing.T) { 69 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 70 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) 71 | responseRecorder := httptest.NewRecorder() 72 | 73 | router := http.NewServeMux() 74 | router.Handle("/handle", &testHandler{}) 75 | router.ServeHTTP(responseRecorder, request) 76 | 77 | if responseRecorder.Code != http.StatusOK { 78 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 79 | } 80 | } 81 | 82 | func TestGate_ProtectWithInvalidToken(t *testing.T) { 83 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) 84 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 85 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) 86 | responseRecorder := httptest.NewRecorder() 87 | 88 | router := http.NewServeMux() 89 | router.Handle("/handle", gate.Protect(&testHandler{})) 90 | router.ServeHTTP(responseRecorder, request) 91 | 92 | if responseRecorder.Code != http.StatusUnauthorized { 93 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 94 | } 95 | } 96 | 97 | func TestGate_ProtectWithValidToken(t *testing.T) { 98 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) 99 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 100 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token")) 101 | responseRecorder := httptest.NewRecorder() 102 | 103 | router := http.NewServeMux() 104 | router.Handle("/handle", gate.Protect(&testHandler{})) 105 | router.ServeHTTP(responseRecorder, request) 106 | 107 | if responseRecorder.Code != http.StatusOK { 108 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 109 | } 110 | } 111 | 112 | func TestGate_ProtectMultipleTimes(t *testing.T) { 113 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) 114 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 115 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token")) 116 | badRequest, _ := http.NewRequest("GET", "/handle", http.NoBody) 117 | badRequest.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) 118 | 119 | router := http.NewServeMux() 120 | router.Handle("/handle", gate.Protect(&testHandler{})) 121 | 122 | for i := 0; i < 100; i++ { 123 | responseRecorder := httptest.NewRecorder() 124 | router.ServeHTTP(responseRecorder, request) 125 | if responseRecorder.Code != http.StatusOK { 126 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 127 | } 128 | responseRecorder = httptest.NewRecorder() 129 | router.ServeHTTP(responseRecorder, badRequest) 130 | if responseRecorder.Code != http.StatusUnauthorized { 131 | t.Errorf("%s %s should have returned %d, but returned %d instead", badRequest.Method, badRequest.URL, http.StatusOK, responseRecorder.Code) 132 | } 133 | } 134 | } 135 | 136 | func TestGate_ProtectWithValidTokenExposedThroughClientProvider(t *testing.T) { 137 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider)) 138 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 139 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken)) 140 | responseRecorder := httptest.NewRecorder() 141 | 142 | router := http.NewServeMux() 143 | router.Handle("/handle", gate.Protect(&testHandler{})) 144 | router.ServeHTTP(responseRecorder, request) 145 | 146 | if responseRecorder.Code != http.StatusOK { 147 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 148 | } 149 | } 150 | 151 | func TestGate_ProtectWithValidTokenExposedThroughClientProviderWithCache(t *testing.T) { 152 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider.WithCache(60*time.Minute, 70000))) 153 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 154 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken)) 155 | responseRecorder := httptest.NewRecorder() 156 | 157 | router := http.NewServeMux() 158 | router.Handle("/handle", gate.Protect(&testHandler{})) 159 | router.ServeHTTP(responseRecorder, request) 160 | 161 | if responseRecorder.Code != http.StatusOK { 162 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 163 | } 164 | } 165 | 166 | func TestGate_ProtectWithInvalidTokenWhenUsingClientProvider(t *testing.T) { 167 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider)) 168 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 169 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) 170 | responseRecorder := httptest.NewRecorder() 171 | 172 | router := http.NewServeMux() 173 | router.Handle("/handle", gate.Protect(&testHandler{})) 174 | router.ServeHTTP(responseRecorder, request) 175 | 176 | if responseRecorder.Code != http.StatusUnauthorized { 177 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 178 | } 179 | } 180 | 181 | func TestGate_ProtectWithPermissionsWhenValidTokenAndSufficientPermissionsWhileUsingClientProvider(t *testing.T) { 182 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider)) 183 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 184 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken)) 185 | responseRecorder := httptest.NewRecorder() 186 | 187 | router := http.NewServeMux() 188 | router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{SecondTestProviderClientPermission})) 189 | router.ServeHTTP(responseRecorder, request) 190 | 191 | // Since the client returned from the mockClientProvider has FirstTestProviderClientPermission and 192 | // SecondTestProviderClientPermission and the testHandler is protected by SecondTestProviderClientPermission, 193 | // the request should be authorized 194 | if responseRecorder.Code != http.StatusOK { 195 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 196 | } 197 | } 198 | 199 | func TestGate_ProtectWithPermissionsWhenValidTokenAndInsufficientPermissionsWhileUsingClientProvider(t *testing.T) { 200 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClientProvider(mockClientProvider)) 201 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 202 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", TestProviderClientToken)) 203 | responseRecorder := httptest.NewRecorder() 204 | 205 | router := http.NewServeMux() 206 | router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"unrelated-permission"})) 207 | router.ServeHTTP(responseRecorder, request) 208 | 209 | // Since the client returned from the mockClientProvider has FirstTestProviderClientPermission and 210 | // SecondTestProviderClientPermission and the testHandler is protected by a permission that the client does not 211 | // have, the request should be not be authorized 212 | if responseRecorder.Code != http.StatusUnauthorized { 213 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 214 | } 215 | } 216 | 217 | func TestGate_ProtectWithPermissionsWhenClientHasSufficientPermissions(t *testing.T) { 218 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) 219 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 220 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) 221 | responseRecorder := httptest.NewRecorder() 222 | 223 | router := http.NewServeMux() 224 | router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"admin"})) 225 | router.ServeHTTP(responseRecorder, request) 226 | 227 | // Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler 228 | // is protected by the permission "admin", the request should be authorized 229 | if responseRecorder.Code != http.StatusOK { 230 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 231 | } 232 | } 233 | 234 | func TestGate_ProtectWithPermissionsWhenClientHasInsufficientPermissions(t *testing.T) { 235 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"}))) 236 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 237 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) 238 | responseRecorder := httptest.NewRecorder() 239 | 240 | router := http.NewServeMux() 241 | router.Handle("/handle", gate.ProtectWithPermissions(&testHandler{}, []string{"admin"})) 242 | router.ServeHTTP(responseRecorder, request) 243 | 244 | // Since the client registered directly in the AuthorizationService has the permission "mod" and the 245 | // testHandler is protected by the permission "admin", the request should be not be authorized 246 | if responseRecorder.Code != http.StatusUnauthorized { 247 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 248 | } 249 | } 250 | 251 | func TestGate_ProtectWithPermissions(t *testing.T) { 252 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("mytoken").WithPermissions([]string{"create", "read", "update", "delete"}))) 253 | 254 | router := http.NewServeMux() 255 | router.Handle("/create", gate.ProtectWithPermissions(&testHandler{}, []string{"create"})) 256 | router.Handle("/read", gate.ProtectWithPermissions(&testHandler{}, []string{"read"})) 257 | router.Handle("/update", gate.ProtectWithPermissions(&testHandler{}, []string{"update"})) 258 | router.Handle("/delete", gate.ProtectWithPermissions(&testHandler{}, []string{"delete"})) 259 | router.Handle("/crud", gate.ProtectWithPermissions(&testHandler{}, []string{"create", "read", "update", "delete"})) 260 | router.Handle("/backup", gate.ProtectWithPermissions(&testHandler{}, []string{"read", "backup"})) 261 | 262 | checkRouterOutput := func(t *testing.T, router *http.ServeMux, url string, expectedResponseCode int) { 263 | t.Run(strings.TrimPrefix(url, "/"), func(t *testing.T) { 264 | request, _ := http.NewRequest("GET", url, http.NoBody) 265 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "mytoken")) 266 | responseRecorder := httptest.NewRecorder() 267 | router.ServeHTTP(responseRecorder, request) 268 | if responseRecorder.Code != expectedResponseCode { 269 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, expectedResponseCode, responseRecorder.Code) 270 | } 271 | }) 272 | } 273 | 274 | checkRouterOutput(t, router, "/create", http.StatusOK) 275 | checkRouterOutput(t, router, "/read", http.StatusOK) 276 | checkRouterOutput(t, router, "/update", http.StatusOK) 277 | checkRouterOutput(t, router, "/delete", http.StatusOK) 278 | checkRouterOutput(t, router, "/crud", http.StatusOK) 279 | checkRouterOutput(t, router, "/backup", http.StatusUnauthorized) 280 | } 281 | 282 | func TestGate_ProtectWithPermissionWhenClientHasSufficientPermissions(t *testing.T) { 283 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) 284 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 285 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) 286 | responseRecorder := httptest.NewRecorder() 287 | 288 | router := http.NewServeMux() 289 | router.Handle("/handle", gate.ProtectWithPermission(&testHandler{}, "admin")) 290 | router.ServeHTTP(responseRecorder, request) 291 | 292 | // Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler 293 | // is protected by the permission "admin", the request should be authorized 294 | if responseRecorder.Code != http.StatusOK { 295 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 296 | } 297 | } 298 | 299 | func TestGate_ProtectWithPermissionWhenClientHasInsufficientPermissions(t *testing.T) { 300 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"}))) 301 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 302 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) 303 | responseRecorder := httptest.NewRecorder() 304 | 305 | router := http.NewServeMux() 306 | router.Handle("/handle", gate.ProtectWithPermission(&testHandler{}, "admin")) 307 | router.ServeHTTP(responseRecorder, request) 308 | 309 | // Since the client registered directly in the AuthorizationService has the permission "mod" and the 310 | // testHandler is protected by the permission "admin", the request should be not be authorized 311 | if responseRecorder.Code != http.StatusUnauthorized { 312 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 313 | } 314 | } 315 | 316 | func TestGate_PermissionMiddlewareWhenClientHasSufficientPermissions(t *testing.T) { 317 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) 318 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 319 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) 320 | responseRecorder := httptest.NewRecorder() 321 | 322 | router := http.NewServeMux() 323 | router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{})) 324 | router.ServeHTTP(responseRecorder, request) 325 | 326 | // Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler 327 | // is protected by the permission "admin", the request should be authorized 328 | if responseRecorder.Code != http.StatusOK { 329 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 330 | } 331 | } 332 | 333 | func TestGate_PermissionMiddlewareWhenClientHasInsufficientPermissions(t *testing.T) { 334 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"}))) 335 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 336 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) 337 | responseRecorder := httptest.NewRecorder() 338 | 339 | router := http.NewServeMux() 340 | router.Handle("/handle", gate.PermissionMiddleware("admin")(&testHandler{})) 341 | router.ServeHTTP(responseRecorder, request) 342 | 343 | // Since the client registered directly in the AuthorizationService has the permission "mod" and the 344 | // testHandler is protected by the permission "admin", the request should be not be authorized 345 | if responseRecorder.Code != http.StatusUnauthorized { 346 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 347 | } 348 | } 349 | 350 | func TestGate_ProtectFuncWithInvalidToken(t *testing.T) { 351 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) 352 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 353 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) 354 | responseRecorder := httptest.NewRecorder() 355 | 356 | router := http.NewServeMux() 357 | router.Handle("/handle", gate.ProtectFunc(testHandlerFunc)) 358 | router.ServeHTTP(responseRecorder, request) 359 | 360 | if responseRecorder.Code != http.StatusUnauthorized { 361 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 362 | } 363 | } 364 | 365 | func TestGate_ProtectFuncWithValidToken(t *testing.T) { 366 | gate := New().WithAuthorizationService(NewAuthorizationService().WithToken("good-token")) 367 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 368 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "good-token")) 369 | responseRecorder := httptest.NewRecorder() 370 | 371 | router := http.NewServeMux() 372 | router.Handle("/handle", gate.ProtectFunc(testHandlerFunc)) 373 | router.ServeHTTP(responseRecorder, request) 374 | 375 | if responseRecorder.Code != http.StatusOK { 376 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 377 | } 378 | } 379 | 380 | func TestGate_ProtectFuncWithPermissionWhenClientHasSufficientPermissions(t *testing.T) { 381 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClient("token").WithPermission("admin"))) 382 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 383 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) 384 | responseRecorder := httptest.NewRecorder() 385 | 386 | router := http.NewServeMux() 387 | router.HandleFunc("/handle", gate.ProtectFuncWithPermission(testHandlerFunc, "admin")) 388 | router.ServeHTTP(responseRecorder, request) 389 | 390 | // Since the client registered directly in the AuthorizationService has the permission "admin" and the testHandler 391 | // is protected by the permission "admin", the request should be authorized 392 | if responseRecorder.Code != http.StatusOK { 393 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 394 | } 395 | } 396 | 397 | func TestGate_ProtectFuncWithPermissionWhenClientHasInsufficientPermissions(t *testing.T) { 398 | gate := New().WithAuthorizationService(NewAuthorizationService().WithClient(NewClientWithPermissions("token", []string{"mod"}))) 399 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 400 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "token")) 401 | responseRecorder := httptest.NewRecorder() 402 | 403 | router := http.NewServeMux() 404 | router.HandleFunc("/handle", gate.ProtectFuncWithPermission(testHandlerFunc, "admin")) 405 | router.ServeHTTP(responseRecorder, request) 406 | 407 | // Since the client registered directly in the AuthorizationService has the permission "mod" and the 408 | // testHandler is protected by the permission "admin", the request should be not be authorized 409 | if responseRecorder.Code != http.StatusUnauthorized { 410 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 411 | } 412 | } 413 | 414 | func TestGate_WithCustomUnauthorizedResponseBody(t *testing.T) { 415 | gate := New().WithAuthorizationService(NewAuthorizationService()).WithCustomUnauthorizedResponseBody([]byte("test")) 416 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 417 | request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "bad-token")) 418 | responseRecorder := httptest.NewRecorder() 419 | 420 | router := http.NewServeMux() 421 | router.Handle("/handle", gate.Protect(&testHandler{})) 422 | router.ServeHTTP(responseRecorder, request) 423 | 424 | if responseRecorder.Code != http.StatusUnauthorized { 425 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusUnauthorized, responseRecorder.Code) 426 | } 427 | if responseBody, _ := io.ReadAll(responseRecorder.Body); string(responseBody) != "test" { 428 | t.Errorf("%s %s should have returned %s, but returned %s instead", request.Method, request.URL, "test", string(responseBody)) 429 | } 430 | } 431 | 432 | func TestGate_ProtectWithNoAuthorizationService(t *testing.T) { 433 | gate := New() 434 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 435 | responseRecorder := httptest.NewRecorder() 436 | 437 | router := http.NewServeMux() 438 | router.Handle("/handle", gate.Protect(&testHandler{})) 439 | router.ServeHTTP(responseRecorder, request) 440 | 441 | if responseRecorder.Code != http.StatusOK { 442 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 443 | } 444 | } 445 | 446 | func TestGate_ProtectWithRateLimit(t *testing.T) { 447 | gate := New().WithRateLimit(2) 448 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 449 | router := http.NewServeMux() 450 | router.Handle("/handle", gate.Protect(&testHandler{})) 451 | 452 | responseRecorder := httptest.NewRecorder() 453 | router.ServeHTTP(responseRecorder, request) 454 | if responseRecorder.Code != http.StatusOK { 455 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 456 | } 457 | 458 | responseRecorder = httptest.NewRecorder() 459 | router.ServeHTTP(responseRecorder, request) 460 | if responseRecorder.Code != http.StatusOK { 461 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 462 | } 463 | 464 | responseRecorder = httptest.NewRecorder() 465 | router.ServeHTTP(responseRecorder, request) 466 | if responseRecorder.Code != http.StatusTooManyRequests { 467 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusTooManyRequests, responseRecorder.Code) 468 | } 469 | 470 | // Wait for rate limit time window to pass 471 | time.Sleep(time.Second) 472 | 473 | responseRecorder = httptest.NewRecorder() 474 | router.ServeHTTP(responseRecorder, request) 475 | if responseRecorder.Code != http.StatusOK { 476 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 477 | } 478 | } 479 | 480 | func TestGate_WithCustomTokenExtractor(t *testing.T) { 481 | authorizationService := NewAuthorizationService().WithClientProvider(mockClientProvider) 482 | customTokenExtractorFunc := func(request *http.Request) string { 483 | sessionCookie, err := request.Cookie("session") 484 | if err != nil { 485 | return "" 486 | } 487 | return sessionCookie.Value 488 | } 489 | gate := New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) 490 | 491 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 492 | request.AddCookie(&http.Cookie{Name: "session", Value: TestProviderClientToken}) 493 | responseRecorder := httptest.NewRecorder() 494 | 495 | router := http.NewServeMux() 496 | router.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) { 497 | if r.Context().Value(TokenContextKey) != TestProviderClientToken { 498 | t.Errorf("token should have been passed to the request context") 499 | } 500 | if r.Context().Value(DataContextKey) != TestProviderClientData { 501 | t.Errorf("data should have been passed to the request context") 502 | } 503 | w.WriteHeader(http.StatusOK) 504 | })) 505 | router.ServeHTTP(responseRecorder, request) 506 | 507 | if responseRecorder.Code != http.StatusOK { 508 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 509 | } 510 | } 511 | 512 | func TestGateWithCustomHeader(t *testing.T) { 513 | authorizationService := NewAuthorizationService().WithClientProvider(mockClientProvider) 514 | customTokenExtractorFunc := func(request *http.Request) string { 515 | return request.Header.Get("X-API-Token") 516 | } 517 | gate := New().WithAuthorizationService(authorizationService).WithCustomTokenExtractor(customTokenExtractorFunc) 518 | 519 | request, _ := http.NewRequest("GET", "/handle", http.NoBody) 520 | request.Header.Add("X-API-Token", TestProviderClientToken) 521 | responseRecorder := httptest.NewRecorder() 522 | 523 | router := http.NewServeMux() 524 | router.Handle("/handle", gate.ProtectFunc(func(w http.ResponseWriter, r *http.Request) { 525 | if r.Context().Value(TokenContextKey) != TestProviderClientToken { 526 | t.Errorf("token should have been passed to the request context") 527 | } 528 | if r.Context().Value(DataContextKey) != TestProviderClientData { 529 | t.Errorf("data should have been passed to the request context") 530 | } 531 | w.WriteHeader(http.StatusOK) 532 | })) 533 | router.ServeHTTP(responseRecorder, request) 534 | 535 | if responseRecorder.Code != http.StatusOK { 536 | t.Errorf("%s %s should have returned %d, but returned %d instead", request.Method, request.URL, http.StatusOK, responseRecorder.Code) 537 | } 538 | } 539 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/TwiN/g8/v3 2 | 3 | go 1.23.3 4 | 5 | require github.com/TwiN/gocache/v2 v2.2.2 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/TwiN/gocache/v2 v2.2.2 h1:4HToPfDV8FSbaYO5kkbhLpEllUYse5rAf+hVU/mSsuI= 2 | github.com/TwiN/gocache/v2 v2.2.2/go.mod h1:WfIuwd7GR82/7EfQqEtmLFC3a2vqaKbs4Pe6neB7Gyc= 3 | -------------------------------------------------------------------------------- /ratelimiter.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | // RateLimiter is a fixed rate limiter 9 | type RateLimiter struct { 10 | maximumExecutionsPerSecond int 11 | executionsLeftInWindow int 12 | windowStartTime time.Time 13 | mutex sync.Mutex 14 | } 15 | 16 | // NewRateLimiter creates a RateLimiter 17 | func NewRateLimiter(maximumExecutionsPerSecond int) *RateLimiter { 18 | return &RateLimiter{ 19 | windowStartTime: time.Now(), 20 | executionsLeftInWindow: maximumExecutionsPerSecond, 21 | maximumExecutionsPerSecond: maximumExecutionsPerSecond, 22 | } 23 | } 24 | 25 | // Try updates the number of executions if the rate limit quota hasn't been reached and returns whether the 26 | // attempt was successful or not. 27 | // 28 | // Returns false if the execution was not successful (rate limit quota has been reached) 29 | // Returns true if the execution was successful (rate limit quota has not been reached) 30 | func (r *RateLimiter) Try() bool { 31 | r.mutex.Lock() 32 | defer r.mutex.Unlock() 33 | if time.Now().Add(-time.Second).After(r.windowStartTime) { 34 | r.windowStartTime = time.Now() 35 | r.executionsLeftInWindow = r.maximumExecutionsPerSecond 36 | } 37 | if r.executionsLeftInWindow == 0 { 38 | return false 39 | } 40 | r.executionsLeftInWindow-- 41 | return true 42 | } 43 | -------------------------------------------------------------------------------- /ratelimiter_test.go: -------------------------------------------------------------------------------- 1 | package g8 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestNewRateLimiter(t *testing.T) { 9 | rl := NewRateLimiter(2) 10 | if rl.maximumExecutionsPerSecond != 2 { 11 | t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond) 12 | } 13 | if rl.executionsLeftInWindow != 2 { 14 | t.Errorf("expected executionsLeftInWindow to be %d, got %d", 2, rl.executionsLeftInWindow) 15 | } 16 | // First execution: should not be rate limited 17 | if notRateLimited := rl.Try(); !notRateLimited { 18 | t.Error("expected Try to return true") 19 | } 20 | if rl.maximumExecutionsPerSecond != 2 { 21 | t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond) 22 | } 23 | if rl.executionsLeftInWindow != 1 { 24 | t.Errorf("expected executionsLeftInWindow to be %d, got %d", 1, rl.executionsLeftInWindow) 25 | } 26 | // Second execution: should not be rate limited 27 | if notRateLimited := rl.Try(); !notRateLimited { 28 | t.Error("expected Try to return true") 29 | } 30 | if rl.maximumExecutionsPerSecond != 2 { 31 | t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond) 32 | } 33 | if rl.executionsLeftInWindow != 0 { 34 | t.Errorf("expected executionsLeftInWindow to be %d, got %d", 0, rl.executionsLeftInWindow) 35 | } 36 | // Third execution: should be rate limited 37 | if notRateLimited := rl.Try(); notRateLimited { 38 | t.Error("expected Try to return false") 39 | } 40 | if rl.maximumExecutionsPerSecond != 2 { 41 | t.Errorf("expected maximumExecutionsPerSecond to be %d, got %d", 2, rl.maximumExecutionsPerSecond) 42 | } 43 | if rl.executionsLeftInWindow != 0 { 44 | t.Errorf("expected executionsLeftInWindow to be %d, got %d", 0, rl.executionsLeftInWindow) 45 | } 46 | } 47 | 48 | func TestRateLimiter_Try(t *testing.T) { 49 | rl := NewRateLimiter(5) 50 | for i := 0; i < 20; i++ { 51 | notRateLimited := rl.Try() 52 | if i < 5 { 53 | if !notRateLimited { 54 | t.Fatal("expected to not be rate limited") 55 | } 56 | } else { 57 | if notRateLimited { 58 | t.Fatal("expected to be rate limited") 59 | } 60 | } 61 | } 62 | } 63 | 64 | func TestRateLimiter_TryAlwaysUnderRateLimit(t *testing.T) { 65 | rl := NewRateLimiter(20) 66 | for i := 0; i < 45; i++ { 67 | notRateLimited := rl.Try() 68 | if !notRateLimited { 69 | t.Fatal("expected to not be rate limited") 70 | } 71 | time.Sleep(51 * time.Millisecond) 72 | } 73 | } 74 | --------------------------------------------------------------------------------