├── LICENSE ├── README.md ├── map_store.go ├── map_store_test.go ├── throttle.go ├── throttle_test.go └── wercker.yml /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Beat Richartz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # throttle [![wercker status](https://app.wercker.com/status/55bf32b84fef488e32f82f728e680086/s "wercker status")](https://app.wercker.com/project/bykey/55bf32b84fef488e32f82f728e680086) 2 | 3 | Simple throttling for martini, [negroni](https://github.com/martini-contrib/throttle/pull/6) or [Macaron](https://github.com/Unknwon/macaron). 4 | 5 | [API Reference](http://godoc.org/github.com/beatrichartz/throttle) 6 | 7 | ## Description 8 | 9 | Package `throttle` provides quota-based throttling. 10 | 11 | #### Policy 12 | 13 | `throttle.Policy` is a middleware that allows you to throttle by the given quota. 14 | 15 | #### Quota 16 | 17 | `throttle.Quota` is a quota type with a limit and a duration 18 | 19 | ## Usage 20 | 21 | This package provides a way to install rate limit or interval policies for throttling: 22 | 23 | ```go 24 | m := martini.Classic() 25 | 26 | // A Rate Limit Policy 27 | m.Use(throttle.Policy(&throttle.Quota{ 28 | Limit: 1000, 29 | Within: time.Hour, 30 | })) 31 | 32 | // An Interval Policy 33 | m.Use(throttle.Policy(&throttle.Quota{ 34 | Limit: 1, 35 | Within: time.Second, 36 | })) 37 | 38 | m.Any("/test", func() int { 39 | return http.StatusOK 40 | }) 41 | 42 | // A Policy local to a given route 43 | adminPolicy := Policy(&throttle.Quota{ 44 | Limit: 100, 45 | Within: time.Hour, 46 | }) 47 | 48 | m.Get("/admin", adminPolicy, func() int { 49 | return http.StatusOK 50 | }) 51 | 52 | ... 53 | ``` 54 | 55 | ## Options 56 | You can configure the options for throttling by passing in ``throttle.Options`` as the second argument to ``throttle.Policy``. Use it to configure the following options (defaults are used here): 57 | 58 | ```go 59 | &throttle.Options{ 60 | // The Status Code returned when the client exceeds the quota. Defaults to 429 Too Many Requests 61 | StatusCode int 62 | 63 | // The response body returned when the client exceeds the quota 64 | Message string 65 | 66 | // A function to identify a request, must satisfy the interface func(*http.Request)string 67 | // Defaults to a function identifying the request by IP or X-Forwarded-For Header if provided 68 | // So if you want to identify by an API key given in request headers or something else, configure this option 69 | IdentificationFunction func(*http.Request) string 70 | 71 | // The key prefix to use in any key value store 72 | KeyPrefix string 73 | 74 | // The store to use. The key value store has to satisfy the throttle.KeyValueStorer interface 75 | // For further explanation, see below 76 | Store KeyValueStorer 77 | 78 | // If the throttle is disabled or not 79 | // defaults to false 80 | Disabled bool 81 | } 82 | ``` 83 | 84 | ## State Storage 85 | Throttling relies on storage of one key per Policy and user in a (KeyValue) Storage. The interface the store has to satisfy is ``throttle.KeyValueStorer``, or, more explicit: 86 | 87 | ```go 88 | type KeyValueStorer interface { 89 | Get(key string) ([]byte, error) 90 | Set(key string, value []byte) error 91 | } 92 | ``` 93 | 94 | This allows for drop in replacement of the store with most common go libraries for key value stores like [redis.go](https://github.com/hoisie/redis) 95 | 96 | ```go 97 | var client redis.Client 98 | 99 | m.Use(throttle.Policy(&throttle.Quota{ 100 | Limit: 10, 101 | Within: time.Minute, 102 | }, &throttle.Options{ 103 | Store: &client, 104 | })) 105 | ``` 106 | 107 | Adapters are also very easy to write. ``throttle`` prefixes every key, your adapter does not have to care about it, and the stored value is stringified JSON. 108 | 109 | The default state storage is in memory via a concurrent-safe `map[string][]byte` cleaning up every 15 minutes. While this works fine for clients running one instance of a martini server, for all other uses you should obviously opt for a proper key value store. 110 | 111 | ## Headers & Status Codes 112 | ``throttle`` adds the following ``X-RateLimit-*``-Headers to every response it controls: 113 | 114 | - X-RateLimit-Limit: The maximum number of requests that the consumer is permitted to make within the given time window 115 | - X-RateLimit-Remaining: The number of requests remaining in the current rate limit window 116 | - X-RateLimit-Reset: The time at which the current rate limit window resets in [UTC epoch seconds](http://en.wikipedia.org/wiki/Unix_time) 117 | 118 | No ``Retry-After`` Header is added to the response, since the ``X-RateLimit-Reset`` makes it redundant. Also it is not recommended to use a 503 Service Unavailable Status Code when Limiting the rate of requests, since the 5xx Status Code Family indicates an error on the servers side. 119 | 120 | ## Authors 121 | 122 | * [Beat Richartz](https://github.com/beatrichartz) 123 | -------------------------------------------------------------------------------- /map_store.go: -------------------------------------------------------------------------------- 1 | package throttle 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "reflect" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | const ( 12 | defaultCleaningPeriod = 15 * time.Minute 13 | ) 14 | 15 | // A very simple implementation of a key value store (a concurrent safe map) 16 | type MapStore struct { 17 | *sync.RWMutex 18 | data map[string][]byte 19 | binding FreshnessInformer 20 | } 21 | 22 | type FreshnessInformer interface { 23 | IsFresh() bool 24 | } 25 | 26 | type MapStoreOptions struct { 27 | // The period to clean the store in 28 | CleaningPeriod time.Duration 29 | } 30 | 31 | // Error Type for the key value store 32 | type MapStoreError string 33 | 34 | // The Error for Key Value Store 35 | func (err MapStoreError) Error() string { 36 | return "Throttle Map Store Error: " + string(err) 37 | } 38 | 39 | // Set a key 40 | func (s *MapStore) Set(key string, value []byte) error { 41 | s.Lock() 42 | s.data[key] = value 43 | s.Unlock() 44 | 45 | return nil 46 | } 47 | 48 | // Delete a key 49 | func (s *MapStore) Delete(key string) { 50 | s.Lock() 51 | delete(s.data, key) 52 | s.Unlock() 53 | } 54 | 55 | // Get a key, will return an error if the key does not exist 56 | func (s *MapStore) Get(key string) (value []byte, err error) { 57 | s.RLock() 58 | value, ok := s.data[key] 59 | s.RUnlock() 60 | if !ok { 61 | err = MapStoreError("Key " + key + " does not exist") 62 | return value, err 63 | } else { 64 | return value, nil 65 | } 66 | } 67 | 68 | // Read the data into the given binding 69 | func (s *MapStore) Read(key string) (FreshnessInformer, error) { 70 | byteArray, err := s.Get(key) 71 | if err != nil { 72 | return nil, err 73 | } 74 | 75 | byteBufferString := bytes.NewBuffer(byteArray) 76 | var arbitraryStructure interface{} 77 | if err := json.NewDecoder(byteBufferString).Decode(&arbitraryStructure); err != nil { 78 | return nil, err 79 | } 80 | 81 | for k, v := range arbitraryStructure.(map[string]interface{}) { 82 | if field := reflect.ValueOf(s.binding).FieldByName(k); field.IsValid() && field.CanSet() { 83 | field.Set(reflect.ValueOf(v)) 84 | } 85 | } 86 | 87 | return s.binding, err 88 | } 89 | 90 | // Clean the store from expired values 91 | func (s *MapStore) Clean() { 92 | for key := range s.data { 93 | value, err := s.Read(key) 94 | if err == nil && !value.IsFresh() { 95 | s.Delete(key) 96 | } else if err != nil { 97 | panic(err) 98 | } 99 | } 100 | } 101 | 102 | // Simple cleanup mechanism, cleaning the store every 15 minutes 103 | func (s *MapStore) CleanEvery(cleaningPeriod time.Duration) { 104 | c := time.Tick(cleaningPeriod) 105 | 106 | for { 107 | select { 108 | case <-c: 109 | s.Clean() 110 | } 111 | } 112 | } 113 | 114 | // Returns a simple key value store 115 | func NewMapStore(binding FreshnessInformer, options ...*MapStoreOptions) *MapStore { 116 | s := &MapStore{ 117 | &sync.RWMutex{}, 118 | make(map[string][]byte), 119 | binding, 120 | } 121 | 122 | o := newMapStoreOptions(options) 123 | 124 | go s.CleanEvery(o.CleaningPeriod) 125 | 126 | return s 127 | } 128 | 129 | // Returns new map store options from defaults and given options 130 | func newMapStoreOptions(options []*MapStoreOptions) *MapStoreOptions { 131 | o := &MapStoreOptions{ 132 | defaultCleaningPeriod, 133 | } 134 | 135 | if len(options) == 0 { 136 | return o 137 | } 138 | 139 | if options[0].CleaningPeriod != 0 { 140 | o.CleaningPeriod = options[0].CleaningPeriod 141 | } 142 | 143 | return o 144 | } 145 | -------------------------------------------------------------------------------- /map_store_test.go: -------------------------------------------------------------------------------- 1 | package throttle 2 | 3 | import ( 4 | "encoding/json" 5 | "math/rand" 6 | "strconv" 7 | "sync" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func seedRandom() { 13 | rand.Seed(time.Now().UnixNano()) 14 | } 15 | 16 | func sleepRandom() { 17 | time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) 18 | } 19 | 20 | func TestSet(t *testing.T) { 21 | store := NewMapStore(accessCount{}) 22 | store.Set("KEY", []byte("4")) 23 | value, err := store.Get("KEY") 24 | if err != nil { 25 | t.Errorf(err.Error()) 26 | } 27 | 28 | expectSame(t, string(value), "4") 29 | } 30 | 31 | func TestGet(t *testing.T) { 32 | seedRandom() 33 | store := NewMapStore(accessCount{}) 34 | 35 | wg := &sync.WaitGroup{} 36 | var values []string 37 | store.Set("KEY", []byte(strconv.FormatInt(int64(50), 10))) 38 | 39 | for i := 0; i < 5; i++ { 40 | wg.Add(1) 41 | go func() { 42 | sleepRandom() 43 | value, err := store.Get("KEY") 44 | if err != nil { 45 | t.Errorf(err.Error()) 46 | } 47 | values = append(values, string(value)) 48 | wg.Done() 49 | }() 50 | } 51 | 52 | wg.Wait() 53 | 54 | for _, val := range values { 55 | expectSame(t, val, "50") 56 | } 57 | } 58 | 59 | func TestRead(t *testing.T) { 60 | store := NewMapStore(accessCount{}) 61 | 62 | wg := &sync.WaitGroup{} 63 | var values []bool 64 | marshalled, err := json.Marshal(accessCount{ 65 | 64, 66 | time.Now(), 67 | 10 * time.Millisecond, 68 | }) 69 | if err != nil { 70 | t.Errorf(err.Error()) 71 | } 72 | store.Set("KEY", marshalled) 73 | 74 | for i := 0; i < 5; i++ { 75 | wg.Add(1) 76 | go func() { 77 | value, err := store.Read("KEY") 78 | time.Sleep(10 * time.Millisecond) 79 | if err != nil { 80 | t.Errorf(err.Error()) 81 | } 82 | values = append(values, value.IsFresh()) 83 | wg.Done() 84 | }() 85 | } 86 | 87 | wg.Wait() 88 | 89 | for _, val := range values { 90 | expectSame(t, val, false) 91 | } 92 | } 93 | 94 | func TestDelete(t *testing.T) { 95 | store := NewMapStore(accessCount{}) 96 | wg := &sync.WaitGroup{} 97 | for i := 0; i < 5; i++ { 98 | wg.Add(1) 99 | go func(k int) { 100 | store.Set("KEY", []byte(strconv.FormatInt(int64(k), 10))) 101 | store.Delete("KEY") 102 | wg.Done() 103 | }(i) 104 | } 105 | 106 | wg.Wait() 107 | 108 | value, err := store.Get("KEY") 109 | if err == nil { 110 | t.Errorf("Expected no key to exist, but did: %v", value) 111 | } 112 | 113 | expectSame(t, err.Error(), "Throttle Map Store Error: Key KEY does not exist") 114 | } 115 | 116 | func TestCleaning(t *testing.T) { 117 | store := NewMapStore(accessCount{}, &MapStoreOptions{ 118 | 5 * time.Millisecond, 119 | }) 120 | 121 | marshalled, err := json.Marshal(accessCount{ 122 | 64, 123 | time.Now(), 124 | 10 * time.Millisecond, 125 | }) 126 | 127 | if err != nil { 128 | t.Errorf(err.Error()) 129 | } 130 | 131 | wg := &sync.WaitGroup{} 132 | for i := 0; i < 5; i++ { 133 | wg.Add(1) 134 | go func(k int) { 135 | store.Set("KEY"+strconv.FormatInt(int64(k), 10), marshalled) 136 | wg.Done() 137 | }(i) 138 | } 139 | 140 | wg.Wait() 141 | time.Sleep(11 * time.Millisecond) 142 | 143 | for i := 0; i < 5; i++ { 144 | value, err := store.Get("KEY" + strconv.FormatInt(int64(i), 10)) 145 | if err == nil { 146 | t.Errorf("Expected no key to exist, but did: %v", value) 147 | } else { 148 | expectSame(t, err.Error(), "Throttle Map Store Error: Key KEY"+strconv.FormatInt(int64(i), 10)+" does not exist") 149 | } 150 | 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /throttle.go: -------------------------------------------------------------------------------- 1 | package throttle 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "net" 7 | "net/http" 8 | "reflect" 9 | "strconv" 10 | "strings" 11 | "sync" 12 | "time" 13 | ) 14 | 15 | const ( 16 | // Too Many Requests According to http://tools.ietf.org/html/rfc6585#page-3 17 | StatusTooManyRequests = 429 18 | 19 | // The default Status Code used 20 | defaultStatusCode = StatusTooManyRequests 21 | 22 | // The default Message to include, defaults to 429 status code title 23 | defaultMessage = "Too Many Requests" 24 | 25 | // The default key prefix for Key Value Storage 26 | defaultKeyPrefix = "throttle" 27 | 28 | // The header name to retrieve an IP address under a proxy 29 | forwardedForHeader = "X-FORWARDED-FOR" 30 | 31 | // The default for the disabled setting 32 | defaultDisabled = false 33 | ) 34 | 35 | type Options struct { 36 | // The status code to be returned for throttled requests 37 | // Defaults to 429 Too Many Requests 38 | StatusCode int 39 | 40 | // The message to be returned as the body of throttled requests 41 | Message string 42 | 43 | // The function used to identify the requester 44 | // Defaults to IP identification 45 | IdentificationFunction func(*http.Request) string 46 | 47 | // The key prefix to use in any key value store 48 | // defaults to "throttle" 49 | KeyPrefix string 50 | 51 | // The store to use 52 | // defaults to a simple concurrent-safe map[string]string 53 | Store KeyValueStorer 54 | 55 | // If the throttle is disabled or not 56 | // defaults to false 57 | Disabled bool 58 | } 59 | 60 | // KeyValueStorer is the required interface for the Store Option 61 | // This should allow for either drop-in replacement with compatible libraries, 62 | // or easy write-up of adapters 63 | type KeyValueStorer interface { 64 | // Simple Get Function 65 | Get(key string) ([]byte, error) 66 | // Simple Set Function 67 | Set(key string, value []byte) error 68 | } 69 | 70 | // The Quota is Request Rates per Time for a given policy 71 | type Quota struct { 72 | // The Request Limit 73 | Limit uint64 74 | // The time window for the request Limit 75 | Within time.Duration 76 | } 77 | 78 | func (q *Quota) KeyId() string { 79 | return strconv.FormatInt(int64(q.Within)/int64(q.Limit), 10) 80 | } 81 | 82 | // An access message to return to the user 83 | type accessMessage struct { 84 | // The given status Code 85 | StatusCode int 86 | // The given message 87 | Message string 88 | } 89 | 90 | // Return a new access message with the properties given 91 | func newAccessMessage(statusCode int, message string) *accessMessage { 92 | return &accessMessage{ 93 | StatusCode: statusCode, 94 | Message: message, 95 | } 96 | } 97 | 98 | // An access count for a single identified user. 99 | // Will be stored in the key value store, 1 per Policy and User 100 | type accessCount struct { 101 | Count uint64 `json:"count"` 102 | Start time.Time `json:"start"` 103 | Duration time.Duration `json:"duration"` 104 | } 105 | 106 | // Determine if the count is still fresh 107 | func (r accessCount) IsFresh() bool { 108 | return time.Now().UTC().Sub(r.Start) < r.Duration 109 | } 110 | 111 | // Increment the count when fresh, or reset and then increment when stale 112 | func (r *accessCount) Increment() { 113 | if r.IsFresh() { 114 | r.Count++ 115 | } else { 116 | r.Count = 1 117 | r.Start = time.Now().UTC() 118 | } 119 | } 120 | 121 | // Get the count 122 | func (r *accessCount) GetCount() uint64 { 123 | if r.IsFresh() { 124 | return r.Count 125 | } else { 126 | return 0 127 | } 128 | } 129 | 130 | // Return a new access count with the given duration 131 | func newAccessCount(duration time.Duration) *accessCount { 132 | return &accessCount{ 133 | 0, 134 | time.Now().UTC(), 135 | duration, 136 | } 137 | } 138 | 139 | // Unmarshal a stringified JSON respresentation of an access count 140 | func accessCountFromBytes(accessCountBytes []byte) *accessCount { 141 | byteBufferString := bytes.NewBuffer(accessCountBytes) 142 | a := &accessCount{} 143 | if err := json.NewDecoder(byteBufferString).Decode(a); err != nil { 144 | panic(err.Error()) 145 | } 146 | return a 147 | } 148 | 149 | // The controller, stores the allowed quota and has access to the store 150 | type controller struct { 151 | *sync.Mutex 152 | quota *Quota 153 | store KeyValueStorer 154 | } 155 | 156 | // Get an access count by id 157 | func (c *controller) GetAccessCount(id string) (a *accessCount) { 158 | accessCountBytes, err := c.store.Get(id) 159 | 160 | if err == nil { 161 | a = accessCountFromBytes(accessCountBytes) 162 | } else { 163 | a = newAccessCount(c.quota.Within) 164 | } 165 | 166 | return a 167 | } 168 | 169 | // Set an access count by id, will write to the store 170 | func (c *controller) SetAccessCount(id string, a *accessCount) { 171 | marshalled, err := json.Marshal(a) 172 | if err != nil { 173 | panic(err.Error()) 174 | } 175 | 176 | err = c.store.Set(id, marshalled) 177 | if err != nil { 178 | panic(err.Error()) 179 | } 180 | } 181 | 182 | // Gets the access count, increments it and writes it back to the store 183 | func (c *controller) RegisterAccess(id string) { 184 | c.Lock() 185 | defer c.Unlock() 186 | 187 | counter := c.GetAccessCount(id) 188 | counter.Increment() 189 | c.SetAccessCount(id, counter) 190 | } 191 | 192 | // Check if the controller denies access for the given id based on 193 | // the quota and used access 194 | func (c *controller) DeniesAccess(id string) bool { 195 | counter := c.GetAccessCount(id) 196 | return counter.GetCount() >= c.quota.Limit 197 | } 198 | 199 | // Get a time for the given id when the quota time window will be reset 200 | func (c *controller) RetryAt(id string) time.Time { 201 | counter := c.GetAccessCount(id) 202 | 203 | return counter.Start.Add(c.quota.Within) 204 | } 205 | 206 | // Get the remaining limit for the given id 207 | func (c *controller) RemainingLimit(id string) uint64 { 208 | counter := c.GetAccessCount(id) 209 | 210 | return c.quota.Limit - counter.GetCount() 211 | } 212 | 213 | // Return a new controller with the given quota and store 214 | func newController(quota *Quota, store KeyValueStorer) *controller { 215 | return &controller{ 216 | &sync.Mutex{}, 217 | quota, 218 | store, 219 | } 220 | } 221 | 222 | // Identify via the given Identification Function 223 | func (o *Options) Identify(req *http.Request) string { 224 | return o.IdentificationFunction(req) 225 | } 226 | 227 | // A throttling Policy 228 | // Takes two arguments, one required: 229 | // First is a Quota (A Limit with an associated time). When the given Limit 230 | // of requests is reached by a user within the given time window, access to 231 | // access to resources will be denied to this user 232 | // Second is Options to use with this policy. For further information on options, 233 | // see Options further above. 234 | func Policy(quota *Quota, options ...*Options) func(resp http.ResponseWriter, req *http.Request) { 235 | o := newOptions(options) 236 | if o.Disabled { 237 | return func(resp http.ResponseWriter, req *http.Request) {} 238 | } 239 | 240 | controller := newController(quota, o.Store) 241 | 242 | return func(resp http.ResponseWriter, req *http.Request) { 243 | id := makeKey(o.KeyPrefix, quota.KeyId(), o.Identify(req)) 244 | 245 | if controller.DeniesAccess(id) { 246 | msg := newAccessMessage(o.StatusCode, o.Message) 247 | setRateLimitHeaders(resp, controller, id) 248 | resp.WriteHeader(msg.StatusCode) 249 | resp.Write([]byte(msg.Message)) 250 | return 251 | } else { 252 | controller.RegisterAccess(id) 253 | setRateLimitHeaders(resp, controller, id) 254 | } 255 | 256 | } 257 | } 258 | 259 | // Set Rate Limit Headers helper function 260 | func setRateLimitHeaders(resp http.ResponseWriter, controller *controller, id string) { 261 | headers := resp.Header() 262 | headers.Set("X-RateLimit-Limit", strconv.FormatUint(controller.quota.Limit, 10)) 263 | headers.Set("X-RateLimit-Reset", strconv.FormatInt(controller.RetryAt(id).Unix(), 10)) 264 | headers.Set("X-RateLimit-Remaining", strconv.FormatUint(controller.RemainingLimit(id), 10)) 265 | } 266 | 267 | // The default identifier function. Identifies a client by IP 268 | func defaultIdentify(req *http.Request) string { 269 | if forwardedFor := req.Header.Get(forwardedForHeader); forwardedFor != "" { 270 | if ipParsed := net.ParseIP(forwardedFor); ipParsed != nil { 271 | return ipParsed.String() 272 | } 273 | } 274 | 275 | ip, _, err := net.SplitHostPort(req.RemoteAddr) 276 | if err != nil { 277 | panic(err.Error()) 278 | } 279 | return ip 280 | } 281 | 282 | // Make a key from various parts for use in the key value store 283 | func makeKey(parts ...string) string { 284 | return strings.Join(parts, "_") 285 | } 286 | 287 | // Creates new default options and assigns any given options 288 | func newOptions(options []*Options) *Options { 289 | o := Options{ 290 | StatusCode: defaultStatusCode, 291 | Message: defaultMessage, 292 | IdentificationFunction: defaultIdentify, 293 | KeyPrefix: defaultKeyPrefix, 294 | Store: nil, 295 | Disabled: defaultDisabled, 296 | } 297 | 298 | // when all defaults, return it 299 | if len(options) == 0 { 300 | o.Store = NewMapStore(accessCount{}) 301 | return &o 302 | } 303 | 304 | // map the given values to the options 305 | optionsValue := reflect.ValueOf(options[0]) 306 | oValue := reflect.ValueOf(&o) 307 | numFields := optionsValue.Elem().NumField() 308 | 309 | for i := 0; i < numFields; i++ { 310 | if value := optionsValue.Elem().Field(i); value.IsValid() && value.CanSet() && isNonEmptyOption(value) { 311 | oValue.Elem().Field(i).Set(value) 312 | } 313 | } 314 | 315 | if o.Store == nil { 316 | o.Store = NewMapStore(accessCount{}) 317 | } 318 | 319 | return &o 320 | } 321 | 322 | // Check if an option is assigned 323 | func isNonEmptyOption(v reflect.Value) bool { 324 | switch v.Kind() { 325 | case reflect.String: 326 | return v.Len() != 0 327 | case reflect.Bool: 328 | return v.IsValid() 329 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 330 | return v.Int() != 0 331 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 332 | return v.Uint() != 0 333 | case reflect.Float32, reflect.Float64: 334 | return v.Float() != 0 335 | case reflect.Interface, reflect.Ptr, reflect.Func: 336 | return !v.IsNil() 337 | } 338 | return false 339 | } 340 | -------------------------------------------------------------------------------- /throttle_test.go: -------------------------------------------------------------------------------- 1 | package throttle 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "reflect" 7 | "regexp" 8 | "runtime" 9 | "strconv" 10 | "strings" 11 | "sync" 12 | "testing" 13 | "time" 14 | 15 | "github.com/go-martini/martini" 16 | ) 17 | 18 | const ( 19 | host string = "http://localhost:3000" 20 | endpoint string = "ws://localhost:3000" 21 | ) 22 | 23 | // Test Helpers 24 | func expectSame(t *testing.T, a interface{}, b interface{}) { 25 | if a != b { 26 | t.Errorf("Expected %T: %v to be %T: %v", b, b, a, a) 27 | } 28 | } 29 | 30 | func expectEmpty(t *testing.T, a []string) { 31 | if len(a) != 0 { 32 | t.Errorf("Expected %T: %v to be empty", a, a) 33 | } 34 | } 35 | 36 | func expectApproximateTimestamp(t *testing.T, a int64, b int64) { 37 | if a != b && a != b+1 { 38 | t.Errorf("Expected %v to be bigger than or equal to %v", b, a) 39 | } 40 | } 41 | 42 | func expectMatches(t *testing.T, reg string, result string) { 43 | r := regexp.MustCompile(reg) 44 | if !r.Match([]byte(result)) { 45 | t.Errorf("Expected %v to match %v", result, reg) 46 | } 47 | } 48 | 49 | func expectStatusCode(t *testing.T, expectedStatusCode int, actualStatusCode int) { 50 | if actualStatusCode != expectedStatusCode { 51 | t.Errorf("Expected StatusCode %d, but received %d", expectedStatusCode, actualStatusCode) 52 | } 53 | } 54 | 55 | func utcTimestamp() int64 { 56 | return time.Now().Unix() 57 | } 58 | 59 | type Expectation struct { 60 | StatusCode int 61 | Body string 62 | RateLimitLimit string 63 | RateLimitRemaining string 64 | RateLimitReset int64 65 | Wait time.Duration 66 | ForwardedFor string 67 | Concurrent bool 68 | } 69 | 70 | func setupMartiniWithPolicy(limit uint64, within time.Duration, options ...*Options) *martini.ClassicMartini { 71 | m := martini.Classic() 72 | 73 | addPolicy(m, limit, within, options...) 74 | 75 | m.Any("/test", func() int { 76 | return http.StatusOK 77 | }) 78 | 79 | return m 80 | } 81 | 82 | func addPolicy(m *martini.ClassicMartini, limit uint64, within time.Duration, options ...*Options) { 83 | m.Use(Policy(&Quota{ 84 | Limit: limit, 85 | Within: within, 86 | }, options...)) 87 | } 88 | 89 | func setupMartiniWithPolicyAsHandler(limit uint64, within time.Duration, options ...*Options) *martini.ClassicMartini { 90 | m := martini.Classic() 91 | 92 | m.Any("/test", Policy(&Quota{ 93 | Limit: limit, 94 | Within: within, 95 | }, options...), 96 | func() int { 97 | return http.StatusOK 98 | }) 99 | 100 | return m 101 | } 102 | 103 | func testResponseToExpectation(t *testing.T, m *martini.ClassicMartini, expectation *Expectation) { 104 | req, err := http.NewRequest("GET", "/test", strings.NewReader("")) 105 | 106 | if expectation.ForwardedFor != "" { 107 | req.Header.Set("X-Forwarded-For", expectation.ForwardedFor) 108 | } else { 109 | reflect.ValueOf(req).Elem().FieldByName("RemoteAddr").SetString("1.2.3.4:5000") 110 | } 111 | 112 | if err != nil { 113 | t.Error(err) 114 | } 115 | 116 | time.Sleep(expectation.Wait) 117 | recorder := httptest.NewRecorder() 118 | m.ServeHTTP(recorder, req) 119 | 120 | expectStatusCode(t, expectation.StatusCode, recorder.Code) 121 | if expectation.Body != "" { 122 | expectSame(t, recorder.Body.String(), expectation.Body) 123 | } 124 | 125 | header := recorder.Header() 126 | rateLimitLimit := header["X-Ratelimit-Limit"] 127 | rateLimitRemaining := header["X-Ratelimit-Remaining"] 128 | rateLimitReset := header["X-Ratelimit-Reset"] 129 | 130 | if expectation.RateLimitLimit != "" { 131 | expectSame(t, rateLimitLimit[0], expectation.RateLimitLimit) 132 | } 133 | 134 | if expectation.RateLimitRemaining != "" { 135 | expectSame(t, rateLimitRemaining[0], expectation.RateLimitRemaining) 136 | } 137 | 138 | if expectation.RateLimitReset != 0 { 139 | resetTime, err := strconv.ParseInt(rateLimitReset[0], 10, 64) 140 | if err != nil { 141 | t.Errorf(err.Error()) 142 | } 143 | expectApproximateTimestamp(t, resetTime, expectation.RateLimitReset) 144 | } 145 | } 146 | 147 | func testResponses(t *testing.T, m *martini.ClassicMartini, expectations ...*Expectation) { 148 | runtime.GOMAXPROCS(runtime.NumCPU() * 2) 149 | wg := sync.WaitGroup{} 150 | for i, e := range expectations { 151 | if e.Concurrent { 152 | wg.Add(1) 153 | go func(k int, expectation *Expectation) { 154 | defer wg.Done() 155 | testResponseToExpectation(t, m, expectation) 156 | }(i, e) 157 | } else { 158 | wg.Wait() 159 | testResponseToExpectation(t, m, e) 160 | } 161 | } 162 | 163 | wg.Wait() 164 | } 165 | 166 | func TestTimeLimit(t *testing.T) { 167 | m := setupMartiniWithPolicyAsHandler(1, 10*time.Millisecond) 168 | testResponses(t, m, &Expectation{ 169 | StatusCode: http.StatusOK, 170 | RateLimitLimit: "1", 171 | RateLimitRemaining: "0", 172 | RateLimitReset: utcTimestamp(), 173 | }, &Expectation{ 174 | StatusCode: StatusTooManyRequests, 175 | Body: "Too Many Requests", 176 | RateLimitLimit: "1", 177 | RateLimitRemaining: "0", 178 | RateLimitReset: utcTimestamp(), 179 | }, &Expectation{ 180 | StatusCode: http.StatusOK, 181 | RateLimitLimit: "1", 182 | RateLimitRemaining: "0", 183 | RateLimitReset: utcTimestamp(), 184 | Wait: 10 * time.Millisecond, 185 | }) 186 | } 187 | 188 | func TestTimeLimitWhenForwarded(t *testing.T) { 189 | m := setupMartiniWithPolicyAsHandler(1, 10*time.Millisecond) 190 | testResponses(t, m, &Expectation{ 191 | StatusCode: http.StatusOK, 192 | RateLimitLimit: "1", 193 | RateLimitRemaining: "0", 194 | RateLimitReset: utcTimestamp(), 195 | ForwardedFor: "2.3.4.5", 196 | }, &Expectation{ 197 | StatusCode: StatusTooManyRequests, 198 | Body: "Too Many Requests", 199 | RateLimitLimit: "1", 200 | RateLimitRemaining: "0", 201 | RateLimitReset: utcTimestamp(), 202 | ForwardedFor: "2.3.4.5", 203 | }, &Expectation{ 204 | StatusCode: http.StatusOK, 205 | RateLimitLimit: "1", 206 | RateLimitRemaining: "0", 207 | RateLimitReset: utcTimestamp(), 208 | Wait: 10 * time.Millisecond, 209 | ForwardedFor: "2.3.4.5", 210 | }) 211 | } 212 | 213 | func TestTimeLimitWithOptions(t *testing.T) { 214 | m := setupMartiniWithPolicy(1, 10*time.Millisecond, &Options{ 215 | StatusCode: http.StatusBadRequest, 216 | Message: "Server says no", 217 | }) 218 | 219 | testResponses(t, m, &Expectation{ 220 | StatusCode: http.StatusOK, 221 | RateLimitLimit: "1", 222 | RateLimitRemaining: "0", 223 | RateLimitReset: utcTimestamp(), 224 | }, &Expectation{ 225 | StatusCode: http.StatusBadRequest, 226 | Body: "Server says no", 227 | RateLimitLimit: "1", 228 | RateLimitRemaining: "0", 229 | RateLimitReset: utcTimestamp(), 230 | }, &Expectation{ 231 | StatusCode: http.StatusOK, 232 | RateLimitLimit: "1", 233 | RateLimitRemaining: "0", 234 | RateLimitReset: utcTimestamp(), 235 | Wait: 10 * time.Millisecond, 236 | }) 237 | } 238 | 239 | func TestLimitWhenDisabled(t *testing.T) { 240 | m := setupMartiniWithPolicy(1, 10*time.Millisecond, &Options{ 241 | Disabled: true, 242 | }) 243 | 244 | testResponses(t, m, &Expectation{ 245 | StatusCode: http.StatusOK, 246 | }, &Expectation{ 247 | StatusCode: http.StatusOK, 248 | }, &Expectation{ 249 | StatusCode: http.StatusOK, 250 | Wait: 10 * time.Millisecond, 251 | }) 252 | } 253 | 254 | func TestRateLimit(t *testing.T) { 255 | m := setupMartiniWithPolicy(2, 20*time.Millisecond) 256 | testResponses(t, m, &Expectation{ 257 | StatusCode: http.StatusOK, 258 | RateLimitLimit: "2", 259 | RateLimitReset: utcTimestamp(), 260 | }, &Expectation{ 261 | StatusCode: http.StatusOK, 262 | RateLimitLimit: "2", 263 | RateLimitReset: utcTimestamp(), 264 | }, &Expectation{ 265 | StatusCode: StatusTooManyRequests, 266 | Body: "Too Many Requests", 267 | RateLimitLimit: "2", 268 | RateLimitRemaining: "0", 269 | RateLimitReset: utcTimestamp(), 270 | }, &Expectation{ 271 | StatusCode: http.StatusOK, 272 | RateLimitLimit: "2", 273 | RateLimitRemaining: "1", 274 | RateLimitReset: utcTimestamp(), 275 | Wait: 20 * time.Millisecond, 276 | }) 277 | } 278 | 279 | func TestRateLimitWithOptions(t *testing.T) { 280 | m := setupMartiniWithPolicyAsHandler(2, 10*time.Millisecond, &Options{ 281 | StatusCode: http.StatusBadRequest, 282 | Message: "Server says no", 283 | }) 284 | testResponses(t, m, &Expectation{ 285 | StatusCode: http.StatusOK, 286 | RateLimitLimit: "2", 287 | RateLimitReset: utcTimestamp(), 288 | Concurrent: true, 289 | }, &Expectation{ 290 | StatusCode: http.StatusOK, 291 | RateLimitLimit: "2", 292 | RateLimitReset: utcTimestamp(), 293 | Concurrent: true, 294 | }, &Expectation{ 295 | StatusCode: http.StatusBadRequest, 296 | Body: "Server says no", 297 | RateLimitLimit: "2", 298 | RateLimitRemaining: "0", 299 | RateLimitReset: utcTimestamp(), 300 | }, &Expectation{ 301 | StatusCode: http.StatusOK, 302 | RateLimitLimit: "2", 303 | RateLimitRemaining: "1", 304 | RateLimitReset: utcTimestamp(), 305 | Wait: 10 * time.Millisecond, 306 | }) 307 | } 308 | 309 | func TestMultiplePolicies(t *testing.T) { 310 | m := setupMartiniWithPolicyAsHandler(2, 20*time.Millisecond) 311 | addPolicy(m, 1, 5*time.Millisecond) 312 | 313 | testResponses(t, m, &Expectation{ 314 | StatusCode: http.StatusOK, 315 | RateLimitLimit: "2", 316 | RateLimitRemaining: "1", 317 | RateLimitReset: utcTimestamp(), 318 | }, &Expectation{ // Time Limit Throttling kicks in 319 | StatusCode: StatusTooManyRequests, 320 | Body: "Too Many Requests", 321 | RateLimitLimit: "1", 322 | RateLimitRemaining: "0", 323 | RateLimitReset: utcTimestamp(), 324 | }, &Expectation{ 325 | StatusCode: http.StatusOK, 326 | RateLimitLimit: "2", 327 | RateLimitRemaining: "0", 328 | RateLimitReset: utcTimestamp(), 329 | Wait: 5 * time.Millisecond, 330 | }, &Expectation{ 331 | StatusCode: StatusTooManyRequests, 332 | Body: "Too Many Requests", 333 | RateLimitLimit: "2", 334 | RateLimitRemaining: "0", 335 | RateLimitReset: utcTimestamp(), 336 | Wait: 5 * time.Millisecond, 337 | }) 338 | } 339 | 340 | func TestRateLimitWithConcurrentRequests(t *testing.T) { 341 | m := setupMartiniWithPolicy(5, 20*time.Millisecond) 342 | testResponses(t, m, &Expectation{ 343 | StatusCode: http.StatusOK, 344 | RateLimitLimit: "5", 345 | RateLimitReset: utcTimestamp(), 346 | Concurrent: true, 347 | }, &Expectation{ 348 | StatusCode: http.StatusOK, 349 | RateLimitLimit: "5", 350 | RateLimitReset: utcTimestamp(), 351 | Concurrent: true, 352 | }, &Expectation{ 353 | StatusCode: http.StatusOK, 354 | RateLimitLimit: "5", 355 | RateLimitReset: utcTimestamp(), 356 | Concurrent: true, 357 | }, &Expectation{ 358 | StatusCode: http.StatusOK, 359 | RateLimitLimit: "5", 360 | RateLimitReset: utcTimestamp(), 361 | Concurrent: true, 362 | }, &Expectation{ 363 | StatusCode: http.StatusOK, 364 | RateLimitLimit: "5", 365 | RateLimitReset: utcTimestamp(), 366 | Concurrent: true, 367 | }, &Expectation{ 368 | StatusCode: StatusTooManyRequests, 369 | Body: "Too Many Requests", 370 | RateLimitLimit: "5", 371 | RateLimitRemaining: "0", 372 | RateLimitReset: utcTimestamp(), 373 | }, &Expectation{ 374 | StatusCode: http.StatusOK, 375 | RateLimitLimit: "5", 376 | RateLimitRemaining: "4", 377 | RateLimitReset: utcTimestamp(), 378 | Wait: 20 * time.Millisecond, 379 | }) 380 | } 381 | -------------------------------------------------------------------------------- /wercker.yml: -------------------------------------------------------------------------------- 1 | box: wercker/golang@1.4.0 2 | --------------------------------------------------------------------------------