├── .github └── workflows │ ├── build-n-release.yml │ └── codeql-analysis.yml ├── .gitignore ├── .nvmrc ├── .prettierrc.yml ├── LICENSE ├── README.md ├── SECURITY.md ├── TODO ├── _dev └── firebase-functions-rate-limiter.code-workspace ├── package-lock.json ├── package.json ├── scripts ├── run-integration-tests-with-firebase-emulator.sh └── setup-firebase-emulator-for-tests.sh ├── src ├── .npmignore ├── FirebaseFunctionsRateLimiter.integration.test.ts ├── FirebaseFunctionsRateLimiter.mock.integration.test.ts ├── FirebaseFunctionsRateLimiter.ts ├── FirebaseFunctionsRateLimiterConfiguration.spec.test.ts ├── FirebaseFunctionsRateLimiterConfiguration.ts ├── GenericRateLimiter.spec.test.ts ├── GenericRateLimiter.ts ├── _test │ └── test_environment.ts ├── index.ts ├── persistence │ ├── FirestorePersistenceProvider.ts │ ├── PersistenceProvider.ts │ ├── PersistenceProviderMock.ts │ ├── PersistenceProviders.integration.test.ts │ ├── PersistenceRecord.ts │ └── RealtimeDbPersistenceProvider.ts ├── timestamp │ ├── FirebaseTimestampProvider.ts │ ├── TimestampProvider.ts │ └── TimestampProviderMock.test.ts ├── types │ ├── EquivalentTypes.spec.test.ts │ ├── FirestoreEquivalent.ts │ └── RealtimeDbEquivalent.ts └── utils.test.ts ├── tsconfig.json ├── tsconfig.lint.json └── tslint.json /.github/workflows/build-n-release.yml: -------------------------------------------------------------------------------- 1 | name: "Build and release" 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | 7 | jobs: 8 | build_test: 9 | name: Build and test 10 | runs-on: ubuntu-latest 11 | 12 | strategy: 13 | fail-fast: true 14 | matrix: 15 | nodejs: [12, 14, 16, 18] 16 | 17 | steps: 18 | - name: Checkout repository 19 | uses: actions/checkout@v2 20 | 21 | - uses: actions/setup-node@v3 22 | with: 23 | node-version: ${{ matrix.version }} 24 | 25 | - run: scripts/setup-firebase-emulator-for-tests.sh 26 | 27 | - run: npm ci 28 | 29 | - run: npm run build 30 | 31 | - run: npm run testall-with-coverage-lcov 32 | 33 | - run: npm run upload-coverage 34 | 35 | release: 36 | name: Release 37 | runs-on: ubuntu-latest 38 | needs: [build_test] 39 | permissions: 40 | contents: write 41 | issues: write 42 | pull-requests: write 43 | 44 | steps: 45 | - name: Checkout repository 46 | uses: actions/checkout@v2 47 | 48 | - uses: actions/setup-node@v3 49 | with: 50 | node-version: 16 51 | 52 | - run: npm ci 53 | - run: npm run build 54 | - run: npx semantic-release 55 | env: 56 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 57 | NPM_TOKEN: ${{ secrets.NPM_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | name: "CodeQL" 7 | 8 | on: 9 | push: 10 | branches: [master] 11 | pull_request: 12 | # The branches below must be a subset of the branches above 13 | branches: [master] 14 | schedule: 15 | - cron: '0 0 * * 3' 16 | 17 | jobs: 18 | analyze: 19 | name: Analyze 20 | runs-on: ubuntu-latest 21 | 22 | strategy: 23 | fail-fast: false 24 | matrix: 25 | # Override automatic language detection by changing the below list 26 | # Supported options are ['csharp', 'cpp', 'go', 'java', 'javascript', 'python'] 27 | language: ['javascript'] 28 | # Learn more... 29 | # https://docs.github.com/en/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#overriding-automatic-language-detection 30 | 31 | steps: 32 | - name: Checkout repository 33 | uses: actions/checkout@v2 34 | 35 | # Initializes the CodeQL tools for scanning. 36 | - name: Initialize CodeQL 37 | uses: github/codeql-action/init@v1 38 | with: 39 | languages: ${{ matrix.language }} 40 | # If you wish to specify custom queries, you can do so here or in a config file. 41 | # By default, queries listed here will override any specified in a config file. 42 | # Prefix the list here with "+" to use these queries and those in the config file. 43 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 44 | 45 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 46 | # If this step fails, then you should remove it and run the build manually (see below) 47 | - name: Autobuild 48 | uses: github/codeql-action/autobuild@v1 49 | 50 | # ℹ️ Command-line programs to run using the OS shell. 51 | # 📚 https://git.io/JvXDl 52 | 53 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 54 | # and modify them (or add more) to build your code if your project 55 | # uses a compiled language 56 | 57 | #- run: | 58 | # make bootstrap 59 | # make release 60 | 61 | - name: Perform CodeQL Analysis 62 | uses: github/codeql-action/analyze@v1 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /dist 2 | /.DS_Store 3 | /database-debug.log 4 | /firestore-debug.log 5 | 6 | # Logs 7 | logs 8 | *.log 9 | npm-debug.log* 10 | yarn-debug.log* 11 | yarn-error.log* 12 | 13 | # Runtime data 14 | pids 15 | *.pid 16 | *.seed 17 | *.pid.lock 18 | 19 | # Directory for instrumented libs generated by jscoverage/JSCover 20 | lib-cov 21 | 22 | # Coverage directory used by tools like istanbul 23 | coverage 24 | 25 | # nyc test coverage 26 | .nyc_output 27 | 28 | # Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) 29 | .grunt 30 | 31 | # Bower dependency directory (https://bower.io/) 32 | bower_components 33 | 34 | # node-waf configuration 35 | .lock-wscript 36 | 37 | # Compiled binary addons (https://nodejs.org/api/addons.html) 38 | build/Release 39 | 40 | # Dependency directories 41 | node_modules/ 42 | jspm_packages/ 43 | 44 | # TypeScript v1 declaration files 45 | typings/ 46 | 47 | # Optional npm cache directory 48 | .npm 49 | 50 | # Optional eslint cache 51 | .eslintcache 52 | 53 | # Optional REPL history 54 | .node_repl_history 55 | 56 | # Output of 'npm pack' 57 | *.tgz 58 | 59 | # Yarn Integrity file 60 | .yarn-integrity 61 | 62 | # dotenv environment variables file 63 | .env 64 | 65 | # next.js build output 66 | .next 67 | -------------------------------------------------------------------------------- /.nvmrc: -------------------------------------------------------------------------------- 1 | v16 2 | -------------------------------------------------------------------------------- /.prettierrc.yml: -------------------------------------------------------------------------------- 1 | trailingComma: all 2 | tabWidth: 4 3 | semi: true 4 | printWidth: 120 5 | singleQuote: false 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jędrzej Lewandowski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Firebase functions rate limiter 2 | [![npm](https://img.shields.io/npm/v/firebase-functions-rate-limiter.svg?style=flat-square)](https://www.npmjs.com/package/firebase-functions-rate-limiter) [![Code coverage](https://img.shields.io/codecov/c/gh/jblew/firebase-functions-rate-limiter?style=flat-square)](https://codecov.io/gh/jblew/firebase-functions-rate-limiter) [![License](https://img.shields.io/github/license/Jblew/firebase-functions-rate-limiter.svg?style=flat-square)](https://github.com/Jblew/firebase-functions-rate-limiter/blob/master/LICENSE) [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](http://makeapullrequest.com) 3 | 4 | 5 | 6 | Q: How to limit rate of firebase function calls? 7 | A: Use `firebase-functions-rate-limiter` 8 | 9 | Mission: **limit number of calls per specified period of time** 10 | 11 | 12 | 13 | ## Key features: 14 | 15 | - Two backends: Realtime Database (efficient) or Firestore (convenient) 16 | - Easy: call single function, no configuration 17 | - Efficient: only a single call read call to database (or firestore), two calls if limit not exceeded and usage is recorded 18 | - Concurrent-safe: uses atomic transactions in both backends 19 | - Clean: Uses only one key (or collection in firestore backend), creates single document for each qualifier. Does not leave rubbish in your database. 20 | - Typescript typings included 21 | - No firebase configuration required. You do not have to create any indexes or rules. 22 | - .mock() factory to make functions testing easier 23 | - Works with NodeJS 12, 14, 16, 18 24 | 25 | 26 | ## Installation 27 | 28 | ```bash 29 | $ npm install --save firebase-functions-rate-limiter 30 | ``` 31 | 32 | Then: 33 | 34 | ```typescript 35 | import FirebaseFunctionsRateLimiter from "firebase-functions-rate-limiter"; 36 | // or 37 | const { FirebaseFunctionsRateLimiter } = require("firebase-functions-rate-limiter"); 38 | ``` 39 | 40 | 41 | 42 | ## Usage 43 | 44 | **Example 1**: limit calls for everyone: 45 | 46 | ```javascript 47 | import * as admin from "firebase-admin"; 48 | import * as functions from "firebase-functions"; 49 | import { FirebaseFunctionsRateLimiter } from "firebase-functions-rate-limiter"; 50 | 51 | admin.initializeApp(functions.config().firebase); 52 | const database = admin.database(); 53 | 54 | const limiter = FirebaseFunctionsRateLimiter.withRealtimeDbBackend( 55 | { 56 | name: "rate_limiter_collection", 57 | maxCalls: 2, 58 | periodSeconds: 15, 59 | }, 60 | database, 61 | ); 62 | exports.testRateLimiter = 63 | functions.https.onRequest(async (req, res) => { 64 | await limiter.rejectOnQuotaExceededOrRecordUsage(); // will throw HttpsException with proper warning 65 | 66 | res.send("Function called"); 67 | }); 68 | 69 | ``` 70 | 71 | > You can use two functions: `limiter.rejectOnQuotaExceededOrRecordUsage(qualifier?)` will throw an *functions.https.HttpsException* when limit is exceeded while `limiter.isQuotaExceededOrRecordUsage(qualifier?)` gives you the ability to choose how to handle the situation. 72 | 73 | 74 | **Example 2**: limit calls for each user separately (function called directly - please refer [firebase docs on this topic](https://firebase.google.com/docs/functions/callable)): 75 | 76 | ```javascript 77 | import * as admin from "firebase-admin"; 78 | import * as functions from "firebase-functions"; 79 | import { FirebaseFunctionsRateLimiter } from "firebase-functions-rate-limiter"; 80 | 81 | admin.initializeApp(functions.config().firebase); 82 | const database = admin.database(); 83 | 84 | const perUserlimiter = FirebaseFunctionsRateLimiter.withRealtimeDbBackend( 85 | { 86 | name: "per_user_limiter", 87 | maxCalls: 2, 88 | periodSeconds: 15, 89 | }, 90 | database, 91 | ); 92 | 93 | exports.authenticatedFunction = 94 | functions.https.onCall(async (data, context) => { 95 | if (!context.auth || !context.auth.uid) { 96 | throw new functions.https.HttpsError( 97 | "failed-precondition", 98 | "Please authenticate", 99 | ); 100 | } 101 | const uidQualifier = "u_" + context.auth.uid; 102 | const isQuotaExceeded = await perUserlimiter.isQuotaExceededOrRecordUsage(uidQualifier); 103 | if (isQuotaExceeded) { 104 | throw new functions.https.HttpsError( 105 | "failed-precondition", 106 | "Call quota exceeded for this user. Try again later", 107 | ); 108 | } 109 | 110 | return { result: "Function called" }; 111 | }); 112 | 113 | ``` 114 | 115 | 116 | 117 | ### Step-by-step 118 | 119 | **#1** Initialize admin app and get Realtime database object 120 | 121 | ```typescript 122 | admin.initializeApp(functions.config().firebase); 123 | const database = admin.database(); 124 | ``` 125 | 126 | **#2** Create limiter object outside of the function scope and pass the configuration and Database object. Configuration options are listed below. 127 | 128 | ```typescript 129 | const someLimiter = FirebaseFunctionsRateLimiter.withRealtimeDbBackend( 130 | { 131 | name: "limiter_some", 132 | maxCalls: 10, 133 | periodSeconds: 60, 134 | }, 135 | database, 136 | ); 137 | ``` 138 | 139 | **#3** Inside the function call isQuotaExceededOrRecordUsage. This is an async function so not forget about **await**! The function will check if the limit was exceeded. If limit was not exceeded it will record this usage and return true. Otherwise, write will be only called if there are usage records that are older than the specified period and are about to being cleared. 140 | 141 | ```typescript 142 | exports.testRateLimiter = 143 | functions.https.onRequest(async (req, res) => { 144 | const quotaExceeded = await limiter.isQuotaExceededOrRecordUsage(); 145 | if (quotaExceeded) { 146 | // respond with error 147 | } else { 148 | // continue 149 | } 150 | ``` 151 | 152 | **#3 with qualifier**. Optionally you can pass **a qualifier** to the function. A qualifier is a string that identifies a separate type of call. If you pass a qualifier, the limit will be recorded per each distinct qualifier and won't sum up. 153 | 154 | ```typescript 155 | exports.testRateLimiter = 156 | functions.https.onRequest(async (req, res) => { 157 | const qualifier = "user_1"; 158 | const quotaExceeded = await limiter.isQuotaExceededOrRecordUsage(qualifier); 159 | if (quotaExceeded) { 160 | // respond with error 161 | } else { 162 | // continue 163 | } 164 | ``` 165 | 166 | 167 | 168 | ## Configuration 169 | 170 | ```typescript 171 | const configuration = { 172 | name: // a collection with this name will be created 173 | periodSeconds: // the length of test period in seconds 174 | maxCalls: // number of maximum allowed calls in the period 175 | debug: // boolean (default false) 176 | }; 177 | ``` 178 | 179 | #### Choose backend: 180 | 181 | ```typescript 182 | const limiter = FirebaseFunctionsRateLimiter.withRealtimeDbBackend(configuration, database) 183 | // or 184 | const limiter = FirebaseFunctionsRateLimiter.withFirestoreBackend(configuration, firestore) 185 | // or, for functions unit testing convenience: 186 | const limiter = FirebaseFunctionsRateLimiter.mock() 187 | ``` 188 | 189 | 190 | 191 | ## Methods 192 | 193 | - `isQuotaExceededOrRecordUsage(qualifier?: string)` — Checks if quota was exceed. If not — it records the call time in the appropriate backend. 194 | 195 | - `rejectOnQuotaExceededOrRecordUsage(qualifier?: string, errorFactory?: (configuration) => Error)` — Checks if quota was exceed. If not — it records the call time in the appropriate backend and is rejected with *functions.https.HttpsException*. This particular exception can be caught when calling the firebase function directly (see https://firebase.google.com/docs/functions/callable). When errorFactory is provided, it is used to obtain error that is thrown in case of exceeded limit. 196 | 197 | - `isQuotaAlreadyExceeded(qualifier?: string)` — Checks if quota was exceed, but does not record a usage. If you use this, you must call isQuotaExceededOrRecordUsage() to record the usage. 198 | 199 | - `getConfiguration()` — Returns this rate limiter configuration. 200 | 201 | - ~~`isQuotaExceeded(qualifier?: string)`~~ — **deprecated**: renamed to isQuotaExceededOrRecordUsage 202 | 203 | - ~~`rejectOnQuotaExceeded(qualifier?: string)`~~ — **deprecated**: renamed to rejectOnQuotaExceededOrRecordUsage 204 | 205 | 206 | 207 | Why is there no `recordUsage()` method?** This library uses a document-per-qualifier data model which requires a read call before the update call. Read-and-update is performed inside an atomic transaction in both backend. It would not be concurrency-safe if the read-and-update transaction was split into separate calls. 208 | 209 | 210 | 211 | ### Firebase configuration 212 | 213 | There is no configuration needed in the firebase. This library does not do document search, so you do not need indexes. Also, functions are executed in the firebase admin environment, so you do not have to specify any rules. 214 | 215 | 216 | 217 | ### Need help? 218 | 219 | - Feel free to email me at 220 | 221 | 222 | 223 | ### Would like to help? 224 | 225 | Warmly welcomed: 226 | 227 | - Bug reports via issues 228 | - Enhancement requests via via issues 229 | - Pull requests 230 | - Security reports to jedrzejblew@gmail.com 231 | 232 | 233 | 234 | *** 235 | 236 | Made with ❤️ by [Jędrzej Lewandowski](https://jblewandowski.com/). 237 | 238 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | The following versionf of the project are 6 | currently being supported with security updates: 7 | 8 | | Version | Supported | 9 | | ------- | ------------------ | 10 | | 3.1.x | :white_check_mark: | 11 | | < 3.1 | :x: | 12 | 13 | ## Reporting a Vulnerability 14 | 15 | If you find any security issue related with this project please email me as soon as possible at jedrzejblew[@]gmail.com. 16 | Please start the email title with `Security issue firebase-functions-rate-limiter: xxx`. 17 | Thanks a lot for your commitment! :) 18 | -------------------------------------------------------------------------------- /TODO: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jblew/firebase-functions-rate-limiter/b9325c4bf872355e6929ef56993605bd75f27941/TODO -------------------------------------------------------------------------------- /_dev/firebase-functions-rate-limiter.code-workspace: -------------------------------------------------------------------------------- 1 | { 2 | "folders": [{ 3 | "path": ".." 4 | }], 5 | "settings": { 6 | "editor.rulers": [ 7 | 110, 8 | 120 9 | ], 10 | "git.ignoreLimitWarning": true, 11 | "prettier.semi": true, 12 | "prettier.singleQuote": false, 13 | "prettier.trailingComma": "all", 14 | "prettier.printWidth": 120, 15 | "prettier.tabWidth": 4, 16 | } 17 | } -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "firebase-functions-rate-limiter", 3 | "version": "0.0.0-development", 4 | "description": "JS/TS library that allows you to set per - time, per - user or per - anything limits for calling Firebase cloud functions", 5 | "main": "dist/index.js", 6 | "types": "dist/index.d.ts", 7 | "engines": { 8 | "node": ">=16" 9 | }, 10 | "files": [ 11 | "/dist", 12 | "package-lock.js" 13 | ], 14 | "scripts": { 15 | "build:cleanbefore": "rm -rf dist", 16 | "build:lint": "tslint -c tslint.json -p tsconfig.lint.json", 17 | "build:node": "tsc", 18 | "build": "npm run build:cleanbefore && npm run build:node && npm run build:lint", 19 | "prepare": "NODE_ENV=production npm run build", 20 | "test": "find src -name '*.spec.test.ts' | TS_NODE_FILES=true TS_NODE_CACHE=false TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' xargs mocha -r ts-node/register --require source-map-support/register", 21 | "do_verify": "find src -name '*.integration.test.ts' | TS_NODE_FILES=true TS_NODE_CACHE=false TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' xargs mocha -r ts-node/register --require source-map-support/register", 22 | "verify": "bash scripts/run-integration-tests-with-firebase-emulator.sh", 23 | "lint-fix": "tslint --fix -c tslint.json -p tsconfig.lint.json", 24 | "checkall": "npm run lint && npm run build && npm run test && npm run verify", 25 | "testall": "npm run test && npm run verify", 26 | "testall-with-coverage": "nyc npm run testall", 27 | "testall-with-coverage-lcov": "nyc --reporter=lcov npm run testall", 28 | "upload-coverage": "codecov", 29 | "semantic-release": "semantic-release" 30 | }, 31 | "dependencies": { 32 | "firebase-admin": "^10.3.0", 33 | "firebase-functions": "^3.21.2", 34 | "ow": "^0.28.1" 35 | }, 36 | "devDependencies": { 37 | "@firebase/testing": "^0.20.11", 38 | "@types/chai": "^4.3.1", 39 | "@types/chai-as-promised": "^7.1.5", 40 | "@types/lodash": "^4.14.182", 41 | "@types/mocha": "^9.1.1", 42 | "@types/node": "^17.0.42", 43 | "@types/sinon": "^10.0.11", 44 | "@types/uuid": "^8.3.4", 45 | "chai": "^4.3.6", 46 | "chai-as-promised": "^7.1.1", 47 | "codecov": "^3.8.3", 48 | "istanbul": "^0.4.5", 49 | "lodash": "^4.17.21", 50 | "mocha": "^10.0.0", 51 | "nyc": "^15.1.0", 52 | "semantic-release": "^17.4.7", 53 | "sinon": "^14.0.0", 54 | "ts-node": "^10.8.1", 55 | "tslint": "^6.1.3", 56 | "typescript": "^4.7.3", 57 | "uuid": "^8.3.2" 58 | }, 59 | "nyc": { 60 | "extension": [ 61 | ".ts" 62 | ], 63 | "exclude": [ 64 | "**/*.d.ts", 65 | "**/*.test.ts", 66 | "**/_test" 67 | ], 68 | "include": [ 69 | "src/**/*.ts" 70 | ], 71 | "reporter": [ 72 | "html" 73 | ], 74 | "all": true 75 | }, 76 | "release": {}, 77 | "repository": { 78 | "type": "git", 79 | "url": "https://github.com/Jblew/firebase-functions-rate-limiter" 80 | }, 81 | "keywords": [ 82 | "firebase", 83 | "firebase-functions", 84 | "rate-limiter" 85 | ], 86 | "author": "Jędrzej Lewandowski (https://jedrzej.lewandowski.doctor/)", 87 | "contributors": [ 88 | "Jędrzej Lewandowski (https://jedrzej.lewandowski.doctor/)" 89 | ], 90 | "license": "MIT", 91 | "bugs": { 92 | "url": "https://github.com/Jblew/firebase-functions-rate-limiter/issues" 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /scripts/run-integration-tests-with-firebase-emulator.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env bash 3 | set -e # fail on first error 4 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )/.." # parent dir of scripts dir 5 | cd "${DIR}" 6 | 7 | echo "Begin tests in ${DIR}" 8 | 9 | firebase emulators:exec --only firestore,database "npm run do_verify" 10 | -------------------------------------------------------------------------------- /scripts/setup-firebase-emulator-for-tests.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env bash 3 | set -e # fail on first error 4 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )/.." # parent dir of scripts dir 5 | cd "${DIR}" 6 | 7 | npm i -g firebase firebase-tools 8 | firebase setup:emulators:firestore 9 | firebase setup:emulators:database 10 | -------------------------------------------------------------------------------- /src/.npmignore: -------------------------------------------------------------------------------- 1 | *.test.ts 2 | _test -------------------------------------------------------------------------------- /src/FirebaseFunctionsRateLimiter.integration.test.ts: -------------------------------------------------------------------------------- 1 | /* tslint:disable:max-classes-per-file no-console */ 2 | import * as firebase from "@firebase/testing"; 3 | import * as functions from "firebase-functions"; 4 | 5 | import { FirebaseFunctionsRateLimiter } from "./FirebaseFunctionsRateLimiter"; 6 | import { mock } from "./FirebaseFunctionsRateLimiter.mock.integration.test"; 7 | import { FirebaseFunctionsRateLimiterConfiguration } from "./FirebaseFunctionsRateLimiterConfiguration"; 8 | import { PersistenceRecord } from "./persistence/PersistenceRecord"; 9 | import { delayMs } from "./utils.test"; 10 | import { expect, uuid, _ } from "./_test/test_environment"; 11 | 12 | describe("FirebaseFunctionsRateLimiter", () => { 13 | // 14 | before("startup", async function() { 15 | this.timeout(4000); 16 | const { firestore, database } = mock("firestore", {}); 17 | await firestore 18 | .collection("a") 19 | .doc("a") 20 | .get(); 21 | await database.ref("a").set({ a: "a" }); 22 | }); 23 | 24 | afterEach(async () => { 25 | try { 26 | await Promise.all(firebase.apps().map(app => app.delete())); 27 | } catch (error) { 28 | console.warn("Warning: Error in firebase shutdown " + error); 29 | } 30 | }); 31 | 32 | [ 33 | { 34 | name: "with qualifier", 35 | qualifierFactory() { 36 | return `q${uuid()}`; 37 | }, 38 | }, 39 | { 40 | name: "without qualifier", 41 | qualifierFactory() { 42 | return undefined; 43 | }, 44 | }, 45 | ].forEach(test => 46 | describe(test.name, () => { 47 | const backends = ["firestore", "realtimedb", "mock"] as const; 48 | backends.forEach((backend: "firestore" | "realtimedb" | "mock") => 49 | describe("Backend " + backend, () => { 50 | describe("isQuotaExceededOrRecordUsage", () => { 51 | it("Uses qualifier to identify document in the collection", async () => { 52 | const { rateLimiter, uniqueCollectionName, getDocument } = mock(backend, {}); 53 | const qualifier = test.qualifierFactory(); 54 | await rateLimiter.isQuotaExceededOrRecordUsage(qualifier); 55 | 56 | const doc = await getDocument( 57 | uniqueCollectionName, 58 | qualifier || FirebaseFunctionsRateLimiter.DEFAULT_QUALIFIER, 59 | ); 60 | 61 | const record = doc as PersistenceRecord; 62 | expect(record.u) 63 | .to.be.an("array") 64 | .with.length(1); 65 | }); 66 | 67 | it("Increments counter when limit is not exceeded", async () => { 68 | const { rateLimiter, getDocument, uniqueCollectionName } = mock(backend, { 69 | maxCalls: 10, 70 | }); 71 | const qualifier = test.qualifierFactory(); 72 | 73 | const noOfTestCalls = 5; 74 | for (let i = 0; i < noOfTestCalls; i++) { 75 | await delayMs(5); 76 | await rateLimiter.isQuotaExceededOrRecordUsage(qualifier); 77 | } 78 | 79 | const doc = await getDocument( 80 | uniqueCollectionName, 81 | qualifier || FirebaseFunctionsRateLimiter.DEFAULT_QUALIFIER, 82 | ); 83 | 84 | const record = doc as PersistenceRecord; 85 | expect(record.u) 86 | .to.be.an("array") 87 | .with.length(noOfTestCalls); 88 | }); 89 | 90 | it("Does not increment counter when limit is exceeded", async () => { 91 | const maxCalls = 5; 92 | const noOfTestCalls = 10; 93 | 94 | const { rateLimiter, getDocument, uniqueCollectionName } = mock(backend, { 95 | maxCalls, 96 | }); 97 | const qualifier = test.qualifierFactory(); 98 | 99 | for (let i = 0; i < noOfTestCalls; i++) { 100 | await delayMs(5); 101 | await rateLimiter.isQuotaExceededOrRecordUsage(qualifier); 102 | } 103 | 104 | const doc = await getDocument( 105 | uniqueCollectionName, 106 | qualifier || FirebaseFunctionsRateLimiter.DEFAULT_QUALIFIER, 107 | ); 108 | 109 | const record = doc as PersistenceRecord; 110 | expect(record.u) 111 | .to.be.an("array") 112 | .with.length(maxCalls); 113 | }); 114 | 115 | it("Calls older than period are removed from the database", async function() { 116 | this.timeout(3000); 117 | 118 | const maxCalls = 2; 119 | const periodSeconds = 1; 120 | 121 | const { rateLimiter, uniqueCollectionName, getDocument } = mock(backend, { 122 | maxCalls, 123 | periodSeconds, 124 | }); 125 | const qualifier = test.qualifierFactory(); 126 | 127 | await rateLimiter.isQuotaExceededOrRecordUsage(qualifier); 128 | await delayMs(periodSeconds * 1000 + 200); 129 | await rateLimiter.isQuotaExceededOrRecordUsage(qualifier); 130 | await delayMs(200); 131 | 132 | const doc = await getDocument( 133 | uniqueCollectionName, 134 | qualifier || FirebaseFunctionsRateLimiter.DEFAULT_QUALIFIER, 135 | ); 136 | const record = doc as PersistenceRecord; 137 | expect(record.u) 138 | .to.be.an("array") 139 | .with.length(1); 140 | }); 141 | }); 142 | 143 | describe("rejectOnQuotaExceededOrRecordUsage", () => { 144 | it("throws functions.https.HttpsException when limit is exceeded", async () => { 145 | const maxCalls = 1; 146 | const noOfTestCalls = 2; 147 | 148 | const { rateLimiter } = mock(backend, { 149 | maxCalls, 150 | }); 151 | const qualifier = test.qualifierFactory(); 152 | 153 | for (let i = 0; i < noOfTestCalls; i++) { 154 | await delayMs(5); 155 | await rateLimiter.isQuotaExceededOrRecordUsage(qualifier); 156 | } 157 | 158 | await expect( 159 | rateLimiter.rejectOnQuotaExceededOrRecordUsage(qualifier), 160 | ).to.eventually.be.rejectedWith(functions.https.HttpsError); 161 | }); 162 | 163 | it("Is fulfilled when limit is not exceeded", async () => { 164 | const maxCalls = 10; 165 | const noOfTestCalls = 2; 166 | 167 | const { rateLimiter } = mock(backend, { 168 | maxCalls, 169 | }); 170 | const qualifier = test.qualifierFactory(); 171 | 172 | for (let i = 0; i < noOfTestCalls; i++) { 173 | await delayMs(5); 174 | await rateLimiter.isQuotaExceededOrRecordUsage(qualifier); 175 | } 176 | 177 | await expect(rateLimiter.rejectOnQuotaExceededOrRecordUsage(qualifier)).to.eventually.be 178 | .fulfilled; 179 | }); 180 | 181 | it("When error factory is provided, uses it to throw the error", async () => { 182 | const { rateLimiter } = mock(backend, { 183 | maxCalls: 1, 184 | }); 185 | const qualifier = test.qualifierFactory(); 186 | await rateLimiter.rejectOnQuotaExceededOrRecordUsage(qualifier); 187 | 188 | const errorFactory = () => new Error("error-from-factory"); 189 | await expect( 190 | rateLimiter.rejectOnQuotaExceededOrRecordUsage(qualifier, errorFactory), 191 | ).to.eventually.be.rejectedWith(/error-from-factory/); 192 | }); 193 | 194 | it("Provides valid configuration to error factory", async () => { 195 | const { rateLimiter, config } = mock(backend, { 196 | maxCalls: 1, 197 | }); 198 | const qualifier = test.qualifierFactory(); 199 | await rateLimiter.rejectOnQuotaExceededOrRecordUsage(qualifier); 200 | 201 | const errorFactory = (configInErrorFactory: FirebaseFunctionsRateLimiterConfiguration) => { 202 | expect(configInErrorFactory).to.deep.include(config); 203 | return new Error("error-from-factory"); 204 | }; 205 | 206 | await expect( 207 | rateLimiter.rejectOnQuotaExceededOrRecordUsage(qualifier, errorFactory), 208 | ).to.eventually.be.rejectedWith(/error-from-factory/); 209 | }); 210 | }); 211 | 212 | [ 213 | { 214 | name: "isQuotaExceededOrRecordUsage", 215 | methodFactory(rateLimiter: FirebaseFunctionsRateLimiter) { 216 | return rateLimiter.isQuotaExceededOrRecordUsage.bind(rateLimiter); 217 | }, 218 | }, 219 | { 220 | name: "isQuotaAlreadyExceeded", 221 | methodFactory(rateLimiter: FirebaseFunctionsRateLimiter) { 222 | return rateLimiter.isQuotaAlreadyExceeded.bind(rateLimiter); 223 | }, 224 | }, 225 | ].forEach(testedMethod => 226 | describe(`#${testedMethod.name}`, () => { 227 | it("Limit is exceeded if too much calls in specified period", async () => { 228 | const maxCalls = 5; 229 | const noOfTestCalls = 10; 230 | 231 | const { rateLimiter } = mock(backend, { 232 | maxCalls, 233 | }); 234 | const qualifier = test.qualifierFactory(); 235 | 236 | for (let i = 0; i < noOfTestCalls; i++) { 237 | await delayMs(5); 238 | await rateLimiter.isQuotaExceededOrRecordUsage(qualifier); 239 | } 240 | 241 | const method = testedMethod.methodFactory(rateLimiter); 242 | expect(await method(qualifier)).to.be.equal(true); 243 | }); 244 | 245 | it("Limit is not exceeded if too much calls not in specified period", async function() { 246 | this.timeout(3000); 247 | 248 | const maxCalls = 2; 249 | const periodSeconds = 1; 250 | 251 | const { rateLimiter } = mock(backend, { 252 | maxCalls, 253 | periodSeconds, 254 | }); 255 | const qualifier = test.qualifierFactory(); 256 | 257 | await rateLimiter.isQuotaExceededOrRecordUsage(qualifier); 258 | await delayMs(periodSeconds * 1000 + 200); 259 | 260 | const method = testedMethod.methodFactory(rateLimiter); 261 | expect(await method(qualifier)).to.be.equal(false); 262 | }); 263 | }), 264 | ); 265 | }), 266 | ); 267 | 268 | describe("Firestore backend specific tests", () => { 269 | it("Writes to specified collection", async () => { 270 | const { rateLimiter, firestore, uniqueCollectionName } = mock("firestore", {}); 271 | await rateLimiter.isQuotaExceededOrRecordUsage(); 272 | 273 | const collection = await firestore.collection(uniqueCollectionName).get(); 274 | expect(collection.size).to.be.equal(1); 275 | }); 276 | }); 277 | 278 | describe("Realtimedb backend specific tests", () => { 279 | it("Writes to specified key", async () => { 280 | const { rateLimiter, database, uniqueCollectionName } = mock("realtimedb", {}); 281 | await rateLimiter.isQuotaExceededOrRecordUsage(); 282 | 283 | const collection = (await database.ref(`${uniqueCollectionName}`).once("value")).val(); 284 | expect(_.keys(collection).length).to.be.equal(1); 285 | }); 286 | }); 287 | 288 | describe("getConfiguration", () => { 289 | it("Returns correct configuration", () => { 290 | const { rateLimiter, config } = mock("firestore", { maxCalls: 5 }); 291 | expect(rateLimiter.getConfiguration()).to.deep.include(config); 292 | }); 293 | }); 294 | }), 295 | ); 296 | }); 297 | -------------------------------------------------------------------------------- /src/FirebaseFunctionsRateLimiter.mock.integration.test.ts: -------------------------------------------------------------------------------- 1 | /* tslint:disable:max-classes-per-file no-console */ 2 | import * as firebase from "@firebase/testing"; 3 | import * as _ from "lodash"; 4 | import "mocha"; 5 | import { v4 as uuid } from "uuid"; 6 | 7 | import { FirebaseFunctionsRateLimiter } from "./FirebaseFunctionsRateLimiter"; 8 | import { FirebaseFunctionsRateLimiterConfiguration } from "./FirebaseFunctionsRateLimiterConfiguration"; 9 | import { PersistenceProviderMock } from "./persistence/PersistenceProviderMock"; 10 | 11 | export function mock( 12 | backend: "firestore" | "realtimedb" | "mock", 13 | configApply: FirebaseFunctionsRateLimiterConfiguration, 14 | ) { 15 | const app = firebase.initializeTestApp({ projectId: "unit-testing-" + Date.now(), databaseName: "db" }); 16 | const uniqueCollectionName = uuid(); 17 | const uniqueDocName = uuid(); 18 | const firestore = app.firestore(); 19 | const database = app.database(); 20 | const persistenceProviderMock = new PersistenceProviderMock(); 21 | async function getDocument(collection: string, doc: string): Promise { 22 | if (backend === "firestore") { 23 | return (await firestore 24 | .collection(collection) 25 | .doc(doc) 26 | .get()).data(); 27 | } else if (backend === "realtimedb") { 28 | return (await database.ref(`${collection}/${doc}`).once("value")).val(); 29 | } else if (backend === "mock") { 30 | return persistenceProviderMock.getRecord(collection, doc); 31 | } else throw new Error("Unknown backend " + backend); 32 | } 33 | const config: FirebaseFunctionsRateLimiterConfiguration = { 34 | name: uniqueCollectionName, 35 | debug: false, 36 | ...configApply, 37 | }; 38 | let rateLimiter: FirebaseFunctionsRateLimiter; 39 | if (backend === "firestore") rateLimiter = FirebaseFunctionsRateLimiter.withFirestoreBackend(config, firestore); 40 | else if (backend === "realtimedb") { 41 | rateLimiter = FirebaseFunctionsRateLimiter.withRealtimeDbBackend(config, database); 42 | } else if (backend === "mock") rateLimiter = FirebaseFunctionsRateLimiter.mock(config, persistenceProviderMock); 43 | else throw new Error("Unknown backend " + backend); 44 | return { 45 | app, 46 | firestore, 47 | database, 48 | uniqueCollectionName, 49 | uniqueDocName, 50 | rateLimiter, 51 | getDocument, 52 | config, 53 | }; 54 | } 55 | -------------------------------------------------------------------------------- /src/FirebaseFunctionsRateLimiter.ts: -------------------------------------------------------------------------------- 1 | // tslint:disable no-console 2 | import * as admin from "firebase-admin"; 3 | import * as functions from "firebase-functions"; 4 | import ow from "ow"; 5 | 6 | import { FirebaseFunctionsRateLimiterConfiguration } from "./FirebaseFunctionsRateLimiterConfiguration"; 7 | import { GenericRateLimiter } from "./GenericRateLimiter"; 8 | import { FirestorePersistenceProvider } from "./persistence/FirestorePersistenceProvider"; 9 | import { PersistenceProvider } from "./persistence/PersistenceProvider"; 10 | import { PersistenceProviderMock } from "./persistence/PersistenceProviderMock"; 11 | import { RealtimeDbPersistenceProvider } from "./persistence/RealtimeDbPersistenceProvider"; 12 | import { FirebaseTimestampProvider } from "./timestamp/FirebaseTimestampProvider"; 13 | import { FirestoreEquivalent } from "./types/FirestoreEquivalent"; 14 | import { RealtimeDbEquivalent } from "./types/RealtimeDbEquivalent"; 15 | 16 | export class FirebaseFunctionsRateLimiter { 17 | public static DEFAULT_QUALIFIER = "default_qualifier"; 18 | 19 | /* 20 | * Factories 21 | */ 22 | public static withFirestoreBackend( 23 | configuration: FirebaseFunctionsRateLimiterConfiguration, 24 | firestore: admin.firestore.Firestore | FirestoreEquivalent, 25 | ): FirebaseFunctionsRateLimiter { 26 | const provider = new FirestorePersistenceProvider(firestore); 27 | return new FirebaseFunctionsRateLimiter(configuration, provider); 28 | } 29 | 30 | public static withRealtimeDbBackend( 31 | configuration: FirebaseFunctionsRateLimiterConfiguration, 32 | realtimeDb: admin.database.Database | RealtimeDbEquivalent, 33 | ): FirebaseFunctionsRateLimiter { 34 | const provider = new RealtimeDbPersistenceProvider(realtimeDb); 35 | return new FirebaseFunctionsRateLimiter(configuration, provider); 36 | } 37 | 38 | public static mock( 39 | configuration?: FirebaseFunctionsRateLimiterConfiguration, 40 | persistenceProviderMock?: PersistenceProviderMock, 41 | ): FirebaseFunctionsRateLimiter { 42 | const defaultConfig: FirebaseFunctionsRateLimiterConfiguration = { 43 | periodSeconds: 10, 44 | maxCalls: Number.MAX_SAFE_INTEGER, 45 | }; 46 | /* istanbul ignore next */ 47 | const provider = persistenceProviderMock || new PersistenceProviderMock(); 48 | /* istanbul ignore next */ 49 | return new FirebaseFunctionsRateLimiter(configuration || defaultConfig, provider); 50 | } 51 | 52 | /* 53 | * Implementation 54 | */ 55 | 56 | private configurationFull: FirebaseFunctionsRateLimiterConfiguration.ConfigurationFull; 57 | private genericRateLimiter: GenericRateLimiter; 58 | private debugFn: (msg: string) => void; 59 | 60 | private constructor( 61 | configuration: FirebaseFunctionsRateLimiterConfiguration, 62 | persistenceProvider: PersistenceProvider, 63 | ) { 64 | this.configurationFull = { 65 | ...FirebaseFunctionsRateLimiterConfiguration.DEFAULT_CONFIGURATION, 66 | ...configuration, 67 | }; 68 | ow(this.configurationFull, "configuration", ow.object); 69 | FirebaseFunctionsRateLimiterConfiguration.ConfigurationFull.validate(this.configurationFull); 70 | 71 | this.debugFn = this.constructDebugFn(this.configurationFull); 72 | persistenceProvider.setDebugFn(this.debugFn); 73 | 74 | const timestampProvider = new FirebaseTimestampProvider(); 75 | this.genericRateLimiter = new GenericRateLimiter( 76 | this.configurationFull, 77 | persistenceProvider, 78 | timestampProvider, 79 | this.debugFn, 80 | ); 81 | } 82 | 83 | /* istanbul ignore next because this method was renamed and is now deprecated */ 84 | /** 85 | * Checks if quota is exceeded. If not — records usage time in the backend database. 86 | * The method is deprecated as it was renamed to isQuotaExceededOrRecordUsage 87 | * 88 | * @param qualifier — a string that identifies the limited resource accessor (for example the user id) 89 | * @deprecated 90 | */ 91 | public async isQuotaExceeded(qualifier?: string): Promise { 92 | return this.isQuotaExceededOrRecordUsage(qualifier); 93 | } 94 | 95 | /* istanbul ignore next because this method was renamed and is now deprecated */ 96 | /** 97 | * Checks if quota is exceeded. If not — records usage time in the backend database. 98 | * 99 | * @param qualifier — a string that identifies the limited resource accessor (for example the user id) 100 | * @deprecated 101 | */ 102 | public async isQuotaExceededOrRecordUsage(qualifier?: string): Promise { 103 | return await this.genericRateLimiter.isQuotaExceededOrRecordCall( 104 | qualifier || FirebaseFunctionsRateLimiter.DEFAULT_QUALIFIER, 105 | ); 106 | } 107 | 108 | /* istanbul ignore next because this method was renamed and is now deprecated */ 109 | /** 110 | * Checks if quota is exceeded. If not — records usage time in the backend database and then 111 | * is rejected with functions.https.HttpsError (this is the type of error that can be caught when 112 | * firebase function is called directly: see https://firebase.google.com/docs/functions/callable) 113 | * The method is deprecated as it was renamed to rejectOnQuotaExceededOrRecordUsage 114 | * 115 | * @param qualifier — a string that identifies the limited resource accessor (for example the user id) 116 | * @deprecated 117 | */ 118 | public async rejectOnQuotaExceeded(qualifier?: string): Promise { 119 | await this.rejectOnQuotaExceededOrRecordUsage(qualifier); 120 | } 121 | 122 | /** 123 | * Checks if quota is exceeded. If not — records usage time in the backend database and then 124 | * is rejected with functions.https.HttpsError (this is the type of error that can be caught when 125 | * firebase function is called directly: see https://firebase.google.com/docs/functions/callable) 126 | * 127 | * @param qualifier (optional) — a string that identifies the limited resource accessor (for example the user id) 128 | * @param errorFactory (optional) — when errorFactory is provided, it is used to obtain 129 | * error that is thrown in case of exceeded limit. 130 | */ 131 | public async rejectOnQuotaExceededOrRecordUsage( 132 | qualifier?: string, 133 | errorFactory?: (config: FirebaseFunctionsRateLimiterConfiguration.ConfigurationFull) => Error, 134 | ): Promise { 135 | const isExceeded = await this.genericRateLimiter.isQuotaExceededOrRecordCall( 136 | qualifier || FirebaseFunctionsRateLimiter.DEFAULT_QUALIFIER, 137 | ); 138 | if (isExceeded) { 139 | if (errorFactory) throw errorFactory(this.getConfiguration()); 140 | else throw this.constructRejectionError(qualifier); 141 | } 142 | } 143 | 144 | /** 145 | * Checks if quota is exceeded. If not — DOES NOT RECORD USAGE. It only checks if limit was 146 | * previously exceeded or not. 147 | * @param qualifier — a string that identifies the limited resource accessor (for example the user id) 148 | */ 149 | public async isQuotaAlreadyExceeded(qualifier?: string): Promise { 150 | return await this.genericRateLimiter.isQuotaAlreadyExceededDoNotRecordCall( 151 | qualifier || FirebaseFunctionsRateLimiter.DEFAULT_QUALIFIER, 152 | ); 153 | } 154 | 155 | /** 156 | * Returns this rate limiter configuration 157 | */ 158 | public getConfiguration(): FirebaseFunctionsRateLimiterConfiguration.ConfigurationFull { 159 | return this.configurationFull; 160 | } 161 | 162 | /* 163 | * Private methods 164 | */ 165 | private constructRejectionError(qualifier?: string): functions.https.HttpsError { 166 | const c = this.configurationFull; 167 | const msg = 168 | `FirebaseFunctionsRateLimiter error: Limit of ${c.maxCalls} calls per ` + 169 | `${c.periodSeconds} seconds exceeded for ${qualifier ? "specified qualifier in " : ""}` + 170 | `limiter ${c.name}`; 171 | return new functions.https.HttpsError("resource-exhausted", msg); 172 | } 173 | 174 | private constructDebugFn( 175 | config: FirebaseFunctionsRateLimiterConfiguration.ConfigurationFull, 176 | ): (msg: string) => void { 177 | /* istanbul ignore if */ 178 | if (config.debug) return (msg: string) => console.log(msg); 179 | else { 180 | return (msg: string) => { 181 | /* */ 182 | }; 183 | } 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /src/FirebaseFunctionsRateLimiterConfiguration.spec.test.ts: -------------------------------------------------------------------------------- 1 | /* tslint:disable:max-classes-per-file */ 2 | import { use as chaiUse } from "chai"; 3 | import * as chaiAsPromised from "chai-as-promised"; 4 | import * as _ from "lodash"; 5 | import "mocha"; 6 | 7 | import { FirebaseFunctionsRateLimiterConfiguration } from "./FirebaseFunctionsRateLimiterConfiguration"; 8 | 9 | chaiUse(chaiAsPromised); 10 | 11 | describe("FirebaseFunctionsRateLimiterConfiguration", () => { 12 | it("Default configuration passes validation", async () => { 13 | FirebaseFunctionsRateLimiterConfiguration.ConfigurationFull.validate( 14 | FirebaseFunctionsRateLimiterConfiguration.DEFAULT_CONFIGURATION, 15 | ); 16 | }); 17 | }); 18 | -------------------------------------------------------------------------------- /src/FirebaseFunctionsRateLimiterConfiguration.ts: -------------------------------------------------------------------------------- 1 | import ow from "ow"; 2 | 3 | export interface FirebaseFunctionsRateLimiterConfiguration { 4 | name?: string; 5 | periodSeconds?: number; 6 | maxCalls?: number; 7 | debug?: boolean; 8 | } 9 | 10 | export namespace FirebaseFunctionsRateLimiterConfiguration { 11 | export interface ConfigurationFull extends FirebaseFunctionsRateLimiterConfiguration { 12 | name: string; 13 | periodSeconds: number; 14 | maxCalls: number; 15 | debug: boolean; 16 | } 17 | 18 | export namespace ConfigurationFull { 19 | export function validate(o: ConfigurationFull & FirebaseFunctionsRateLimiterConfiguration) { 20 | ow(o.name, "configuration.name", ow.string.nonEmpty); 21 | ow(o.periodSeconds, "configuration.periodSeconds", ow.number.integer.finite.greaterThan(0)); 22 | ow(o.maxCalls, "configuration.maxCalls", ow.number.integer.finite.greaterThan(0)); 23 | ow(o.debug, "configuration.debug", ow.boolean); 24 | } 25 | } 26 | 27 | export const DEFAULT_CONFIGURATION: ConfigurationFull = { 28 | name: "rlimit", 29 | periodSeconds: 5 * 60, 30 | maxCalls: 5, 31 | debug: false, 32 | }; 33 | } 34 | -------------------------------------------------------------------------------- /src/GenericRateLimiter.spec.test.ts: -------------------------------------------------------------------------------- 1 | /* tslint:disable:max-classes-per-file */ 2 | 3 | import { FirebaseFunctionsRateLimiterConfiguration } from "./FirebaseFunctionsRateLimiterConfiguration"; 4 | import { GenericRateLimiter } from "./GenericRateLimiter"; 5 | import { PersistenceProviderMock } from "./persistence/PersistenceProviderMock"; 6 | import { TimestampProviderMock } from "./timestamp/TimestampProviderMock.test"; 7 | import { expect, sinon, _ } from "./_test/test_environment"; 8 | 9 | const sampleConfiguration: FirebaseFunctionsRateLimiterConfiguration.ConfigurationFull = { 10 | name: "rate_limiter_1", 11 | periodSeconds: 5 * 60, 12 | maxCalls: 1, 13 | debug: false, 14 | }; 15 | const sampleQualifier = "samplequalifier"; 16 | 17 | describe("GenericRateLimiter", () => { 18 | function mock(configChanges: object) { 19 | const persistenceProviderMock: PersistenceProviderMock = new PersistenceProviderMock(); 20 | persistenceProviderMock.persistenceObject = {}; 21 | const timestampProviderMock = new TimestampProviderMock(); 22 | const genericRateLimiter = new GenericRateLimiter( 23 | { ...sampleConfiguration, ...configChanges }, 24 | persistenceProviderMock, 25 | timestampProviderMock, 26 | ); 27 | return { genericRateLimiter, timestampProviderMock, persistenceProviderMock }; 28 | } 29 | 30 | describe("#isQuotaAlreadyExceededDoNotRecordCall", () => { 31 | it("Calls get on PersistenceProvider", async () => { 32 | const { genericRateLimiter, persistenceProviderMock } = mock({}); 33 | persistenceProviderMock.get = sinon.spy(persistenceProviderMock.get); 34 | 35 | await genericRateLimiter.isQuotaAlreadyExceededDoNotRecordCall(sampleQualifier); 36 | 37 | expect((persistenceProviderMock.get as sinon.SinonSpy).callCount, "get call count").to.be.equal(1); 38 | }); 39 | }); 40 | 41 | describe("#isQuotaExceededOrRecordCall", () => { 42 | it("Quota is not exceeded on first call when maxCalls=1", async () => { 43 | const { genericRateLimiter } = mock({ maxCalls: 1 }); 44 | 45 | expect(await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier)).to.be.equal(false); 46 | }); 47 | 48 | it("Does not fail on empty collection", async () => { 49 | const { genericRateLimiter } = mock({}); 50 | 51 | await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 52 | }); 53 | 54 | it("Calls updateAndGet on PersistenceProvider", async () => { 55 | const { genericRateLimiter, persistenceProviderMock } = mock({}); 56 | persistenceProviderMock.updateAndGet = sinon.spy(persistenceProviderMock.updateAndGet); 57 | 58 | await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 59 | 60 | expect( 61 | (persistenceProviderMock.updateAndGet as sinon.SinonSpy).callCount, 62 | "updateAndGet call count", 63 | ).to.be.equal(1); 64 | }); 65 | 66 | it("Puts new current timestamp when quota was not exceeded", async () => { 67 | const { genericRateLimiter, persistenceProviderMock, timestampProviderMock } = mock({}); 68 | 69 | const sampleTimestamp = _.random(10, 5000); 70 | timestampProviderMock.setTimestampSeconds(sampleTimestamp); 71 | 72 | await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 73 | 74 | expect(_.values(persistenceProviderMock.persistenceObject)[0].u).to.contain(sampleTimestamp); 75 | }); 76 | 77 | it("does not put current timestamp when quota was exceeded", async () => { 78 | const { genericRateLimiter, persistenceProviderMock, timestampProviderMock } = mock({ 79 | maxCalls: 1, 80 | periodSeconds: 20, 81 | }); 82 | 83 | const sampleTimestamp = _.random(10, 5000); 84 | timestampProviderMock.setTimestampSeconds(sampleTimestamp); 85 | const quotaExceeded1 = await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 86 | expect(quotaExceeded1).to.be.equal(false); 87 | 88 | timestampProviderMock.setTimestampSeconds(sampleTimestamp + 1); 89 | const quotaExceeded2 = await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 90 | expect(quotaExceeded2).to.be.equal(true); 91 | 92 | expect(_.values(persistenceProviderMock.persistenceObject)[0].u) 93 | .to.be.an("array") 94 | .with.length(1); 95 | }); 96 | 97 | describe("threshold tests", () => { 98 | const savedTimestamps: number[] = []; 99 | const persistenceProviderMock: PersistenceProviderMock = new PersistenceProviderMock(); 100 | 101 | before(async () => { 102 | const timestampProviderMock = new TimestampProviderMock(); 103 | const periodSeconds = 5; 104 | const maxCalls = 10; 105 | const genericRateLimiter = new GenericRateLimiter( 106 | { ...sampleConfiguration, periodSeconds, maxCalls }, 107 | persistenceProviderMock, 108 | timestampProviderMock, 109 | ); 110 | 111 | let timestamp = _.random(10, 5000); 112 | 113 | for (let i = 0; i < 6; i++) { 114 | timestampProviderMock.setTimestampSeconds(timestamp); 115 | savedTimestamps.push(timestamp); 116 | await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 117 | timestamp += periodSeconds / 3 + 0.1; // remember: never push floats to the edges ;) 118 | } 119 | }); 120 | 121 | it("saved record does not contain timestamps below threshold", () => { 122 | expect(_.values(persistenceProviderMock.persistenceObject)[0].u) 123 | .to.be.an("array") 124 | .with.length(3) 125 | .that.contains(savedTimestamps[savedTimestamps.length - 1]) 126 | .and.contains(savedTimestamps[savedTimestamps.length - 2]) 127 | .and.contains(savedTimestamps[savedTimestamps.length - 3]); 128 | }); 129 | 130 | it("saved record contains all timestamps above or equal threshold", () => { 131 | expect(_.values(persistenceProviderMock.persistenceObject)[0].u) 132 | .to.be.an("array") 133 | .with.length(3) 134 | .that.does.not.contain(savedTimestamps[0]) 135 | .and.does.not.contains(savedTimestamps[1]) 136 | .and.does.not.contains(savedTimestamps[2]); 137 | }); 138 | }); 139 | 140 | it("updates or reads only single qualifier", async () => { 141 | const periodSeconds = 30; 142 | const maxCalls = 3; 143 | const { genericRateLimiter, timestampProviderMock } = mock({ 144 | maxCalls, 145 | periodSeconds, 146 | }); 147 | 148 | let timestamp = _.random(10, 5000); 149 | 150 | for (let i = 0; i < 5; i++) { 151 | timestampProviderMock.setTimestampSeconds(timestamp); 152 | await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 153 | timestamp += 1; 154 | } 155 | 156 | timestampProviderMock.setTimestampSeconds(timestamp); 157 | expect(await genericRateLimiter.isQuotaExceededOrRecordCall("another_qualifier")).to.be.equal(false); 158 | }); 159 | }); 160 | 161 | describe("check tests", function() { 162 | [ 163 | { 164 | name: "isQuotaExceededOrRecordCall", 165 | methodFactory(genericRateLimiter: GenericRateLimiter) { 166 | return genericRateLimiter.isQuotaExceededOrRecordCall.bind(genericRateLimiter); 167 | }, 168 | }, 169 | { 170 | name: "isQuotaAlreadyExceededDoNotRecordCall", 171 | methodFactory(genericRateLimiter: GenericRateLimiter) { 172 | return genericRateLimiter.isQuotaAlreadyExceededDoNotRecordCall.bind(genericRateLimiter); 173 | }, 174 | }, 175 | ].forEach(testedMethod => 176 | describe(`#${testedMethod.name}`, () => { 177 | it("returns true if there are more calls than maxCalls", async () => { 178 | const periodSeconds = 20; 179 | const maxCalls = 3; 180 | const { genericRateLimiter, timestampProviderMock } = mock({ 181 | maxCalls, 182 | periodSeconds, 183 | }); 184 | 185 | let timestamp = _.random(10, 5000); 186 | 187 | for (let i = 0; i < 6; i++) { 188 | timestampProviderMock.setTimestampSeconds(timestamp); 189 | await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 190 | timestamp += 1; 191 | } 192 | const method = testedMethod.methodFactory(genericRateLimiter); 193 | expect(await method(sampleQualifier)).to.be.equal(true); 194 | }); 195 | 196 | it("returns false if there are exactly maxCalls calls in the period", async () => { 197 | const periodSeconds = 20; 198 | const maxCalls = 3; 199 | const { genericRateLimiter, timestampProviderMock } = mock({ 200 | maxCalls, 201 | periodSeconds, 202 | }); 203 | 204 | let timestamp = _.random(10, 5000); 205 | 206 | for (let i = 0; i < 2; i++) { 207 | timestampProviderMock.setTimestampSeconds(timestamp); 208 | await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 209 | timestamp += 1; 210 | } 211 | // the following call is the third, should be passed 212 | const method = testedMethod.methodFactory(genericRateLimiter); 213 | expect(await method(sampleQualifier)).to.be.equal(false); 214 | }); 215 | 216 | it("returns false if there are no calls, maxCalls=1 ant this is the first call", async () => { 217 | const periodSeconds = 20; 218 | const maxCalls = 1; 219 | const { genericRateLimiter, timestampProviderMock } = mock({ 220 | maxCalls, 221 | periodSeconds, 222 | }); 223 | 224 | const timestamp = _.random(10, 5000); 225 | timestampProviderMock.setTimestampSeconds(timestamp); 226 | const method = testedMethod.methodFactory(genericRateLimiter); 227 | expect(await method(sampleQualifier)).to.be.equal(false); 228 | }); 229 | 230 | it("returns false if there are less calls than maxCalls", async () => { 231 | const periodSeconds = 20; 232 | const maxCalls = 10; 233 | const { genericRateLimiter, timestampProviderMock } = mock({ 234 | maxCalls, 235 | periodSeconds, 236 | }); 237 | 238 | let timestamp = _.random(10, 5000); 239 | 240 | for (let i = 0; i < 2; i++) { 241 | timestampProviderMock.setTimestampSeconds(timestamp); 242 | await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 243 | timestamp += 1; 244 | } 245 | // the following call is the third, should be passed 246 | const method = testedMethod.methodFactory(genericRateLimiter); 247 | expect(await method(sampleQualifier)).to.be.equal(false); 248 | }); 249 | 250 | it("returns false if exceeding calls are out of the period", async () => { 251 | const periodSeconds = 20; 252 | const maxCalls = 5; 253 | const { genericRateLimiter, timestampProviderMock } = mock({ 254 | maxCalls, 255 | periodSeconds, 256 | }); 257 | 258 | let timestamp = _.random(10, 5000); 259 | 260 | for (let i = 0; i < 10; i++) { 261 | timestampProviderMock.setTimestampSeconds(timestamp); 262 | await genericRateLimiter.isQuotaExceededOrRecordCall(sampleQualifier); 263 | timestamp += 1; 264 | } 265 | 266 | timestamp += 30; 267 | timestampProviderMock.setTimestampSeconds(timestamp); 268 | 269 | const method = testedMethod.methodFactory(genericRateLimiter); 270 | expect(await method(sampleQualifier)).to.be.equal(false); 271 | }); 272 | }), 273 | ); 274 | }); 275 | }); 276 | -------------------------------------------------------------------------------- /src/GenericRateLimiter.ts: -------------------------------------------------------------------------------- 1 | import ow from "ow"; 2 | 3 | import { FirebaseFunctionsRateLimiterConfiguration } from "./FirebaseFunctionsRateLimiterConfiguration"; 4 | import { PersistenceProvider } from "./persistence/PersistenceProvider"; 5 | import { PersistenceRecord } from "./persistence/PersistenceRecord"; 6 | import { TimestampProvider } from "./timestamp/TimestampProvider"; 7 | 8 | export class GenericRateLimiter { 9 | private configuration: FirebaseFunctionsRateLimiterConfiguration.ConfigurationFull; 10 | private persistenceProvider: PersistenceProvider; 11 | private timestampProvider: TimestampProvider; 12 | private debugFn: (msg: string) => void; 13 | 14 | public constructor( 15 | configuration: FirebaseFunctionsRateLimiterConfiguration, 16 | persistenceProvider: PersistenceProvider, 17 | timestampProvider: TimestampProvider, 18 | debugFn: (msg: string) => void = (msg: string) => { 19 | /* */ 20 | }, 21 | ) { 22 | this.configuration = { ...FirebaseFunctionsRateLimiterConfiguration.DEFAULT_CONFIGURATION, ...configuration }; 23 | ow(this.configuration, "configuration", ow.object); 24 | FirebaseFunctionsRateLimiterConfiguration.ConfigurationFull.validate(this.configuration); 25 | 26 | this.persistenceProvider = persistenceProvider; 27 | ow(this.persistenceProvider, "persistenceProvider", ow.object); 28 | 29 | this.timestampProvider = timestampProvider; 30 | ow(this.timestampProvider, "timestampProvider", ow.object); 31 | 32 | this.debugFn = debugFn; 33 | } 34 | 35 | public async isQuotaExceededOrRecordCall(qualifier: string): Promise { 36 | const resultHolder = { 37 | isQuotaExceeded: false, 38 | }; 39 | await this.persistenceProvider.updateAndGet(this.configuration.name, qualifier, record => { 40 | return this.runTransactionForAnswer(record, resultHolder); 41 | }); 42 | 43 | return resultHolder.isQuotaExceeded; 44 | } 45 | 46 | public async isQuotaAlreadyExceededDoNotRecordCall(qualifier: string): Promise { 47 | const timestampsSeconds = this.getTimestampsSeconds(); 48 | const record = await this.persistenceProvider.get(this.configuration.name, qualifier); 49 | const recentUsages: number[] = this.selectRecentUsages(record.u, timestampsSeconds.threshold); 50 | return this.isQuotaExceeded(recentUsages.length); 51 | } 52 | 53 | private runTransactionForAnswer( 54 | input: PersistenceRecord, 55 | resultHolder: { isQuotaExceeded: boolean }, 56 | ): PersistenceRecord { 57 | const timestampsSeconds = this.getTimestampsSeconds(); 58 | 59 | this.debugFn("Got record with usages " + input.u.length); 60 | 61 | const recentUsages: number[] = this.selectRecentUsages(input.u, timestampsSeconds.threshold); 62 | this.debugFn("Of these usages there are" + recentUsages.length + " usages that count into period"); 63 | 64 | const result = this.isQuotaExceeded(recentUsages.length); 65 | resultHolder.isQuotaExceeded = result; 66 | this.debugFn("The result is quotaExceeded=" + result); 67 | 68 | if (!result) { 69 | this.debugFn("Quota was not exceeded, so recording a usage at " + timestampsSeconds.current); 70 | recentUsages.push(timestampsSeconds.current); 71 | } 72 | 73 | const newRecord: PersistenceRecord = { 74 | u: recentUsages, 75 | }; 76 | return newRecord; 77 | } 78 | 79 | private selectRecentUsages(allUsages: number[], timestampThresholdSeconds: number): number[] { 80 | const recentUsages: number[] = []; 81 | 82 | for (const usageTime of allUsages) { 83 | if (usageTime > timestampThresholdSeconds) { 84 | recentUsages.push(usageTime); 85 | } 86 | } 87 | return recentUsages; 88 | } 89 | 90 | private isQuotaExceeded(numOfRecentUsages: number): boolean { 91 | return numOfRecentUsages >= this.configuration.maxCalls; 92 | } 93 | 94 | private getTimestampsSeconds(): { current: number; threshold: number } { 95 | const currentServerTimestampSeconds: number = this.timestampProvider.getTimestampSeconds(); 96 | return { 97 | current: currentServerTimestampSeconds, 98 | threshold: currentServerTimestampSeconds - this.configuration.periodSeconds, 99 | }; 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/_test/test_environment.ts: -------------------------------------------------------------------------------- 1 | import { expect, use as chaiUse } from "chai"; 2 | import * as chaiAsPromised from "chai-as-promised"; 3 | import * as _ from "lodash"; 4 | import "mocha"; 5 | import * as sinon from "sinon"; 6 | import { v4 as uuid } from "uuid"; 7 | 8 | chaiUse(chaiAsPromised); 9 | 10 | export { _, expect, sinon, uuid }; 11 | -------------------------------------------------------------------------------- /src/index.ts: -------------------------------------------------------------------------------- 1 | export { FirebaseFunctionsRateLimiter } from "./FirebaseFunctionsRateLimiter"; 2 | export { FirebaseFunctionsRateLimiterConfiguration } from "./FirebaseFunctionsRateLimiterConfiguration"; 3 | export { FirestoreEquivalent } from "./types/FirestoreEquivalent"; 4 | export { RealtimeDbEquivalent } from "./types/RealtimeDbEquivalent"; 5 | 6 | import { FirebaseFunctionsRateLimiter } from "./FirebaseFunctionsRateLimiter"; 7 | export default FirebaseFunctionsRateLimiter; 8 | -------------------------------------------------------------------------------- /src/persistence/FirestorePersistenceProvider.ts: -------------------------------------------------------------------------------- 1 | import * as admin from "firebase-admin"; 2 | import ow from "ow"; 3 | 4 | import { FirestoreEquivalent } from "../types/FirestoreEquivalent"; 5 | 6 | import { PersistenceProvider } from "./PersistenceProvider"; 7 | import { PersistenceRecord } from "./PersistenceRecord"; 8 | 9 | export class FirestorePersistenceProvider implements PersistenceProvider { 10 | private firestore: admin.firestore.Firestore | FirestoreEquivalent; 11 | private debugFn: (msg: string) => void; 12 | 13 | /* istanbul ignore next (debugFn), because typescript injects if for default parameters */ 14 | public constructor( 15 | firestore: FirestoreEquivalent, 16 | debugFn: (msg: string) => void = (msg: string) => { 17 | /* */ 18 | }, 19 | ) { 20 | this.firestore = firestore; 21 | ow(this.firestore, "firestore", ow.object); 22 | 23 | this.debugFn = debugFn; 24 | } 25 | 26 | public async updateAndGet( 27 | collectionName: string, 28 | recordName: string, 29 | updaterFn: (record: PersistenceRecord) => PersistenceRecord, 30 | ): Promise { 31 | let result: PersistenceRecord | undefined; 32 | await this.runTransaction(async () => { 33 | const record = await this.getRecord(collectionName, recordName); 34 | const updatedRecord = updaterFn(record); 35 | if (this.hasRecordChanged(record, updatedRecord)) { 36 | await this.saveRecord(collectionName, recordName, updatedRecord); 37 | } 38 | result = updatedRecord; 39 | }); 40 | /* istanbul ignore next */ 41 | if (!result) throw new Error("FirestorePersistenceProvider: Persistence record could not be updated"); 42 | return result; 43 | } 44 | 45 | public async get(collectionName: string, recordName: string): Promise { 46 | return await this.getRecord(collectionName, recordName); 47 | } 48 | 49 | public setDebugFn(debugFn: (msg: string) => void) { 50 | this.debugFn = debugFn; 51 | } 52 | 53 | private async runTransaction(asyncTransactionFn: () => Promise): Promise { 54 | return await this.firestore.runTransaction(async (transaction: any) => { 55 | await asyncTransactionFn(); 56 | }); 57 | } 58 | 59 | private async getRecord(collectionName: string, recordName: string): Promise { 60 | const docSnapshot = await this.getDocumentRef(collectionName, recordName).get(); 61 | this.debugFn("Got record from collection=" + collectionName + ", document=" + recordName); 62 | 63 | if (!docSnapshot.exists) return this.createEmptyRecord(); 64 | 65 | const record: PersistenceRecord = docSnapshot.data() as PersistenceRecord; 66 | PersistenceRecord.validate(record); 67 | return record; 68 | } 69 | 70 | private async saveRecord(collectionName: string, recordName: string, record: PersistenceRecord): Promise { 71 | this.debugFn("Save record collection=" + collectionName + ", document=" + recordName); 72 | await this.getDocumentRef(collectionName, recordName).set(record); 73 | } 74 | 75 | private getDocumentRef( 76 | collectionName: string, 77 | recordName: string, 78 | ): FirestoreEquivalent.DocumentReferenceEquivalent { 79 | return this.firestore.collection(collectionName).doc(recordName); 80 | } 81 | 82 | private createEmptyRecord(): PersistenceRecord { 83 | return { 84 | u: [], 85 | }; 86 | } 87 | 88 | private hasRecordChanged(oldRecord: PersistenceRecord, newRecord: PersistenceRecord): boolean { 89 | if (oldRecord.u.length !== newRecord.u.length) { 90 | return true; 91 | } else { 92 | const a1 = oldRecord.u.concat().sort(); 93 | const a2 = newRecord.u.concat().sort(); 94 | for (let i = 0; i < a1.length; i++) { 95 | if (a1[i] !== a2[i]) return true; 96 | } 97 | return false; 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/persistence/PersistenceProvider.ts: -------------------------------------------------------------------------------- 1 | import { PersistenceRecord } from "./PersistenceRecord"; 2 | 3 | export interface PersistenceProvider { 4 | updateAndGet( 5 | collectionName: string, 6 | recordName: string, 7 | updater: (record: PersistenceRecord) => PersistenceRecord, 8 | ): Promise; 9 | get(collectionName: string, recordName: string): Promise; 10 | setDebugFn(debugFn: (msg: string) => void): void; 11 | } 12 | -------------------------------------------------------------------------------- /src/persistence/PersistenceProviderMock.ts: -------------------------------------------------------------------------------- 1 | import { PersistenceProvider } from "./PersistenceProvider"; 2 | import { PersistenceRecord } from "./PersistenceRecord"; 3 | 4 | export class PersistenceProviderMock implements PersistenceProvider { 5 | public persistenceObject: { [x: string]: PersistenceRecord } = {}; 6 | 7 | public async updateAndGet( 8 | collectionName: string, 9 | recordName: string, 10 | updaterFn: (record: PersistenceRecord) => PersistenceRecord, 11 | ): Promise { 12 | let result: PersistenceRecord | undefined; 13 | await this.runTransaction(async () => { 14 | const record = await this.getRecord(collectionName, recordName); 15 | const updatedRecord = updaterFn(record); 16 | await this.saveRecord(collectionName, recordName, updatedRecord); 17 | result = updatedRecord; 18 | }); 19 | /* istanbul ignore next */ 20 | if (!result) throw new Error("PersistenceProviderMock: Persistence record could not be updated"); 21 | return result; 22 | } 23 | 24 | public async get(collectionName: string, recordName: string): Promise { 25 | return await this.getRecord(collectionName, recordName); 26 | } 27 | 28 | public setDebugFn(debugFn: (msg: string) => void) { 29 | // 30 | } 31 | 32 | public async getRecord(collectionName: string, recordName: string): Promise { 33 | await this.delay(2); 34 | const key = this.getKey(collectionName, recordName); 35 | return this.persistenceObject[key] || this.createEmptyRecord(); 36 | } 37 | 38 | private async runTransaction(asyncTransactionFn: () => Promise): Promise { 39 | await asyncTransactionFn(); 40 | } 41 | 42 | private async saveRecord(collectionName: string, recordName: string, record: PersistenceRecord): Promise { 43 | await this.delay(2); 44 | const key = this.getKey(collectionName, recordName); 45 | this.persistenceObject[key] = record; 46 | } 47 | 48 | private getKey(collectionName: string, recordName: string): string { 49 | return collectionName + "_" + recordName; 50 | } 51 | 52 | private createEmptyRecord(): PersistenceRecord { 53 | return { 54 | u: [], 55 | }; 56 | } 57 | 58 | private delay(delayMs: number) { 59 | return new Promise(function(resolve, reject) { 60 | setTimeout(function() { 61 | resolve(); 62 | }, delayMs); 63 | }); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/persistence/PersistenceProviders.integration.test.ts: -------------------------------------------------------------------------------- 1 | /* tslint:disable:max-classes-per-file */ 2 | import * as firebase from "@firebase/testing"; 3 | 4 | import { expect, sinon, uuid, _ } from "../_test/test_environment"; 5 | 6 | import { FirestorePersistenceProvider } from "./FirestorePersistenceProvider"; 7 | import { PersistenceProvider } from "./PersistenceProvider"; 8 | import { PersistenceProviderMock } from "./PersistenceProviderMock"; 9 | import { PersistenceRecord } from "./PersistenceRecord"; 10 | import { RealtimeDbPersistenceProvider } from "./RealtimeDbPersistenceProvider"; 11 | 12 | describe("PersistenceProviders", function() { 13 | this.timeout(4000); 14 | 15 | function mock() { 16 | const app = firebase.initializeTestApp({ projectId: "unit-testing-" + Date.now(), databaseName: "db" }); 17 | const uniqueCollectionName = uuid(); 18 | const uniqueDocName = uuid(); 19 | const firestore = app.firestore(); 20 | const database = app.database(); 21 | const provider: PersistenceProvider = undefined as any; 22 | const emptyPersistenceRecord: PersistenceRecord = { u: [] }; 23 | const nonModifyingUpdater = (pr: PersistenceRecord) => pr; 24 | return { 25 | app, 26 | firestore, 27 | database, 28 | uniqueCollectionName, 29 | uniqueDocName, 30 | provider, 31 | emptyPersistenceRecord, 32 | nonModifyingUpdater, 33 | }; 34 | } 35 | 36 | const mockFirestoreProvider: typeof mock = () => { 37 | const mockResult = mock(); 38 | const provider = new FirestorePersistenceProvider(mockResult.firestore); 39 | return { ...mockResult, provider }; 40 | }; 41 | 42 | const mockRealtimeProvider: typeof mock = () => { 43 | const mockResult = mock(); 44 | const provider = new RealtimeDbPersistenceProvider(mockResult.database); 45 | return { ...mockResult, provider }; 46 | }; 47 | 48 | const mockMockProvider: typeof mock = () => { 49 | const mockResult = mock(); 50 | const provider = new PersistenceProviderMock(); 51 | return { ...mockResult, provider }; 52 | }; 53 | 54 | afterEach(async () => { 55 | await Promise.all(firebase.apps().map(app => app.delete())); 56 | }); 57 | 58 | before("startup", async function() { 59 | this.timeout(4000); 60 | const { firestore, database } = mock(); 61 | await firestore 62 | .collection("a") 63 | .doc("a") 64 | .get(); 65 | await database.ref("a").set({ a: "a" }); 66 | }); 67 | 68 | [ 69 | { name: "FirestorePersistenceProvider", mockFactory: mockFirestoreProvider }, 70 | { name: "RealtimeDbPersistenceProvider", mockFactory: mockRealtimeProvider }, 71 | { name: "PersistenceProviderMock", mockFactory: mockMockProvider }, 72 | ].forEach(test => 73 | describe(test.name, () => { 74 | describe("#updateAndGet", () => { 75 | it("Runs transaction code", async () => { 76 | const { 77 | provider, 78 | uniqueCollectionName, 79 | uniqueDocName, 80 | emptyPersistenceRecord, 81 | } = test.mockFactory(); 82 | const spy = sinon.spy(); 83 | await provider.updateAndGet(uniqueCollectionName, uniqueDocName, record => { 84 | spy(); 85 | return emptyPersistenceRecord; 86 | }); 87 | expect(spy.callCount).to.be.equal(1); 88 | }); 89 | 90 | it("Resolves when transaction callback is finshed", async () => { 91 | const { provider, uniqueCollectionName, uniqueDocName } = test.mockFactory(); 92 | const spy = sinon.spy(); 93 | await provider.updateAndGet(uniqueCollectionName, uniqueDocName, record => { 94 | spy(); 95 | return { u: [] }; 96 | }); 97 | expect(spy.callCount).to.be.equal(1); 98 | }); 99 | 100 | it("Returns empty record when no data", async () => { 101 | const { provider, uniqueCollectionName, uniqueDocName, nonModifyingUpdater } = test.mockFactory(); 102 | const rec = await provider.updateAndGet(uniqueCollectionName, uniqueDocName, nonModifyingUpdater); 103 | expect(rec.u) 104 | .to.be.an("array") 105 | .with.length(0); 106 | }); 107 | 108 | it("Saves record properly", async () => { 109 | const { provider, uniqueCollectionName, uniqueDocName, nonModifyingUpdater } = test.mockFactory(); 110 | 111 | const recToBeSaved: PersistenceRecord = { 112 | u: [1, 2, 3], 113 | }; 114 | await provider.updateAndGet(uniqueCollectionName, uniqueDocName, r => recToBeSaved); 115 | 116 | const recRetrived = await provider.updateAndGet( 117 | uniqueCollectionName, 118 | uniqueDocName, 119 | nonModifyingUpdater, 120 | ); 121 | expect(recRetrived.u) 122 | .to.be.an("array") 123 | .with.length(recToBeSaved.u.length) 124 | .that.have.members(recToBeSaved.u); 125 | }); 126 | }); 127 | 128 | describe("#get", () => { 129 | it("Returns empty record when no data", async () => { 130 | const { provider, uniqueCollectionName, uniqueDocName } = test.mockFactory(); 131 | const rec = await provider.get(uniqueCollectionName, uniqueDocName); 132 | expect(rec.u) 133 | .to.be.an("array") 134 | .with.length(0); 135 | }); 136 | 137 | it("Returns previously saved record", async () => { 138 | const { provider, uniqueCollectionName, uniqueDocName } = test.mockFactory(); 139 | 140 | const recToBeSaved: PersistenceRecord = { 141 | u: [1, 2, 3], 142 | }; 143 | await provider.updateAndGet(uniqueCollectionName, uniqueDocName, r => recToBeSaved); 144 | 145 | const recRetrived = await provider.get(uniqueCollectionName, uniqueDocName); 146 | expect(recRetrived.u) 147 | .to.be.an("array") 148 | .with.length(recToBeSaved.u.length) 149 | .that.have.members(recToBeSaved.u); 150 | }); 151 | }); 152 | }), 153 | ); 154 | }); 155 | -------------------------------------------------------------------------------- /src/persistence/PersistenceRecord.ts: -------------------------------------------------------------------------------- 1 | import ow from "ow"; 2 | 3 | export interface PersistenceRecord { 4 | // "u" instead of "usages" to save data transfer 5 | u: number[]; 6 | } 7 | 8 | export namespace PersistenceRecord { 9 | export function validate(r: PersistenceRecord) { 10 | ow(r, "record", ow.object); 11 | ow(r.u, "record.u", ow.array); // checking item types is a costly operation so we skip it 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/persistence/RealtimeDbPersistenceProvider.ts: -------------------------------------------------------------------------------- 1 | import ow from "ow"; 2 | 3 | import { RealtimeDbEquivalent } from "../types/RealtimeDbEquivalent"; 4 | 5 | import { PersistenceProvider } from "./PersistenceProvider"; 6 | import { PersistenceRecord } from "./PersistenceRecord"; 7 | 8 | export class RealtimeDbPersistenceProvider implements PersistenceProvider { 9 | private database: RealtimeDbEquivalent; 10 | 11 | private debugFn: (msg: string) => void; 12 | 13 | /* istanbul ignore next (debugFn), because typescript injects if for default parameters */ 14 | public constructor( 15 | database: RealtimeDbEquivalent, 16 | debugFn: (msg: string) => void = (msg: string) => { 17 | /* */ 18 | }, 19 | ) { 20 | this.database = database; 21 | ow(this.database, "database", ow.object); 22 | 23 | this.debugFn = debugFn; 24 | } 25 | 26 | public async updateAndGet( 27 | collectionName: string, 28 | recordName: string, 29 | updaterFn: (record: PersistenceRecord) => PersistenceRecord, 30 | ): Promise { 31 | const ref = this.getDatabaseRef(collectionName, recordName); 32 | 33 | const response = await ref.transaction(dataToUpdate => this.wrapUpdaterFn(updaterFn)(dataToUpdate)); 34 | const { snapshot, committed } = response; 35 | /* istanbul ignore next because this is not testable locally */ 36 | if (!snapshot) throw new Error("RealtimeDbPersistenceProvider: realtime db didn't respond with data"); 37 | /* istanbul ignore next because this is not testable locally */ 38 | if (!committed) throw new Error("RealtimeDbPersistenceProvider: could not save data"); 39 | 40 | const data = snapshot.val(); 41 | if (data === null) return this.createEmptyRecord(); 42 | else return data as PersistenceRecord; 43 | } 44 | 45 | public async get(collectionName: string, recordName: string): Promise { 46 | const snapshot = await this.getDatabaseRef(collectionName, recordName).once("value"); 47 | 48 | const data = snapshot.val(); 49 | if (data === null) return this.createEmptyRecord(); 50 | else return data as PersistenceRecord; 51 | } 52 | 53 | public setDebugFn(debugFn: (msg: string) => void) { 54 | this.debugFn = debugFn; 55 | } 56 | 57 | private wrapUpdaterFn(updaterFn: (record: PersistenceRecord) => PersistenceRecord): (data: any) => any { 58 | return (data: any) => { 59 | this.debugFn("RealtimeDbPersistenceProvider: updateFn called with data of type" + typeof data); 60 | if (data === null) { 61 | const emptyRecord = this.createEmptyRecord(); 62 | const updatedPr = updaterFn(emptyRecord); 63 | return updatedPr; 64 | } else { 65 | const updatedPr = updaterFn(data); 66 | return updatedPr; 67 | } 68 | }; 69 | } 70 | 71 | private getDatabaseRef(collectionName: string, recordName: string) { 72 | const refName = `${collectionName}/${recordName}`; 73 | return this.database.ref(refName); 74 | } 75 | 76 | private createEmptyRecord(): PersistenceRecord { 77 | return { 78 | u: [], 79 | }; 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/timestamp/FirebaseTimestampProvider.ts: -------------------------------------------------------------------------------- 1 | import * as admin from "firebase-admin"; 2 | 3 | import { TimestampProvider } from "./TimestampProvider"; 4 | 5 | export class FirebaseTimestampProvider implements TimestampProvider { 6 | public getTimestampSeconds(): number { 7 | return admin.firestore.Timestamp.now().seconds; 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /src/timestamp/TimestampProvider.ts: -------------------------------------------------------------------------------- 1 | export interface TimestampProvider { 2 | getTimestampSeconds(): number; 3 | } 4 | -------------------------------------------------------------------------------- /src/timestamp/TimestampProviderMock.test.ts: -------------------------------------------------------------------------------- 1 | import { TimestampProvider } from "./TimestampProvider"; 2 | 3 | export class TimestampProviderMock implements TimestampProvider { 4 | private timestampNowSeconds: number | undefined = undefined; 5 | 6 | public getTimestampSeconds(): number { 7 | if (!this.timestampNowSeconds) return Date.now() / 1000; 8 | else return this.timestampNowSeconds; 9 | } 10 | 11 | public setTimestampSeconds(timestampSeconds: number) { 12 | this.timestampNowSeconds = timestampSeconds; 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /src/types/EquivalentTypes.spec.test.ts: -------------------------------------------------------------------------------- 1 | import * as firebaseTypes from "firebase/app"; 2 | 3 | import { _ } from "../_test/test_environment"; 4 | 5 | import { FirestoreEquivalent } from "./FirestoreEquivalent"; 6 | import { RealtimeDbEquivalent } from "./RealtimeDbEquivalent"; 7 | 8 | describe("Firebase equivalents", () => { 9 | // tslint:disable prefer-const 10 | let firestore!: firebaseTypes.firestore.Firestore; 11 | let database!: firebaseTypes.database.Database; 12 | 13 | describe("FirestoreEquivalent", () => { 14 | function acceptFirestoreEquivalent(firestoreEquivalent: FirestoreEquivalent) { 15 | return firestoreEquivalent; 16 | } 17 | 18 | it("Matches firebase/app typings", () => { 19 | acceptFirestoreEquivalent(firestore); 20 | }); 21 | }); 22 | 23 | describe("RealtimeDbEquivalent", () => { 24 | function acceptRealtimeDbEquivalent(realtimeDbEquivalent: RealtimeDbEquivalent) { 25 | return realtimeDbEquivalent; 26 | } 27 | 28 | it("Matches firebase/app typings", () => { 29 | acceptRealtimeDbEquivalent(database); 30 | }); 31 | }); 32 | }); 33 | -------------------------------------------------------------------------------- /src/types/FirestoreEquivalent.ts: -------------------------------------------------------------------------------- 1 | export interface FirestoreEquivalent { 2 | runTransaction(tCallback: (transaction: any) => Promise): Promise; 3 | 4 | collection( 5 | name: string, 6 | ): { 7 | doc(name: string): FirestoreEquivalent.DocumentReferenceEquivalent; 8 | }; 9 | } 10 | 11 | export namespace FirestoreEquivalent { 12 | export interface DocumentReferenceEquivalent { 13 | get(): Promise<{ 14 | exists: boolean; 15 | data(): object | undefined; 16 | }>; 17 | set(record: object): Promise; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/types/RealtimeDbEquivalent.ts: -------------------------------------------------------------------------------- 1 | export interface RealtimeDbEquivalent { 2 | ref(path?: string | any | undefined): RealtimeDbEquivalent.Reference; 3 | } 4 | 5 | export namespace RealtimeDbEquivalent { 6 | export interface Reference { 7 | transaction( 8 | updateFn: (data: any) => any, 9 | completeFn?: any, 10 | ): Promise<{ committed: boolean; snapshot: DataSnapshot | null }>; 11 | once(eventType: "value"): Promise; 12 | } 13 | 14 | export interface DataSnapshot { 15 | val(): any; 16 | exists(): boolean; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/utils.test.ts: -------------------------------------------------------------------------------- 1 | export function delayMs(ms: number) { 2 | return new Promise((resolve) => setTimeout(resolve, ms)); 3 | } -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "module": "commonjs", 4 | "target": "es5", 5 | "strict": true, 6 | "declaration": true, 7 | "moduleResolution": "node", 8 | "allowSyntheticDefaultImports": false, 9 | "noImplicitAny": true, 10 | "allowJs": false, 11 | "sourceMap": true, 12 | "outDir": "dist", 13 | "baseUrl": "src/", 14 | "paths": { 15 | "*": [ 16 | "node_modules/*", 17 | "src/types/*" 18 | ] 19 | } 20 | }, 21 | "include": [ 22 | "src/**/*" 23 | ], 24 | "exclude": [ 25 | "src/**/_test", 26 | "src/**/*.test.ts" 27 | ] 28 | } -------------------------------------------------------------------------------- /tsconfig.lint.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "./tsconfig.json", 3 | "exclude": [] 4 | } -------------------------------------------------------------------------------- /tslint.json: -------------------------------------------------------------------------------- 1 | { 2 | "defaultSeverity": "error", 3 | "extends": [ 4 | "tslint:recommended" 5 | ], 6 | "jsRules": { 7 | "max-line-length": { 8 | "options": [ 9 | 120 10 | ] 11 | } 12 | }, 13 | "rules": { 14 | "max-line-length": { 15 | "options": [ 16 | 120 17 | ] 18 | }, 19 | "no-namespace": false, 20 | "only-arrow-functions": false, 21 | "object-literal-sort-keys": false, 22 | "interface-name": false, 23 | "arrow-parens": false, 24 | "semicolon": [ 25 | true, 26 | "always" 27 | ], 28 | "object-literal-shorthand": true, 29 | "ordered-imports": [ 30 | true, 31 | { 32 | "grouped-imports": "true" 33 | } 34 | ], 35 | "no-unused-variable": true, 36 | "cyclomatic-complexity": [ 37 | true, 38 | 10 39 | ], 40 | "curly": [ 41 | true, 42 | "ignore-same-line" 43 | ] 44 | }, 45 | "rulesDirectory": [] 46 | } --------------------------------------------------------------------------------