├── example ├── __init__.py ├── wsgi.py ├── urls.py └── settings.py ├── requirements-dev.txt ├── django_graphql_ratelimit ├── __init__.py ├── middleware.py ├── ratelimit.py └── tests.py ├── .gitignore ├── requirements.txt ├── .flake8 ├── Makefile ├── .circleci └── config.yml ├── tox.ini ├── .github └── workflows │ └── main.yml ├── manage.py ├── setup.py └── README.md /example/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | tox 3 | flake8 4 | django-nose 5 | twine 6 | -------------------------------------------------------------------------------- /django_graphql_ratelimit/__init__.py: -------------------------------------------------------------------------------- 1 | from .ratelimit import GQLRatelimitKey, ratelimit # noqa 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | *.egg-info 4 | dist 5 | *.tox 6 | *.sw[a-z] 7 | build 8 | .python-version 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | graphene 2 | graphene-django 3 | django-ratelimit>=4.0.0 4 | # Get request ip addres 5 | django-ipware 6 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503 3 | max-line-length = 80 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | clean: 2 | rm -f dist/* 3 | test: 4 | tox 5 | build: clean 6 | @python setup.py sdist bdist_wheel 7 | release: build 8 | @twine upload dist/* 9 | 10 | .PHONY: release 11 | -------------------------------------------------------------------------------- /django_graphql_ratelimit/middleware.py: -------------------------------------------------------------------------------- 1 | from ipware import get_client_ip 2 | 3 | 4 | def ParseClientIpMiddleware(get_response): 5 | def middleware(request): 6 | request.META["REMOTE_ADDR"] = get_client_ip(request)[0] 7 | response = get_response(request) 8 | return response 9 | 10 | return middleware 11 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-python/ for more details 4 | # 5 | version: 2 6 | jobs: 7 | build: 8 | docker: 9 | - image: themattrix/tox 10 | 11 | working_directory: ~/repo 12 | steps: 13 | - run: apt-get update && apt-get install -y git ssh 14 | - checkout 15 | - run: tox 16 | -------------------------------------------------------------------------------- /example/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for example project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'example.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | minversion = 2.0 3 | skipsdist = True 4 | envlist = pep8, py3-django{18,19,20,21,22} 5 | 6 | [testenv] 7 | deps= 8 | django18: Django>=1.8,<1.9 9 | django19: Django>=1.9,<2.0 10 | django20: Django>=2.0,<2.1 11 | django21: Django>=2.1,<2.2 12 | django22: Django>=2.2,<2.3 13 | -r{toxinidir}/requirements-dev.txt 14 | commands = 15 | python manage.py test 16 | 17 | [testenv:pep8] 18 | deps = flake8 19 | changedir={toxinidir} 20 | commands = flake8 ./example django_graphql_ratelimit 21 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: 11 | - '3.9' 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | python -m pip install tox tox-gh-actions 22 | - name: Test with tox 23 | run: tox 24 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Django's command-line utility for administrative tasks.""" 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "example.settings") 9 | try: 10 | from django.core.management import execute_from_command_line 11 | except ImportError as exc: 12 | raise ImportError( 13 | "Couldn't import Django. Are you sure it's installed and " 14 | "available on your PYTHONPATH environment variable? Did you " 15 | "forget to activate a virtual environment?" 16 | ) from exc 17 | execute_from_command_line(sys.argv) 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /example/urls.py: -------------------------------------------------------------------------------- 1 | """example URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/2.2/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.urls import include, path 14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 15 | """ 16 | from django.contrib import admin 17 | from django.urls import path 18 | 19 | urlpatterns = [ 20 | path('admin/', admin.site.urls), 21 | ] 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | install_requires = open("requirements.txt").read().splitlines() 4 | 5 | setup( 6 | name="django-graphql-ratelimit", 7 | version="1.0.2", 8 | description="Use django-ratelimit for graphql", 9 | long_description=open("README.md").read(), 10 | long_description_content_type="text/markdown", 11 | author="o3o3o", 12 | author_email="o3o3o.me@gmail.com", 13 | url="https://github.com/o3o3o/django-graphql-ratelimit", 14 | license="Apache Software License", 15 | include_package_data=True, 16 | packages=find_packages(), 17 | install_requires=install_requires, 18 | classifiers=[ 19 | "Development Status :: 5 - Production/Stable", 20 | "Environment :: Web Environment", 21 | "Framework :: Django", 22 | "Intended Audience :: Developers", 23 | "License :: OSI Approved :: Apache Software License", 24 | "Operating System :: OS Independent", 25 | "Programming Language :: Python", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.3", 28 | "Programming Language :: Python :: 3.4", 29 | "Programming Language :: Python :: 3.5", 30 | "Programming Language :: Python :: 3.6", 31 | "Programming Language :: Python :: 3.7", 32 | "Programming Language :: Python :: Implementation :: CPython", 33 | "Programming Language :: Python :: Implementation :: PyPy", 34 | "Topic :: Software Development :: Libraries :: Python Modules", 35 | ], 36 | ) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![CircleCI](https://circleci.com/gh/o3o3o/django-graphql-ratelimit.svg?style=svg)](https://circleci.com/gh/o3o3o/django-graphql-ratelimit) [![PyPI version](https://badge.fury.io/py/django-graphql-ratelimit.svg)](https://badge.fury.io/py/django-graphql-ratelimit) 2 | 3 | Eaiser to use [django-ratelimit](https://github.com/jsocol/django-ratelimit) for graphql in django. 4 | 5 | 6 | # Install 7 | 8 | ``` 9 | pip install django-graphql-ratelimit 10 | ``` 11 | 12 | # Usage 13 | 14 | ratelimit key support `gql:xxx`, where `xxx` is argument. 15 | 16 | ```python 17 | from django_graphql_ratelimit import ratelimit 18 | 19 | class RequestSMSCode(graphene.Mutation): 20 | class Arguments: 21 | phone = graphene.String(required=True) 22 | 23 | ok = graphene.Boolean() 24 | 25 | @ratelimit(key="ip", rate="10/m", block=True) 26 | @ratelimit(key="gql:phone", rate="5/m", block=True) 27 | def mutate(self, info, phone): 28 | request = info.context 29 | # send sms code logic 30 | return RequestSMSCode(ok=True) 31 | ``` 32 | You can use [django-ratelimit keys](https://django-ratelimit.readthedocs.io/en/latest/keys.html#common-keys) except `get:xxx` and `post:xxx`: 33 | * `ip` - Use the request IP address (i.e. `request.META['REMOTE_ADDR']`) 34 | I suggest you to use [django-ipware](https://github.com/un33k/django-ipware) to get client ip, modify your `MIDDLEWARE` in settings: 35 | ``` 36 | MIDDLEWARE = [ 37 | "django_graphql_ratelimit.middleware.ParseClientIpMiddleware", 38 | ... 39 | ] 40 | ``` 41 | 42 | * `header:x-x` - Use the value of request.META.get('HTTP_X_X', ''). 43 | * `user` - Use an appropriate value from request.user. Do not use with unauthenticated users. 44 | * `user_or_ip` - Use an appropriate value from `request.user` if the user is authenticated, otherwise use `request.META['REMOTE_ADDR']`. 45 | -------------------------------------------------------------------------------- /django_graphql_ratelimit/ratelimit.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from django.conf import settings 3 | from django.utils.module_loading import import_string 4 | from functools import wraps 5 | from django_ratelimit import ALL, UNSAFE 6 | from django_ratelimit.exceptions import Ratelimited 7 | from django_ratelimit.core import is_ratelimited 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | __all__ = ["ratelimit"] 13 | 14 | 15 | def GQLRatelimitKey(group, request): 16 | return request.gql_rl_field 17 | 18 | 19 | def ratelimit(group=None, key=None, rate=None, method=ALL, block=False): 20 | def decorator(fn): 21 | @wraps(fn) 22 | def _wrapped(root, info, **kw): 23 | request = info.context 24 | 25 | old_limited = getattr(request, "limited", False) 26 | 27 | if key and key.startswith("gql:"): 28 | _key = key.split("gql:")[1] 29 | value = kw.get(_key, None) 30 | if not value: 31 | raise ValueError(f"Cannot get key: {key}") 32 | request.gql_rl_field = value 33 | 34 | new_key = GQLRatelimitKey 35 | else: 36 | new_key = key 37 | 38 | ratelimited = is_ratelimited( 39 | request=request, 40 | group=group, 41 | fn=fn, 42 | key=new_key, 43 | rate=rate, 44 | method=method, 45 | increment=True, 46 | ) 47 | 48 | request.limited = ratelimited or old_limited 49 | 50 | if ratelimited and block: 51 | # logger.warn( 52 | # "url:<%s> is denied for <%s> in Ratelimit" 53 | # % (request.path, request.META["REMOTE_ADDR"]) 54 | # ) 55 | cls = getattr(settings, "RATELIMIT_EXCEPTION_CLASS", Ratelimited) 56 | raise (import_string(cls) if isinstance(cls, str) else cls)( 57 | "rate_limited" 58 | ) 59 | return fn(root, info, **kw) 60 | 61 | return _wrapped 62 | 63 | return decorator 64 | 65 | 66 | ratelimit.ALL = ALL 67 | ratelimit.UNSAFE = UNSAFE 68 | -------------------------------------------------------------------------------- /django_graphql_ratelimit/tests.py: -------------------------------------------------------------------------------- 1 | import graphene 2 | from django.test import TestCase 3 | 4 | from django.test import RequestFactory 5 | from graphene.test import Client 6 | from django_graphql_ratelimit import ratelimit 7 | 8 | 9 | class MockUser(object): 10 | def __init__(self, authenticated=False): 11 | self.pk = 1 12 | self.is_authenticated = authenticated 13 | self.META = {"REMOTE_ADDR": "192.168.1.1"} 14 | 15 | 16 | class RequestSMSCode(graphene.Mutation): 17 | class Arguments: 18 | phone = graphene.String(required=True) 19 | 20 | ok = graphene.Boolean() 21 | 22 | # @ratelimit(key="ip", rate="1/m", block=True) 23 | # @ratelimit(key="user_or_ip", rate="2/m", block=True) 24 | @ratelimit(key="gql:phone", rate="1/m", block=True) 25 | def mutate(self, info, phone): 26 | # request = info.context 27 | # send sms code logic 28 | return RequestSMSCode(ok=True) 29 | 30 | 31 | class Picture(graphene.ObjectType): 32 | url = graphene.String() 33 | 34 | 35 | class Query(graphene.ObjectType): 36 | pictures = graphene.List(Picture) 37 | 38 | @ratelimit(key="user_or_ip", rate="1/m", block=True) 39 | def resolve_pictures(root, info): 40 | return [Picture(url="https://example.com")] 41 | 42 | 43 | class Mutation(graphene.ObjectType): 44 | request_sms_code = RequestSMSCode.Field() 45 | 46 | 47 | schema = graphene.Schema(query=Query, mutation=Mutation) 48 | 49 | 50 | class GqlRatelimitTestCase(TestCase): 51 | def setUp(self): 52 | self.rf = RequestFactory() 53 | self.context = self.rf.get("/") 54 | self.context.user = MockUser() 55 | self.client = Client(schema) 56 | 57 | def test_rateliimt_with_gql(self): 58 | query = """ 59 | mutation { 60 | requestSmsCode(phone: "+8612345678"){ 61 | ok 62 | } 63 | } 64 | """ 65 | resp = self.client.schema.execute(query, context_value=self.context) 66 | self.assertIsNone(resp.errors, msg=resp.errors) 67 | 68 | # ratelimited by phone 69 | resp = self.client.schema.execute(query, context_value=self.context) 70 | self.assertIn("rate_limited", str(resp.errors[0]), msg=resp.errors) 71 | 72 | query = """ 73 | query { 74 | pictures{ 75 | url 76 | } 77 | } 78 | """ 79 | resp = self.client.schema.execute(query, context_value=self.context) 80 | self.assertIsNone(resp.errors, msg=resp.errors) 81 | 82 | resp = self.client.schema.execute(query, context_value=self.context) 83 | # ratelimit by user_or_ip 84 | self.assertIn("rate_limited", str(resp.errors[0]), msg=resp.errors) 85 | -------------------------------------------------------------------------------- /example/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for example project. 3 | 4 | Generated by 'django-admin startproject' using Django 2.2.5. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/2.2/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/2.2/ref/settings/ 11 | """ 12 | 13 | import os 14 | 15 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/2.2/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = "9+bdkzs&m$(*$1d8e_)(1ghtbs$^^!#66lbbx-yx0zl)z9@6_p" 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = [] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | "django.contrib.admin", 35 | "django.contrib.auth", 36 | "django.contrib.contenttypes", 37 | "django.contrib.sessions", 38 | "django.contrib.messages", 39 | "django.contrib.staticfiles", 40 | ] 41 | 42 | MIDDLEWARE = [ 43 | "django.middleware.security.SecurityMiddleware", 44 | "django.contrib.sessions.middleware.SessionMiddleware", 45 | "django.middleware.common.CommonMiddleware", 46 | "django.middleware.csrf.CsrfViewMiddleware", 47 | "django.contrib.auth.middleware.AuthenticationMiddleware", 48 | "django.contrib.messages.middleware.MessageMiddleware", 49 | "django.middleware.clickjacking.XFrameOptionsMiddleware", 50 | ] 51 | 52 | ROOT_URLCONF = "example.urls" 53 | 54 | TEMPLATES = [ 55 | { 56 | "BACKEND": "django.template.backends.django.DjangoTemplates", 57 | "DIRS": [], 58 | "APP_DIRS": True, 59 | "OPTIONS": { 60 | "context_processors": [ 61 | "django.template.context_processors.debug", 62 | "django.template.context_processors.request", 63 | "django.contrib.auth.context_processors.auth", 64 | "django.contrib.messages.context_processors.messages", 65 | ] 66 | }, 67 | } 68 | ] 69 | 70 | WSGI_APPLICATION = "example.wsgi.application" 71 | 72 | 73 | # Database 74 | # https://docs.djangoproject.com/en/2.2/ref/settings/#databases 75 | 76 | DATABASES = { 77 | "default": { 78 | "ENGINE": "django.db.backends.sqlite3", 79 | "NAME": os.path.join(BASE_DIR, "db.sqlite3"), 80 | } 81 | } 82 | 83 | 84 | # Password validation 85 | # https://docs.djangoproject.com/en/2.2/ref/settings/#auth-password-validators 86 | 87 | AUTH_PASSWORD_VALIDATORS = [ 88 | { 89 | "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator" 90 | }, 91 | {"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"}, 92 | {"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"}, 93 | {"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"}, 94 | ] 95 | 96 | 97 | # Internationalization 98 | # https://docs.djangoproject.com/en/2.2/topics/i18n/ 99 | 100 | LANGUAGE_CODE = "en-us" 101 | 102 | TIME_ZONE = "UTC" 103 | 104 | USE_I18N = True 105 | 106 | USE_L10N = True 107 | 108 | USE_TZ = True 109 | 110 | 111 | # Static files (CSS, JavaScript, Images) 112 | # https://docs.djangoproject.com/en/2.2/howto/static-files/ 113 | 114 | STATIC_URL = "/static/" 115 | 116 | TEST_RUNNER = "django_nose.NoseTestSuiteRunner" 117 | --------------------------------------------------------------------------------