├── .DEREK.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── auth ├── jwt_authenticator.go └── jwt_authenticator_test.go ├── concurrency-limiter ├── concurrency_limiter.go └── concurrency_limiter_test.go ├── go.mod └── go.sum /.DEREK.yml: -------------------------------------------------------------------------------- 1 | redirect: https://raw.githubusercontent.com/openfaas/faas/master/.DEREK.yml 2 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | This project follows the contribution process for [OpenFaaS](https://github.com/openfaas/faas/blob/master/CONTRIBUTING.md). 2 | 3 | You are required to raise a proposal issue and have it approved before working on changes or raising a PR. 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 OpenFaaS 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # faas-middleware 2 | HTTP middleware for OpenFaaS 3 | 4 | ## Components 5 | ### Concurrency Limiter 6 | `concurrency-limiter` is a tool that can be used to limit the number of active inflight requests for a given http request handler. 7 | 8 | ### JWT Authenticator middleware 9 | 10 | The JWT authenticator middleware can be used to authorize http request for functions with IAM for OpenFaaS. The middleware verifies the permissions in the `function` claim of an OpenFaaS function access token that is set in the `Authorization` header of the request. 11 | 12 | This middleware is used by the [classic-watchdog](https://github.com/openfaas/classic-watchdog) and [of-watchdog](https://github.com/openfaas/of-watchdog) -------------------------------------------------------------------------------- /auth/jwt_authenticator.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "log" 8 | "net/http" 9 | "regexp" 10 | "strings" 11 | "time" 12 | 13 | "github.com/golang-jwt/jwt/v5" 14 | "github.com/rakutentech/jwk-go/jwk" 15 | ) 16 | 17 | const ( 18 | authorityURL = "http://gateway.openfaas:8080/.well-known/openid-configuration" 19 | localAuthorityURL = "http://127.0.0.1:8000/.well-known/openid-configuration" 20 | functionRealm = "IAM function invoke" 21 | ) 22 | 23 | type jwtAuth struct { 24 | next http.Handler 25 | opts JWTAuthOptions 26 | 27 | keySet jwk.KeySpecSet 28 | issuer string 29 | } 30 | 31 | // JWTAuthOptions stores the configuration for JWT based function authentication 32 | type JWTAuthOptions struct { 33 | Name string 34 | Namespace string 35 | LocalAuthority bool 36 | Debug bool 37 | } 38 | 39 | func (a jwtAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) { 40 | 41 | issuer := a.issuer 42 | 43 | st := time.Now() 44 | 45 | if a.opts.Debug { 46 | for _, key := range a.keySet.Keys { 47 | log.Printf("%s: %v", issuer, key.KeyID) 48 | } 49 | } 50 | 51 | var bearer string 52 | if v := r.Header.Get("Authorization"); v != "" { 53 | bearer = strings.TrimPrefix(v, "Bearer ") 54 | } 55 | 56 | if bearer == "" { 57 | writeUnauthorized(w, "Bearer must be present in Authorization header") 58 | log.Printf("%s %s - %d ACCESS DENIED - (%s)", r.Method, r.URL.Path, http.StatusUnauthorized, time.Since(st).Round(time.Millisecond)) 59 | return 60 | } 61 | 62 | parseOptions := []jwt.ParserOption{ 63 | jwt.WithIssuer(issuer), 64 | // The OpenFaaS gateway is the expected audience but we can use the issuer url 65 | // since the gateway is also the issuer of function tokens and thus has the same url. 66 | jwt.WithAudience(issuer), 67 | jwt.WithLeeway(time.Second * 1), 68 | } 69 | 70 | functionClaims := FunctionClaims{} 71 | token, err := jwt.ParseWithClaims(bearer, &functionClaims, func(token *jwt.Token) (interface{}, error) { 72 | if a.opts.Debug { 73 | log.Printf("[JWT Auth] Token: audience: %v\tissuer: %v", functionClaims.Audience, functionClaims.Issuer) 74 | } 75 | 76 | kid, ok := token.Header["kid"].(string) 77 | if !ok { 78 | return nil, fmt.Errorf("invalid kid: %v", token.Header["kid"]) 79 | } 80 | 81 | // HV: Consider caching and refreshing the keyset to handle key rotations. 82 | var key *jwk.KeySpec 83 | for _, k := range a.keySet.Keys { 84 | if k.KeyID == kid { 85 | key = &k 86 | break 87 | } 88 | } 89 | 90 | if key == nil { 91 | return nil, fmt.Errorf("invalid kid: %s", kid) 92 | } 93 | return key.Key, nil 94 | }, parseOptions...) 95 | if err != nil { 96 | writeUnauthorized(w, fmt.Sprintf("failed to parse JWT token: %s", err)) 97 | log.Printf("%s %s - %d ACCESS DENIED - (%s)", r.Method, r.URL.Path, http.StatusUnauthorized, time.Since(st).Round(time.Millisecond)) 98 | return 99 | } 100 | 101 | if !token.Valid { 102 | writeUnauthorized(w, fmt.Sprintf("invalid JWT token: %s", bearer)) 103 | 104 | log.Printf("%s %s - %d ACCESS DENIED - (%s)", r.Method, r.URL.Path, http.StatusUnauthorized, time.Since(st).Round(time.Millisecond)) 105 | return 106 | } 107 | 108 | if !isAuthorized(functionClaims.Authentication, a.opts.Namespace, a.opts.Name) { 109 | w.Header().Set("X-OpenFaaS-Internal", "faas-middleware") 110 | http.Error(w, "insufficient permissions", http.StatusForbidden) 111 | 112 | log.Printf("%s %s - %d ACCESS DENIED - (%s)", r.Method, r.URL.Path, http.StatusForbidden, time.Since(st).Round(time.Millisecond)) 113 | return 114 | } 115 | 116 | r.Header.Set("X-Auth-Seconds", fmt.Sprintf("%f", time.Since(st).Seconds())) 117 | 118 | a.next.ServeHTTP(w, r) 119 | } 120 | 121 | // NewJWTAuthMiddleware creates a new middleware handler to handle authentication with OpenFaaS function 122 | // access tokens. 123 | func NewJWTAuthMiddleware(opts JWTAuthOptions, next http.Handler) (http.Handler, error) { 124 | authority := authorityURL 125 | if opts.LocalAuthority { 126 | authority = localAuthorityURL 127 | } 128 | 129 | config, err := getConfig(authority) 130 | if err != nil { 131 | return nil, err 132 | } 133 | 134 | if opts.Debug { 135 | log.Printf("[JWT Auth] Issuer: %s\tJWKS URI: %s", config.Issuer, config.JWKSURI) 136 | } 137 | 138 | keySet, err := getKeyset(config.JWKSURI) 139 | if err != nil { 140 | return nil, err 141 | } 142 | 143 | if opts.Debug { 144 | for _, key := range keySet.Keys { 145 | log.Printf("[JWT Auth] Key: %s", key.KeyID) 146 | } 147 | } 148 | 149 | return jwtAuth{ 150 | next: next, 151 | opts: opts, 152 | keySet: keySet, 153 | issuer: config.Issuer, 154 | }, nil 155 | } 156 | 157 | // writeUnauthorized replies to the request with the specified error message and 401 HTTP code. 158 | // It sets the WWW-Authenticate header. 159 | // It does not otherwise end the request; the caller should ensure no further writes are done to w. 160 | // The error message should be plain text. 161 | func writeUnauthorized(w http.ResponseWriter, err string) { 162 | w.Header().Set("X-OpenFaaS-Internal", "faas-middleware") 163 | w.Header().Set("WWW-Authenticate", fmt.Sprintf("Bearer realm=%s", functionRealm)) 164 | http.Error(w, err, http.StatusUnauthorized) 165 | } 166 | 167 | func getKeyset(uri string) (jwk.KeySpecSet, error) { 168 | var set jwk.KeySpecSet 169 | req, err := http.NewRequest(http.MethodGet, uri, nil) 170 | if err != nil { 171 | return set, err 172 | } 173 | 174 | req.Header.Add("User-Agent", "openfaas-watchdog") 175 | 176 | res, err := http.DefaultClient.Do(req) 177 | if err != nil { 178 | return set, err 179 | } 180 | 181 | var body []byte 182 | 183 | if res.Body != nil { 184 | defer res.Body.Close() 185 | body, _ = io.ReadAll(res.Body) 186 | } 187 | 188 | if res.StatusCode != http.StatusOK { 189 | return set, fmt.Errorf("failed to get keyset from %s, status code: %d, body: %s", uri, res.StatusCode, string(body)) 190 | } 191 | 192 | if err := json.Unmarshal(body, &set); err != nil { 193 | return set, err 194 | } 195 | 196 | return set, nil 197 | } 198 | 199 | func getConfig(jwksURL string) (openIDConfiguration, error) { 200 | var config openIDConfiguration 201 | 202 | req, err := http.NewRequest(http.MethodGet, jwksURL, nil) 203 | if err != nil { 204 | return config, err 205 | } 206 | 207 | res, err := http.DefaultClient.Do(req) 208 | if err != nil { 209 | return config, err 210 | } 211 | 212 | var body []byte 213 | if res.Body != nil { 214 | defer res.Body.Close() 215 | body, _ = io.ReadAll(res.Body) 216 | } 217 | 218 | if res.StatusCode != http.StatusOK { 219 | return config, fmt.Errorf("failed to get config from %s, status code: %d, body: %s", jwksURL, res.StatusCode, string(body)) 220 | } 221 | 222 | if err := json.Unmarshal(body, &config); err != nil { 223 | return config, err 224 | } 225 | 226 | return config, nil 227 | } 228 | 229 | type openIDConfiguration struct { 230 | Issuer string `json:"issuer"` 231 | JWKSURI string `json:"jwks_uri"` 232 | } 233 | 234 | type FunctionClaims struct { 235 | jwt.RegisteredClaims 236 | 237 | Authentication AuthPermissions `json:"function"` 238 | } 239 | 240 | type AuthPermissions struct { 241 | Permissions []string `json:"permissions"` 242 | Audience []string `json:"audience,omitempty"` 243 | } 244 | 245 | func isAuthorized(auth AuthPermissions, namespace, fn string) bool { 246 | functionRef := fmt.Sprintf("%s:%s", namespace, fn) 247 | 248 | return matchResource(auth.Audience, functionRef, false) && 249 | matchResource(auth.Permissions, functionRef, true) 250 | } 251 | 252 | // matchResources checks if ref matches one of the resources. 253 | // The function will return true if a match is found. 254 | // If required is false, this function will return true if a match is found or the resource list is empty. 255 | func matchResource(resources []string, ref string, req bool) bool { 256 | if !req { 257 | if len(resources) == 0 { 258 | return true 259 | } 260 | } 261 | 262 | for _, res := range resources { 263 | if res == "*" { 264 | return true 265 | } 266 | 267 | if matchString(res, ref) { 268 | return true 269 | } 270 | } 271 | 272 | return false 273 | } 274 | 275 | func matchString(pattern string, value string) bool { 276 | if len(pattern) > 0 { 277 | result, _ := regexp.MatchString(wildCardToRegexp(pattern), value) 278 | return result 279 | } 280 | 281 | return pattern == value 282 | } 283 | 284 | // wildCardToRegexp converts a wildcard pattern to a regular expression pattern. 285 | func wildCardToRegexp(pattern string) string { 286 | var result strings.Builder 287 | for i, literal := range strings.Split(pattern, "*") { 288 | 289 | // Replace * with .* 290 | if i > 0 { 291 | result.WriteString(".*") 292 | } 293 | 294 | // Quote any regular expression meta characters in the 295 | // literal text. 296 | result.WriteString(regexp.QuoteMeta(literal)) 297 | } 298 | return result.String() 299 | } 300 | -------------------------------------------------------------------------------- /auth/jwt_authenticator_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func Test_isAuthorized(t *testing.T) { 8 | tests := []struct { 9 | name string 10 | want bool 11 | permissions AuthPermissions 12 | namespace string 13 | function string 14 | }{ 15 | { 16 | name: "deny empty permission list", 17 | want: false, 18 | permissions: AuthPermissions{ 19 | Permissions: []string{}, 20 | }, 21 | namespace: "staging", 22 | function: "env", 23 | }, 24 | { 25 | name: "allow empty audience list", 26 | want: true, 27 | permissions: AuthPermissions{ 28 | Permissions: []string{"staging:env"}, 29 | }, 30 | namespace: "staging", 31 | function: "env", 32 | }, 33 | { 34 | name: "allow cluster wildcard", 35 | want: true, 36 | permissions: AuthPermissions{ 37 | Permissions: []string{"*"}, 38 | }, 39 | namespace: "staging", 40 | function: "figlet", 41 | }, 42 | { 43 | name: "allow function wildcard", 44 | want: true, 45 | permissions: AuthPermissions{ 46 | Permissions: []string{"dev:*"}, 47 | }, 48 | namespace: "dev", 49 | function: "figlet", 50 | }, 51 | { 52 | name: "allow namespace wildcard", 53 | want: true, 54 | permissions: AuthPermissions{ 55 | Permissions: []string{"*:env"}, 56 | }, 57 | namespace: "openfaas-fn", 58 | function: "env", 59 | }, 60 | { 61 | name: "allow function", 62 | want: true, 63 | permissions: AuthPermissions{ 64 | Permissions: []string{"openfaas-fn:env"}, 65 | }, 66 | namespace: "openfaas-fn", 67 | function: "env", 68 | }, 69 | { 70 | name: "deny function", 71 | want: false, 72 | permissions: AuthPermissions{ 73 | Permissions: []string{"openfaas-fn:env"}, 74 | }, 75 | namespace: "openfaas-fn", 76 | function: "figlet", 77 | }, 78 | { 79 | name: "deny namespace", 80 | want: false, 81 | permissions: AuthPermissions{ 82 | Permissions: []string{"openfaas-fn:*"}, 83 | }, 84 | namespace: "staging", 85 | function: "env", 86 | }, 87 | { 88 | name: "deny namespace wildcard", 89 | want: false, 90 | permissions: AuthPermissions{ 91 | Permissions: []string{"*:figlet"}, 92 | }, 93 | namespace: "staging", 94 | function: "env", 95 | }, 96 | { 97 | name: "multiple permissions allow function", 98 | want: true, 99 | permissions: AuthPermissions{ 100 | Permissions: []string{"openfaas-fn:*", "staging:env"}, 101 | }, 102 | namespace: "staging", 103 | function: "env", 104 | }, 105 | { 106 | name: "multiple permissions deny function", 107 | want: false, 108 | permissions: AuthPermissions{ 109 | Permissions: []string{"openfaas-fn:figlet", "staging-*:env"}, 110 | }, 111 | namespace: "staging", 112 | function: "env", 113 | }, 114 | { 115 | name: "allow audience", 116 | want: true, 117 | permissions: AuthPermissions{ 118 | Permissions: []string{"openfaas-fn:*"}, 119 | Audience: []string{"openfaas-fn:env"}, 120 | }, 121 | namespace: "openfaas-fn", 122 | function: "env", 123 | }, 124 | { 125 | name: "deny audience", 126 | want: false, 127 | permissions: AuthPermissions{ 128 | Permissions: []string{"openfaas-fn:*"}, 129 | Audience: []string{"openfaas-fn:env"}, 130 | }, 131 | namespace: "openfaas-fn", 132 | function: "figlet", 133 | }, 134 | { 135 | name: "allow audience function wildcard", 136 | want: true, 137 | permissions: AuthPermissions{ 138 | Permissions: []string{"openfaas-fn:figlet"}, 139 | Audience: []string{"openfaas-fn:*"}, 140 | }, 141 | namespace: "openfaas-fn", 142 | function: "figlet", 143 | }, 144 | { 145 | name: "deny audience function wildcard", 146 | want: false, 147 | permissions: AuthPermissions{ 148 | Permissions: []string{"openfaas-fn:figlet", "dev:env"}, 149 | Audience: []string{"openfaas-fn:*"}, 150 | }, 151 | namespace: "dev", 152 | function: "env", 153 | }, 154 | { 155 | name: "deny audience namespace wildcard", 156 | want: false, 157 | permissions: AuthPermissions{ 158 | Permissions: []string{"openfaas-fn:*", "dev:*"}, 159 | Audience: []string{"*:env"}, 160 | }, 161 | namespace: "dev", 162 | function: "figlet", 163 | }, 164 | { 165 | name: "allow audience namespace wildcard", 166 | want: true, 167 | permissions: AuthPermissions{ 168 | Permissions: []string{"openfaas-fn:*", "dev:*"}, 169 | Audience: []string{"*:env"}, 170 | }, 171 | namespace: "openfaas-fn", 172 | function: "env", 173 | }, 174 | } 175 | 176 | for _, test := range tests { 177 | t.Run(test.name, func(t *testing.T) { 178 | want := test.want 179 | got := isAuthorized(test.permissions, test.namespace, test.function) 180 | 181 | if want != got { 182 | t.Errorf("want: %t, got: %t", want, got) 183 | } 184 | }) 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /concurrency-limiter/concurrency_limiter.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "sync/atomic" 7 | ) 8 | 9 | // Limiter is an interface that can be used to check if a limit has been met. 10 | type Limiter interface { 11 | Met() bool 12 | } 13 | 14 | type ConcurrencyLimiter struct { 15 | backendHTTPHandler http.Handler 16 | /* 17 | We keep two counters here in order to make it so that we can know when a request has gone to completed 18 | in the tests. We could wrap these up in a condvar, so there's no need to spinlock, but that seems overkill 19 | for testing. 20 | 21 | This is effectively a very fancy semaphore built for optimistic concurrency only, and with spinlocks. If 22 | you want to add timeouts here / pessimistic concurrency, signaling needs to be added and/or a condvar esque 23 | sorta thing needs to be done to wake up waiters who are waiting post-spin. 24 | 25 | Otherwise, there's all sorts of futzing in order to make sure that the concurrency limiter handler 26 | has completed 27 | The math works on overflow: 28 | var x, y uint64 29 | x = (1 << 64 - 1) 30 | y = (1 << 64 - 1) 31 | x++ 32 | fmt.Println(x) 33 | fmt.Println(y) 34 | fmt.Println(x - y) 35 | Prints: 36 | 0 37 | 18446744073709551615 38 | 1 39 | */ 40 | requestsStarted uint64 41 | requestsCompleted uint64 42 | 43 | maxInflightRequests uint64 44 | } 45 | 46 | func (cl *ConcurrencyLimiter) Met() bool { 47 | if cl == nil { 48 | return false 49 | } 50 | 51 | // We should not have any ConcurrencyLimiter created with a limit of 0 52 | // but return early if that's the case. 53 | if cl.maxInflightRequests == 0 { 54 | return false 55 | } 56 | 57 | requestsStarted := atomic.LoadUint64(&cl.requestsStarted) 58 | completedRequested := atomic.LoadUint64(&cl.requestsCompleted) 59 | return requestsStarted-completedRequested >= cl.maxInflightRequests 60 | } 61 | 62 | func (cl *ConcurrencyLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { 63 | // We should not have any ConcurrencyLimiter created with a limit of 0 64 | // but we'll check anyway and return early. 65 | if cl.maxInflightRequests == 0 { 66 | cl.backendHTTPHandler.ServeHTTP(w, r) 67 | return 68 | } 69 | 70 | requestsStarted := atomic.AddUint64(&cl.requestsStarted, 1) 71 | completedRequested := atomic.LoadUint64(&cl.requestsCompleted) 72 | if requestsStarted-completedRequested > cl.maxInflightRequests { 73 | // This is a failure pathway, and we do not want to block on the write to finish 74 | atomic.AddUint64(&cl.requestsCompleted, 1) 75 | 76 | // Some APIs only return JSON, since we can interfere here and send a plain/text 77 | // message, let's do the right thing so that downstream users can consume it. 78 | w.Header().Add("Content-Type", "text/plain") 79 | w.Header().Add("X-OpenFaaS-Internal", "faas-middleware") 80 | 81 | w.WriteHeader(http.StatusTooManyRequests) 82 | 83 | fmt.Fprintf(w, "Concurrent request limit exceeded. Max concurrent requests: %d\n", cl.maxInflightRequests) 84 | return 85 | } 86 | 87 | cl.backendHTTPHandler.ServeHTTP(w, r) 88 | atomic.AddUint64(&cl.requestsCompleted, 1) 89 | } 90 | 91 | // NewConcurrencyLimiter creates NewConcurrencyLimiter with a Handler() function that returns a 92 | // handler which limits the active number of active, concurrent requests. 93 | // 94 | // If the concurrency limit is less than, or equal to 0, then it will just return the handler 95 | // passed to it. 96 | // 97 | // The Met() function will return true if the concurrency limit is exceeded within the handler 98 | // at the time of the call. 99 | func NewConcurrencyLimiter(handler http.Handler, concurrencyLimit int) *ConcurrencyLimiter { 100 | return &ConcurrencyLimiter{ 101 | backendHTTPHandler: handler, 102 | maxInflightRequests: uint64(concurrencyLimit), 103 | } 104 | } 105 | 106 | func (cl *ConcurrencyLimiter) Handler() http.Handler { 107 | return cl 108 | } 109 | -------------------------------------------------------------------------------- /concurrency-limiter/concurrency_limiter_test.go: -------------------------------------------------------------------------------- 1 | package limiter 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "net/http/httptest" 8 | "sync" 9 | "sync/atomic" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func Test_Met_FalseWhenNil(t *testing.T) { 15 | var cl *ConcurrencyLimiter 16 | if cl.Met() != false { 17 | t.Fatalf("Want Met() to be false when nil, got: %t", cl.Met()) 18 | } 19 | } 20 | 21 | func Test_Met_FalseWhenNoLimitSet(t *testing.T) { 22 | t.Parallel() 23 | 24 | clMet := false 25 | 26 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 27 | time.Sleep(time.Millisecond * 10) 28 | w.WriteHeader(http.StatusAccepted) 29 | }) 30 | 31 | limit := 0 32 | cl := NewConcurrencyLimiter(http.Handler(handler), limit) 33 | if cl.Met() == true { 34 | t.Fatalf("Want Met() to be false due to no requests, got: %t", cl.Met()) 35 | } 36 | 37 | wg := &sync.WaitGroup{} 38 | wg.Add(3) 39 | go func() { 40 | req := httptest.NewRequest("GET", "/", nil) 41 | rr := httptest.ResponseRecorder{} 42 | cl.ServeHTTP(&rr, req) 43 | 44 | wg.Done() 45 | }() 46 | go func() { 47 | req := httptest.NewRequest("GET", "/", nil) 48 | rr := httptest.ResponseRecorder{} 49 | cl.ServeHTTP(&rr, req) 50 | 51 | wg.Done() 52 | }() 53 | 54 | // Is there a better way to catch the Met() function whilst at least 55 | // one of the HTTP calls is in progress? 56 | go func() { 57 | for i := 0; i < 100; i++ { 58 | if cl.Met() == true { 59 | clMet = true 60 | break 61 | } 62 | time.Sleep(time.Millisecond * 1) 63 | } 64 | wg.Done() 65 | }() 66 | 67 | wg.Wait() 68 | 69 | want := false 70 | if clMet != want { 71 | t.Fatalf("Want Met() to be false due to a limit of 0 request and 2 in-flight, got: %t", cl.Met()) 72 | } 73 | } 74 | 75 | func Test_Met_True_When_OverLimit(t *testing.T) { 76 | t.Parallel() 77 | 78 | clMet := false 79 | 80 | handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 81 | time.Sleep(time.Millisecond * 10) 82 | w.WriteHeader(http.StatusAccepted) 83 | }) 84 | 85 | limit := 1 86 | cl := NewConcurrencyLimiter(http.Handler(handler), limit) 87 | if cl.Met() == true { 88 | t.Fatalf("Want Met() to be false due to no requests, got: %t", cl.Met()) 89 | } 90 | 91 | wg := &sync.WaitGroup{} 92 | wg.Add(3) 93 | go func() { 94 | req := httptest.NewRequest("GET", "/", nil) 95 | rr := httptest.ResponseRecorder{} 96 | cl.ServeHTTP(&rr, req) 97 | 98 | wg.Done() 99 | }() 100 | go func() { 101 | req := httptest.NewRequest("GET", "/", nil) 102 | rr := httptest.ResponseRecorder{} 103 | cl.ServeHTTP(&rr, req) 104 | 105 | wg.Done() 106 | }() 107 | 108 | // Is there a better way to catch the Met() function whilst at least 109 | // one of the HTTP calls is in progress? 110 | go func() { 111 | for i := 0; i < 100; i++ { 112 | if cl.Met() == true { 113 | clMet = true 114 | break 115 | } 116 | time.Sleep(time.Millisecond * 1) 117 | } 118 | wg.Done() 119 | }() 120 | 121 | wg.Wait() 122 | 123 | want := true 124 | if clMet != want { 125 | t.Fatalf("Want Met() to be true due to a limit of 1 request and 2 in-flight, got: %t", cl.Met()) 126 | } 127 | } 128 | 129 | func makeFakeHandler(ctx context.Context, completeInFlightRequestChan chan struct{}) http.HandlerFunc { 130 | return func(w http.ResponseWriter, r *http.Request) { 131 | select { 132 | case <-ctx.Done(): 133 | w.WriteHeader(http.StatusServiceUnavailable) 134 | case <-completeInFlightRequestChan: 135 | w.WriteHeader(http.StatusOK) 136 | } 137 | } 138 | } 139 | 140 | func doRRandRequest(ctx context.Context, wg *sync.WaitGroup, cl http.Handler) *httptest.ResponseRecorder { 141 | // If wait for handler is true, it waits until the code is in the handler function 142 | rr := httptest.NewRecorder() 143 | // This should never fail unless we're out of memory or something 144 | req, err := http.NewRequest("GET", "/", nil) 145 | if err != nil { 146 | panic(err) 147 | } 148 | req = req.WithContext(ctx) 149 | 150 | wg.Add(1) 151 | go func() { 152 | // If this code path is meant to make it into the handler, we need a way to figure out if it's there or not 153 | cl.ServeHTTP(rr, req) 154 | // If the request was aborted, unblock any waiting goroutines 155 | wg.Done() 156 | }() 157 | 158 | return rr 159 | } 160 | 161 | func TestConcurrencyLimitUnderLimit(t *testing.T) { 162 | t.Parallel() 163 | 164 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 165 | defer cancel() 166 | 167 | completeInFlightRequestChan := make(chan struct{}) 168 | handler := makeFakeHandler(ctx, completeInFlightRequestChan) 169 | cl := NewConcurrencyLimiter(http.Handler(handler), 2) 170 | 171 | wg := &sync.WaitGroup{} 172 | rr1 := doRRandRequest(ctx, wg, cl) 173 | // This will "release" the request rr1 174 | completeInFlightRequestChan <- struct{}{} 175 | 176 | // This should never take more than the timeout 177 | wg.Wait() 178 | 179 | // We want to access the response recorder directly, so we don't accidentally get an implicitly correct answer 180 | if rr1.Code != http.StatusOK { 181 | t.Fatalf("Want response code %d, got: %d", http.StatusOK, rr1.Code) 182 | } 183 | 184 | } 185 | 186 | func TestConcurrencyLimitAtLimit(t *testing.T) { 187 | t.Parallel() 188 | 189 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 190 | defer cancel() 191 | 192 | completeInFlightRequestChan := make(chan struct{}) 193 | handler := makeFakeHandler(ctx, completeInFlightRequestChan) 194 | 195 | cl := NewConcurrencyLimiter(http.Handler(handler), 2) 196 | 197 | wg := &sync.WaitGroup{} 198 | rr1 := doRRandRequest(ctx, wg, cl) 199 | rr2 := doRRandRequest(ctx, wg, cl) 200 | 201 | completeInFlightRequestChan <- struct{}{} 202 | completeInFlightRequestChan <- struct{}{} 203 | 204 | wg.Wait() 205 | 206 | if rr1.Code != http.StatusOK { 207 | t.Fatalf("Want response code %d, got: %d", http.StatusOK, rr1.Code) 208 | } 209 | if rr2.Code != http.StatusOK { 210 | t.Fatalf("Want response code %d, got: %d", http.StatusOK, rr1.Code) 211 | } 212 | 213 | } 214 | 215 | func count(r *httptest.ResponseRecorder, code200s, code429s *int) { 216 | switch r.Code { 217 | case http.StatusTooManyRequests: 218 | *code429s = *code429s + 1 219 | case http.StatusOK: 220 | *code200s = *code200s + 1 221 | default: 222 | panic(fmt.Sprintf("Unknown code: %d", r.Code)) 223 | } 224 | } 225 | 226 | func TestConcurrencyLimitOverLimit(t *testing.T) { 227 | t.Parallel() 228 | 229 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 230 | defer cancel() 231 | completeInFlightRequestChan := make(chan struct{}, 3) 232 | handler := makeFakeHandler(ctx, completeInFlightRequestChan) 233 | 234 | cl := NewConcurrencyLimiter(http.Handler(handler), 2) 235 | 236 | wg := &sync.WaitGroup{} 237 | 238 | rr1 := doRRandRequest(ctx, wg, cl) 239 | rr2 := doRRandRequest(ctx, wg, cl) 240 | for ctx.Err() == nil { 241 | if requestsStarted := atomic.LoadUint64(&cl.requestsStarted); requestsStarted == 2 { 242 | break 243 | } 244 | time.Sleep(time.Millisecond) 245 | } 246 | rr3 := doRRandRequest(ctx, wg, cl) 247 | for ctx.Err() == nil { 248 | if requestsStarted := atomic.LoadUint64(&cl.requestsStarted); requestsStarted == 3 { 249 | break 250 | } 251 | time.Sleep(time.Millisecond) 252 | } 253 | completeInFlightRequestChan <- struct{}{} 254 | completeInFlightRequestChan <- struct{}{} 255 | completeInFlightRequestChan <- struct{}{} 256 | 257 | wg.Wait() 258 | 259 | code200s := 0 260 | code429s := 0 261 | count(rr1, &code200s, &code429s) 262 | count(rr2, &code200s, &code429s) 263 | count(rr3, &code200s, &code429s) 264 | if code200s != 2 || code429s != 1 { 265 | t.Fatalf("code 200s: %d, and code429s: %d", code200s, code429s) 266 | } 267 | 268 | want := "text/plain" 269 | gotContentType := 0 270 | gotInternalHeader := 0 271 | if rr1.Header().Get("Content-Type") == want { 272 | gotContentType++ 273 | } 274 | if rr2.Header().Get("Content-Type") == want { 275 | gotContentType++ 276 | } 277 | if rr3.Header().Get("Content-Type") == want { 278 | gotContentType++ 279 | } 280 | 281 | if rr1.Header().Get("X-OpenFaaS-Internal") == "faas-middleware" { 282 | gotInternalHeader++ 283 | } 284 | if rr2.Header().Get("X-OpenFaaS-Internal") == "faas-middleware" { 285 | gotInternalHeader++ 286 | } 287 | if rr3.Header().Get("X-OpenFaaS-Internal") == "faas-middleware" { 288 | gotInternalHeader++ 289 | } 290 | 291 | if gotContentType == 0 { 292 | t.Fatalf("Want at least one request with Content-Type %q, got: %q %q %q", want, rr1.Header().Get("Content-Type"), rr2.Header().Get("Content-Type"), rr3.Header().Get("Content-Type")) 293 | } 294 | 295 | if gotInternalHeader == 0 { 296 | t.Fatalf("Want at least one request with X-OpenFaaS-Internal header, got: %q %q %q", rr1.Header().Get("X-OpenFaaS-Internal"), rr2.Header().Get("X-OpenFaaS-Internal"), rr3.Header().Get("X-OpenFaaS-Internal")) 297 | } 298 | } 299 | 300 | func TestConcurrencyLimitOverLimitAndRecover(t *testing.T) { 301 | t.Parallel() 302 | 303 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 304 | defer cancel() 305 | completeInFlightRequestChan := make(chan struct{}, 4) 306 | handler := makeFakeHandler(ctx, completeInFlightRequestChan) 307 | cl := NewConcurrencyLimiter(http.Handler(handler), 2) 308 | 309 | wg := &sync.WaitGroup{} 310 | 311 | rr1 := doRRandRequest(ctx, wg, cl) 312 | rr2 := doRRandRequest(ctx, wg, cl) 313 | for ctx.Err() == nil { 314 | if requestsStarted := atomic.LoadUint64(&cl.requestsStarted); requestsStarted == 2 { 315 | break 316 | } 317 | time.Sleep(time.Millisecond) 318 | } 319 | // This will 429 320 | rr3 := doRRandRequest(ctx, wg, cl) 321 | for ctx.Err() == nil { 322 | if requestsStarted := atomic.LoadUint64(&cl.requestsStarted); requestsStarted == 3 { 323 | break 324 | } 325 | time.Sleep(time.Millisecond) 326 | } 327 | completeInFlightRequestChan <- struct{}{} 328 | completeInFlightRequestChan <- struct{}{} 329 | completeInFlightRequestChan <- struct{}{} 330 | // Although we could do another wg.Wait here, I don't think we should because 331 | // it might provide a false sense of confidence 332 | for ctx.Err() == nil { 333 | if requestsCompleted := atomic.LoadUint64(&cl.requestsCompleted); requestsCompleted == 3 { 334 | break 335 | } 336 | time.Sleep(time.Millisecond) 337 | } 338 | rr4 := doRRandRequest(ctx, wg, cl) 339 | completeInFlightRequestChan <- struct{}{} 340 | wg.Wait() 341 | 342 | code200s := 0 343 | code429s := 0 344 | count(rr1, &code200s, &code429s) 345 | count(rr2, &code200s, &code429s) 346 | count(rr3, &code200s, &code429s) 347 | count(rr4, &code200s, &code429s) 348 | 349 | if code200s != 3 || code429s != 1 { 350 | t.Fatalf("code 200s: %d, and code429s: %d", code200s, code429s) 351 | } 352 | } 353 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/openfaas/faas-middleware 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/golang-jwt/jwt/v5 v5.2.1 7 | github.com/rakutentech/jwk-go v1.1.3 8 | ) 9 | 10 | require ( 11 | golang.org/x/crypto v0.17.0 // indirect 12 | golang.org/x/sys v0.15.0 // indirect 13 | ) 14 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= 2 | github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= 3 | github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= 4 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 5 | github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= 6 | github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= 7 | github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= 8 | github.com/onsi/ginkgo v1.12.0 h1:Iw5WCbBcaAAd0fpRb1c9r5YCylv4XDoCSigm1zLevwU= 9 | github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg= 10 | github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= 11 | github.com/onsi/gomega v1.9.0 h1:R1uwffexN6Pr340GtYRIdZmAiN4J+iw6WG4wog1DUXg= 12 | github.com/onsi/gomega v1.9.0/go.mod h1:Ho0h+IUsWyvy1OpqCwxlQ/21gkhVunqlU8fDGcoTdcA= 13 | github.com/rakutentech/jwk-go v1.1.3 h1:PiLwepKyUaW+QFG3ki78DIO2+b4IVK3nMhlxM70zrQ4= 14 | github.com/rakutentech/jwk-go v1.1.3/go.mod h1:LtzSv4/+Iti1nnNeVQiP6l5cI74GBStbhyXCYvgPZFk= 15 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 16 | golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 17 | golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= 18 | golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= 19 | golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 20 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 21 | golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= 22 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 23 | golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 24 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 25 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 26 | golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 27 | golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= 28 | golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 29 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 30 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 31 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= 32 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 33 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 34 | gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= 35 | gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= 36 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= 37 | gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= 38 | gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I= 39 | gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 40 | --------------------------------------------------------------------------------