├── .prettierignore ├── .husky ├── .gitignore └── pre-commit ├── .eslintignore ├── tsconfig.eslint.json ├── .babelrc ├── .npmignore ├── .prettierrc.json ├── jest.config.ts ├── src ├── index.ts ├── utils │ ├── __mocks__ │ │ └── redis.ts │ └── redis.ts ├── @types │ ├── expressMiddleware.d.ts │ ├── buildTypeWeights.d.ts │ └── rateLimit.d.ts ├── middleware │ ├── rateLimiterSetup.ts │ └── index.ts ├── rateLimiters │ ├── tokenBucket.ts │ ├── fixedWindow.ts │ ├── slidingWindowLog.ts │ └── slidingWindowCounter.ts └── analysis │ ├── QueryParser.ts │ └── buildTypeWeights.ts ├── .vscode ├── settings.json └── launch.json ├── .travis.yml ├── .github └── pull_request_template.md ├── .eslintrc.json ├── LICENSE ├── tsconfig.json ├── .gitignore ├── package.json ├── test ├── rateLimiters │ ├── fixedWindow.test.ts │ ├── tokenBucket.test.ts │ └── slidingWindowLog.test.ts ├── analysis │ └── weightFunction.test.ts └── middleware │ └── express.test.ts └── README.md /.prettierignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.husky/.gitignore: -------------------------------------------------------------------------------- 1 | _ 2 | -------------------------------------------------------------------------------- /.eslintignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | dist -------------------------------------------------------------------------------- /tsconfig.eslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "tsconfig.json", 3 | "include": [ 4 | "test/*", 5 | "src" 6 | ] 7 | } -------------------------------------------------------------------------------- /.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": [["@babel/preset-env", { "targets": { "node": "current" } }], "@babel/preset-typescript"], 3 | } 4 | -------------------------------------------------------------------------------- /.husky/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | . "$(dirname "$0")/_/husky.sh" 3 | 4 | # Runs all code quality tools prior to each commit. 5 | npx lint-staged 6 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | **/* 2 | dist/src/utils/__mock__ 3 | dist/test/* 4 | !dist/src/analysis/* 5 | !dist/src/middleware/* 6 | !dist/src/rateLimiters/* 7 | !dist/src/utils/* 8 | !dist/src/* 9 | !package.json -------------------------------------------------------------------------------- /.prettierrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "trailingComma": "es5", 3 | "tabWidth": 4, 4 | "semi": true, 5 | "singleQuote": true, 6 | "printWidth": 100, 7 | "bracketSpacing": true, 8 | "arrowParens": "always", 9 | "proseWrap": "never" 10 | } 11 | -------------------------------------------------------------------------------- /jest.config.ts: -------------------------------------------------------------------------------- 1 | import type { Config } from '@jest/types'; 2 | 3 | const config: Config.InitialOptions = { 4 | verbose: true, 5 | roots: ['./test'], 6 | preset: 'ts-jest', 7 | testEnvironment: 'node', 8 | moduleFileExtensions: ['js', 'ts'], 9 | }; 10 | 11 | export default config; 12 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | export { default as expressGraphQLRateLimiter } from './middleware/index.js'; 2 | 3 | export { default as rateLimiter } from './middleware/rateLimiterSetup.js'; 4 | 5 | export { default as QueryParser } from './analysis/QueryParser.js'; 6 | 7 | export { default as typeWeightsFromSchema } from './analysis/buildTypeWeights.js'; 8 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "folders": [], 3 | "settings": {}, 4 | "editor.codeActionsOnSave": { 5 | "source.fixAll.eslint": true 6 | }, 7 | "editor.formatOnSave": true, 8 | "configurations": [{ 9 | "type": "node", 10 | "request": "launch", 11 | "name": "Jest Tests", 12 | "program": "${workspaceRoot}\\node_modules\\jest\\bin\\jest.js", 13 | "args": [ 14 | "-i" 15 | ], 16 | // "preLaunchTask": "build", 17 | "internalConsoleOptions": "openOnSessionStart", 18 | "outFiles": [ 19 | "${workspaceRoot}/dist/**/*" 20 | ], 21 | "envFile": "${workspaceRoot}/.env" 22 | }] 23 | } -------------------------------------------------------------------------------- /src/utils/__mocks__/redis.ts: -------------------------------------------------------------------------------- 1 | import Redis from 'ioredis'; 2 | 3 | // eslint-disable-next-line @typescript-eslint/no-var-requires 4 | const RedisMock = require('ioredis-mock'); 5 | 6 | const clients: Redis[] = []; 7 | 8 | /** 9 | * Connects to a client returning the client and a spe 10 | * @param options 11 | */ 12 | export function connect(): Redis { 13 | const client = new RedisMock(); 14 | clients.push(client); 15 | return client; 16 | } 17 | 18 | /** 19 | * Shutsdown all redis client connections 20 | */ 21 | export async function shutdown(): Promise<'OK'[]> { 22 | return Promise.all(clients.map((client: Redis) => client.quit())); 23 | } 24 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: node_js 2 | node_js: 3 | - 16 4 | - 17 5 | # - 18 6 | 7 | # run test for the above node versions for branches dev and main 8 | branches: 9 | only: 10 | - dev 11 | - main 12 | # scripts to run for each test 13 | script: 14 | - echo "Running tests against $(node -v) ..." 15 | - 'npm run lint' 16 | - 'npm run test' 17 | - 'npm run build' 18 | 19 | # specify deployment 20 | before_deploy: 21 | - 'npm run build' 22 | - 'npm run build:fix' 23 | 24 | deploy: 25 | on: 26 | branch: main 27 | tags: false 28 | skip_cleanup: true 29 | provider: npm 30 | email: $NPM_EMAIL_ADDRESS 31 | api_key: $NPM_API_KEY 32 | 33 | 34 | -------------------------------------------------------------------------------- /src/utils/redis.ts: -------------------------------------------------------------------------------- 1 | import Redis, { RedisOptions } from 'ioredis'; 2 | 3 | const clients: Redis[] = []; 4 | 5 | /** 6 | * Connects to a client returning the client and a spe 7 | * @param options 8 | */ 9 | export function connect(options: RedisOptions): Redis { 10 | // TODO: Figure out what other options we should set (timeouts, etc) 11 | const client: Redis = new Redis(options); // Default port is 6379 automatically 12 | clients.push(client); 13 | return client; 14 | } 15 | 16 | /** 17 | * Shutsdown all redis client connections 18 | */ 19 | export async function shutdown(): Promise<'OK'[]> { 20 | // TODO: Add functinoality to shutdown a client by an id 21 | // TODO: Error handling 22 | return Promise.all(clients.map((client: Redis) => client.quit())); 23 | } 24 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | 8 | { 9 | "type": "node", 10 | "request": "launch", 11 | "name": "Jest Tests", 12 | "program": "${workspaceRoot}/node_modules/jest/bin/jest.js", 13 | "args": [ 14 | "-i", "--verbose", "--no-cache" 15 | ], 16 | // "preLaunchTask": "build", 17 | // "internalConsoleOptions": "openOnSessionStart", 18 | // "outFiles": [ 19 | // "${workspaceRoot}/dist/**/*" 20 | // ], 21 | // "envFile": "${workspaceRoot}/.env" 22 | }] 23 | } -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ### Summary 2 | _Provide a short summary of the changes in this PR_ 3 | 4 | ### Type of Change 5 | Please delete options that are not relevant. 6 | 7 | - [ ] Bug fix (non-breaking change which fixes an issue) 8 | - [ ] New feature (non-breaking change which adds functionality) 9 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 10 | - [ ] This change requires a documentation update 11 | 12 | ### Issues 13 | - Link any issues this PR resolves using keywords (resolve, closes, fixed) 14 | 15 | ### Evidence 16 | - Provide evidence of the the changes functioning as expected or describe your tests. If tests are included in the CI pipeline this may be omitted. 17 | 18 | 19 | _(delete this line)_ Prior to submitting the PR assign a reviewer from each team to review this PR. 20 | -------------------------------------------------------------------------------- /.eslintrc.json: -------------------------------------------------------------------------------- 1 | { 2 | "env": { 3 | "browser": true, 4 | "commonjs": true, 5 | "es2021": true 6 | }, 7 | "extends": [ 8 | "eslint:recommended", 9 | "airbnb-base", 10 | "airbnb-typescript/base", 11 | "plugin:@typescript-eslint/recommended", 12 | "plugin:import/typescript", 13 | "prettier" 14 | ], 15 | "parser": "@typescript-eslint/parser", 16 | "parserOptions": { 17 | "ecmaVersion": "latest", 18 | "project": "./tsconfig.json" 19 | }, 20 | "plugins": [ 21 | "import", "prettier" 22 | ], 23 | "rules": { 24 | "no-plusplus": [ 25 | 2, 26 | { 27 | "allowForLoopAfterthoughts": true 28 | } 29 | ], 30 | "prettier/prettier": [ 31 | "error" 32 | ] 33 | }, 34 | "ignorePatterns": [ 35 | "jest.*", 36 | "dist/*" 37 | ] 38 | } -------------------------------------------------------------------------------- /src/@types/expressMiddleware.d.ts: -------------------------------------------------------------------------------- 1 | import { RedisOptions } from 'ioredis'; 2 | import { TypeWeightConfig, TypeWeightSet } from './buildTypeWeights'; 3 | import { RateLimiterConfig } from './rateLimit'; 4 | 5 | // extend ioredis configuration options to include an expiry prooperty for rate limiting cache 6 | interface RedisConfig { 7 | keyExpiry?: number; 8 | options?: RedisOptions; 9 | } 10 | // extend the redis config type to have keyExpiry set once configured in the middleware 11 | interface RedisConfigSet extends RedisConfig { 12 | keyExpiry: number; 13 | options: RedisOptions; 14 | } 15 | 16 | export interface ExpressMiddlewareConfig { 17 | rateLimiter: RateLimiterConfig; 18 | redis?: RedisConfig; 19 | typeWeights?: TypeWeightConfig; 20 | dark?: boolean; 21 | enforceBoundedLists?: boolean; 22 | depthLimit?: number; 23 | } 24 | 25 | export interface ExpressMiddlewareSet extends ExpressMiddlewareConfig { 26 | redis: RedisConfigSet; 27 | typeWeights: TypeWeightSet; 28 | dark: boolean; 29 | enforceBoundedLists: boolean; 30 | depthLimit: number; 31 | } 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 OSLabs Beta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/@types/buildTypeWeights.d.ts: -------------------------------------------------------------------------------- 1 | export interface Field { 2 | resolveTo?: string; 3 | weight?: FieldWeight; 4 | } 5 | export interface Fields { 6 | [index: string]: Field; 7 | } 8 | export type WeightFunction = (args: ArgumentNode[], variables, selectionsCost: number) => number; 9 | export type FieldWeight = number | WeightFunction; 10 | export interface Type { 11 | readonly weight: number; 12 | readonly fields: Fields; 13 | } 14 | export interface TypeWeightObject { 15 | [index: string]: Type; 16 | } 17 | export interface TypeWeightConfig { 18 | mutation?: number; 19 | query?: number; 20 | object?: number; 21 | scalar?: number; 22 | connection?: number; 23 | } 24 | export interface TypeWeightSet { 25 | mutation: number; 26 | query: number; 27 | object: number; 28 | scalar: number; 29 | connection: number; 30 | } 31 | type Variables = { 32 | [index: string]: readonly unknown; 33 | }; 34 | 35 | // Type for use when getting fields for union types 36 | type FieldMap = { 37 | [index: string]: { 38 | type: GraphQLOutputType; 39 | weight?: FieldWeight; 40 | resolveTo?: string; 41 | }; 42 | }; 43 | -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "module": "ES2015", 4 | "moduleResolution": "node", 5 | "strict": true, 6 | "removeComments": false, 7 | "preserveConstEnums": true, 8 | "sourceMap": true, 9 | "target": "es6", 10 | "lib": [ 11 | "dom", 12 | "dom.iterable", 13 | "esnext" 14 | ], 15 | "allowJs": true, 16 | "skipLibCheck": true, 17 | "esModuleInterop": true, 18 | "allowSyntheticDefaultImports": true, 19 | "forceConsistentCasingInFileNames": true, 20 | "noFallthroughCasesInSwitch": true, 21 | "resolveJsonModule": true, 22 | "isolatedModules": true, 23 | "noEmit": false, 24 | "typeRoots": [ 25 | "src/@types", 26 | "node_modules/@types" 27 | ], 28 | "types": [ 29 | "node", 30 | "jest" 31 | ], 32 | "outDir": "dist", 33 | "declaration": true, 34 | "declarationDir": "dist", 35 | }, 36 | "include": [ 37 | "src/**/*.ts", 38 | "src/**/*.js", 39 | "test/**/*.ts", 40 | "test/**/*.js" 41 | ], 42 | "exclude": [ 43 | "node_modules", 44 | "**/*.spec.ts", 45 | "dist", 46 | ] 47 | } -------------------------------------------------------------------------------- /src/@types/rateLimit.d.ts: -------------------------------------------------------------------------------- 1 | export interface RateLimiter { 2 | /** 3 | * Checks if a request is allowed under the given conditions and withdraws the specified number of tokens 4 | * @param uuid Unique identifier for the user associated with the request 5 | * @param timestamp UNIX format timestamp of when request was received 6 | * @param tokens Number of tokens being used in this request. Optional 7 | * @returns a RateLimiterResponse indicating with a sucess and tokens property indicating the number of tokens remaining 8 | */ 9 | processRequest: ( 10 | uuid: string, 11 | timestamp: number, 12 | tokens?: number 13 | ) => Promise; 14 | } 15 | 16 | export interface RateLimiterResponse { 17 | success: boolean; 18 | tokens: number; 19 | retryAfter?: number; 20 | } 21 | 22 | export interface RedisBucket { 23 | tokens: number; 24 | timestamp: number; 25 | } 26 | 27 | export interface FixedWindow { 28 | currentTokens: number; 29 | fixedWindowStart: number; 30 | } 31 | export interface RedisWindow extends FixedWindow { 32 | previousTokens: number; 33 | } 34 | 35 | export type RedisLog = RedisBucket[]; 36 | 37 | type BucketType = 'TOKEN_BUCKET' | 'LEAKY_BUCKET'; 38 | 39 | type WindowType = 'FIXED_WINDOW' | 'SLIDING_WINDOW_LOG' | 'SLIDING_WINDOW_COUNTER'; 40 | 41 | type BucketRateLimiter = { 42 | type: BucketType; 43 | refillRate: number; 44 | capacity: number; 45 | }; 46 | 47 | type WindowRateLimiter = { 48 | type: WindowType; 49 | windowSize: number; 50 | capacity: number; 51 | }; 52 | 53 | export type RateLimiterConfig = WindowRateLimiter | BucketRateLimiter; 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | lerna-debug.log* 8 | 9 | # Diagnostic reports (https://nodejs.org/api/report.html) 10 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 11 | 12 | # Runtime data 13 | pids 14 | *.pid 15 | *.seed 16 | *.pid.lock 17 | 18 | # Directory for instrumented libs generated by jscoverage/JSCover 19 | lib-cov 20 | 21 | # Coverage directory used by tools like istanbul 22 | coverage 23 | *.lcov 24 | 25 | # nyc test coverage 26 | .nyc_output 27 | 28 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 29 | .grunt 30 | 31 | # Bower dependency directory (https://bower.io/) 32 | bower_components 33 | 34 | # node-waf configuration 35 | .lock-wscript 36 | 37 | # Compiled binary addons (https://nodejs.org/api/addons.html) 38 | build/Release 39 | 40 | # Dependency directories 41 | node_modules/ 42 | jspm_packages/ 43 | 44 | # TypeScript v1 declaration files 45 | typings/ 46 | 47 | # TypeScript cache 48 | *.tsbuildinfo 49 | 50 | # Optional npm cache directory 51 | .npm 52 | 53 | # Optional eslint cache 54 | .eslintcache 55 | 56 | # Microbundle cache 57 | .rpt2_cache/ 58 | .rts2_cache_cjs/ 59 | .rts2_cache_es/ 60 | .rts2_cache_umd/ 61 | 62 | # Optional REPL history 63 | .node_repl_history 64 | 65 | # Output of 'npm pack' 66 | *.tgz 67 | 68 | # Yarn Integrity file 69 | .yarn-integrity 70 | 71 | # dotenv environment variables file 72 | .env 73 | .env.test 74 | 75 | # parcel-bundler cache (https://parceljs.org/) 76 | .cache 77 | 78 | # Next.js build output 79 | .next 80 | 81 | # Nuxt.js build / generate output 82 | .nuxt 83 | dist 84 | 85 | # Gatsby files 86 | .cache/ 87 | # Comment in the public line in if your project uses Gatsby and *not* Next.js 88 | # https://nextjs.org/blog/next-9-1#public-directory-support 89 | # public 90 | 91 | # vuepress build output 92 | .vuepress/dist 93 | 94 | # Serverless directories 95 | .serverless/ 96 | 97 | # FuseBox cache 98 | .fusebox/ 99 | 100 | # DynamoDB Local files 101 | .dynamodb/ 102 | 103 | # TernJS port file 104 | .tern-port 105 | 106 | build/* -------------------------------------------------------------------------------- /src/middleware/rateLimiterSetup.ts: -------------------------------------------------------------------------------- 1 | import Redis from 'ioredis'; 2 | import { RateLimiterConfig } from '../@types/rateLimit'; 3 | import TokenBucket from '../rateLimiters/tokenBucket'; 4 | import SlidingWindowCounter from '../rateLimiters/slidingWindowCounter'; 5 | import SlidingWindowLog from '../rateLimiters/slidingWindowLog'; 6 | import FixedWindow from '../rateLimiters/fixedWindow'; 7 | 8 | /** 9 | * Instatieate the rateLimiting algorithm class based on the developer selection and options 10 | * 11 | * @export 12 | * @param {RateLimiterConfig} rateLimiter limiter selection and option 13 | * @param {Redis} client 14 | * @param {number} keyExpiry 15 | * @return {*} 16 | */ 17 | export default function setupRateLimiter( 18 | rateLimiter: RateLimiterConfig, 19 | client: Redis, 20 | keyExpiry: number 21 | ) { 22 | try { 23 | switch (rateLimiter.type) { 24 | case 'TOKEN_BUCKET': 25 | return new TokenBucket( 26 | rateLimiter.capacity, 27 | rateLimiter.refillRate, 28 | client, 29 | keyExpiry 30 | ); 31 | break; 32 | case 'LEAKY_BUCKET': 33 | throw new Error('Leaky Bucket algonithm has not be implemented.'); 34 | case 'FIXED_WINDOW': 35 | return new FixedWindow( 36 | rateLimiter.capacity, 37 | rateLimiter.windowSize, 38 | client, 39 | keyExpiry 40 | ); 41 | case 'SLIDING_WINDOW_LOG': 42 | return new SlidingWindowLog( 43 | rateLimiter.windowSize, 44 | rateLimiter.capacity, 45 | client, 46 | keyExpiry 47 | ); 48 | case 'SLIDING_WINDOW_COUNTER': 49 | return new SlidingWindowCounter( 50 | rateLimiter.windowSize, 51 | rateLimiter.capacity, 52 | client, 53 | keyExpiry 54 | ); 55 | break; 56 | default: 57 | // typescript should never let us invoke this function with anything other than the options above 58 | throw new Error('Selected rate limiting algorithm is not suppported'); 59 | } 60 | } catch (err) { 61 | throw new Error(`Error in expressGraphQLRateLimiter setting up rate-limiter: ${err}`); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "graphql-limiter", 3 | "version": "1.3.0", 4 | "description": "A GraphQL rate limiting library using query complexity analysis.", 5 | "main": "./dist/src/index.js", 6 | "types": "./dist/src/index.d.ts", 7 | "type": "module", 8 | "scripts": { 9 | "test": "jest --passWithNoTests --coverage --detectOpenHandles", 10 | "lint": "eslint src test", 11 | "lint:fix": "eslint --fix src test @types", 12 | "prettier": "prettier --write .", 13 | "prepare": "husky install", 14 | "build": "tsc", 15 | "build:fix": "node node_modules/.bin/yab dist" 16 | }, 17 | "repository": { 18 | "type": "git", 19 | "url": "git+https://github.com/oslabs-beta/graphql-gate.git" 20 | }, 21 | "keywords": [ 22 | "graphql", 23 | "graphqlgate", 24 | "rate-limiting", 25 | "throttling", 26 | "query", 27 | "express", 28 | "complexity", 29 | "analysis" 30 | ], 31 | "author": "Evan McNeely, Stephan Halarewicz, Flora Yufei Wu, Jon Dewey, Milos Popovic", 32 | "license": "ISC", 33 | "bugs": { 34 | "url": "https://github.com/oslabs-beta/GraphQL-Gate/issues" 35 | }, 36 | "homepage": "https://github.com/oslabs-beta/GraphQL-Gate#readme", 37 | "devDependencies": { 38 | "@babel/core": "^7.17.12", 39 | "@babel/preset-env": "^7.17.12", 40 | "@babel/preset-typescript": "^7.17.12", 41 | "@types/express": "^4.17.13", 42 | "@types/ioredis": "^4.28.10", 43 | "@types/ioredis-mock": "^5.6.0", 44 | "@types/jest": "^27.5.1", 45 | "@typescript-eslint/eslint-plugin": "^5.24.0", 46 | "@typescript-eslint/parser": "^5.24.0", 47 | "add-js-extension": "^1.0.4", 48 | "babel-jest": "^28.1.0", 49 | "eslint": "^8.15.0", 50 | "eslint-config-airbnb-base": "^15.0.0", 51 | "eslint-config-airbnb-typescript": "^17.0.0", 52 | "eslint-config-prettier": "^8.5.0", 53 | "eslint-plugin-import": "^2.26.0", 54 | "eslint-plugin-prettier": "^4.0.0", 55 | "husky": "^8.0.1", 56 | "ioredis-mock": "^8.2.2", 57 | "jest": "^28.1.0", 58 | "lint-staged": "^12.4.1", 59 | "npm": "^8.16.0", 60 | "prettier": "2.6.2", 61 | "ts-jest": "^28.0.2", 62 | "ts-node": "^10.8.0", 63 | "typescript": "^4.6.4" 64 | }, 65 | "lint-staged": { 66 | "*.{js, ts}": "eslint --cache --fix", 67 | "*.{js,ts,css,md}": "prettier --write --ignore-unknown" 68 | }, 69 | "dependencies": { 70 | "graphql": "^16.5.0", 71 | "ioredis": "^5.0.5" 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/rateLimiters/tokenBucket.ts: -------------------------------------------------------------------------------- 1 | import Redis from 'ioredis'; 2 | import { RateLimiter, RateLimiterResponse, RedisBucket } from '../@types/rateLimit'; 3 | 4 | /** 5 | * The TokenBucket instance of a RateLimiter limits requests based on a unique user ID. 6 | * Whenever a user makes a request the following steps are performed: 7 | * 1. Refill the bucket based on time elapsed since the previous request 8 | * 2. Update the timestamp of the last request. 9 | * 3. Allow the request and remove the requested amount of tokens from the bucket if the user has enough. 10 | * 4. Otherwise, disallow the request and do not update the token total. 11 | */ 12 | class TokenBucket implements RateLimiter { 13 | private capacity: number; 14 | 15 | private refillRate: number; 16 | 17 | private client: Redis; 18 | 19 | private keyExpiry: number; 20 | 21 | /** 22 | * Create a new instance of a TokenBucket rate limiter that can be connected to any database store 23 | * @param capacity max token bucket capacity 24 | * @param refillRate rate at which the token bucket is refilled 25 | * @param client redis client where rate limiter will cache information 26 | * @param expiry redis key expiry in ms 27 | */ 28 | constructor(capacity: number, refillRate: number, client: Redis, expiry: number) { 29 | this.capacity = capacity; 30 | this.refillRate = refillRate; 31 | this.client = client; 32 | this.keyExpiry = expiry; 33 | if (!refillRate || !capacity || refillRate <= 0 || capacity <= 0 || expiry <= 0) 34 | throw Error('TokenBucket refillRate, capacity and keyExpiry must be positive'); 35 | } 36 | 37 | /** 38 | * 39 | * 40 | * @param {string} uuid - unique identifer used to throttle requests 41 | * @param {number} timestamp - time the request was recieved 42 | * @param {number} [tokens=1] - complexity of the query for throttling requests 43 | * @return {*} {Promise} 44 | * @memberof TokenBucket 45 | */ 46 | public async processRequest( 47 | uuid: string, 48 | timestamp: number, 49 | tokens = 1 50 | ): Promise { 51 | // attempt to get the value for the uuid from the redis cache 52 | const bucketJSON = await this.client.get(uuid); 53 | 54 | // if the response is null, we need to create a bucket for the user 55 | if (!bucketJSON) { 56 | const newUserBucket: RedisBucket = { 57 | // conditionally set tokens depending on how many are requested comapred to the capacity 58 | tokens: tokens > this.capacity ? this.capacity : this.capacity - tokens, 59 | timestamp, 60 | }; 61 | // reject the request, not enough tokens could even be in the bucket 62 | if (tokens > this.capacity) { 63 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(newUserBucket)); 64 | return { 65 | success: false, 66 | tokens: this.capacity, 67 | retryAfter: Infinity, 68 | }; 69 | } 70 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(newUserBucket)); 71 | return { success: true, tokens: newUserBucket.tokens }; 72 | } 73 | 74 | // parse the returned string from redis and update their token budget based on the time lapse between queries 75 | const bucket: RedisBucket = await JSON.parse(bucketJSON); 76 | bucket.tokens = this.calculateTokenBudgetFromTimestamp(bucket, timestamp); 77 | 78 | const updatedUserBucket = { 79 | // conditionally set tokens depending on how many are requested comapred to the bucket 80 | tokens: bucket.tokens < tokens ? bucket.tokens : bucket.tokens - tokens, 81 | timestamp, 82 | }; 83 | if (bucket.tokens < tokens) { 84 | // reject the request, not enough tokens in bucket 85 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(updatedUserBucket)); 86 | return { 87 | success: false, 88 | tokens: bucket.tokens, 89 | retryAfter: 90 | tokens > this.capacity 91 | ? Infinity 92 | : Math.abs(tokens - bucket.tokens) * this.refillRate, 93 | }; 94 | } 95 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(updatedUserBucket)); 96 | return { success: true, tokens: updatedUserBucket.tokens }; 97 | } 98 | 99 | /** 100 | * Resets the rate limiter to the intial state by clearing the redis store. 101 | */ 102 | public reset(): void { 103 | this.client.flushall(); 104 | } 105 | 106 | /** 107 | * Calculates the tokens a user bucket should have given the time lapse between requests. 108 | */ 109 | private calculateTokenBudgetFromTimestamp = ( 110 | bucket: RedisBucket, 111 | timestamp: number 112 | ): number => { 113 | const timeSinceLastQueryInSeconds: number = Math.floor( 114 | (timestamp - bucket.timestamp) / 1000 // 1000 ms in a second 115 | ); 116 | const tokensToAdd = timeSinceLastQueryInSeconds * this.refillRate; 117 | const updatedTokenCount = bucket.tokens + tokensToAdd; 118 | return updatedTokenCount > this.capacity ? this.capacity : updatedTokenCount; 119 | }; 120 | } 121 | 122 | export default TokenBucket; 123 | -------------------------------------------------------------------------------- /src/rateLimiters/fixedWindow.ts: -------------------------------------------------------------------------------- 1 | import Redis from 'ioredis'; 2 | import { RateLimiter, RateLimiterResponse, FixedWindow as Window } from '../@types/rateLimit'; 3 | 4 | /** 5 | * The FixedWindow instance of a RateLimiter limits requests based on a unique user ID and a fixed time window. 6 | * Whenever a user makes a request the following steps are performed: 7 | * 1. Define the time window with fixed amount of queries. 8 | * 2. Update the timestamp of the last request. 9 | * 3. Allow the request and decrease the allowed amount of requests if the user has enough at this time window. 10 | * 4. Otherwise, disallow the request until the next time window opens. 11 | */ 12 | 13 | class FixedWindow implements RateLimiter { 14 | private capacity: number; 15 | 16 | private keyExpiry: number; 17 | 18 | private windowSize: number; 19 | 20 | private client: Redis; 21 | 22 | /** 23 | * Create a new instance of a FixedWindow rate limiter that can be connected to any database store 24 | * @param capacity max requests capacity in one time window 25 | * @param windowSize rate at which the token bucket is refilled 26 | * @param client redis client where rate limiter will cache information 27 | */ 28 | 29 | constructor(capacity: number, windowSize: number, client: Redis, expiry: number) { 30 | this.capacity = capacity; 31 | this.windowSize = windowSize; 32 | this.client = client; 33 | this.keyExpiry = expiry; 34 | if (!windowSize || !capacity || windowSize <= 0 || capacity <= 0 || expiry <= 0) 35 | throw Error('FixedWindow windowSize, capacity and keyExpiry must be positive'); 36 | } 37 | 38 | /** 39 | * @function processRequest - Fixed Window algorithm to allow or block 40 | * based on the depth/complexity (in amount of tokens) of incoming requests. 41 | * Fixed Window 42 | * _________________________________ 43 | * | *full capacity | 44 | * | | move to next time window 45 | * | token adds up until full | ----------> 46 | *____._________________________________.____ 47 | * |<-- window size -->| 48 | *current timestamp next timestamp 49 | * 50 | * First, checks if a window exists in the redis cache. 51 | * If not, then `fixedWindowStart` is set as the current timestamp, and `currentTokens` is checked against `capacity`. 52 | * If enough room exists for the request, returns success as true and tokens as how many tokens remain in the current fixed window. 53 | * 54 | * If a window does exist in the cache, we first check if the timestamp is greater than the fixedWindowStart + windowSize. 55 | * If it isn't, we update currentToken with the incoming token until reach the capcity 56 | * 57 | * @param {string} uuid - unique identifer used to throttle requests 58 | * @param {number} timestamp - time the request was recieved 59 | * @param {number} [tokens=1] - complexity of the query for throttling requests 60 | * @return {*} {Promise} 61 | * @memberof FixedWindow 62 | */ 63 | async processRequest( 64 | uuid: string, 65 | timestamp: number, 66 | tokens = 1 67 | ): Promise { 68 | // attempt to get the value for the uuid from the redis cache 69 | const windowJSON = await this.client.get(uuid); 70 | 71 | if (!windowJSON) { 72 | const newUserWindow: Window = { 73 | currentTokens: tokens > this.capacity ? 0 : tokens, 74 | fixedWindowStart: timestamp, 75 | }; 76 | 77 | if (tokens > this.capacity) { 78 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(newUserWindow)); 79 | return { success: false, tokens: this.capacity, retryAfter: Infinity }; 80 | } 81 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(newUserWindow)); 82 | return { success: true, tokens: this.capacity - newUserWindow.currentTokens }; 83 | } 84 | const window: Window = await JSON.parse(windowJSON); 85 | 86 | const previousWindowStart = window.fixedWindowStart; 87 | const updatedUserWindow = this.updateTimeWindow(window, timestamp); 88 | updatedUserWindow.currentTokens += tokens; 89 | // update the currentToken until reaches its capacity 90 | if (updatedUserWindow.currentTokens > this.capacity) { 91 | updatedUserWindow.currentTokens -= tokens; 92 | return { 93 | success: false, 94 | tokens: this.capacity - updatedUserWindow.currentTokens, 95 | retryAfter: Math.ceil((this.windowSize - (timestamp - previousWindowStart)) / 1000), 96 | }; 97 | } 98 | 99 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(updatedUserWindow)); 100 | return { 101 | success: true, 102 | tokens: this.capacity - updatedUserWindow.currentTokens, 103 | }; 104 | } 105 | 106 | /** 107 | * Resets the rate limiter to the intial state by clearing the redis store. 108 | */ 109 | public reset(): void { 110 | this.client.flushall(); 111 | } 112 | 113 | private updateTimeWindow = (window: Window, timestamp: number): Window => { 114 | const updatedUserWindow: Window = { 115 | currentTokens: window.currentTokens, 116 | fixedWindowStart: window.fixedWindowStart, 117 | }; 118 | if (timestamp >= window.fixedWindowStart + this.windowSize) { 119 | if (timestamp >= window.fixedWindowStart + this.windowSize * 2) { 120 | updatedUserWindow.fixedWindowStart = timestamp; 121 | updatedUserWindow.currentTokens = 0; 122 | } else { 123 | updatedUserWindow.fixedWindowStart = window.fixedWindowStart + this.windowSize; 124 | updatedUserWindow.currentTokens = 0; 125 | } 126 | } 127 | return updatedUserWindow; 128 | }; 129 | } 130 | 131 | export default FixedWindow; 132 | -------------------------------------------------------------------------------- /src/rateLimiters/slidingWindowLog.ts: -------------------------------------------------------------------------------- 1 | import Redis from 'ioredis'; 2 | import { RateLimiter, RateLimiterResponse, RedisBucket, RedisLog } from '../@types/rateLimit'; 3 | 4 | /** 5 | * The SlidingWindowLog instance of a RateLimiter limits requests based on a unique user ID. 6 | * With the FixedWindow algorithm, users are able to send more requests to go through at the 7 | * edges of a window. The SlidingWindowLog algorithm addresses this issue by tracking request 8 | * timestamps in a log then removing these requests from the log once they fall outside of the window. 9 | * If a request is received and there are more than capacity requests in the log then the request is dropped 10 | * 11 | * Whenever a user makes a request the following steps are performed: 12 | * 1. The user's log is obtained from redis. 13 | * 2. Any requests that are older than window size are dropped from the log. 14 | * 3. The complexity of the current request is added to the complexity of all requests in the log. 15 | * 4. If the request exceeds the specified capacity it is dropped. 16 | * 5. Otherwise the request is allowed and the current request is added to the end of the log (if it has a complexity > 0). 17 | */ 18 | class SlidingWindowLog implements RateLimiter { 19 | private windowSize: number; 20 | 21 | private keyExpiry: number; 22 | 23 | private capacity: number; 24 | 25 | private client: Redis; 26 | 27 | /** 28 | * Create a new instance of a SlidingWindowLog rate limiter that can be connected to any redis store 29 | * @param windowSize size of window in milliseconds 30 | * @param capacity max number of tokens allowed in each window 31 | * @param client redis client where rate limiter will cache information 32 | */ 33 | constructor(windowSize: number, capacity: number, client: Redis, expiry: number) { 34 | this.windowSize = windowSize; 35 | this.capacity = capacity; 36 | this.client = client; 37 | this.keyExpiry = expiry; 38 | if (!windowSize || !capacity || windowSize <= 0 || capacity <= 0 || expiry <= 0) 39 | throw SyntaxError( 40 | 'SlidingWindowLog window size, capacity and keyExpiry must be positive' 41 | ); 42 | 43 | // TODO: Define lua script for server side computation using either sorted sets or lists 44 | // while x.timestamp + window_size < timestamp lpop 45 | // //https://stackoverflow.com/questions/35677682/filtering-deleting-items-from-a-redis-set 46 | // this.client.defineCommand('popWindow', { 47 | // // 2 value timestamp and complexity of this request 48 | // lua: ` 49 | // local totalComplexity = 0 -- complexity of active requests 50 | // local expiredMembers = 0 -- number of requests to remove 51 | // local key = keys[1] -- uuid 52 | // local current_time = keys[2] 53 | 54 | // for index, value in next, redis.call(key, ????) do 55 | // -- string comparisson of timestamps 56 | // if .... then 57 | 58 | // else 59 | // totalComplexity += ???? 60 | // end 61 | // end 62 | 63 | // redis.call(pop, ???) 64 | 65 | // if total_complexity < window_size then 66 | // then 67 | // end 68 | // return { 69 | 70 | // } 71 | // `, 72 | // numberOfKeys: 3, // uuid 73 | // readOnly: true, 74 | // }); 75 | } 76 | 77 | /** 78 | * @param {string} uuid - unique identifer used to throttle requests 79 | * @param {number} timestamp - time the request was recieved 80 | * @param {number} [tokens=1] - complexity of the query for throttling requests 81 | * @return {*} {Promise} 82 | * @memberof SlidingWindowLog 83 | */ 84 | async processRequest( 85 | uuid: string, 86 | timestamp: number, 87 | tokens = 1 88 | ): Promise { 89 | // Each user's log is represented by a redis list with a score = request timestamp 90 | // and a value equal to the complexity 91 | // Drop expired requests from the log. represented by a sorted set in redis 92 | 93 | // Get the log from redis 94 | let requestLog: RedisLog = JSON.parse((await this.client.get(uuid)) || '[]'); 95 | 96 | // Iterate through the list in reverse and count active tokens 97 | // This allows us to track the threshold for when this request would be allowed if it is blocked 98 | // Stop at the first timestamp that's expired and cut the rest. 99 | 100 | const cutoff = timestamp - this.windowSize; 101 | let tokensInLog = 0; // total active tokens in the log 102 | let cutoffIndex = 0; // index of oldest active request 103 | let lastAllowedIndex = requestLog.length; // Index of oldest request in the log for which this request would be allowed. 104 | 105 | for (let index = requestLog.length - 1; index >= 0; index--) { 106 | if (cutoff >= requestLog[index].timestamp) { 107 | // we reached the first expired request 108 | cutoffIndex = index + 1; 109 | break; 110 | } else { 111 | // the request is active 112 | tokensInLog += requestLog[index].tokens; 113 | if (this.capacity - tokensInLog >= tokens) { 114 | // the log is able to accept the current request 115 | lastAllowedIndex = index; 116 | } 117 | } 118 | } 119 | 120 | // Time (ms) after which the current request would succeed if it is blocked. 121 | let retryAfter: number; 122 | 123 | // Request will never be allowed 124 | if (tokens > this.capacity) retryAfter = Infinity; 125 | // need the request before lastAllowedIndex 126 | else if (lastAllowedIndex > 0) 127 | retryAfter = Math.ceil( 128 | (this.windowSize + requestLog[lastAllowedIndex - 1].timestamp - timestamp) / 1000 129 | ); 130 | else retryAfter = 0; // request is allowed 131 | 132 | // Conditional check to avoid unecessary slice 133 | if (cutoffIndex > 0) requestLog = requestLog.slice(cutoffIndex); 134 | 135 | // allow/disallow current request 136 | if (tokensInLog + tokens <= this.capacity) { 137 | // update the log 138 | if (tokens > 0) requestLog.push({ timestamp, tokens }); 139 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(requestLog)); 140 | tokensInLog += tokens; 141 | return { success: true, tokens: this.capacity - tokensInLog }; 142 | } 143 | 144 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(requestLog)); 145 | 146 | return { success: false, tokens: this.capacity - tokensInLog, retryAfter }; 147 | } 148 | 149 | /** 150 | * Resets the rate limiter to the intial state by clearing the redis store. 151 | */ 152 | public reset(): void { 153 | this.client.flushall(); 154 | } 155 | } 156 | 157 | export default SlidingWindowLog; 158 | -------------------------------------------------------------------------------- /test/rateLimiters/fixedWindow.test.ts: -------------------------------------------------------------------------------- 1 | import * as ioredis from 'ioredis'; 2 | import { FixedWindow as Window } from '../../src/@types/rateLimit'; 3 | import FixedWindow from '../../src/rateLimiters/fixedWindow'; 4 | 5 | // eslint-disable-next-line @typescript-eslint/no-var-requires 6 | const RedisMock = require('ioredis-mock'); 7 | 8 | const CAPACITY = 10; 9 | const WINDOW_SIZE = 6000; 10 | 11 | let limiter: FixedWindow; 12 | let client: ioredis.Redis; 13 | let timestamp: number; 14 | const user1 = '1'; 15 | const user2 = '2'; 16 | const user3 = '3'; 17 | 18 | async function getWindowFromClient(redisClient: ioredis.Redis, uuid: string): Promise { 19 | const res = await redisClient.get(uuid); 20 | // if no uuid is found, return -1 for tokens and timestamp, which are both impossible 21 | if (res === null) return { currentTokens: -1, fixedWindowStart: -1 }; 22 | return JSON.parse(res); 23 | } 24 | 25 | async function setTokenCountInClient( 26 | redisClient: ioredis.Redis, 27 | uuid: string, 28 | currentTokens: number, 29 | fixedWindowStart: number 30 | ) { 31 | const value: Window = { currentTokens, fixedWindowStart }; 32 | await redisClient.set(uuid, JSON.stringify(value)); 33 | } 34 | describe('Test FixedWindow Rate Limiter', () => { 35 | beforeEach(async () => { 36 | client = new RedisMock(); 37 | limiter = new FixedWindow(CAPACITY, WINDOW_SIZE, client, 8000); 38 | timestamp = new Date().valueOf(); 39 | }); 40 | describe('FixedWindow returns correct number of tokens and updates redis store as expected', () => { 41 | describe('after an ALLOWED request...', () => { 42 | afterEach(() => { 43 | client.flushall(); 44 | }); 45 | test('current time window has no token initially', async () => { 46 | // zero token used in this time window 47 | const withdraw5 = 5; 48 | expect((await limiter.processRequest(user1, timestamp, withdraw5)).tokens).toBe( 49 | CAPACITY - withdraw5 50 | ); 51 | const tokenCountFull = await getWindowFromClient(client, user1); 52 | expect(tokenCountFull.currentTokens).toBe(5); 53 | }); 54 | test('reached 40% capacity in current time window and still can pass request', async () => { 55 | const initial = 5; 56 | await setTokenCountInClient(client, user2, initial, timestamp); 57 | const partialWithdraw = 2; 58 | expect( 59 | ( 60 | await limiter.processRequest( 61 | user2, 62 | timestamp + WINDOW_SIZE * 0.4, 63 | partialWithdraw 64 | ) 65 | ).tokens 66 | ).toBe(CAPACITY - initial - partialWithdraw); 67 | 68 | const tokenCountPartial = await getWindowFromClient(client, user2); 69 | expect(tokenCountPartial.currentTokens).toBe(initial + partialWithdraw); 70 | }); 71 | 72 | test('window is partially full and request has no leftover tokens', async () => { 73 | const initial = 6; 74 | const partialWithdraw = 4; 75 | await setTokenCountInClient(client, user2, initial, timestamp); 76 | expect( 77 | (await limiter.processRequest(user2, timestamp, partialWithdraw)).success 78 | ).toBe(true); 79 | expect( 80 | (await limiter.processRequest(user2, timestamp, partialWithdraw)).tokens 81 | ).toBe(0); 82 | }); 83 | 84 | test('window is partially full and request exceeds tokens in availability', async () => { 85 | const initial = 6; 86 | const partialWithdraw = 5; 87 | await setTokenCountInClient(client, user2, initial, timestamp); 88 | const blocked = await limiter.processRequest( 89 | user2, 90 | timestamp + WINDOW_SIZE * 0.4, 91 | partialWithdraw 92 | ); 93 | expect(blocked.success).toBe(false); 94 | expect(blocked.retryAfter).toBe(Math.ceil((WINDOW_SIZE * 0.6) / 1000)); 95 | expect( 96 | (await limiter.processRequest(user2, timestamp, partialWithdraw)).tokens 97 | ).toBe(4); 98 | }); 99 | }); 100 | describe('after a BLOCKED request...', () => { 101 | afterEach(() => { 102 | client.flushall(); 103 | }); 104 | test('initial request is greater than capacity', async () => { 105 | // expect remaining tokens to be 10, b/c the 11 token request should be blocked 106 | const blocked = await limiter.processRequest(user1, timestamp, 11); 107 | expect(blocked.success).toBe(false); 108 | expect(blocked.retryAfter).toBe(Infinity); 109 | // expect current tokens in the window to still be 0 110 | expect((await getWindowFromClient(client, user1)).currentTokens).toBe(0); 111 | }); 112 | test('window is partially full but not enough time elapsed to reach new window', async () => { 113 | const requestedTokens = 9; 114 | 115 | await setTokenCountInClient(client, user2, requestedTokens, timestamp); 116 | // expect remaining tokens to be 1, b/c the 2-token-request should be blocked 117 | const result = await limiter.processRequest(user2, timestamp + WINDOW_SIZE - 1, 2); 118 | 119 | expect(result.success).toBe(false); 120 | expect(result.tokens).toBe(1); 121 | expect(result.retryAfter).toBe(1); // 1 second 122 | 123 | // expect current tokens in the window to still be 9 124 | expect((await getWindowFromClient(client, user2)).currentTokens).toBe(9); 125 | }); 126 | }); 127 | describe('updateTimeWindow function works as expect', () => { 128 | afterEach(() => { 129 | client.flushall(); 130 | }); 131 | test('New window is initialized after reaching the window size', async () => { 132 | const fullRequest = 10; 133 | await setTokenCountInClient(client, user3, fullRequest, timestamp); 134 | const noAccess = await limiter.processRequest( 135 | user3, 136 | timestamp + WINDOW_SIZE - 1, 137 | 2 138 | ); 139 | 140 | // expect not passing any request 141 | expect(noAccess.tokens).toBe(0); 142 | expect(noAccess.success).toBe(false); 143 | 144 | const newRequest = 1; 145 | expect( 146 | (await limiter.processRequest(user3, timestamp + WINDOW_SIZE, newRequest)) 147 | .success 148 | ).toBe(true); 149 | const count = await getWindowFromClient(client, user3); 150 | expect(count.currentTokens).toBe(1); 151 | }); 152 | test('Request will be passed after two window sizes', async () => { 153 | const fullRequest = 10; 154 | await setTokenCountInClient(client, user3, fullRequest, timestamp); 155 | const noAccess = await limiter.processRequest( 156 | user3, 157 | timestamp + WINDOW_SIZE - 1, 158 | 2 159 | ); 160 | 161 | // expect not passing any request 162 | expect(noAccess.tokens).toBe(0); 163 | expect(noAccess.success).toBe(false); 164 | 165 | const newRequest = 6; 166 | // check if current time is over one window size 167 | const newAccess = await limiter.processRequest( 168 | user3, 169 | timestamp + WINDOW_SIZE * 2, 170 | newRequest 171 | ); 172 | 173 | expect(newAccess.tokens).toBe(4); 174 | expect(newAccess.success).toBe(true); 175 | }); 176 | }); 177 | }); 178 | }); 179 | -------------------------------------------------------------------------------- /test/analysis/weightFunction.test.ts: -------------------------------------------------------------------------------- 1 | import 'ts-jest'; 2 | import { buildSchema, DocumentNode, parse } from 'graphql'; 3 | import { TypeWeightObject } from '../../src/@types/buildTypeWeights'; 4 | import buildTypeWeightsFromSchema from '../../src/analysis/buildTypeWeights'; 5 | import QueryParser from '../../src/analysis/QueryParser'; 6 | // Test the weight function generated by the typeweights object when a limiting keyword is provided 7 | 8 | // Test cases: 9 | // Default value provided to schema 10 | // Arg passed in as variable 11 | // Arg passed in as scalar 12 | // Invalid arg type provided 13 | 14 | // Default value passed with query 15 | 16 | describe('Weight Function correctly parses Argument Nodes if', () => { 17 | const schema = buildSchema(` 18 | type Query { 19 | reviews(episode: Episode!, first: Int = 5): [Review] 20 | heroes(episode: Episode!, first: Int): [Review] 21 | villains(episode: Episode!, limit: Int! = 3): [Review]! 22 | characters(episode: Episode!, limit: Int!): [Review!] 23 | droids(episode: Episode!, limit: Int!): [Review!]! 24 | 25 | } 26 | type Review { 27 | episode: Episode 28 | stars: Int! 29 | commentary: String 30 | scalarList(last: Int): [Int] 31 | objectList(first: Int): [Object] 32 | } 33 | type Object { 34 | hi: String 35 | } 36 | enum Episode { 37 | NEWHOPE 38 | EMPIRE 39 | JEDI 40 | }`); 41 | // building the typeWeights object here since we're testing the weight function created in 42 | // the typeWeights object 43 | const typeWeights: TypeWeightObject = buildTypeWeightsFromSchema(schema); 44 | let queryParser: QueryParser; 45 | describe('a default value is provided in the schema', () => { 46 | beforeEach(() => { 47 | queryParser = new QueryParser(typeWeights, {}); 48 | }); 49 | test('and a value is not provided with the query', () => { 50 | const query = `query { reviews(episode: NEWHOPE) { stars, episode } }`; 51 | const queryAST: DocumentNode = parse(query); 52 | expect(queryParser.processQuery(queryAST)).toBe(6); 53 | }); 54 | 55 | test('and a scalar value is provided with the query', () => { 56 | const query = `query { reviews(episode: NEWHOPE, first: 3) { stars, episode } }`; 57 | const queryAST: DocumentNode = parse(query); 58 | expect(queryParser.processQuery(queryAST)).toBe(4); 59 | }); 60 | 61 | test('and the argument is passed in as a variable', () => { 62 | const query = `query variableQuery ($items: Int){ reviews(episode: NEWHOPE, first: $items) { stars, episode } }`; 63 | const queryAST: DocumentNode = parse(query); 64 | queryParser = new QueryParser(typeWeights, { items: 7, first: 4 }); 65 | expect(queryParser.processQuery(queryAST)).toBe(8); 66 | queryParser = new QueryParser(typeWeights, { first: 4, items: 7 }); 67 | expect(queryParser.processQuery(queryAST)).toBe(8); 68 | }); 69 | }); 70 | 71 | describe('a default value is not provided in the schema', () => { 72 | xtest('and a value is not provied with the query', () => { 73 | const query = `query { heroes(episode: NEWHOPE) { stars, episode } }`; 74 | const queryAST: DocumentNode = parse(query); 75 | // FIXME: Update expected result if unbounded lists are suppored 76 | expect(queryParser.processQuery(queryAST)).toBe(5); 77 | }); 78 | 79 | test('and a scalar value is provided with the query', () => { 80 | const query = `query { heroes(episode: NEWHOPE, first: 3) { stars, episode } }`; 81 | const queryAST: DocumentNode = parse(query); 82 | expect(queryParser.processQuery(queryAST)).toBe(4); 83 | }); 84 | 85 | test('and the argument is passed in as a variable', () => { 86 | const query = `query variableQuery ($items: Int){ heroes(episode: NEWHOPE, first: $items) { stars, episode } }`; 87 | const queryAST: DocumentNode = parse(query); 88 | queryParser = new QueryParser(typeWeights, { items: 7 }); 89 | expect(queryParser.processQuery(queryAST)).toBe(8); 90 | }); 91 | }); 92 | 93 | test('the list is defined with non-null operators (!)', () => { 94 | const villainsQuery = `query { villains(episode: NEWHOPE, limit: 3) { stars, episode } }`; 95 | const villainsQueryAST: DocumentNode = parse(villainsQuery); 96 | expect(queryParser.processQuery(villainsQueryAST)).toBe(4); 97 | 98 | const charQuery = `query { characters(episode: NEWHOPE, limit: 3) { stars, episode } }`; 99 | const charQueryAST: DocumentNode = parse(charQuery); 100 | expect(queryParser.processQuery(charQueryAST)).toBe(4); 101 | 102 | const droidsQuery = `query droidsQuery { droids(episode: NEWHOPE, limit: 3) { stars, episode } }`; 103 | const droidsQueryAST: DocumentNode = parse(droidsQuery); 104 | expect(queryParser.processQuery(droidsQueryAST)).toBe(4); 105 | }); 106 | 107 | test('a custom object weight was configured', () => { 108 | const customTypeWeights: TypeWeightObject = buildTypeWeightsFromSchema(schema, { 109 | object: 3, 110 | }); 111 | queryParser = new QueryParser(customTypeWeights, {}); 112 | const query = `query { heroes(episode: NEWHOPE, first: 3) { stars, episode } }`; 113 | const queryAST: DocumentNode = parse(query); 114 | expect(queryParser.processQuery(queryAST)).toBe(10); 115 | }); 116 | 117 | test('a custom object weight was set to 0', () => { 118 | const customTypeWeights: TypeWeightObject = buildTypeWeightsFromSchema(schema, { 119 | object: 0, 120 | }); 121 | queryParser = new QueryParser(customTypeWeights, {}); 122 | const query = `query { heroes(episode: NEWHOPE, first: 3) { stars, episode } }`; 123 | const queryAST: DocumentNode = parse(query); 124 | expect(queryParser.processQuery(queryAST)).toBe(1); // 1 query 125 | }); 126 | test('a custom scalar weight was set to greater than 0', () => { 127 | const customTypeWeights: TypeWeightObject = buildTypeWeightsFromSchema(schema, { 128 | scalar: 2, 129 | }); 130 | queryParser = new QueryParser(customTypeWeights, {}); 131 | const query = `query { heroes(episode: NEWHOPE, first: 3) { stars, episode } }`; 132 | const queryAST: DocumentNode = parse(query); 133 | expect(queryParser.processQuery(queryAST)).toBe(16); 134 | }); 135 | 136 | test('variable names matching limiting keywords do not interfere with scalar argument values', () => { 137 | const query = `query variableQuery ($items: Int){ heroes(episode: NEWHOPE, first: 3) { stars, episode } }`; 138 | const queryAST: DocumentNode = parse(query); 139 | queryParser = new QueryParser(typeWeights, { first: 7 }); 140 | expect(queryParser.processQuery(queryAST)).toBe(4); 141 | }); 142 | 143 | test('nested queries with lists', () => { 144 | const query = `query { reviews(episode: NEWHOPE, first: 2) {stars, objectList(first: 3) {hi}}} `; 145 | expect(queryParser.processQuery(parse(query))).toBe(9); // 1 Query + 2 review + (2 * 3 objects) 146 | }); 147 | 148 | test('queries with inner scalar lists', () => { 149 | const query = `query { reviews(episode: NEWHOPE, first: 2) {stars, scalarList(last: 3) }}`; 150 | expect(queryParser.processQuery(parse(query))).toBe(3); // 1 Query + 2 reviews 151 | }); 152 | 153 | test('queries with inner scalar lists and custom scalar weight greater than 0', () => { 154 | const customTypeWeights: TypeWeightObject = buildTypeWeightsFromSchema(schema, { 155 | scalar: 2, 156 | }); 157 | queryParser = new QueryParser(customTypeWeights, {}); 158 | const query = `query { reviews(episode: NEWHOPE, first: 2) {stars, scalarList(last: 3) }}`; 159 | expect(queryParser.processQuery(parse(query))).toBe(19); // 1 Query + 2 reviews + 2 * (2 stars + (3 * 2 scalarList) 160 | }); 161 | 162 | xtest('an invalid arg type is provided', () => { 163 | const query = `query { heroes(episode: NEWHOPE, first = 3) { stars, episode } }`; 164 | const queryAST: DocumentNode = parse(query); 165 | // FIXME: What is the expected behavior? Treat as unbounded? 166 | fail('test not implemented'); 167 | }); 168 | }); 169 | -------------------------------------------------------------------------------- /src/middleware/index.ts: -------------------------------------------------------------------------------- 1 | import EventEmitter from 'events'; 2 | import { parse, validate } from 'graphql'; 3 | import { GraphQLSchema } from 'graphql/type/schema'; 4 | import { Request, Response, NextFunction, RequestHandler } from 'express'; 5 | import buildTypeWeightsFromSchema, { defaultTypeWeightsConfig } from '../analysis/buildTypeWeights'; 6 | import setupRateLimiter from './rateLimiterSetup'; 7 | import { ExpressMiddlewareConfig, ExpressMiddlewareSet } from '../@types/expressMiddleware'; 8 | import { RateLimiterResponse } from '../@types/rateLimit'; 9 | import { connect } from '../utils/redis'; 10 | import QueryParser from '../analysis/QueryParser'; 11 | 12 | /** 13 | * Primary entry point for adding GraphQL Rate Limiting middleware to an Express Server 14 | * @param {GraphQLSchema} schema GraphQLSchema object 15 | * @param {ExpressMiddlewareConfig} middlewareConfig 16 | * , "ratelimiter" must be explicitly specified in the setup of the middleware. 17 | * , "redis" connection options (https://ioredis.readthedocs.io/en/stable/API/#new_Redis) and an optional "keyExpiry" property (defaults to 24h) 18 | * , "typeWeights" optional type weight configuration for the GraphQL Schema. Developers can override default typeWeights. Defaults to {mutation: 10, query: 1, object: 1, scalar/enum: 0, connection: 2} 19 | * , "dark: true" will run the package in "dark mode" to monitor queries and rate limiting data before implementing rate limitng functionality. Defaults to false 20 | * , "enforceBoundedLists: true" will throw an error if any lists in the schema are not constrained by slicing arguments: Defaults to false 21 | * , "depthLimit: number" will block queries with deeper nesting than the specified depth. Will not block queries by depth by default 22 | * @returns {RequestHandler} express middleware that computes the complexity of req.query and calls the next middleware 23 | * if the query is allowed or sends a 429 status if the request is blocked 24 | * @throws Error 25 | */ 26 | export default function expressGraphQLRateLimiter( 27 | schema: GraphQLSchema, 28 | middlewareConfig: ExpressMiddlewareConfig 29 | ): RequestHandler { 30 | /** 31 | * Setup the middleware configuration with a passed in and default values 32 | * - redis "keyExpiry" defaults to 1 day (in ms) 33 | * - "typeWeights" defaults to defaultTypeWeightsConfig 34 | * - "dark" and "enforceBoundedLists" default to false 35 | * - "depthLimit" defaults to Infinity 36 | */ 37 | const middlewareSetup: ExpressMiddlewareSet = { 38 | rateLimiter: middlewareConfig.rateLimiter, 39 | typeWeights: { ...defaultTypeWeightsConfig, ...middlewareConfig.typeWeights }, 40 | redis: { 41 | keyExpiry: middlewareConfig.redis?.keyExpiry || 86400000, 42 | options: { ...middlewareConfig.redis?.options }, 43 | }, 44 | dark: middlewareConfig.dark || false, 45 | enforceBoundedLists: middlewareConfig.enforceBoundedLists || false, 46 | depthLimit: middlewareConfig.depthLimit || Infinity, 47 | }; 48 | 49 | /** No query can have a depth of less than 2 */ 50 | if (middlewareSetup.depthLimit <= 2) { 51 | throw new Error( 52 | `Error in expressGraphQLRateLimiter: depthLimit cannot be less than or equal to 1` 53 | ); 54 | } 55 | 56 | /** Build the type weight object, create the redis client and instantiate the ratelimiter */ 57 | const typeWeightObject = buildTypeWeightsFromSchema( 58 | schema, 59 | middlewareSetup.typeWeights, 60 | middlewareSetup.enforceBoundedLists 61 | ); 62 | const redisClient = connect(middlewareSetup.redis.options); 63 | const rateLimiter = setupRateLimiter( 64 | middlewareSetup.rateLimiter, 65 | redisClient, 66 | middlewareSetup.redis.keyExpiry 67 | ); 68 | 69 | /** 70 | * We are using a queue and event emitter to handle situations where a user has two concurrent requests being processed. 71 | * The trailing request will be added to the queue to and await the prior request processing by the rate-limiter 72 | * This will maintain the consistency and accuracy of the cache when under load from one user 73 | */ 74 | // stores request IDs for each user in an array to be processed 75 | const requestQueues: { [index: string]: string[] } = {}; 76 | // Manages processing of requests queue 77 | const requestEvents = new EventEmitter(); 78 | 79 | // processes requests (by resolving promises) that have been throttled by throttledProcess 80 | async function processRequestResolver( 81 | userId: string, 82 | timestamp: number, 83 | tokens: number, 84 | resolve: (value: RateLimiterResponse | PromiseLike) => void, 85 | reject: (reason: any) => void 86 | ) { 87 | try { 88 | const response = await rateLimiter.processRequest(userId, timestamp, tokens); 89 | requestQueues[userId] = requestQueues[userId].slice(1); 90 | resolve(response); 91 | // trigger the next event and delete the request queue for this user if there are no more requests to process 92 | requestEvents.emit(requestQueues[userId][0]); 93 | if (requestQueues[userId].length === 0) delete requestQueues[userId]; 94 | } catch (err) { 95 | reject(err); 96 | } 97 | } 98 | 99 | /** 100 | * Throttle rateLimiter.processRequest based on user IP to prevent inaccurate redis reads 101 | * Throttling is based on a event driven promise fulfillment approach. 102 | * Each time a request is received a promise is added to the user's request queue. The promise "subscribes" 103 | * to the previous request in the user's queue then calls processRequest and resolves once the previous request 104 | * is complete. 105 | * @param userId 106 | * @param timestamp 107 | * @param tokens 108 | * @returns 109 | */ 110 | async function throttledProcess( 111 | userId: string, 112 | timestamp: number, 113 | tokens: number 114 | ): Promise { 115 | // Alternatively use crypto.randomUUID() to generate a random uuid 116 | const requestId = `${timestamp}${tokens}`; 117 | 118 | if (!requestQueues[userId]) { 119 | requestQueues[userId] = []; 120 | } 121 | requestQueues[userId].push(requestId); 122 | 123 | return new Promise((resolve, reject) => { 124 | if (requestQueues[userId].length > 1) { 125 | requestEvents.once(requestId, async () => { 126 | await processRequestResolver(userId, timestamp, tokens, resolve, reject); 127 | }); 128 | } else { 129 | processRequestResolver(userId, timestamp, tokens, resolve, reject); 130 | } 131 | }); 132 | } 133 | 134 | /** Rate-limiting middleware */ 135 | return async ( 136 | req: Request, 137 | res: Response, 138 | next: NextFunction 139 | ): Promise>> => { 140 | const requestTimestamp = new Date().valueOf(); 141 | // access the query and variables passed to the server in the body or query string 142 | let query; 143 | let variables; 144 | if (req.query) { 145 | query = req.query.query; 146 | variables = req.query.variables; 147 | } else if (req.body) { 148 | query = req.body.query; 149 | variables = req.body.variables; 150 | } 151 | if (!query) { 152 | console.error( 153 | 'Error in expressGraphQLRateLimiter: There is no query on the request. Rate-Limiting skipped' 154 | ); 155 | return next(); 156 | } 157 | // check for a proxied ip address before using the ip address on request 158 | const ip: string = req.ips ? req.ips[0] : req.ip; 159 | 160 | const queryAST = parse(query); 161 | // validate the query against the schema. returns an array of errors. 162 | const validationErrors = validate(schema, queryAST); 163 | // return the errors to the client if the array has length. otherwise there are no errors 164 | if (validationErrors.length > 0) { 165 | res.status(400).json({ errors: validationErrors }); 166 | } 167 | 168 | const queryParser = new QueryParser(typeWeightObject, variables); 169 | const queryComplexity = queryParser.processQuery(queryAST); 170 | 171 | try { 172 | const rateLimiterResponse = await throttledProcess( 173 | ip, 174 | requestTimestamp, 175 | queryComplexity 176 | ); 177 | res.locals.graphqlGate = { 178 | timestamp: requestTimestamp, 179 | complexity: queryComplexity, 180 | tokens: rateLimiterResponse.tokens, 181 | success: 182 | rateLimiterResponse.success && 183 | queryParser.maxDepth >= middlewareSetup.depthLimit, 184 | depth: queryParser.maxDepth, 185 | }; 186 | /** The three conditions for returning a status code 429 are 187 | * 1. rate-limiter blocked the request 188 | * 2. query exceeded the depth limit 189 | * 3. the middleware is configured not to run in dark mode 190 | */ 191 | if ( 192 | (!rateLimiterResponse.success || 193 | queryParser.maxDepth > middlewareSetup.depthLimit) && 194 | !middlewareSetup.dark 195 | ) { 196 | // a Retry-After header of Infinity means the request will never be accepted 197 | return res 198 | .status(429) 199 | .set({ 200 | 'Retry-After': `${ 201 | queryParser.maxDepth > middlewareSetup.depthLimit 202 | ? Infinity 203 | : rateLimiterResponse.retryAfter 204 | }`, 205 | }) 206 | .json(res.locals.graphqlgate); 207 | } 208 | return next(); 209 | } catch (err) { 210 | // log the error to the console and pass the request onto the next middleware. 211 | console.error( 212 | `Error in expressGraphQLRateLimiter processing query. Rate limiting is skipped: ${err}` 213 | ); 214 | return next(err); 215 | } 216 | }; 217 | } 218 | -------------------------------------------------------------------------------- /src/rateLimiters/slidingWindowCounter.ts: -------------------------------------------------------------------------------- 1 | import Redis from 'ioredis'; 2 | import { RateLimiter, RateLimiterResponse, RedisWindow } from '../@types/rateLimit'; 3 | 4 | /** 5 | * The SlidingWindowCounter instance of a RateLimiter limits requests based on a unique user ID. 6 | * This algorithm improves upon the FixedWindowCounter because this algorithm prevents fixed window's 7 | * flaw of allowing doubled capacity requests when hugging the window's borders with a rolling window, 8 | * allowing us to average the requests between both windows proportionately with the rolling window's 9 | * takeup in each. 10 | * 11 | * Whenever a user makes a request the following steps are performed: 12 | * 1. Fixed windows are defined along with redis caches if previously undefined. 13 | * 2. Rolling windows are defined or updated based on the timestamp of the new request. 14 | * 3. Counter of the current fixed window is updated with the new request's token usage. 15 | * 4. If a new minute interval is reached, the averaging formula is run to prevent fixed window's flaw 16 | * of flooded requests around window borders 17 | * (ex. 1m windows, 10 token capacity: 1m59s 10 reqs 2m2s 10 reqs) 18 | */ 19 | class SlidingWindowCounter implements RateLimiter { 20 | private windowSize: number; 21 | 22 | private keyExpiry: number; 23 | 24 | private capacity: number; 25 | 26 | private client: Redis; 27 | 28 | /** 29 | * Create a new instance of a SlidingWindowCounter rate limiter that can be connected to any database store 30 | * @param windowSize size of each window in milliseconds (fixed and rolling) 31 | * @param capacity max capacity of tokens allowed per fixed window 32 | * @param client redis client where rate limiter will cache information 33 | */ 34 | constructor(windowSize: number, capacity: number, client: Redis, expiry: number) { 35 | this.windowSize = windowSize; 36 | this.capacity = capacity; 37 | this.client = client; 38 | this.keyExpiry = expiry; 39 | if (!windowSize || !capacity || windowSize <= 0 || capacity <= 0 || expiry <= 0) 40 | throw SyntaxError( 41 | 'SlidingWindowCounter window size, capacity and keyExpiry must be positive' 42 | ); 43 | } 44 | 45 | /** 46 | * @function processRequest - Sliding window counter algorithm to allow or block 47 | * based on the depth/complexity (in amount of tokens) of incoming requests. 48 | * 49 | * First, checks if a window exists in the redis cache. 50 | * 51 | * If not, then `fixedWindowStart` is set as the current timestamp, and `currentTokens` 52 | * is checked against `capacity`. If enough room exists for the request, returns 53 | * success as true and tokens as how many tokens remain in the current fixed window. 54 | * 55 | * If a window does exist in the cache, we first check if the timestamp is greater than 56 | * the fixedWindowStart + windowSize. 57 | * 58 | * If it isn't then we check the number of tokens in the arguments as well as in the cache 59 | * against the capacity and return success or failure from there while updating the cache. 60 | * 61 | * If the timestamp is over the windowSize beyond the fixedWindowStart, then we update fixedWindowStart 62 | * to be fixedWindowStart + windowSize (to create a new fixed window) and 63 | * make previousTokens = currentTokens, and currentTokens equal to the number of tokens in args, if 64 | * not over capacity. 65 | * 66 | * Once previousTokens is not null, we then run functionality using the rolling window to compute 67 | * the formula this entire limiting algorithm is distinguished by: 68 | * 69 | * currentTokens + previousTokens * overlap % of rolling window over previous fixed window 70 | * 71 | * @param {string} uuid - unique identifer used to throttle requests 72 | * @param {number} timestamp - time the request was recieved 73 | * @param {number} [tokens=1] - complexity of the query for throttling requests 74 | * @return {*} {Promise} 75 | * RateLimiterResponse: {success: boolean, tokens: number} 76 | * (tokens represents the remaining available capacity of the window) 77 | * @memberof SlidingWindowCounter 78 | */ 79 | async processRequest( 80 | uuid: string, 81 | timestamp: number, 82 | tokens = 1 83 | ): Promise { 84 | // attempt to get the value for the uuid from the redis cache 85 | const windowJSON = await this.client.get(uuid); 86 | 87 | // if the response is null, we need to create a window for the user 88 | if (windowJSON === null) { 89 | const newUserWindow: RedisWindow = { 90 | // current and previous tokens represent how many tokens are in each window 91 | currentTokens: tokens <= this.capacity ? tokens : 0, 92 | previousTokens: 0, 93 | fixedWindowStart: timestamp, 94 | }; 95 | 96 | if (tokens <= this.capacity) { 97 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(newUserWindow)); 98 | return { success: true, tokens: this.capacity - newUserWindow.currentTokens }; 99 | } 100 | 101 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(newUserWindow)); 102 | // tokens property represents how much capacity remains 103 | return { success: false, tokens: this.capacity, retryAfter: Infinity }; 104 | } 105 | 106 | // if the cache is populated 107 | 108 | const window: RedisWindow = await JSON.parse(windowJSON); 109 | 110 | const updatedUserWindow: RedisWindow = { 111 | currentTokens: window.currentTokens, 112 | previousTokens: window.previousTokens, 113 | fixedWindowStart: window.fixedWindowStart, 114 | }; 115 | 116 | // if request time is in a new window 117 | if (window.fixedWindowStart && timestamp >= window.fixedWindowStart + this.windowSize) { 118 | // if more than one window was skipped 119 | if (timestamp >= window.fixedWindowStart + this.windowSize * 2) { 120 | // if one or more windows was skipped, reset new window to be at current timestamp 121 | updatedUserWindow.previousTokens = 0; 122 | updatedUserWindow.currentTokens = 0; 123 | updatedUserWindow.fixedWindowStart = timestamp; 124 | } else { 125 | updatedUserWindow.previousTokens = updatedUserWindow.currentTokens; 126 | updatedUserWindow.currentTokens = 0; 127 | updatedUserWindow.fixedWindowStart = window.fixedWindowStart + this.windowSize; 128 | } 129 | } 130 | 131 | // assigned to avoid TS error, this var will never be used as 0 132 | // var is declared here so that below can be inside a conditional for efficiency's sake 133 | let rollingWindowProportion = 0; 134 | let previousRollingTokens = 0; 135 | 136 | if (updatedUserWindow.fixedWindowStart) { 137 | // proportion of rolling window present in previous window 138 | rollingWindowProportion = 139 | (this.windowSize - (timestamp - updatedUserWindow.fixedWindowStart)) / 140 | this.windowSize; 141 | 142 | // remove unecessary decimals, 0.xx is enough 143 | // rollingWindowProportion -= rollingWindowProportion % 0.01; 144 | 145 | // # of tokens present in rolling & previous window 146 | previousRollingTokens = Math.floor( 147 | updatedUserWindow.previousTokens * rollingWindowProportion 148 | ); 149 | } 150 | 151 | // # of tokens present in rolling and/or current window 152 | // if previous tokens is null, previousRollingTokens will be 0 153 | const rollingTokens = updatedUserWindow.currentTokens + previousRollingTokens; 154 | 155 | // if request is allowed 156 | if (tokens + rollingTokens <= this.capacity) { 157 | updatedUserWindow.currentTokens += tokens; 158 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(updatedUserWindow)); 159 | return { 160 | success: true, 161 | tokens: this.capacity - (updatedUserWindow.currentTokens + previousRollingTokens), 162 | }; 163 | } 164 | 165 | // if request is blocked 166 | await this.client.setex(uuid, this.keyExpiry, JSON.stringify(updatedUserWindow)); 167 | 168 | const { previousTokens, currentTokens } = updatedUserWindow; 169 | // Size and proportion of the window in seconds 170 | const windowSizeSeconds = this.windowSize / 1000; 171 | const rollingWindowProportionSeconds = windowSizeSeconds * rollingWindowProportion; 172 | // Tokens available for the request to use 173 | const tokensAvailable = this.capacity - (currentTokens + previousRollingTokens); 174 | // Additional tokens that are needed for the request to pass 175 | const tokensNeeded = tokens - tokensAvailable; 176 | // share of the tokens needed that can come from the previous window 177 | // 1. if the previous rolling portion of the window has more tokens than is needed for the request, than we need only those tokens needed from this window 178 | // 2. otherwise we need all the previous rolling tokens(and then some) for the request to pass 179 | const tokensNeededFromPreviousWindow = 180 | previousRollingTokens >= tokensNeeded ? tokensNeeded : previousRollingTokens; 181 | // time needed to wait to aquire the tokens needed from the previous window 182 | // 1. if the tokens available in the previous rolling window equals those needed form this window, we need to wait the remaing protion of this window to pass 183 | // 2. otherwise wait a fraction of that window to pass, determined by the ratio of previous rolling tokens available to the tokens needed from this window 184 | const timeToWaitFromPreviousTokens = 185 | previousRollingTokens === tokensNeededFromPreviousWindow 186 | ? rollingWindowProportionSeconds 187 | : rollingWindowProportionSeconds * 188 | ((previousTokens - tokensNeededFromPreviousWindow) / previousRollingTokens); 189 | // tokens needed from the current window for the request to pass 190 | const tokensNeededFromCurrentWindow = tokensNeeded - tokensNeededFromPreviousWindow; 191 | // time needed to wait to aquire the from the current window tfor the request to pass 192 | // 1. if the tokens needed from the current window is 0, thon no time is needed 193 | // 2. otherwise wait a fraction of time as determined by 194 | const timeToWaitFromCurrentTokens = 195 | tokensNeededFromCurrentWindow === 0 196 | ? 0 197 | : windowSizeSeconds * (tokensNeededFromCurrentWindow / currentTokens); 198 | 199 | return { 200 | success: false, 201 | tokens: this.capacity - (updatedUserWindow.currentTokens + previousRollingTokens), 202 | retryAfter: 203 | tokens > this.capacity 204 | ? Infinity 205 | : Math.ceil(timeToWaitFromPreviousTokens + timeToWaitFromCurrentTokens), 206 | }; 207 | } 208 | 209 | /** 210 | * Resets the rate limiter to the intial state by clearing the redis store. 211 | */ 212 | public reset(): void { 213 | this.client.flushall(); 214 | } 215 | } 216 | 217 | export default SlidingWindowCounter; 218 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

GraphQLGate

4 | GitHub stars GitHub issues GitHub last commit 5 | 6 |

A GraphQL rate-limiting library with query complextiy analysis for Node.js and Express

7 |
8 | 9 |   10 | ## Summary 11 | 12 | Developed under tech-accelerator [OSLabs](https://opensourcelabs.io/), GraphQLGate strives for a principled approach to complexity analysis and rate-limiting for GraphQL queries by accurately estimating an upper-bound of the response size of the query. Within a loosely opinionated framework with lots of configuration options, you can reliably throttle GraphQL queries by complexity and depth to protect your GraphQL API. Our solution is inspired by [this paper](https://github.com/Alan-Cha/fse20/blob/master/submissions/functional/FSE-24/graphql-paper.pdf) from IBM research teams. 13 | 14 | ## Table of Contents 15 | 16 | - [Getting Started](#getting-started) 17 | - [Configuration](#configuration) 18 | - [Notes on Lists](#lists) 19 | - [How It Works](#how-it-works) 20 | - [Response](#response) 21 | - [Error Handling](#error-handling) 22 | - [Internals](#internals) 23 | - [Future Development](#future-development) 24 | - [Contributions](#contributions) 25 | - [Developers](#developers) 26 | - [License](#license) 27 | 28 | ## Getting Started 29 | 30 | Install the package 31 | 32 | ``` 33 | npm i graphql-limiter 34 | ``` 35 | 36 | Import the package and add the rate-limiting middleware to the Express middleware chain before the GraphQL server. 37 | 38 | NOTE: a Redis server instance will need to be started in order for the limiter to cache data. 39 | 40 | ```javascript 41 | // import package 42 | import { expressGraphQLRateLimiter } from 'graphql-limiter'; 43 | 44 | /** 45 | * Import other dependencies 46 | * */ 47 | 48 | //Add the middleware into your GraphQL middleware chain 49 | app.use( 50 | 'gql', 51 | expressGraphQLRateLimiter(schemaObject, { 52 | rateLimiter: { 53 | type: 'TOKEN_BUCKET', 54 | refillRate: 10, 55 | capacity: 100, 56 | }, 57 | }) /** add GraphQL server here */ 58 | ); 59 | ``` 60 | 61 | ## Configuration 62 | 63 | 1. #### `schema: GraphQLSchema` | required 64 | 65 | 2. #### `config: ExpressMiddlewareConfig` | required 66 | 67 | - `rateLimiter: RateLimiterOptions` | required 68 | 69 | - `type: 'TOKEN_BUCKET' | 'FIXED_WINDOW' | 'SLIDING_WINDOW_LOG' | 'SLIDING_WINDOW_COUTER'` 70 | - `capacity: number` 71 | - `refillRate: number` | bucket algorithms only 72 | - `windowSize: number` | (in ms) window algorithms only 73 | 74 | - `redis: RedisConfig` 75 | 76 | - `options: RedisOptions` | [ioredis configuration options](https://github.com/luin/ioredis) | defaults to standard ioredis connection options (`localhost:6379`) 77 | - `keyExpiry: number` (ms) | custom expiry of keys in redis cache | defaults to 24 hours 78 | 79 | - `typeWeights: TypeWeightObject` 80 | 81 | - `mutation: number` | assigned weight to mutations | defaults to 10 82 | - `query: number` | assigned weight of a query | defaults to 1 83 | - `object: number` | assigned weight of GraphQL object, interface and union types | defaults to `1` 84 | - `scalar: number` | assigned weight of GraphQL scalar and enum types | defaults to `0` 85 | 86 | - `depthLimit: number` | throttle queies by the depth of the nested stucture | defaults to `Infinity` (ie. no limit) 87 | - `enforceBoundedLists: boolean` | if true, an error will be thrown if any lists types are not bound by slicing arguments [`first`, `last`, `limit`] or directives | defaults to `false` 88 | - `dark: boolean` | if true, the package will calculate complexity, depth and tokens but not throttle any queries. Use this to dark launch the package and monitor the rate limiter's impact without limiting user requests. 89 | 90 | All configuration options 91 | 92 | ```javascript 93 | expressGraphQLRateLimiter(schemaObject, { 94 | rateLimiter: { 95 | type: 'SLIDING_WINDOW_LOG', // rate-limiter selection 96 | windowSize: 6000, // 6 seconds 97 | capacity: 100, 98 | }, 99 | redis: { 100 | keyExpiry: 14400000 // 4 hours, defaults to 86400000 (24 hours) 101 | options: { 102 | host: 'localhost' // ioredis connection options 103 | port: 6379, 104 | } 105 | }, 106 | typeWeights: { // weights of GraphQL types 107 | mutation: 10, 108 | query: 1, 109 | object: 1, 110 | scalar: 0, 111 | }, 112 | enforceBoundedLists: false, // defaults to false 113 | dark: false, // defaults to false 114 | depthLimit: 7 // defaults to Infinity (ie. no depth limiting) 115 | }); 116 | ``` 117 | 118 | ## Notes on Lists 119 | 120 | For queries that return a list, the complexity can be determined by providing a slicing argument to the query (`first`, `last`, `limit`), or using a schema directive. 121 | 122 | 1. Slicing arguments: lists must be bounded by one integer slicing argument in order to calculate the complexity for the field. This package supports the slicing arguments `first`, `last` and `limit`. The complexity of the list will be the value passed as the argument to the field. 123 | 124 | 2. Directives: To use directives, `@listCost` must be defined in your schema with `directive @listCost(cost: Int!) on FIELD_DEFINITION`. Then, on any field which resolves to an unbounded list, add `@listCost(cost: [Int])` where `[Int]` is the complexity for this field. 125 | 126 | (Note: Slicing arguments are preferred and will override the the `@listCost` directive! `@listCost` is in place as a fall back.) 127 | 128 | ```graphql 129 | directive @listCost(cost: Int!) on FIELD_DEFINITION 130 | type Human { 131 | id: ID! 132 | } 133 | type Query { 134 | humans: [Human] @listCost(cost: 10) 135 | } 136 | ``` 137 | 138 | ## How It Works 139 | 140 | Requests are rate-limited based on the IP address associated with the request. 141 | 142 | On startup, the GraphQL (GQL) schema is parsed to build an object that maps GQL types/fields to their corresponding weights. Type weights can be provided during initial configuration. When a request is received, this object is used to cross reference the fields queried by the user and compute the complexity of each field. The total complexity of the request is the sum of these values. 143 | 144 | Complexity is determined, statically (before any resolvers are called) to estimate the upper bound of the response size - a proxy for the work done by the server to build the response. The total complexity is then used to allow/block the request based on popular rate-limiting algorithms. 145 | 146 | Requests for each user are processed sequentially by the rate limiter. 147 | 148 | Example (with default weights): 149 | 150 | ```graphql 151 | query { 152 | # 1 query 153 | hero(episode: EMPIRE) { 154 | # 1 object 155 | name # 0 scalar 156 | id # 0 scalar 157 | friends(first: 3) { 158 | # 3 objects 159 | name # 0 scalar 160 | id # 0 scalar 161 | } 162 | } 163 | reviews(episode: EMPIRE, limit: 5) { 164 | # 5 objects 165 | stars # 0 scalar 166 | commentary # 0 scalar 167 | } 168 | } # total complexity of 10 169 | ``` 170 | 171 | ## Response 172 | 173 | 1. Blocked Requests: blocked requests recieve a response with, 174 | 175 | - status of `429` for `Too Many Requests` 176 | - `Retry-After` header indicating the time to wait in seconds before the request could be approved (`Infinity` if the complexity is greater than rate-limiting capacity). 177 | - A JSON response with the remaining `tokens` available, `complexity` of the query, `depth` of the query, `success` of the query set to `false`, and the UNIX `timestamp` of the request 178 | 179 | 2. Successful Requests: successful requests are passed on to the next function in the middleware chain with the following properties saved to `res.locals` 180 | 181 | ```javascript 182 | { 183 | graphqlGate: { 184 | success: boolean, // true when successful 185 | tokens: number, // tokens available after request 186 | compexity: number, // complexity of the query 187 | depth: number, // depth of the query 188 | timestamp: number, // UNIX timestamp 189 | } 190 | } 191 | ``` 192 | 193 | ## Error Handling 194 | 195 | - Incoming queries are validated against the GraphQL schema. If the query is invalid, a response with status code `400` is returned along with an array of GraphQL Errors that were found. 196 | - To avoid disrupting server activity, errors thrown during the analysis and rate-limiting of the query are logged and the request is passed onto the next piece of middleware in the chain. 197 | 198 | ## Internals 199 | 200 | This package exposes 3 additional functionalities which comprise the internals of the package. This is a breif documentaion on them. 201 | 202 | ### Complexity Analysis 203 | 204 | 1. #### `typeWeightsFromSchema` | function to create the type weight object from the schema for complexity analysis 205 | 206 | - `schema: GraphQLSchema` | GraphQL schema object 207 | - `typeWeightsConfig: TypeWeightConfig = defaultTypeWeightsConfig` | type weight configuration 208 | - `enforceBoundedLists = false` 209 | - returns: `TypeWeightObject` 210 | - usage: 211 | 212 | ```ts 213 | import { typeWeightsFromSchema } from 'graphql-limiter'; 214 | import { GraphQLSchema } from 'graphql/type/schema'; 215 | import { buildSchema } from 'graphql'; 216 | 217 | let schema: GraphQLSchema = buildSchema(`...`); 218 | 219 | const typeWeights: TypeWeightObject = typeWeightsFromSchema(schema); 220 | ``` 221 | 222 | 2. #### `QueryParser` | class to calculate the complexity of the query based on the type weights and variables 223 | 224 | - `typeWeights: TypeWeightObject` 225 | - `variables: Variables` | variables on request 226 | - returns a class with method: 227 | 228 | - `processQuery(queryAST: DocumentNode): number` 229 | - returns: complexity of the query and exposes `maxDepth` property for depth limiting 230 | 231 | ```ts 232 | import { typeWeightsFromSchema } from 'graphql-limiter'; 233 | import { parse, validate } from 'graphql'; 234 | 235 | let queryAST: DocumentNode = parse(`...`); 236 | 237 | const queryParser: QueryParser = new QueryParser(typeWeights, variables); 238 | 239 | // query must be validatied against the schema before processing the query 240 | const validationErrors = validate(schema, queryAST); 241 | 242 | const complexity: number = queryParser.processQuery(queryAST); 243 | ``` 244 | 245 | ### Rate-limiting 246 | 247 | 3. #### `rateLimiter` | returns a rate limiting class instance based on selections 248 | 249 | - `rateLimiter: RateLimiterConfig` | see "configuration" -> rateLimiter 250 | - `client: Redis` | an ioredis client 251 | - `keyExpiry: number` | time (ms) for key to persist in cache 252 | - returns a rate limiter class with method: 253 | 254 | - `processRequest(uuid: string, timestamp: number, tokens = 1): Promise` 255 | - returns: `{ success: boolean, tokens: number, retryAfter?: number }` | where `tokens` is tokens available, `retryAfter` is time to wait in seconds before the request would be successful and `success` is false if the request is blocked 256 | 257 | ```ts 258 | import { rateLimiter } from 'graphql-limiter'; 259 | 260 | const limiter: RateLimiter = rateLimiter( 261 | { 262 | type: 'TOKEN_BUCKET', 263 | refillRate: 1, 264 | capacity: 10, 265 | }, 266 | redisClient, 267 | 86400000 // 24 hours 268 | ); 269 | 270 | const response: RateLimiterResponse = limiter.processRequest( 271 | 'user-1', 272 | new Date().valueOf(), 273 | 5 274 | ); 275 | ``` 276 | 277 | ## Future Development 278 | 279 | - Ability to use this package with other caching technologies or libraries 280 | - Implement "resolve complexity analysis" for queries 281 | - Implement leaky bucket algorithm for rate-limiting 282 | - Experiment with performance improvements 283 | - caching optimization 284 | - Ensure connection pagination conventions can be accuratly acconuted for in complexity analysis 285 | - Ability to use middleware with other server frameworks 286 | 287 | ## Contributions 288 | 289 | Contributions to the code, examples, documentation, etc. are very much appreciated. 290 | 291 | - Please report issues and bugs directly in this [GitHub project](https://github.com/oslabs-beta/GraphQL-Gate/issues). 292 | 293 | ## Developers 294 | 295 | - [Evan McNeely](https://github.com/evanmcneely) 296 | - [Stephan Halarewicz](https://github.com/shalarewicz) 297 | - [Flora Yufei Wu](https://github.com/feiw101) 298 | - [Jon Dewey](https://github.com/donjewey) 299 | - [Milos Popovic](https://github.com/milos381) 300 | 301 | ## License 302 | 303 | This product is licensed under the MIT License - see the LICENSE.md file for details. 304 | 305 | This is an open source product. 306 | 307 | This product is accelerated by OS Labs. 308 | -------------------------------------------------------------------------------- /src/analysis/QueryParser.ts: -------------------------------------------------------------------------------- 1 | import { 2 | DocumentNode, 3 | FieldNode, 4 | SelectionSetNode, 5 | DefinitionNode, 6 | Kind, 7 | DirectiveNode, 8 | SelectionNode, 9 | } from 'graphql'; 10 | import { FieldWeight, TypeWeightObject, Variables } from '../@types/buildTypeWeights'; 11 | /** 12 | * The AST node functions call each other following the nested structure below 13 | * Each function handles a specific GraphQL AST node type 14 | * 15 | * AST nodes call each other in the following way 16 | * 17 | * Document Node 18 | * | 19 | * Definiton Node 20 | * (operation and fragment definitons) 21 | * / | 22 | * |-----> Selection Set Node <-------| 23 | * | / 24 | * | Selection Node 25 | * | (Field, Inline fragment and fragment spread) 26 | * | | | \ 27 | * | Field Node | fragmentCache 28 | * | | | 29 | * |<--calculateCast | 30 | * | | 31 | * |<------------------| 32 | */ 33 | 34 | class QueryParser { 35 | private typeWeights: TypeWeightObject; 36 | 37 | private depth: number; 38 | 39 | public maxDepth: number; 40 | 41 | private variables: Variables; 42 | 43 | private fragmentCache: { [index: string]: { complexity: number; depth: number } }; 44 | 45 | constructor(typeWeights: TypeWeightObject, variables: Variables) { 46 | this.typeWeights = typeWeights; 47 | this.variables = variables; 48 | this.fragmentCache = {}; 49 | this.depth = 0; 50 | this.maxDepth = 0; 51 | } 52 | 53 | private calculateCost( 54 | node: FieldNode, 55 | parentName: string, 56 | typeName: string, 57 | typeWeight: FieldWeight 58 | ) { 59 | let complexity = 0; 60 | // field resolves to an object or a list with possible selections 61 | let selectionsCost = 0; 62 | let calculatedWeight = 0; 63 | 64 | if (node.selectionSet) { 65 | selectionsCost += this.selectionSetNode(node.selectionSet, typeName); 66 | } 67 | // if there are arguments and this is a list, call the 'weightFunction' to get the weight of this field. otherwise the weight is static and can be accessed through the typeWeights object 68 | if (node.arguments && typeof typeWeight === 'function') { 69 | // FIXME: May never happen but what if weight is a function and arguments don't exist 70 | calculatedWeight += typeWeight([...node.arguments], this.variables, selectionsCost); 71 | } else if (typeof typeWeight === 'number') { 72 | calculatedWeight += typeWeight + selectionsCost; 73 | } else { 74 | calculatedWeight += this.typeWeights[typeName].weight + selectionsCost; 75 | } 76 | complexity += calculatedWeight; 77 | 78 | return complexity; 79 | } 80 | 81 | private fieldNode(node: FieldNode, parentName: string): number { 82 | try { 83 | let complexity = 0; 84 | // the node must have a parent in typeweights or the analysis will fail. this should never happen 85 | const parentType = this.typeWeights[parentName]; 86 | if (!parentType) { 87 | throw new Error( 88 | `ERROR: QueryParser Failed to obtain parentType for parent: ${parentName} and node: ${node.name.value}` 89 | ); 90 | } 91 | 92 | let typeName: string | undefined; 93 | let typeWeight: FieldWeight | undefined; 94 | 95 | if (node.name.value === '__typename') return complexity; // this will be zero, ie. this field has no complexity 96 | 97 | if (node.name.value in this.typeWeights) { 98 | // node is an object type in the typeWeight root 99 | typeName = node.name.value; 100 | typeWeight = this.typeWeights[typeName].weight; 101 | complexity += this.calculateCost(node, parentName, typeName, typeWeight); 102 | } else if (parentType.fields[node.name.value].resolveTo) { 103 | // node is a field on a typeWeight root, field resolves to another type in type weights or a list 104 | typeName = parentType.fields[node.name.value].resolveTo; 105 | typeWeight = parentType.fields[node.name.value].weight; 106 | // if this is a list typeWeight is a weight function 107 | // otherwise the weight would be null as the weight is defined on the typeWeights root 108 | if (typeName && typeWeight) { 109 | // Type is a list and has a weight function 110 | complexity += this.calculateCost(node, parentName, typeName, typeWeight); 111 | } else if (typeName) { 112 | // resolve type exists at root of typeWeight object and is not a list 113 | typeWeight = this.typeWeights[typeName].weight; 114 | complexity += this.calculateCost(node, parentName, typeName, typeWeight); 115 | } else { 116 | throw new Error( 117 | `ERROR: QueryParser Failed to obtain resolved type name or weight for node: ${parentName}.${node.name.value}` 118 | ); 119 | } 120 | } else { 121 | // field is a scalar 122 | typeName = node.name.value; 123 | if (typeName) { 124 | typeWeight = parentType.fields[typeName].weight; 125 | if (typeof typeWeight === 'number') { 126 | complexity += typeWeight; 127 | } else { 128 | throw new Error( 129 | `ERROR: QueryParser Failed to obtain type weight for ${parentName}.${node.name.value}` 130 | ); 131 | } 132 | } else { 133 | throw new Error( 134 | `ERROR: QueryParser Failed to obtain type name for ${parentName}.${node.name.value}` 135 | ); 136 | } 137 | } 138 | return complexity; 139 | } catch (err) { 140 | throw new Error( 141 | `ERROR: QueryParser.fieldNode Uncaught error handling ${parentName}.${ 142 | node.name.value 143 | }\n 144 | ${err instanceof Error && err.stack}` 145 | ); 146 | } 147 | } 148 | 149 | /** 150 | * Return true if: 151 | * 1. there is no directive 152 | * 2. there is a directive named inlcude and the value is true 153 | * 3. there is a directive named skip and the value is false 154 | */ 155 | // THIS IS NOT CALLED ANYWEHERE. IN PROGRESS 156 | private directiveCheck(directive: DirectiveNode): boolean { 157 | if (directive?.arguments) { 158 | // get the first argument 159 | const argument = directive.arguments[0]; 160 | // ensure the argument name is 'if' 161 | const argumentHasVariables = 162 | argument.value.kind === Kind.VARIABLE && argument.name.value === 'if'; 163 | // access the value of the argument depending on whether it is passed as a variable or not 164 | let directiveArgumentValue; 165 | if (argument.value.kind === Kind.BOOLEAN) { 166 | directiveArgumentValue = Boolean(argument.value.value); 167 | } else if (argumentHasVariables) { 168 | directiveArgumentValue = Boolean(this.variables[argument.value.name.value]); 169 | } 170 | 171 | return ( 172 | (directive.name.value === 'include' && directiveArgumentValue === true) || 173 | (directive.name.value === 'skip' && directiveArgumentValue === false) 174 | ); 175 | } 176 | return true; 177 | } 178 | 179 | private selectionNode(node: SelectionNode, parentName: string): number { 180 | let complexity = 0; 181 | // TODO: complete implementation of directives include and skip 182 | /** 183 | * process this node only if: 184 | * 1. there is no directive 185 | * 2. there is a directive named inlcude and the value is true 186 | * 3. there is a directive named skip and the value is false 187 | */ 188 | // const directive = node.directives; 189 | // if (directive && this.directiveCheck(directive[0])) { 190 | this.depth += 1; 191 | if (this.depth > this.maxDepth) this.maxDepth = this.depth; 192 | // the kind of a field node will either be field, fragment spread or inline fragment 193 | if (node.kind === Kind.FIELD) { 194 | complexity += this.fieldNode(node, parentName.toLowerCase()); 195 | } else if (node.kind === Kind.FRAGMENT_SPREAD) { 196 | // add complexity and depth from fragment cache 197 | const { complexity: fragComplexity, depth: fragDepth } = 198 | this.fragmentCache[node.name.value]; 199 | complexity += fragComplexity; 200 | this.depth += fragDepth; 201 | if (this.depth > this.maxDepth) this.maxDepth = this.depth; 202 | this.depth -= fragDepth; 203 | 204 | // This is a leaf 205 | // need to parse fragment definition at root and get the result here 206 | } else if (node.kind === Kind.INLINE_FRAGMENT) { 207 | const { typeCondition } = node; 208 | 209 | // named type is the type from which inner fields should be take 210 | // If the TypeCondition is omitted, an inline fragment is considered to be of the same type as the enclosing context 211 | const namedType = typeCondition ? typeCondition.name.value.toLowerCase() : parentName; 212 | 213 | // TODO: Handle directives like @include and @skip 214 | // subtract 1 before, and add one after, entering the fragment selection to negate the additional level of depth added 215 | this.depth -= 1; 216 | complexity += this.selectionSetNode(node.selectionSet, namedType); 217 | this.depth += 1; 218 | } else { 219 | throw new Error(`ERROR: QueryParser.selectionNode: node type not supported`); 220 | } 221 | 222 | this.depth -= 1; 223 | //* } 224 | return complexity; 225 | } 226 | 227 | private selectionSetNode(node: SelectionSetNode, parentName: string): number { 228 | let complexity = 0; 229 | let maxFragmentComplexity = 0; 230 | for (let i = 0; i < node.selections.length; i += 1) { 231 | // pass the current parent through because selection sets act only as intermediaries 232 | const selectionNode = node.selections[i]; 233 | const selectionCost = this.selectionNode(selectionNode, parentName); 234 | 235 | // we need to get the largest possible complexity so we save the largest inline fragment 236 | // e.g. ...UnionType and ...PartofTheUnion 237 | // this case these complexities should be summed in order to be accurate 238 | // However an estimation suffice 239 | // FIXME: Consider the case where 2 typed fragments are applicable 240 | if (selectionNode.kind === Kind.INLINE_FRAGMENT) { 241 | if (!selectionNode.typeCondition) { 242 | // complexity is always applicable 243 | complexity += selectionCost; 244 | } else if (selectionCost > maxFragmentComplexity) 245 | maxFragmentComplexity = selectionCost; 246 | } else { 247 | complexity += selectionCost; 248 | } 249 | } 250 | return complexity + maxFragmentComplexity; 251 | } 252 | 253 | private definitionNode(node: DefinitionNode): number { 254 | let complexity = 0; 255 | // Operation definition is either query, mutation or subscripiton 256 | if (node.kind === Kind.OPERATION_DEFINITION) { 257 | if (node.operation.toLocaleLowerCase() in this.typeWeights) { 258 | complexity += this.typeWeights[node.operation].weight; 259 | if (node.selectionSet) { 260 | complexity += this.selectionSetNode(node.selectionSet, node.operation); 261 | } 262 | } 263 | } else if (node.kind === Kind.FRAGMENT_DEFINITION) { 264 | // Fragments can only be defined on the root type. 265 | // Parse the complexity of this fragment once and store it for use when analyzing other nodes 266 | const namedType = node.typeCondition.name.value; 267 | // Duplicate fragment names are not allowed by the GraphQL spec and an error is thrown if used. 268 | const fragmentName = node.name.value; 269 | 270 | const fragmentComplexity = this.selectionSetNode( 271 | node.selectionSet, 272 | namedType.toLowerCase() 273 | ); 274 | 275 | // Don't count fragment complexity in the node's complexity. Only when fragment is used. 276 | this.fragmentCache[fragmentName] = { 277 | complexity: fragmentComplexity, 278 | depth: this.maxDepth - 1, // subtract one from the calculated depth of the fragment to correct for the additional depth the fragment adds to the query when used 279 | }; 280 | } 281 | // TODO: Verify that there are no other type definition nodes that need to be handled (see ast.d.ts in 'graphql') 282 | // else { 283 | // 284 | // // Other types include TypeSystemDefinitionNode (Schema, Type, Directvie) and 285 | // // TypeSystemExtensionNode(Schema, Type); 286 | // throw new Error(`ERROR: QueryParser.definitionNode: ${node.kind} type not supported`); 287 | // } 288 | return complexity; 289 | } 290 | 291 | private documentNode(node: DocumentNode): number { 292 | let complexity = 0; 293 | // Sort the definitions array by kind so that fragments are always parsed first. 294 | // Fragments must be parsed first so that their complexity is available to other nodes. 295 | const sortedDefinitions = [...node.definitions].sort((a, b) => 296 | a.kind.localeCompare(b.kind) 297 | ); 298 | for (let i = 0; i < sortedDefinitions.length; i += 1) { 299 | complexity += this.definitionNode(sortedDefinitions[i]); 300 | } 301 | return complexity; 302 | } 303 | 304 | public processQuery(queryAST: DocumentNode): number { 305 | return this.documentNode(queryAST); 306 | } 307 | } 308 | 309 | export default QueryParser; 310 | -------------------------------------------------------------------------------- /test/rateLimiters/tokenBucket.test.ts: -------------------------------------------------------------------------------- 1 | import * as ioredis from 'ioredis'; 2 | import { RedisBucket } from '../../src/@types/rateLimit'; 3 | import TokenBucket from '../../src/rateLimiters/tokenBucket'; 4 | 5 | // eslint-disable-next-line @typescript-eslint/no-var-requires 6 | const RedisMock = require('ioredis-mock'); 7 | 8 | const CAPACITY = 10; 9 | // FIXME: Changing the refill rate effects test outcomes. 10 | const REFILL_RATE = 1; // 1 token per second 11 | const keyExpiry = 1000000; 12 | 13 | let limiter: TokenBucket; 14 | let client: ioredis.Redis; 15 | let timestamp: number; 16 | const user1 = '1'; 17 | const user2 = '2'; 18 | const user3 = '3'; 19 | const user4 = '4'; 20 | 21 | async function getBucketFromClient(redisClient: ioredis.Redis, uuid: string): Promise { 22 | const res = await redisClient.get(uuid); 23 | // if no uuid is found, return -1 for tokens and timestamp, which are both impossible 24 | if (res === null) return { tokens: -1, timestamp: -1 }; 25 | return JSON.parse(res); 26 | } 27 | 28 | async function setTokenCountInClient( 29 | redisClient: ioredis.Redis, 30 | uuid: string, 31 | tokens: number, 32 | time: number 33 | ) { 34 | const value: RedisBucket = { tokens, timestamp: time }; 35 | await redisClient.set(uuid, JSON.stringify(value)); 36 | } 37 | 38 | describe('Test TokenBucket Rate Limiter', () => { 39 | beforeEach(async () => { 40 | // Initialize a new token bucket before each test 41 | // create a mock user 42 | // intialze the token bucket algorithm 43 | client = new RedisMock(); 44 | limiter = new TokenBucket(CAPACITY, REFILL_RATE, client, keyExpiry); 45 | timestamp = new Date().valueOf(); 46 | }); 47 | 48 | describe('TokenBucket returns correct number of tokens and updates redis store as expected', () => { 49 | describe('after an ALLOWED request...', () => { 50 | afterEach(() => { 51 | client.flushall(); 52 | }); 53 | test('bucket is initially full', async () => { 54 | // Bucket intially full 55 | const withdraw5 = 5; 56 | expect((await limiter.processRequest(user1, timestamp, withdraw5)).tokens).toBe( 57 | CAPACITY - withdraw5 58 | ); 59 | const tokenCountFull = await getBucketFromClient(client, user1); 60 | expect(tokenCountFull.tokens).toBe(CAPACITY - withdraw5); 61 | }); 62 | 63 | test('bucket is partially full and request has leftover tokens', async () => { 64 | // Bucket partially full but enough time has elapsed to fill the bucket since the last request and 65 | // has leftover tokens after reqeust 66 | const initial = 6; 67 | const partialWithdraw = 1; 68 | await setTokenCountInClient(client, user2, initial, timestamp); 69 | expect( 70 | ( 71 | await limiter.processRequest( 72 | user2, 73 | timestamp + 1000 * (CAPACITY - initial), 74 | initial + partialWithdraw 75 | ) 76 | ).tokens 77 | ).toBe(CAPACITY - (initial + partialWithdraw)); 78 | const tokenCountPartial = await getBucketFromClient(client, user2); 79 | expect(tokenCountPartial.tokens).toBe(CAPACITY - (initial + partialWithdraw)); 80 | }); 81 | 82 | // Bucket partially full and no leftover tokens after reqeust 83 | test('bucket is partially full and request has no leftover tokens', async () => { 84 | const initial = 6; 85 | await setTokenCountInClient(client, user2, initial, timestamp); 86 | expect((await limiter.processRequest(user2, timestamp, initial)).tokens).toBe(0); 87 | const tokenCountPartialToEmpty = await getBucketFromClient(client, user2); 88 | expect(tokenCountPartialToEmpty.tokens).toBe(0); 89 | }); 90 | 91 | // Bucket initially empty but enough time elapsed to paritally fill bucket since last request 92 | test('bucket is initially empty but enough time has elapsed to partially fill the bucket', async () => { 93 | await setTokenCountInClient(client, user4, 0, timestamp); 94 | expect((await limiter.processRequest(user4, timestamp + 6000, 4)).tokens).toBe(2); 95 | const count = await getBucketFromClient(client, user4); 96 | expect(count.tokens).toBe(2); 97 | }); 98 | }); 99 | 100 | describe('after a BLOCKED request...', () => { 101 | let redisData: RedisBucket; 102 | 103 | afterAll(() => { 104 | client.flushall(); 105 | }); 106 | 107 | test('where intial request is greater than bucket capacity', async () => { 108 | // Initial request greater than capacity 109 | expect((await limiter.processRequest(user1, timestamp, CAPACITY + 1)).tokens).toBe( 110 | CAPACITY 111 | ); 112 | 113 | redisData = await getBucketFromClient(client, user1); 114 | expect(redisData.tokens).toBe(CAPACITY); 115 | }); 116 | 117 | test('Bucket is partially full but not enough time elapsed to complete the request', async () => { 118 | // Bucket is partially full and time has elapsed but not enough to allow the current request 119 | const fillLevel = 5; 120 | const timeDelta = 3; 121 | const requestedTokens = 9; 122 | await setTokenCountInClient(client, user2, fillLevel, timestamp); 123 | 124 | expect( 125 | ( 126 | await limiter.processRequest( 127 | user2, 128 | timestamp + timeDelta * 1000, 129 | requestedTokens 130 | ) 131 | ).tokens 132 | ).toBe(fillLevel + timeDelta * REFILL_RATE); 133 | 134 | redisData = await getBucketFromClient(client, user2); 135 | expect(redisData.tokens).toBe(fillLevel + timeDelta * REFILL_RATE); 136 | }); 137 | }); 138 | }); 139 | 140 | describe('Token Bucket functions as expected', () => { 141 | afterEach(() => { 142 | client.flushall(); 143 | }); 144 | test('allows a user to consume up to their current allotment of tokens', async () => { 145 | // "free requests" 146 | expect((await limiter.processRequest(user1, timestamp, 0)).success).toBe(true); 147 | // Test 1 token requested 148 | expect((await limiter.processRequest(user1, timestamp, 1)).success).toBe(true); 149 | // Test < CAPACITY tokens requested 150 | expect((await limiter.processRequest(user2, timestamp, CAPACITY - 1)).success).toBe( 151 | true 152 | ); 153 | // <= CAPACITY tokens requested 154 | expect((await limiter.processRequest(user3, timestamp, CAPACITY)).success).toBe(true); 155 | }); 156 | 157 | test("blocks requests exceeding the user's current allotment of tokens", async () => { 158 | // Test > capacity tokens requested 159 | expect((await limiter.processRequest(user1, timestamp, CAPACITY + 1)).success).toBe( 160 | false 161 | ); 162 | 163 | // Empty user 1's bucket 164 | const value: RedisBucket = { tokens: 0, timestamp }; 165 | await client.set(user1, JSON.stringify(value)); 166 | 167 | // bucket is empty. Shouldn't be allowed to take 1 token 168 | expect((await limiter.processRequest(user1, timestamp, 1)).success).toBe(false); 169 | 170 | // Should still be allowed to process "free" requests 171 | expect((await limiter.processRequest(user1, timestamp, 0)).success).toBe(true); 172 | }); 173 | 174 | test('token bucket never exceeds maximum capacity', async () => { 175 | // make sure bucket doesn't exceed max size without any requests. 176 | // Fill the user's bucket then request additional tokens after an interval 177 | const value: RedisBucket = { tokens: CAPACITY, timestamp }; 178 | await client.set(user1, JSON.stringify(value)); 179 | expect( 180 | (await limiter.processRequest(user1, timestamp + 1000, CAPACITY + 1)).success 181 | ).toBe(false); 182 | expect( 183 | (await limiter.processRequest(user1, timestamp + 10000, CAPACITY + 1)).success 184 | ).toBe(false); 185 | expect( 186 | (await limiter.processRequest(user1, timestamp + 100000, CAPACITY + 1)).success 187 | ).toBe(false); 188 | }); 189 | 190 | test('token bucket refills at specified rate', async () => { 191 | // make sure bucket refills if user takes tokens. 192 | const withdraw = 5; 193 | let timeDelta = 3; 194 | await limiter.processRequest(user1, timestamp, withdraw); // 5 tokens after this 195 | expect( 196 | ( 197 | await limiter.processRequest( 198 | user1, 199 | timestamp + timeDelta * 1000, // wait 3 seconds -> 8 tokens available 200 | withdraw + REFILL_RATE * timeDelta // 5 + 3 = 8 tokens requested after this , 0 remaining 201 | ) 202 | ).tokens 203 | ).toBe(0); 204 | 205 | // check if bucket refills completely and doesn't spill over. 206 | timeDelta = 2 * CAPACITY; 207 | expect( 208 | (await limiter.processRequest(user1, timestamp + timeDelta * 1000, CAPACITY + 1)) 209 | .tokens 210 | ).toBe(CAPACITY); 211 | }); 212 | 213 | test('bucket allows custom refill rates', async () => { 214 | const doubleRefillClient: ioredis.Redis = new RedisMock(); 215 | limiter = new TokenBucket(CAPACITY, 2, doubleRefillClient, keyExpiry); 216 | 217 | await setTokenCountInClient(doubleRefillClient, user1, 0, timestamp); 218 | 219 | const timeDelta = 5; 220 | expect( 221 | (await limiter.processRequest(user1, timestamp + timeDelta * 1000, 0)).tokens 222 | ).toBe(timeDelta * 2); 223 | }); 224 | 225 | test('users have their own buckets', async () => { 226 | const requested = 6; 227 | const user3Tokens = 8; 228 | // Add tokens for user 3 so we have both a user that exists in the store (3) and one that doesn't (2) 229 | await setTokenCountInClient(client, user3, user3Tokens, timestamp); 230 | 231 | // issue a request for user 1; 232 | await limiter.processRequest(user1, timestamp, requested); 233 | 234 | // Check that each user has the expected amount of tokens. 235 | expect((await getBucketFromClient(client, user1)).tokens).toBe(CAPACITY - requested); 236 | expect((await getBucketFromClient(client, user2)).tokens).toBe(-1); // not in the store so this returns -1 237 | expect((await getBucketFromClient(client, user3)).tokens).toBe(user3Tokens); 238 | 239 | await limiter.processRequest(user2, timestamp, 1); 240 | expect((await getBucketFromClient(client, user1)).tokens).toBe(CAPACITY - requested); 241 | expect((await getBucketFromClient(client, user2)).tokens).toBe(CAPACITY - 1); 242 | expect((await getBucketFromClient(client, user3)).tokens).toBe(user3Tokens); 243 | }); 244 | 245 | test('bucket does not allow capacity or refill rate <= 0', () => { 246 | expect(() => new TokenBucket(0, 1, client, keyExpiry)).toThrow( 247 | 'TokenBucket refillRate, capacity and keyExpiry must be positive' 248 | ); 249 | expect(() => new TokenBucket(-10, 1, client, keyExpiry)).toThrow( 250 | 'TokenBucket refillRate, capacity and keyExpiry must be positive' 251 | ); 252 | expect(() => new TokenBucket(10, -1, client, keyExpiry)).toThrow( 253 | 'TokenBucket refillRate, capacity and keyExpiry must be positive' 254 | ); 255 | expect(() => new TokenBucket(10, 0, client, keyExpiry)).toThrow( 256 | 'TokenBucket refillRate, capacity and keyExpiry must be positive' 257 | ); 258 | expect(() => new TokenBucket(10, 2, client, 0)).toThrow( 259 | 'TokenBucket refillRate, capacity and keyExpiry must be positive' 260 | ); 261 | }); 262 | 263 | test('All buckets should be able to be reset', async () => { 264 | const tokens = 5; 265 | await setTokenCountInClient(client, user1, tokens, timestamp); 266 | await setTokenCountInClient(client, user2, tokens, timestamp); 267 | await setTokenCountInClient(client, user3, tokens, timestamp); 268 | 269 | limiter.reset(); 270 | 271 | expect((await limiter.processRequest(user1, timestamp, CAPACITY)).success).toBe(true); 272 | expect((await limiter.processRequest(user2, timestamp, CAPACITY - 1)).success).toBe( 273 | true 274 | ); 275 | expect((await limiter.processRequest(user3, timestamp, CAPACITY + 1)).success).toBe( 276 | false 277 | ); 278 | }); 279 | }); 280 | describe('returns "retryAfter" if a request fails and', () => { 281 | /** 282 | * Strategy 283 | * Check where limitint request is at either end of log and in the middle 284 | * Infinity if > capacity (handled above) 285 | * doesn't appear if success (handled above) 286 | * */ 287 | beforeEach(() => { 288 | timestamp = 1000; 289 | }); 290 | 291 | test('the user already has key in cache and tokens is less than capacity', async () => { 292 | // set 5 tokens in bucket 293 | await setTokenCountInClient(client, user1, 5, timestamp); 294 | // wait 2 seconds and request 9 tokens 295 | const { retryAfter } = await limiter.processRequest(user1, timestamp + 2000, 9); 296 | // this is 2 tokens more than in bucket, should have to wait 2 seconds 297 | expect(retryAfter).toBe(2); 298 | }); 299 | 300 | test('the user already has key in cache and tokens is greater than capacity', async () => { 301 | // set 5 tokens in bucket 302 | await setTokenCountInClient(client, user1, 9, timestamp); 303 | // wait 2 seconds and request 9 tokens 304 | const { retryAfter } = await limiter.processRequest(user1, timestamp + 2000, 11); 305 | // this is 2 tokens more than in bucket, should have to wait 2 seconds 306 | expect(retryAfter).toBe(Infinity); 307 | }); 308 | 309 | test('the user has no key in cache and tokens is greater than capacity', async () => { 310 | // wait 2 seconds and request 9 tokens 311 | const { retryAfter } = await limiter.processRequest(user1, timestamp + 2000, 11); 312 | // this is 2 tokens more than in bucket, should have to wait 2 seconds 313 | expect(retryAfter).toBe(Infinity); 314 | }); 315 | }); 316 | describe('Token Bucket correctly updates redis store', () => { 317 | test('timestamp correctly updated in redis', async () => { 318 | let redisData: RedisBucket; 319 | 320 | // blocked request 321 | await limiter.processRequest(user1, timestamp, CAPACITY + 1); 322 | redisData = await getBucketFromClient(client, user1); 323 | expect(redisData.timestamp).toBe(timestamp); 324 | 325 | timestamp += 1000; 326 | // allowed request 327 | await limiter.processRequest(user2, timestamp, CAPACITY); 328 | redisData = await getBucketFromClient(client, user2); 329 | expect(redisData.timestamp).toBe(timestamp); 330 | }); 331 | 332 | test('All buckets should be able to be reset', async () => { 333 | // add data to redis 334 | const time = new Date(); 335 | const value = JSON.stringify({ tokens: 0, timestamp: time.valueOf() }); 336 | 337 | await client.set(user1, value); 338 | await client.set(user2, value); 339 | await client.set(user3, value); 340 | 341 | limiter.reset(); 342 | 343 | const resetUser1 = await client.get(user1); 344 | const resetUser2 = await client.get(user2); 345 | const resetUser3 = await client.get(user3); 346 | expect(resetUser1).toBe(null); 347 | expect(resetUser2).toBe(null); 348 | expect(resetUser3).toBe(null); 349 | }); 350 | }); 351 | }); 352 | -------------------------------------------------------------------------------- /src/analysis/buildTypeWeights.ts: -------------------------------------------------------------------------------- 1 | import { 2 | ArgumentNode, 3 | GraphQLArgument, 4 | GraphQLNamedType, 5 | GraphQLObjectType, 6 | GraphQLInterfaceType, 7 | GraphQLOutputType, 8 | isCompositeType, 9 | isEnumType, 10 | isInterfaceType, 11 | isListType, 12 | isNonNullType, 13 | isObjectType, 14 | isScalarType, 15 | isUnionType, 16 | Kind, 17 | ValueNode, 18 | GraphQLUnionType, 19 | GraphQLFieldMap, 20 | isInputObjectType, 21 | } from 'graphql'; 22 | import { ObjMap } from 'graphql/jsutils/ObjMap'; 23 | import { GraphQLSchema } from 'graphql/type/schema'; 24 | import { 25 | TypeWeightConfig, 26 | TypeWeightSet, 27 | TypeWeightObject, 28 | Variables, 29 | Type, 30 | Fields, 31 | FieldMap, 32 | } from '../@types/buildTypeWeights'; 33 | 34 | export const KEYWORDS = ['first', 'last', 'limit']; 35 | 36 | // default configuration weights for GraphQL types 37 | export const defaultTypeWeightsConfig: TypeWeightSet = { 38 | mutation: 10, 39 | object: 1, 40 | scalar: 0, 41 | connection: 2, 42 | query: 1, 43 | }; 44 | 45 | /** 46 | * Parses the fields on an object type (query, object, interface) and returns field weights in type weight object format 47 | * 48 | * @param {(GraphQLObjectType | GraphQLInterfaceType)} type 49 | * @param {TypeWeightObject} typeWeightObject 50 | * @param {TypeWeightSet} typeWeights 51 | * @param {boolean} enforceBoundedLists 52 | * @return {Type} 53 | */ 54 | function parseObjectFields( 55 | type: GraphQLObjectType | GraphQLInterfaceType, 56 | typeWeightObject: TypeWeightObject, 57 | typeWeights: TypeWeightSet, 58 | enforceBoundedLists: boolean 59 | ): Type { 60 | let result: Type; 61 | switch (type.name) { 62 | case 'Query': 63 | result = { weight: typeWeights.query, fields: {} }; 64 | break; 65 | case 'Mutation': 66 | result = { weight: typeWeights.mutation, fields: {} }; 67 | break; 68 | default: 69 | result = { weight: typeWeights.object, fields: {} }; 70 | break; 71 | } 72 | 73 | const fields = type.getFields(); 74 | 75 | // Iterate through the fields and add the required data to the result 76 | Object.keys(fields).forEach((field: string) => { 77 | // The GraphQL type that this field represents 78 | let fieldType: GraphQLOutputType = fields[field].type; 79 | if (isNonNullType(fieldType)) fieldType = fieldType.ofType; 80 | if (isScalarType(fieldType)) { 81 | result.fields[field] = { 82 | weight: typeWeights.scalar, 83 | }; 84 | } else if ( 85 | isInterfaceType(fieldType) || 86 | isEnumType(fieldType) || 87 | isObjectType(fieldType) || 88 | isUnionType(fieldType) 89 | ) { 90 | result.fields[field] = { 91 | resolveTo: fieldType.name.toLocaleLowerCase(), 92 | }; 93 | } else if (isListType(fieldType)) { 94 | // 'listType' is the GraphQL type that the list resolves to 95 | let listType = fieldType.ofType; 96 | if (isNonNullType(listType)) listType = listType.ofType; 97 | if (isScalarType(listType) && typeWeights.scalar === 0) { 98 | // list won't compound if weight is zero 99 | result.fields[field] = { 100 | weight: typeWeights.scalar, 101 | }; 102 | } else if (isEnumType(listType) && typeWeights.scalar === 0) { 103 | // list won't compound if weight of enum is zero 104 | result.fields[field] = { 105 | resolveTo: listType.toString().toLocaleLowerCase(), 106 | }; 107 | } else { 108 | // fieldAdded is a boolean flag to check if we have added a something to the typeweight object for this field. 109 | // if we reach end of the list and fieldAdded is false, we have an unbounded list. 110 | let fieldAdded = false; 111 | // if the @listCost directive is given for the field, apply the cost argument's value to the field's weight 112 | const directives = fields[field].astNode?.directives; 113 | if (directives && directives.length > 0) { 114 | directives.forEach((dir) => { 115 | if (dir.name.value === 'listCost') { 116 | fieldAdded = true; 117 | if ( 118 | dir.arguments && 119 | dir.arguments[0].value.kind === Kind.INT && 120 | Number(dir.arguments[0].value.value) >= 0 121 | ) { 122 | result.fields[field] = { 123 | resolveTo: listType.toString().toLocaleLowerCase(), 124 | weight: Number(dir.arguments[0].value.value), 125 | }; 126 | } else { 127 | throw new SyntaxError(`@listCost directive improperly configured`); 128 | } 129 | } 130 | }); 131 | } 132 | 133 | // chcek for slicing arguments on field for bounding lists 134 | fields[field].args.forEach((arg: GraphQLArgument) => { 135 | // If field has an argument matching one of the limiting keywords and resolves to a list 136 | // then the weight of the field should be dependent on both the weight of the resolved type and the limiting argument. 137 | if (KEYWORDS.includes(arg.name)) { 138 | // Get the type that comprises the list 139 | fieldAdded = true; 140 | /** "weight" property is a function that calculates the list complexity based: 141 | * 1. on the cost of it's field selections 142 | * 2. the value of the slicing argment (multiplier) 143 | * 3. the wight of the field itself */ 144 | result.fields[field] = { 145 | resolveTo: listType.toString().toLocaleLowerCase(), 146 | weight: ( 147 | args: ArgumentNode[], 148 | variables: Variables, 149 | selectionsCost: number 150 | ): number => { 151 | const limitArg: ArgumentNode | undefined = args.find( 152 | (cur) => cur.name.value === arg.name 153 | ); 154 | const weight = isCompositeType(listType) 155 | ? typeWeightObject[listType.name.toLowerCase()].weight 156 | : typeWeights.scalar; // Note this includes enums 157 | let multiplier = 1; 158 | if (limitArg) { 159 | const node: ValueNode = limitArg.value; 160 | if (Kind.INT === node.kind) { 161 | multiplier = Number(node.value || arg.defaultValue); 162 | } 163 | if (Kind.VARIABLE === node.kind) { 164 | multiplier = Number( 165 | variables[node.name.value] || arg.defaultValue 166 | ); 167 | } 168 | // ? what else can get through here 169 | } else if (arg.defaultValue) { 170 | // if there is no argument provided with the query, check the schema for a default 171 | multiplier = Number(arg.defaultValue); 172 | } 173 | // if there is no argument or defaultValue, multiplier will still be one, effectively making list size equel to 1 as a last resort 174 | return multiplier * (selectionsCost + weight); 175 | }, 176 | }; 177 | } 178 | }); 179 | 180 | // throw an error if an unbounded list has no @listCost directive attached or slicing arguments 181 | // and the enforceBoundedLists configuration option is sent to true 182 | if (fieldAdded === false && enforceBoundedLists) { 183 | throw new Error( 184 | `ERROR: buildTypeWeights: Use directive @listCost(cost: Int!) on unbounded lists, or limit query results with ${KEYWORDS}` 185 | ); 186 | } 187 | } 188 | } else { 189 | // FIXME what else can get through here 190 | throw new Error(`ERROR: buildTypeWeight: Unsupported field type: ${fieldType}`); 191 | } 192 | }); 193 | 194 | return result; 195 | } 196 | 197 | /** 198 | * Recursively compares two types for type equality based on type name 199 | * @param a 200 | * @param b 201 | * @returns true if the types are recursively equal. 202 | */ 203 | function compareTypes(a: GraphQLOutputType, b: GraphQLOutputType): boolean { 204 | // Base Case: Object or Scalar => compare type names 205 | // Recursive Case(List / NonNull): compare ofType 206 | return ( 207 | (isObjectType(b) && isObjectType(a) && a.name === b.name) || 208 | (isUnionType(b) && isUnionType(a) && a.name === b.name) || 209 | (isEnumType(b) && isEnumType(a) && a.name === b.name) || 210 | (isInterfaceType(b) && isInterfaceType(a) && a.name === b.name) || 211 | (isScalarType(b) && isScalarType(a) && a.name === b.name) || 212 | (isListType(b) && isListType(a) && compareTypes(b.ofType, a.ofType)) || 213 | (isNonNullType(b) && isNonNullType(a) && compareTypes(a.ofType, b.ofType)) 214 | ); 215 | } 216 | 217 | /** 218 | * 219 | * @param unionType union type to be parsed 220 | * @param typeWeightObject type weight mapping object that must already contain all of the types in the schema. 221 | * @returns object mapping field names for each union type to their respective weights, resolve type names and resolve type object 222 | */ 223 | function getFieldsForUnionType( 224 | unionType: GraphQLUnionType, 225 | typeWeightObject: TypeWeightObject 226 | ): FieldMap[] { 227 | return unionType.getTypes().map((objectType: GraphQLObjectType) => { 228 | // Get the field data for this type 229 | const fields: GraphQLFieldMap = objectType.getFields(); 230 | 231 | const fieldMap: FieldMap = {}; 232 | Object.keys(fields).forEach((field: string) => { 233 | // Get the weight of this field on from parent type on the root typeWeight object. 234 | // this only exists for scalars and lists (which resolve to a function); 235 | const { weight, resolveTo } = 236 | typeWeightObject[objectType.name.toLowerCase()].fields[field]; 237 | 238 | fieldMap[field] = { 239 | type: fields[field].type, 240 | weight, // will only be undefined for object types 241 | resolveTo, 242 | }; 243 | }); 244 | return fieldMap; 245 | }); 246 | } 247 | 248 | /** 249 | * 250 | * @param typesInUnion 251 | * @returns a single field map containg information for fields common to the union 252 | */ 253 | function getSharedFieldsFromUnionTypes(typesInUnion: FieldMap[]): FieldMap { 254 | return typesInUnion.reduce((prev: FieldMap, fieldMap: FieldMap): FieldMap => { 255 | // iterate through the field map checking the types for any common field names 256 | const sharedFields: FieldMap = {}; 257 | Object.keys(prev).forEach((field: string) => { 258 | if (fieldMap[field]) { 259 | if (compareTypes(prev[field].type, fieldMap[field].type)) { 260 | // they match add the type to the next set 261 | sharedFields[field] = prev[field]; 262 | } 263 | } 264 | }); 265 | return sharedFields; 266 | }); 267 | } 268 | 269 | /** 270 | * Parses the provided union types and returns a type weight object with any fields common to all types 271 | * in a union added to the union type 272 | * @param unionTypes union types to be parsed. 273 | * @param typeWeights object specifying generic type weights. 274 | * @param typeWeightObject original type weight object 275 | * @returns 276 | */ 277 | function parseUnionTypes( 278 | unionTypes: GraphQLUnionType[], 279 | typeWeights: TypeWeightSet, 280 | typeWeightObject: TypeWeightObject 281 | ) { 282 | const typeWeightsWithUnions: TypeWeightObject = { ...typeWeightObject }; 283 | 284 | unionTypes.forEach((unionType: GraphQLUnionType) => { 285 | /** 286 | * 1. For each provided union type. We first obtain the fields for each object that 287 | * is part of the union and store these in an object 288 | * When obtaining types, save: 289 | * - field name 290 | * - type object to which the field resolves. This holds any information for recursive types (lists / not null / unions) 291 | * - weight - for easy lookup later 292 | * - resolveTo type - for easy lookup later 293 | * 2. We then reduce the array of objects from step 1 a single object only containing fields 294 | * common to each type in the union. To determine field "equality" we compare the field names and 295 | * recursively compare the field types: 296 | * */ 297 | 298 | // types is an array mapping each field name to it's respective output type 299 | // const typesInUnion = getFieldsForUnionType(unionType, typeWeightObject); 300 | const typesInUnion: FieldMap[] = getFieldsForUnionType(unionType, typeWeightObject); 301 | 302 | // reduce the data for all the types in the union 303 | const commonFields: FieldMap = getSharedFieldsFromUnionTypes(typesInUnion); 304 | 305 | // transform commonFields into the correct format for the type weight object 306 | const fieldTypes: Fields = {}; 307 | 308 | Object.keys(commonFields).forEach((field: string) => { 309 | /** 310 | * The type weight object requires that: 311 | * a. scalars have a weight 312 | * b. lists have a resolveTo and weight property 313 | * c. objects have a resolveTo type. 314 | * */ 315 | 316 | let current = commonFields[field].type; 317 | if (isNonNullType(current)) current = current.ofType; 318 | if (isScalarType(current)) { 319 | fieldTypes[field] = { 320 | weight: commonFields[field].weight, 321 | }; 322 | } else if ( 323 | isObjectType(current) || 324 | isInterfaceType(current) || 325 | isUnionType(current) || 326 | isEnumType(current) 327 | ) { 328 | fieldTypes[field] = { 329 | resolveTo: commonFields[field].resolveTo, 330 | }; 331 | } else if (isListType(current)) { 332 | fieldTypes[field] = { 333 | resolveTo: commonFields[field].resolveTo, 334 | weight: commonFields[field].weight, 335 | }; 336 | } else { 337 | throw new Error('Unhandled union type. Should never get here'); 338 | } 339 | }); 340 | typeWeightsWithUnions[unionType.name.toLowerCase()] = { 341 | fields: fieldTypes, 342 | weight: typeWeights.object, 343 | }; 344 | }); 345 | 346 | return typeWeightsWithUnions; 347 | } 348 | /** 349 | * Parses all types in the provided schema object excempt for Query, Mutation 350 | * and built in types that begin with '__' and outputs a new TypeWeightObject 351 | * @param schema 352 | * @param typeWeights 353 | * @param enforceBoundedLists 354 | * @returns 355 | */ 356 | function parseTypes( 357 | schema: GraphQLSchema, 358 | typeWeights: TypeWeightSet, 359 | enforceBoundedLists: boolean 360 | ): TypeWeightObject { 361 | const typeMap: ObjMap = schema.getTypeMap(); 362 | 363 | const result: TypeWeightObject = {}; 364 | 365 | const unions: GraphQLUnionType[] = []; 366 | 367 | // Handle Object, Interface, Enum and Union types 368 | Object.keys(typeMap).forEach((type) => { 369 | const typeName: string = type.toLowerCase(); 370 | const currentType: GraphQLNamedType = typeMap[type]; 371 | 372 | // Get all types that aren't Query or Mutation or a built in type that starts with '__' 373 | if (!type.startsWith('__')) { 374 | if (isObjectType(currentType) || isInterfaceType(currentType)) { 375 | // Add the type and it's associated fields to the result 376 | result[typeName] = parseObjectFields( 377 | currentType, 378 | result, 379 | typeWeights, 380 | enforceBoundedLists 381 | ); 382 | } else if (isEnumType(currentType)) { 383 | result[typeName] = { 384 | fields: {}, 385 | weight: typeWeights.scalar, 386 | }; 387 | } else if (isUnionType(currentType)) { 388 | unions.push(currentType); 389 | } else if (!isScalarType(currentType) && !isInputObjectType(currentType)) { 390 | throw new Error(`ERROR: buildTypeWeight: Unsupported type: ${currentType}`); 391 | } 392 | } 393 | }); 394 | 395 | // parse union types to complete the build of the typeWeightObject 396 | return parseUnionTypes(unions, typeWeights, result); 397 | } 398 | 399 | /** 400 | * The default typeWeightsConfig object is based off of Shopifys implementation of query 401 | * cost analysis. Our function should input a users configuration of type weights or fall 402 | * back on shopifys settings. We can change this later. 403 | * 404 | * This function should 405 | * - iterate through the schema object and create the typeWeightObject as described in the tests 406 | * - validate that the typeWeightsConfig parameter has no negative values (throw an error if it does) 407 | * 408 | * @param schema 409 | * @param enforceBoundedLists Defaults to false 410 | * @param typeWeightsConfig Defaults to {mutation: 10, object: 1, field: 0, connection: 2} 411 | */ 412 | function buildTypeWeightsFromSchema( 413 | schema: GraphQLSchema, 414 | typeWeightsConfig: TypeWeightConfig = defaultTypeWeightsConfig, 415 | enforceBoundedLists = false 416 | ): TypeWeightObject { 417 | try { 418 | if (!schema) throw new Error('Missing Argument: schema is required'); 419 | 420 | // Merge the provided type weights with the default to account for missing values 421 | const typeWeights: TypeWeightSet = { 422 | ...defaultTypeWeightsConfig, 423 | ...typeWeightsConfig, 424 | }; 425 | 426 | // Confirm that any custom weights are non-negative 427 | Object.entries(typeWeights).forEach((value: [string, number]) => { 428 | if (value[1] < 0) { 429 | throw new Error( 430 | `Type weights cannot be negative. Received: ${value[0]}: ${value[1]} ` 431 | ); 432 | } 433 | }); 434 | 435 | return parseTypes(schema, typeWeights, enforceBoundedLists); 436 | } catch (err) { 437 | throw new Error(`Error in expressGraphQLRateLimiter when parsing schema object: ${err}`); 438 | } 439 | } 440 | 441 | export default buildTypeWeightsFromSchema; 442 | -------------------------------------------------------------------------------- /test/middleware/express.test.ts: -------------------------------------------------------------------------------- 1 | import 'ts-jest'; 2 | import { Request, Response, NextFunction, RequestHandler } from 'express'; 3 | import { GraphQLSchema, buildSchema } from 'graphql'; 4 | import * as ioredis from 'ioredis'; 5 | import expressGraphQLRateLimiter from '../../src/middleware/index'; 6 | 7 | import * as redis from '../../src/utils/redis'; 8 | 9 | const mockConnect = jest.spyOn(redis, 'connect'); 10 | 11 | // eslint-disable-next-line @typescript-eslint/no-var-requires 12 | const RedisMock = require('ioredis-mock'); 13 | 14 | let middleware: RequestHandler; 15 | let mockRequest: Partial; 16 | let complexRequest: Partial; 17 | let mockResponse: Partial; 18 | let nextFunction: NextFunction = jest.fn(); 19 | const schema: GraphQLSchema = buildSchema(` 20 | directive @listCost(cost: Int!) on FIELD_DEFINITION 21 | type Query { 22 | hero(episode: Episode): Character 23 | reviews(episode: Episode!, first: Int): [Review] 24 | character(id: ID!): Character 25 | droid(id: ID!): Droid 26 | human(id: ID!): Human 27 | scalars: Scalars 28 | } 29 | enum Episode { 30 | NEWHOPE 31 | EMPIRE 32 | JEDI 33 | } 34 | interface Character { 35 | id: ID! 36 | name: String! 37 | friends: [Character] @listCost(cost: 10) 38 | appearsIn: [Episode]! 39 | } 40 | type Human implements Character { 41 | id: ID! 42 | name: String! 43 | homePlanet: String 44 | friends: [Character] @listCost(cost: 10) 45 | appearsIn: [Episode]! 46 | } 47 | type Droid implements Character { 48 | id: ID! 49 | name: String! 50 | friends: [Character] @listCost(cost: 10) 51 | primaryFunction: String 52 | appearsIn: [Episode]! 53 | } 54 | type Review { 55 | episode: Episode 56 | stars: Int! 57 | commentary: String 58 | } 59 | type Scalars { 60 | num: Int, 61 | id: ID, 62 | float: Float, 63 | bool: Boolean, 64 | string: String 65 | test: Test, 66 | } 67 | type Test { 68 | name: String, 69 | variable: Scalars 70 | } 71 | `); 72 | 73 | describe('Express Middleware tests', () => { 74 | afterEach(() => { 75 | redis.shutdown(); 76 | }); 77 | describe('Middleware is configurable...', () => { 78 | xdescribe('...successfully connects to redis using standard connection options', () => { 79 | let mockRedis; 80 | beforeEach(() => { 81 | mockRedis = new RedisMock(); 82 | }); 83 | 84 | xtest('...via url', () => { 85 | // TODO: Connect to redis instance and add 'connect' event listener 86 | // assert that event listener is called once 87 | expect(true).toBeFalsy(); 88 | 89 | // expect.assertions(1); 90 | // redis.on('connect', () => { 91 | // expect(true); 92 | // }); 93 | // expressGraphQLRateLimiter(schema, { 94 | // rateLimiter: { 95 | // type: 'TOKEN_BUCKET', 96 | // options: { refillRate: 1, bucketSize: 10 }, 97 | // }, 98 | // redis: { options: { host: '//localhost:6379' } }, 99 | // }); 100 | }); 101 | 102 | xtest('via socket', () => { 103 | // TODO: Connect to redis instance and add 'connect' event listener 104 | // assert that event listener is called once 105 | expect(true).toBeFalsy(); 106 | }); 107 | 108 | xtest('defaults to localhost', () => { 109 | // TODO: Connect to redis instance and add 'connect' event listener 110 | // assert that event listener is called once 111 | expect(true).toBeFalsy(); 112 | }); 113 | }); 114 | 115 | describe('...Can be configured to use a valid algorithm', () => { 116 | test('... Token Bucket', () => { 117 | // FIXME: Is it possible to check which algorithm was chosen beyond error checking? 118 | expect(() => 119 | expressGraphQLRateLimiter(schema, { 120 | rateLimiter: { 121 | type: 'TOKEN_BUCKET', 122 | refillRate: 1, 123 | capacity: 10, 124 | }, 125 | }) 126 | ).not.toThrow(); 127 | }); 128 | 129 | xtest('...Leaky Bucket', () => { 130 | expect(() => 131 | expressGraphQLRateLimiter(schema, { 132 | rateLimiter: { 133 | type: 'LEAKY_BUCKET', 134 | refillRate: 1, 135 | capacity: 10, // FIXME: Replace with valid params 136 | }, 137 | }) 138 | ).not.toThrow(); 139 | }); 140 | 141 | test('...Fixed Window', () => { 142 | expect(() => 143 | expressGraphQLRateLimiter(schema, { 144 | rateLimiter: { 145 | type: 'FIXED_WINDOW', 146 | capacity: 1, 147 | windowSize: 1000, 148 | }, 149 | }) 150 | ).not.toThrow(); 151 | }); 152 | 153 | test('...Sliding Window Log', () => { 154 | expect(() => 155 | expressGraphQLRateLimiter(schema, { 156 | rateLimiter: { 157 | type: 'SLIDING_WINDOW_LOG', 158 | windowSize: 1000, 159 | capacity: 10, 160 | }, 161 | }) 162 | ).not.toThrow(); 163 | }); 164 | 165 | test('...Sliding Window Counter', () => { 166 | expect(() => 167 | expressGraphQLRateLimiter(schema, { 168 | rateLimiter: { 169 | type: 'SLIDING_WINDOW_LOG', 170 | windowSize: 1, 171 | capacity: 10, 172 | }, 173 | }) 174 | ).not.toThrow(); 175 | }); 176 | }); 177 | 178 | xdescribe('... throws an error', () => { 179 | test('... for invalid schemas', () => { 180 | const invalidSchema: GraphQLSchema = buildSchema(`{Query {name}`); 181 | 182 | expect(() => 183 | expressGraphQLRateLimiter(invalidSchema, { 184 | rateLimiter: { 185 | type: 'TOKEN_BUCKET', 186 | refillRate: 1, 187 | capacity: 10, 188 | }, 189 | }) 190 | ).toThrow('GraphQLError'); 191 | }); 192 | 193 | xtest('... if unable to connect to redis', () => { 194 | expect(async () => 195 | expressGraphQLRateLimiter(schema, { 196 | rateLimiter: { 197 | type: 'TOKEN_BUCKET', 198 | refillRate: 1, 199 | capacity: 10, 200 | }, 201 | 202 | redis: { options: { host: 'localhost', port: 1 } }, 203 | }) 204 | ).toThrow('ECONNREFUSED'); 205 | }); 206 | }); 207 | 208 | describe('...other configuration parameters', () => { 209 | beforeAll(() => mockConnect.mockImplementation(() => new RedisMock())); 210 | beforeEach(() => { 211 | mockRequest = { 212 | body: { 213 | // complexity should be 2 (1 Query + 1 Scalar) 214 | query: `query { 215 | droid(id: 1) { 216 | name 217 | } 218 | reviews(episode: NEWHOPE, first: 8) { 219 | episode 220 | stars 221 | commentary 222 | } 223 | } `, 224 | }, 225 | ip: '111', 226 | }; 227 | 228 | mockResponse = { 229 | json: jest.fn(), 230 | send: jest.fn(), 231 | set: jest.fn().mockReturnThis(), 232 | sendStatus: jest.fn(), 233 | status: jest.fn().mockReturnThis(), 234 | locals: {}, 235 | }; 236 | nextFunction = jest.fn(); 237 | }); 238 | 239 | test('can be configured to run in dark mode', async () => { 240 | middleware = expressGraphQLRateLimiter(schema, { 241 | rateLimiter: { 242 | type: 'TOKEN_BUCKET', 243 | refillRate: 1, 244 | capacity: 2, 245 | }, 246 | dark: true, 247 | }); 248 | 249 | await middleware(mockRequest as Request, mockResponse as Response, nextFunction); 250 | // ratelimiting capacity is set very low 251 | // request exceeds capacity 252 | // request will not be blocked 253 | expect(nextFunction).toBeCalled(); 254 | expect(mockResponse.json).not.toBeCalled(); 255 | expect(mockResponse.locals?.graphqlGate.success).toBe(false); 256 | }); 257 | 258 | test('can be configured to throw an error for unbounded lists', () => { 259 | const unboundedSchema = ` 260 | Query { 261 | biglist: [List] 262 | } 263 | List { 264 | stuff: String 265 | } 266 | `; 267 | expect(() => 268 | expressGraphQLRateLimiter(buildSchema(unboundedSchema), { 269 | rateLimiter: { 270 | type: 'TOKEN_BUCKET', 271 | refillRate: 1, 272 | capacity: 2, 273 | }, 274 | enforceBoundedLists: true, 275 | }) 276 | ).toThrow(); 277 | }); 278 | 279 | test('can be configured to limit requests by depth', async () => { 280 | middleware = expressGraphQLRateLimiter(schema, { 281 | rateLimiter: { 282 | type: 'TOKEN_BUCKET', 283 | refillRate: 1, 284 | capacity: 20, 285 | }, 286 | depthLimit: 3, 287 | }); 288 | 289 | await middleware(mockRequest as Request, mockResponse as Response, nextFunction); 290 | // depthLimit is set very low 291 | // request will be blocked 292 | expect(mockResponse.json).toBeCalled(); 293 | expect(mockResponse.locals?.graphqlGate.success).toBe(false); 294 | expect(nextFunction).not.toBeCalled(); 295 | }); 296 | 297 | // ? test for key expiry in redis cache? 298 | test('can be configured with a key expiry without error', () => { 299 | expect(() => 300 | expressGraphQLRateLimiter(schema, { 301 | rateLimiter: { 302 | type: 'TOKEN_BUCKET', 303 | refillRate: 1, 304 | capacity: 2, 305 | }, 306 | redis: { keyExpiry: 4000 }, 307 | }) 308 | ).not.toThrow(); 309 | }); 310 | }); 311 | }); 312 | 313 | describe('Middleware is Functional', () => { 314 | // Before each test configure a new middleware amd mock req, res objects. 315 | let ip = 0; 316 | beforeAll(() => { 317 | jest.useFakeTimers('modern'); 318 | mockConnect.mockImplementation(() => new RedisMock()); 319 | }); 320 | 321 | afterAll(() => { 322 | jest.useRealTimers(); 323 | jest.clearAllTimers(); 324 | jest.clearAllMocks(); 325 | }); 326 | 327 | beforeEach(async () => { 328 | middleware = expressGraphQLRateLimiter(schema, { 329 | rateLimiter: { 330 | type: 'TOKEN_BUCKET', 331 | refillRate: 1, 332 | capacity: 10, 333 | }, 334 | }); 335 | mockRequest = { 336 | body: { 337 | // complexity should be 2 (1 Query + 1 Scalar) 338 | query: `query { 339 | scalars { 340 | num 341 | } 342 | }`, 343 | }, 344 | ip: `${(ip += 1)}`, 345 | }; 346 | 347 | mockResponse = { 348 | json: jest.fn(), 349 | send: jest.fn(), 350 | set: jest.fn().mockReturnThis(), 351 | sendStatus: jest.fn(), 352 | status: jest.fn().mockReturnThis(), 353 | locals: {}, 354 | }; 355 | 356 | complexRequest = { 357 | // complexity should be 10 if 'first' is accounted for. 358 | // Query: 1, droid: 1, reviews 8: 1) 359 | body: { 360 | query: `query { 361 | droid(id: 1) { 362 | name 363 | } 364 | reviews(episode: NEWHOPE, first: 8) { 365 | episode 366 | stars 367 | commentary 368 | } 369 | } `, 370 | }, 371 | ip: `${ip + 100}`, 372 | }; 373 | nextFunction = jest.fn(); 374 | }); 375 | 376 | describe('Adds expected properties to res.locals', () => { 377 | test('Adds UNIX timestamp', async () => { 378 | jest.useRealTimers(); 379 | await middleware(mockRequest as Request, mockResponse as Response, nextFunction); 380 | jest.useFakeTimers(); 381 | 382 | // confirm that this is timestamp +/- 5 minutes of now. 383 | const now: number = Date.now().valueOf(); 384 | const diff: number = Math.abs( 385 | now - (mockResponse.locals?.graphqlGate.timestamp || 0) 386 | ); 387 | expect(diff).toBeLessThan(5 * 60 * 1000); 388 | }); 389 | 390 | test('adds complexity', async () => { 391 | await middleware(mockRequest as Request, mockResponse as Response, nextFunction); 392 | 393 | expect(mockResponse.locals?.graphqlGate).toHaveProperty('complexity'); 394 | expect(typeof mockResponse.locals?.graphqlGate.complexity).toBe('number'); 395 | expect(mockResponse.locals?.graphqlGate.complexity).toBeGreaterThanOrEqual(0); 396 | }); 397 | 398 | test('adds tokens', async () => { 399 | await middleware(mockRequest as Request, mockResponse as Response, nextFunction); 400 | 401 | expect(mockResponse.locals?.graphqlGate).toHaveProperty('tokens'); 402 | expect(typeof mockResponse.locals?.graphqlGate.tokens).toBe('number'); 403 | expect(mockResponse.locals?.graphqlGate.tokens).toBeGreaterThanOrEqual(0); 404 | }); 405 | 406 | test('adds success', async () => { 407 | await middleware(mockRequest as Request, mockResponse as Response, nextFunction); 408 | 409 | expect(mockResponse.locals?.graphqlGate).toHaveProperty('success'); 410 | expect(typeof mockResponse.locals?.graphqlGate.success).toBe('boolean'); 411 | }); 412 | 413 | test('adds depth', async () => { 414 | await middleware(mockRequest as Request, mockResponse as Response, nextFunction); 415 | 416 | expect(mockResponse.locals?.graphqlGate).toHaveProperty('depth'); 417 | expect(typeof mockResponse.locals?.graphqlGate.depth).toBe('number'); 418 | expect(mockResponse.locals?.graphqlGate.depth).toBeGreaterThanOrEqual(0); 419 | }); 420 | }); 421 | 422 | describe('Correctly limits requests', () => { 423 | describe('Allows requests', () => { 424 | test('...a single request', async () => { 425 | // successful request calls next without any arguments. 426 | await middleware( 427 | mockRequest as Request, 428 | mockResponse as Response, 429 | nextFunction 430 | ); 431 | expect(nextFunction).toBeCalledTimes(1); 432 | expect(nextFunction).toBeCalledWith(); 433 | }); 434 | 435 | test('Multiple valid requests at > 10 second intervals', async () => { 436 | const requests: Array = []; 437 | for (let i = 0; i < 3; i++) { 438 | requests.push( 439 | middleware( 440 | complexRequest as Request, 441 | mockResponse as Response, 442 | nextFunction 443 | ) 444 | ); 445 | // advance the timers by 10 seconds for the next request 446 | jest.advanceTimersByTime(10000); 447 | } 448 | await Promise.all(requests); 449 | expect(nextFunction).toBeCalledTimes(3); 450 | for (let i = 1; i <= 3; i++) { 451 | expect(nextFunction).nthCalledWith(i); 452 | } 453 | }); 454 | 455 | test('Multiple valid requests at within one second', async () => { 456 | const requests: Array = []; 457 | 458 | for (let i = 0; i < 3; i++) { 459 | // Send 3 queries of complexity 2. These should all succeed 460 | requests.push( 461 | middleware( 462 | mockRequest as Request, 463 | mockResponse as Response, 464 | nextFunction 465 | ) 466 | ); 467 | 468 | // advance the timers by 20 miliseconds for the next request 469 | jest.advanceTimersByTime(20); 470 | } 471 | await Promise.all(requests); 472 | expect(nextFunction).toBeCalledTimes(3); 473 | expect(nextFunction).toBeCalledWith(); 474 | }); 475 | }); 476 | 477 | describe('BLOCKS requests', () => { 478 | test('A single request that exceeds capacity', async () => { 479 | nextFunction = jest.fn(); 480 | 481 | const blockedRequest: Partial = { 482 | // complexity should be 12 if 'first' is accounted for. 483 | // scalars: 1, droid: 1, reviews (10 * (1 Review, 0 episode)) 484 | body: { 485 | query: `query { 486 | scalars { 487 | num 488 | } 489 | droid(id: 1) { 490 | name 491 | } 492 | reviews(episode: NEWHOPE, first: 10) { 493 | episode 494 | stars 495 | commentary 496 | } 497 | } `, 498 | }, 499 | ip: '1100', 500 | }; 501 | 502 | expect(nextFunction).not.toBeCalled(); 503 | await middleware( 504 | blockedRequest as Request, 505 | mockResponse as Response, 506 | nextFunction 507 | ); 508 | expect(mockResponse.status).toHaveBeenCalledWith(429); 509 | expect(nextFunction).not.toBeCalled(); 510 | 511 | // FIXME: There are multiple functions to send a response 512 | // json, send html, sendStatus etc. How do we check at least one was called 513 | expect(mockResponse.json).toBeCalled(); 514 | }); 515 | 516 | test('Multiple queries that exceed token limit', async () => { 517 | const requests: Array = []; 518 | 519 | for (let i = 0; i < 5; i++) { 520 | // Send 5 queries of complexity 2. These should all succeed 521 | requests.push( 522 | middleware( 523 | mockRequest as Request, 524 | mockResponse as Response, 525 | nextFunction 526 | ) 527 | ); 528 | 529 | // advance the timers by 20 miliseconds for the next request 530 | jest.advanceTimersByTime(20); 531 | } 532 | 533 | await Promise.all(requests); 534 | // Send a 6th request that should be blocked. 535 | const next: NextFunction = jest.fn(); 536 | 537 | const lastRequest = middleware( 538 | mockRequest as Request, 539 | mockResponse as Response, 540 | next 541 | ); 542 | 543 | await lastRequest; 544 | 545 | expect(mockResponse.status).toHaveBeenCalledWith(429); 546 | expect(next).not.toBeCalled(); 547 | 548 | // FIXME: See above comment on sending responses 549 | expect(mockResponse.json).toBeCalled(); 550 | }); 551 | 552 | xtest('Retry-After header is on blocked response', () => { 553 | // TODO: 554 | }); 555 | }); 556 | }); 557 | 558 | xtest('Uses User IP Address in Redis', async () => { 559 | // FIXME: In order to test this accurately the middleware would need to connect 560 | // to a mock instance or the tests would need to connect to an actual redis instance 561 | // We could use NODE_ENV varibale in the implementation to determine the connection type. 562 | 563 | // TODO: connect to the actual redis client here. Make sure to disconnect for proper teardown 564 | const client: ioredis.Redis = new RedisMock(); 565 | await client.connect(); 566 | // Check for change in the redis store for the IP key 567 | 568 | // eslint-disable-next-line @typescript-eslint/ban-ts-comment 569 | // @ts-ignore mockRequest will always have an ip address. 570 | const initialValue: string | null = await client.get(mockRequest.ip); 571 | 572 | middleware(mockRequest as Request, mockResponse as Response, nextFunction); 573 | 574 | // eslint-disable-next-line @typescript-eslint/ban-ts-comment 575 | // @ts-ignore 576 | const finalValue: string | null = await client.get(mockRequest.ip); 577 | 578 | expect(finalValue).not.toBeNull(); 579 | expect(finalValue).not.toBe(initialValue); 580 | }); 581 | 582 | xdescribe('handles error correctly', () => { 583 | // validation errors 584 | // redis connection errors in token bucket 585 | // complexity anaylsis errors 586 | }); 587 | }); 588 | }); 589 | -------------------------------------------------------------------------------- /test/rateLimiters/slidingWindowLog.test.ts: -------------------------------------------------------------------------------- 1 | import 'ts-jest'; 2 | import * as ioredis from 'ioredis'; 3 | import { RateLimiterResponse, RedisLog } from '../../src/@types/rateLimit'; 4 | import SlidingWindowLog from '../../src/rateLimiters/slidingWindowLog'; 5 | 6 | // eslint-disable-next-line @typescript-eslint/no-var-requires 7 | const RedisMock = require('ioredis-mock'); 8 | 9 | const WINDOW_SIZE = 1000; 10 | const CAPACITY = 10; 11 | 12 | let limiter: SlidingWindowLog; 13 | let client: ioredis.Redis; 14 | let timestamp: number; 15 | const user1 = '1'; 16 | const user2 = '2'; 17 | const user3 = '3'; 18 | 19 | async function getLogFromClient(redisClient: ioredis.Redis, uuid: string): Promise { 20 | const res = await redisClient.get(uuid); 21 | // if no uuid is found, return -1 for tokens and timestamp, which are both impossible 22 | if (res === null) return []; 23 | return JSON.parse(res); 24 | } 25 | 26 | async function setLogInClient(redisClient: ioredis.Redis, uuid: string, log: RedisLog) { 27 | await redisClient.set(uuid, JSON.stringify(log)); 28 | } 29 | 30 | /** 31 | * Strategy 32 | * 33 | * Log and Redis updates 34 | * Doesn't exist 35 | * 1. Request with complexity 0 => allowed. 36 | * 2. Request with complexity < capacity => allowed 37 | * 3. Request with complexity = capacity => allowed 38 | * 4. Request with complexity > capacity => blocked 39 | * Empty 40 | * 1. Request with complexity 0 => allowed. 41 | * 2. Request with complexity < capacity => allowed 42 | * 3. Request with complexity = capacity => allowed 43 | * 4. Request with complexity > capacity => blocked 44 | * Contains active requests (still in window) 45 | * 1. sum of requests = capacity => blocked 46 | * 2. sum of request < capacity 47 | * 1. current request complexity small enough => allowed 48 | * 1. current request complexity remaining complexity => allowed TODO: 49 | * 2. current request complexity to big => blocked 50 | * 3. current request complexity = 0 => allowed 51 | * Contains expired requests (no longer in the window) 52 | * 1. Request with complexity 0 => allowed. 53 | * 2. Request with complexity < capacity => allowed 54 | * 3. Request with complexity = capacity => allowed 55 | * 4. Request with complexity > capacity => blocked 56 | * Contains active and expired requests (both in and out of the window) 57 | * 1. Request with complexity 0 => allowed. 58 | * 2. Request with complexity < capacity => allowed 59 | * 3. Request with complexity = remaining capacity => allowed 60 | * 4. Request with complexity > capacity => blocked 61 | * 62 | * RateLimiter Functionality 63 | * User Buckets are unique 64 | * 65 | * Config: 66 | * Capacity and Window Size must be positive 67 | * Custom capacity and window size allowed 68 | * 69 | * 70 | * reset() 71 | * flushes all data from the redis store 72 | */ 73 | 74 | describe('SlidingWindowLog Rate Limiter', () => { 75 | beforeAll(() => { 76 | client = new RedisMock(); 77 | }); 78 | 79 | beforeEach(() => { 80 | limiter = new SlidingWindowLog(WINDOW_SIZE, CAPACITY, client, 80000); 81 | timestamp = new Date().valueOf(); 82 | }); 83 | 84 | afterEach(async () => { 85 | await client.flushall(); 86 | }); 87 | 88 | describe('correctly limits requests and updates redis when...', () => { 89 | describe('the redis log is empty, does not exist, or only contains expired requests', () => { 90 | // User 1 => no log exists 91 | let user1Response: RateLimiterResponse; 92 | let user1Log: RedisLog; 93 | // User 2 => empty log 94 | let user2Response: RateLimiterResponse; 95 | let user2Log: RedisLog; 96 | // User 3 => log has expired requests 97 | let user3Response: RateLimiterResponse; 98 | let user3Log: RedisLog; 99 | 100 | beforeEach(async () => { 101 | await setLogInClient(client, user2, []); 102 | const user3Timestamps = [ 103 | timestamp - 2 * WINDOW_SIZE, 104 | timestamp - WINDOW_SIZE - 1, 105 | timestamp - WINDOW_SIZE, 106 | ]; 107 | await setLogInClient( 108 | client, 109 | user3, 110 | user3Timestamps.map((time, i) => ({ timestamp: time, tokens: i + 1 })) 111 | ); 112 | }); 113 | 114 | test('and the request complexity is zero', async () => { 115 | [user1Response, user2Response, user3Response] = await Promise.all([ 116 | limiter.processRequest(user1, timestamp, 0), 117 | limiter.processRequest(user2, timestamp, 0), 118 | limiter.processRequest(user3, timestamp, 0), 119 | ]); 120 | 121 | // Check the received response 122 | const expectedResponse: RateLimiterResponse = { tokens: 10, success: true }; 123 | expect(user1Response).toEqual(expectedResponse); 124 | expect(user2Response).toEqual(expectedResponse); 125 | expect(user3Response).toEqual(expectedResponse); 126 | 127 | // Check that redis is correctly updated. 128 | [user1Log, user2Log, user3Log] = await Promise.all([ 129 | getLogFromClient(client, user1), 130 | getLogFromClient(client, user2), 131 | getLogFromClient(client, user3), 132 | ]); 133 | expect(user1Log).toEqual([]); 134 | expect(user2Log).toEqual([]); 135 | expect(user3Log).toEqual([]); 136 | }); 137 | test('and the request complexity is less than the capacity', async () => { 138 | const user1Tokens = 3; 139 | const user2Tokens = 4; 140 | const user3Tokens = 2; 141 | [user1Response, user2Response, user3Response] = await Promise.all([ 142 | limiter.processRequest(user1, timestamp, user1Tokens), 143 | limiter.processRequest(user2, timestamp, user2Tokens), 144 | limiter.processRequest(user3, timestamp, user3Tokens), 145 | ]); 146 | 147 | // Check the received response 148 | expect(user1Response).toEqual({ tokens: CAPACITY - user1Tokens, success: true }); 149 | expect(user2Response).toEqual({ tokens: CAPACITY - user2Tokens, success: true }); 150 | expect(user3Response).toEqual({ tokens: CAPACITY - user3Tokens, success: true }); 151 | 152 | // Check that redis is correctly updated. 153 | [user1Log, user2Log, user3Log] = await Promise.all([ 154 | getLogFromClient(client, user1), 155 | getLogFromClient(client, user2), 156 | getLogFromClient(client, user3), 157 | ]); 158 | expect(user1Log).toEqual([{ timestamp, tokens: user1Tokens }]); 159 | expect(user2Log).toEqual([{ timestamp, tokens: user2Tokens }]); 160 | expect(user3Log).toEqual([{ timestamp, tokens: user3Tokens }]); 161 | }); 162 | test('and the request complexity is equal to the capacity', async () => { 163 | const user1Tokens = CAPACITY; 164 | const user2Tokens = CAPACITY; 165 | const user3Tokens = CAPACITY; 166 | 167 | [user1Response, user2Response, user3Response] = await Promise.all([ 168 | limiter.processRequest(user1, timestamp, user1Tokens), 169 | limiter.processRequest(user2, timestamp, user2Tokens), 170 | limiter.processRequest(user3, timestamp, user3Tokens), 171 | ]); 172 | 173 | // Check the received response 174 | const expectedResponse: RateLimiterResponse = { tokens: 0, success: true }; 175 | expect(user1Response).toEqual(expectedResponse); 176 | expect(user2Response).toEqual(expectedResponse); 177 | expect(user3Response).toEqual(expectedResponse); 178 | 179 | // Check that redis is correctly updated. 180 | [user1Log, user2Log, user3Log] = await Promise.all([ 181 | getLogFromClient(client, user1), 182 | getLogFromClient(client, user2), 183 | getLogFromClient(client, user3), 184 | ]); 185 | expect(user1Log).toEqual([{ timestamp, tokens: user1Tokens }]); 186 | expect(user2Log).toEqual([{ timestamp, tokens: user2Tokens }]); 187 | expect(user3Log).toEqual([{ timestamp, tokens: user3Tokens }]); 188 | }); 189 | test('and the request complexity is greater than the capacity', async () => { 190 | const user1Tokens = CAPACITY + 1; 191 | const user2Tokens = CAPACITY + 1; 192 | const user3Tokens = CAPACITY + 1; 193 | 194 | [user1Response, user2Response, user3Response] = await Promise.all([ 195 | limiter.processRequest(user1, timestamp, user1Tokens), 196 | limiter.processRequest(user2, timestamp, user2Tokens), 197 | limiter.processRequest(user3, timestamp, user3Tokens), 198 | ]); 199 | 200 | // Check the received response 201 | const expectedResponse: RateLimiterResponse = { 202 | tokens: CAPACITY, 203 | success: false, 204 | retryAfter: Infinity, 205 | }; 206 | expect(user1Response).toEqual(expectedResponse); 207 | expect(user2Response).toEqual(expectedResponse); 208 | expect(user3Response).toEqual(expectedResponse); 209 | 210 | // Check that redis is correctly updated. 211 | [user1Log, user2Log, user3Log] = await Promise.all([ 212 | getLogFromClient(client, user1), 213 | getLogFromClient(client, user2), 214 | getLogFromClient(client, user3), 215 | ]); 216 | expect(user1Log).toEqual([]); 217 | expect(user2Log).toEqual([]); 218 | expect(user3Log).toEqual([]); 219 | }); 220 | }); 221 | 222 | describe('the redis log contains active requests in the window when...', () => { 223 | test('the sum of requests is equal to capacity', async () => { 224 | // add 2 requests to the redis store 3, 7 225 | const initialLog = [ 226 | { timestamp, tokens: 3 }, 227 | { timestamp: timestamp + 100, tokens: 7 }, 228 | ]; 229 | await setLogInClient(client, user1, initialLog); 230 | 231 | timestamp += 100; 232 | const response: RateLimiterResponse = await limiter.processRequest( 233 | user1, 234 | timestamp, 235 | 1 236 | ); 237 | 238 | expect(response.tokens).toBe(0); 239 | expect(response.success).toBe(false); 240 | 241 | const redisLog = await getLogFromClient(client, user1); 242 | expect(redisLog).toEqual(initialLog); 243 | }); 244 | describe('the sum of requests is less than capacity and..', () => { 245 | let initialLog: RedisLog; 246 | let initialTokenSum = 0; 247 | 248 | beforeAll(() => { 249 | initialLog = [ 250 | { timestamp, tokens: 3 }, 251 | { timestamp: timestamp + 100, tokens: 4 }, 252 | ]; 253 | initialTokenSum = 7; 254 | }); 255 | 256 | beforeEach(async () => { 257 | await setLogInClient(client, user1, initialLog); 258 | timestamp += 200; 259 | }); 260 | test('the current request complexity is small enough to be allowed', async () => { 261 | const tokens = 2; 262 | const response: RateLimiterResponse = await limiter.processRequest( 263 | user1, 264 | timestamp, 265 | tokens 266 | ); 267 | 268 | expect(response.tokens).toBe(CAPACITY - (initialTokenSum + tokens)); 269 | expect(response.success).toBe(true); 270 | 271 | const redisLog = await getLogFromClient(client, user1); 272 | 273 | expect(redisLog).toEqual([...initialLog, { timestamp, tokens }]); 274 | }); 275 | 276 | test('the current request has complexity = remaining capacity', async () => { 277 | const tokens = 3; 278 | const response: RateLimiterResponse = await limiter.processRequest( 279 | user1, 280 | timestamp, 281 | tokens 282 | ); 283 | 284 | expect(response.tokens).toBe(CAPACITY - (initialTokenSum + tokens)); 285 | expect(response.success).toBe(true); 286 | 287 | const redisLog = await getLogFromClient(client, user1); 288 | 289 | expect(redisLog).toEqual([...initialLog, { timestamp, tokens }]); 290 | }); 291 | test('the current request complexity to big to be allowed', async () => { 292 | const tokens = 4; 293 | const response: RateLimiterResponse = await limiter.processRequest( 294 | user1, 295 | timestamp, 296 | tokens 297 | ); 298 | 299 | expect(response.tokens).toBe(CAPACITY - initialTokenSum); 300 | expect(response.success).toBe(false); 301 | 302 | const redisLog = await getLogFromClient(client, user1); 303 | 304 | expect(redisLog).toEqual(initialLog); 305 | }); 306 | test('the current request complexity = 0', async () => { 307 | const tokens = 0; 308 | const response: RateLimiterResponse = await limiter.processRequest( 309 | user1, 310 | timestamp, 311 | tokens 312 | ); 313 | 314 | expect(response.tokens).toBe(CAPACITY - initialTokenSum); 315 | expect(response.success).toBe(true); 316 | 317 | const redisLog = await getLogFromClient(client, user1); 318 | 319 | expect(redisLog).toEqual(initialLog); 320 | }); 321 | }); 322 | }); 323 | 324 | describe('the redis log contains active and expired requests when...', () => { 325 | // Current request is sent at timestamp + 1.5 * WINDOW_SIZE (1500) 326 | let initialLog: RedisLog; 327 | let activeLog: RedisLog; 328 | let activeTokenSum = 0; 329 | 330 | beforeAll(() => { 331 | initialLog = [ 332 | { timestamp, tokens: 1 }, // expired 333 | { timestamp: timestamp + 100, tokens: 2 }, // expired 334 | { timestamp: timestamp + 600, tokens: 3 }, // active 335 | { timestamp: timestamp + 700, tokens: 4 }, // active 336 | ]; 337 | activeLog = initialLog.slice(2); 338 | activeTokenSum = 7; 339 | }); 340 | 341 | beforeEach(async () => { 342 | await setLogInClient(client, user1, initialLog); 343 | timestamp += 1500; 344 | }); 345 | 346 | test('the current request has complexity 0', async () => { 347 | const response: RateLimiterResponse = await limiter.processRequest( 348 | user1, 349 | timestamp, 350 | 0 351 | ); 352 | 353 | expect(response.tokens).toBe(CAPACITY - activeTokenSum); 354 | expect(response.success).toBe(true); 355 | 356 | const redisLog = await getLogFromClient(client, user1); 357 | 358 | expect(redisLog).toEqual(activeLog); 359 | }); 360 | test('the current request has complexity < capacity', async () => { 361 | const tokens = 2; 362 | const response: RateLimiterResponse = await limiter.processRequest( 363 | user1, 364 | timestamp, 365 | tokens 366 | ); 367 | 368 | expect(response.tokens).toBe(CAPACITY - (activeTokenSum + tokens)); 369 | expect(response.success).toBe(true); 370 | 371 | const redisLog = await getLogFromClient(client, user1); 372 | 373 | expect(redisLog).toEqual([...activeLog, { timestamp, tokens }]); 374 | }); 375 | test('the current request has complexity = remaining capacity', async () => { 376 | const tokens = 3; 377 | const response: RateLimiterResponse = await limiter.processRequest( 378 | user1, 379 | timestamp, 380 | tokens 381 | ); 382 | 383 | expect(response.tokens).toBe(CAPACITY - (activeTokenSum + tokens)); 384 | expect(response.success).toBe(true); 385 | 386 | const redisLog = await getLogFromClient(client, user1); 387 | 388 | expect(redisLog).toEqual([...activeLog, { timestamp, tokens }]); 389 | }); 390 | test('the current request has complexity > capacity => blocked', async () => { 391 | const tokens = 4; 392 | const response: RateLimiterResponse = await limiter.processRequest( 393 | user1, 394 | timestamp, 395 | tokens 396 | ); 397 | 398 | expect(response.tokens).toBe(CAPACITY - activeTokenSum); 399 | expect(response.success).toBe(false); 400 | 401 | const redisLog = await getLogFromClient(client, user1); 402 | 403 | expect(redisLog).toEqual(activeLog); 404 | }); 405 | }); 406 | 407 | test('the log contains a request on a window boundary', async () => { 408 | const initialLog = [{ timestamp, tokens: CAPACITY }]; 409 | 410 | await setLogInClient(client, user1, initialLog); 411 | 412 | // Should not be allowed to perform any requests inside the indow 413 | const inWindowRequest = await limiter.processRequest( 414 | user1, 415 | timestamp + WINDOW_SIZE - 1, 416 | 1 417 | ); 418 | expect(inWindowRequest.success).toBe(false); 419 | const startNewWindowRequest = await limiter.processRequest( 420 | user1, 421 | timestamp + WINDOW_SIZE, 422 | 1 423 | ); 424 | expect(startNewWindowRequest.success).toBe(true); 425 | }); 426 | }); 427 | 428 | describe('returns "retryAfter" if a request fails and', () => { 429 | /** 430 | * Strategy 431 | * Check where limitint request is at either end of log and in the middle 432 | * Infinity if > capacity (handled above) 433 | * doesn't appear if success (handled above) 434 | * */ 435 | beforeEach(() => { 436 | timestamp = 1000; 437 | limiter = new SlidingWindowLog(WINDOW_SIZE * 5, CAPACITY, client, 80000); 438 | }); 439 | 440 | test('the limiting request was is at the beginning of the log', async () => { 441 | const requestLog = [ 442 | { timestamp, tokens: 9 }, // limiting request 443 | { timestamp: timestamp + 1000, tokens: 1 }, // newer request 444 | ]; 445 | await setLogInClient(client, user1, requestLog); 446 | const { retryAfter } = await limiter.processRequest(user1, timestamp + 2000, 9); 447 | expect(retryAfter).toBe((WINDOW_SIZE * 5 - 2000) / 1000); // 3 seconds 448 | }); 449 | 450 | test('the limiting request was is at the end of the log', async () => { 451 | const requestLog = [ 452 | { timestamp, tokens: 1 }, // older request 453 | { timestamp: timestamp + 1000, tokens: 9 }, // limiting request 454 | ]; 455 | await setLogInClient(client, user1, requestLog); 456 | const { retryAfter } = await limiter.processRequest(user1, timestamp + 2000, 9); 457 | expect(retryAfter).toBe(Math.ceil((1000 + WINDOW_SIZE * 5 - 2000) / 1000)); // 4 seconds 458 | }); 459 | 460 | test('the limiting request was is the middle of the log', async () => { 461 | const requestLog = [ 462 | { timestamp, tokens: 1 }, // older request 463 | { timestamp: timestamp + 1000, tokens: 8 }, // limiting request 464 | { timestamp: timestamp + 2000, tokens: 1 }, // newer request 465 | ]; 466 | await setLogInClient(client, user1, requestLog); 467 | const { retryAfter } = await limiter.processRequest(user1, timestamp + 3000, 9); 468 | expect(retryAfter).toBe(Math.ceil((1000 + WINDOW_SIZE * 5 - 3000) / 1000)); // 3 seconds 469 | }); 470 | 471 | test('request exceeds the capacity', async () => { 472 | const requestLog = [ 473 | { timestamp, tokens: 1 }, // older request 474 | { timestamp: timestamp + 1000, tokens: 8 }, // limiting request 475 | { timestamp: timestamp + 2000, tokens: 1 }, // newer request 476 | ]; 477 | await setLogInClient(client, user1, requestLog); 478 | const { retryAfter } = await limiter.processRequest(user1, timestamp + 3000, 11); 479 | expect(retryAfter).toBe(Infinity); 480 | }); 481 | }); 482 | test('users have their own logs', async () => { 483 | const requested = 6; 484 | const user3Tokens = 8; 485 | // // Add log for user 3 so we have both a user that exists in the store (3) and one that doesn't (2) 486 | await setLogInClient(client, user3, [{ tokens: user3Tokens, timestamp }]); 487 | 488 | // // issue a request for user 1; 489 | await limiter.processRequest(user1, timestamp + 100, requested); 490 | 491 | // // Check that each user has the expected log 492 | expect(await getLogFromClient(client, user1)).toEqual([ 493 | { 494 | timestamp: timestamp + 100, 495 | tokens: requested, 496 | }, 497 | ]); 498 | expect(await getLogFromClient(client, user2)).toEqual([]); 499 | expect(await getLogFromClient(client, user3)).toEqual([{ timestamp, tokens: user3Tokens }]); 500 | 501 | await limiter.processRequest(user2, timestamp + 200, 1); 502 | expect(await getLogFromClient(client, user1)).toEqual([ 503 | { 504 | timestamp: timestamp + 100, 505 | tokens: requested, 506 | }, 507 | ]); 508 | expect(await getLogFromClient(client, user2)).toEqual([ 509 | { 510 | timestamp: timestamp + 200, 511 | tokens: 1, 512 | }, 513 | ]); 514 | expect(await getLogFromClient(client, user3)).toEqual([{ timestamp, tokens: user3Tokens }]); 515 | }); 516 | 517 | test('is able to be reset', async () => { 518 | const tokens = 5; 519 | await setLogInClient(client, user1, [{ tokens, timestamp }]); 520 | await setLogInClient(client, user2, [{ tokens, timestamp }]); 521 | await setLogInClient(client, user3, [{ tokens, timestamp }]); 522 | 523 | limiter.reset(); 524 | 525 | expect(getLogFromClient(client, user1)).resolves.toEqual([]); 526 | expect(getLogFromClient(client, user2)).resolves.toEqual([]); 527 | expect(getLogFromClient(client, user3)).resolves.toEqual([]); 528 | 529 | expect((await limiter.processRequest(user1, timestamp, CAPACITY)).success).toBe(true); 530 | expect((await limiter.processRequest(user2, timestamp, CAPACITY - 1)).success).toBe(true); 531 | expect((await limiter.processRequest(user3, timestamp, CAPACITY + 1)).success).toBe(false); 532 | }); 533 | 534 | describe('is configurable...', () => { 535 | test('does not allow capacity or window size <= 0', () => { 536 | expect(() => new SlidingWindowLog(0, 1, client, 8000)).toThrow( 537 | 'SlidingWindowLog window size, capacity and keyExpiry must be positive' 538 | ); 539 | expect(() => new SlidingWindowLog(-10, 1, client, 8000)).toThrow( 540 | 'SlidingWindowLog window size, capacity and keyExpiry must be positive' 541 | ); 542 | expect(() => new SlidingWindowLog(10, -1, client, 8000)).toThrow( 543 | 'SlidingWindowLog window size, capacity and keyExpiry must be positive' 544 | ); 545 | expect(() => new SlidingWindowLog(10, 0, client, 8000)).toThrow( 546 | 'SlidingWindowLog window size, capacity and keyExpiry must be positive' 547 | ); 548 | }); 549 | 550 | test('...allows custom window size and capacity', async () => { 551 | const customWindow = 500; 552 | const customSizelimiter = new SlidingWindowLog(customWindow, CAPACITY, client, 8000); 553 | 554 | let customSizeSuccess = await customSizelimiter 555 | .processRequest(user1, timestamp, CAPACITY) 556 | .then((res) => res.success); 557 | expect(customSizeSuccess).toBe(true); 558 | 559 | customSizeSuccess = await customSizelimiter 560 | .processRequest(user1, timestamp + 100, CAPACITY) 561 | .then((res) => res.success); 562 | expect(customSizeSuccess).toBe(false); 563 | 564 | customSizeSuccess = await customSizelimiter 565 | .processRequest(user1, timestamp + customWindow, CAPACITY) 566 | .then((res) => res.success); 567 | 568 | // Reset the redis store 569 | customSizelimiter.reset(); 570 | 571 | const customCapacity = 5; 572 | const customCapacitylimiter = new SlidingWindowLog( 573 | WINDOW_SIZE, 574 | customCapacity, 575 | client, 576 | 8000 577 | ); 578 | 579 | let customCapacitySuccess = await customCapacitylimiter 580 | .processRequest(user1, timestamp, customCapacity + 1) 581 | .then((res) => res.success); 582 | expect(customCapacitySuccess).toBe(false); 583 | 584 | customCapacitySuccess = await customCapacitylimiter 585 | .processRequest(user1, timestamp + 100, customCapacity) 586 | .then((res) => res.success); 587 | expect(customCapacitySuccess).toBe(true); 588 | }); 589 | }); 590 | }); 591 | --------------------------------------------------------------------------------