├── .gitignore ├── README.md ├── manage.py ├── setup.py ├── sprinklers ├── __init__.py ├── app_settings.py ├── base.py └── registry.py └── tests ├── __init__.py ├── migrations ├── 0001_initial.py └── __init__.py ├── models.py ├── settings.py ├── tasks.py └── test_sprinklers.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | *.pyc 3 | *.db 4 | /project/ 5 | /dist 6 | *.egg-info 7 | .DS_Store 8 | /build 9 | *# 10 | *~ 11 | .coverage 12 | /htmlcov/ 13 | *.sqlite* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## TL;DR 2 | 3 | Django-Sprinklers imposes structure on jobs that perform asynchronous processing of Django models. 4 | 5 | Specify a queryset and a function to perform on each object in the queryset and Sprinklers will distribute your queryset into a group of asynchronous jobs, perform logging, and track the status of each job. 6 | 7 | Using this wrapper you can avoid repeatedly writing fiddly bits of celery code and repetitive logging so you can focus on your business logic. 8 | 9 | ## For example 10 | 11 | At a high level celery cron jobs often take the form of "run this queryset, and take an action on each object." For example: 12 | 13 | - Find all customers that should be billed for our recurring service, and bill them. 14 | - Find all users who should receive a welcome email, and send them each a message. 15 | - For each product listing, update the details from an external service. 16 | - etc. etc. 17 | 18 | Typically code for these cron tasks would look something like: 19 | 20 | ```python 21 | from celery import task 22 | 23 | @task 24 | def refresh_objects(): 25 | logger.info("Starting refresh objects...") 26 | qs = Item.objects.all() 27 | for obj in qs: 28 | if obj.needs_update(): 29 | # Remember to pass 'id' and not the object itself! 30 | # Remember to call .delay on the task! 31 | get_updated_item_from_slow_external_service.delay(obj.id) 32 | logger.info("I would love to say 'finished' here, but I spawned a bunch of async tasks and can't actually do that...") 33 | 34 | @task 35 | def get_updated_field_from_slow_external_service(id): 36 | try: 37 | item = Item.objects.get(pk=id) 38 | except Item.DoesNotExist: 39 | logger.info("grrr...") 40 | logger.info("Starting update of item %s..." % id) 41 | item.field = ExternalServiceWrapper().get(id)['field'] 42 | item.save() 43 | logger.info("Successful update of item %s." % id) 44 | ``` 45 | 46 | This is fine, but as logic gets more complex and as you add more jobs that follow a similar pattern, you'll find that you handle logging slightly differently from job to job, that you want to run code after all the subtasks have completed, and in general things are looking a bit messy. 47 | 48 | Use Sprinklers to impose structure and make these jobs testable. 49 | 50 | ```python 51 | 52 | # tasks.py 53 | 54 | from sprinklers.base import SprinklerBase, registry, SubtaskValidationException 55 | 56 | class ItemUpdateSprinkler(SprinklerBase): 57 | 58 | def get_queryset(): 59 | return Item.objects.all() 60 | 61 | def validate(obj) 62 | if not obj.needs_update(): 63 | raise SubtaskValidationException() 64 | 65 | def subtask(obj): 66 | obj.field = ExternalServiceWrapper().get(obj.id)['field'] 67 | obj.save() 68 | return obj.id # gets aggregated into a results argument 69 | 70 | def finished(results): 71 | logger.info("Updated %s items." % len(results)) 72 | registry.register(ItemUpdateSprinkler) 73 | 74 | 75 | # This is the entry point to the job. You can use it in your crontab configuration as normal: 76 | 77 | # CELERYBEAT_SCHEDULE = { 78 | # 'item.tasks.start_item_sprinkler': { 79 | # 'task': 'item.tasks.start_item_sprinkler', 80 | # 'schedule': crontab(hour=24, minute=0), 81 | # }, 82 | 83 | @task 84 | def start_item_sprinkler(): 85 | ItemUpdateSprinkler().start() 86 | 87 | 88 | ``` 89 | 90 | You can also pass **kwargs into the Sprinkler's start() function, which will be accessible downstream to all Sprinkler methods. See tasks.py and models.py in /tests for how this works. 91 | 92 | ## Testing 93 | 94 | The Sprinkler tests are a bit trickier to run that just 'manage.py test' because every attempt has been made to mimic an async production celery environment. 95 | 96 | In order to run the tests you will need: 97 | 98 | 0. Change CELERY_ALWAYS_EAGER = False in settings.py 99 | 1. A local postgres server set up in line with the DB config in test.settings 100 | 2. A running redis server with default localhost config (redis://localhost:6379/0) 101 | 3. A running celery daemon/worker. 102 | 4. Then run 'python manage.py test' 103 | 104 | I find it easiest to run 2 & 3 via: 105 | 106 | ``` 107 | screen -d -S 'redis' -m redis-server 108 | screen -d -S 'celery' -m python manage.py celeryd 109 | ``` 110 | 111 | This will run each in the background, and you can 'screen -r redis' or 'screen -r celery' to view them (Ctrl-a-d to detach). 112 | 113 | If you are working on this project directly, **remember to restart celery after code changes** to django-sprinklers. Celery does not live reload! 114 | 115 | ## FAQ 116 | 117 | - Q: Will this work on any iterable? Does it have to be a Django queryset? 118 | - A: It has to be a queryset (or a valuesqueryset). Sprinklers relies on some introspection to determine which model class to use for individual object retrieval. -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | if __name__ == "__main__": 6 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.settings") 7 | 8 | from django.core.management import execute_from_command_line 9 | 10 | execute_from_command_line(sys.argv) 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from sprinklers import __version__ 4 | 5 | # Utility function to read the README file. 6 | # Used for the long_description. It's nice, because now 1) we have a top level 7 | # README file and 2) it's easier to type in the README file than to put a raw 8 | # string in below ... 9 | 10 | def read(fname): 11 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 12 | 13 | setup( 14 | name = "django-sprinklers", 15 | version = __version__, 16 | author = "Chris Clark", 17 | author_email = "chris@untrod.com", 18 | description = ("A simple framework for bulletproof distributed tasks with Django and Celery."), 19 | license = "MIT", 20 | keywords = "django celery sprinklers sprinkler distributed tasks", 21 | url = "https://github.com/chrisclark/django-sprinklers", 22 | packages=['sprinklers'], 23 | long_description=read('README.md'), 24 | classifiers=[ 25 | "Topic :: Utilities", 26 | ], 27 | install_requires=[ 28 | 'Django>=1.6.7', 29 | ], 30 | include_package_data=True, 31 | zip_safe = False, 32 | ) 33 | -------------------------------------------------------------------------------- /sprinklers/__init__.py: -------------------------------------------------------------------------------- 1 | __version_info__ = { 2 | 'major': 0, 3 | 'minor': 0, 4 | 'micro': 1, 5 | 'releaselevel': 'final', 6 | 'serial': 0 7 | } 8 | 9 | def get_version(short=False): 10 | assert __version_info__['releaselevel'] in ('alpha', 'beta', 'final') 11 | vers = ["%(major)i.%(minor)i" % __version_info__, ] 12 | if __version_info__['micro']: 13 | vers.append(".%(micro)i" % __version_info__) 14 | if __version_info__['releaselevel'] != 'final' and not short: 15 | vers.append('%s%i' % (__version_info__['releaselevel'][0], __version_info__['serial'])) 16 | return ''.join(vers) 17 | 18 | __version__ = get_version() 19 | -------------------------------------------------------------------------------- /sprinklers/app_settings.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | 3 | 4 | SPRINKLER_DEFAULT_SHARD_SIZE = getattr(settings, 'SPRINKLER_DEFAULT_SHARD_SIZE', 20000) 5 | -------------------------------------------------------------------------------- /sprinklers/base.py: -------------------------------------------------------------------------------- 1 | from . import app_settings 2 | from celery import chord, current_app, Task 3 | from .registry import sprinkler_registry as registry 4 | import logging 5 | import uuid 6 | from time import time 7 | 8 | logger = logging.getLogger('') 9 | 10 | 11 | def async_subtask(obj_pk, sprinkler_name, kwargs): 12 | """ 13 | async_subtask -- inner implementation of :func:`_async_subtask` 14 | 15 | Useful for overriding the task decoration on the default implementation to 16 | change the ``rate_limit``, ``name``, and other registered properties of 17 | :class:`celery.Task` in dependent modules: 18 | 19 | >>> from django.contrib.auth.models import User 20 | >>> from celery import current_app 21 | >>> from sprinklers.base import async_subtask, SprinklerBase 22 | >>> one_per_minute = current_app.task( 23 | ... async_subtask, 24 | ... rate_limit='1/m', 25 | ... name='one_per_minute', 26 | ... ) 27 | >>> class SomeSprinkler(SprinklerBase): 28 | ... _async_subtask = one_per_minute 29 | ... def subtask(self, obj): 30 | ... print(obj) 31 | ... def get_queryset(self): 32 | ... return User.objects.all()[0:5] 33 | ... 34 | >>> SomeSprinkler().start() 35 | """ 36 | return registry[sprinkler_name](**kwargs)._run_subtask(obj_pk) 37 | 38 | 39 | _async_subtask = current_app.task(async_subtask) 40 | 41 | 42 | @current_app.task() 43 | def _async_shard_start(shard_id, from_pk, to_pk, sprinkler_name, kwargs): 44 | sprinkler = registry[sprinkler_name](**kwargs) 45 | return sprinkler.shard_start(shard_id, from_pk, to_pk) 46 | 47 | 48 | @current_app.task() 49 | def _sprinkler_shard_finished_wrap(results, shard_id, sprinkler_name, kwargs): 50 | sprinkler = registry[sprinkler_name](**kwargs) 51 | sprinkler.log(f"shard finished: {shard_id}") 52 | sprinkler.shard_finished(shard_id, results) 53 | 54 | 55 | @current_app.task() 56 | def _sprinkler_finished_wrap(results, sprinkler_name, kwargs): 57 | sprinkler = registry[sprinkler_name](**kwargs) 58 | sprinkler.log("Finished with results (length %s): %s" % (len(results), results)) 59 | sprinkler.finished(results) 60 | 61 | 62 | class SubtaskValidationException(Exception): 63 | pass 64 | 65 | 66 | class SprinklerBase(object): 67 | subtask_queue = current_app.conf.CELERY_DEFAULT_QUEUE 68 | klass = None 69 | 70 | def __init__(self, **kwargs): 71 | self.kwargs = kwargs 72 | if self.klass is None: 73 | self.klass = self.get_queryset().model 74 | 75 | def start(self): 76 | qs = self.get_queryset() 77 | ids = [o['id'] if isinstance(o, dict) else o.id for o in qs] 78 | 79 | async_subtask = ( 80 | self._async_subtask 81 | if isinstance(getattr(self, '_async_subtask', None), Task) 82 | else _async_subtask 83 | ) 84 | 85 | c = chord( 86 | ( 87 | # .s is shorthand for .signature() 88 | async_subtask 89 | .s(i, self.__class__.__name__, self.kwargs) 90 | .set(queue=self.get_subtask_queue()) 91 | for i in ids 92 | ), 93 | _sprinkler_finished_wrap.s(sprinkler_name=self.__class__.__name__, kwargs=self.kwargs).set(queue=self.get_subtask_queue()) 94 | ) 95 | 96 | start_time = time() 97 | c.apply_async() 98 | end_time = time() 99 | 100 | duration = (end_time - start_time) * 1000 101 | self.log("Started with %s objects in %sms." % (len(ids), duration)) 102 | self.log("Started with objects: %s" % ids) 103 | 104 | def finished(self, results): 105 | pass 106 | 107 | def get_queryset(self): 108 | raise NotImplementedError 109 | 110 | def validate(self, obj): 111 | """Should raise SubtaskValidationException if validation fails.""" 112 | pass 113 | 114 | def subtask(self, obj): 115 | """ Do work on obj and return whatever results are needed.""" 116 | raise NotImplementedError 117 | 118 | def get_subtask_queue(self): 119 | return self.subtask_queue 120 | 121 | def on_error(self, obj, e): 122 | """ Called if an unexpected exception, e, occurs while running the subtask on obj. 123 | Results from this function will be aggregated into the results passed to the 124 | .finished() method. To emulate default Celery behavior, just reraise e here. 125 | Note that raising an exception in subtask execution will prevent the chord from 126 | ever firing its callback (though other subtasks will continue to execute).""" 127 | raise e 128 | 129 | def on_validation_exception(self, obj, e): 130 | """ Called if validate raises a SubtaskValidationException.""" 131 | return None 132 | 133 | def _run_subtask(self, obj_pk): 134 | """Executes the sprinkle pipeline. Should not be overridden.""" 135 | obj = None 136 | try: 137 | obj = self.klass.objects.get(pk=obj_pk) 138 | self._log_execution_step(self.validate, obj) 139 | # if subtask() doesn't return a value, return the object id so something more helpful than None 140 | # gets aggregated into the results object (passed to 'finish'). 141 | return self._log_execution_step(self.subtask, obj) or obj.id 142 | except self.klass.DoesNotExist: 143 | self.log("Object <%s - %s> does not exist." % (self.klass.__name__, obj_pk)) 144 | except SubtaskValidationException as e: 145 | self.log("Validation failed for object %s: %s" % (obj, e)) 146 | return self.on_validation_exception(obj, e) 147 | except Exception as e: 148 | if obj is not None: 149 | self.log("Unexpected exception for object %s: %s" % (obj, e)) 150 | return self.on_error(obj, e) 151 | raise e 152 | 153 | def _log_execution_step(self, fn, obj): 154 | fn_name = fn.__name__.split('.')[-1] 155 | self.log("%s is starting for object %s." % (fn_name, obj)) 156 | res = fn(obj) 157 | self.log("%s has finished for object %s." % (fn_name, obj)) 158 | return res 159 | 160 | def __repr__(self): 161 | return "%s - %s" % (str(self.__class__.__name__), self.kwargs) 162 | 163 | def log(self, msg): 164 | logger.info("SPRINKLER %s: %s" % (self, msg)) 165 | 166 | 167 | class ShardedSprinkler(SprinklerBase): 168 | shard_size = app_settings.SPRINKLER_DEFAULT_SHARD_SIZE 169 | 170 | def start(self): 171 | shards = list(self.build_shards()) 172 | 173 | # the sharded sprinkler calls finished on output of shard_start for each shard, passing the shard ID, 174 | # rather than the results of the completed shard tasks 175 | 176 | c = chord( 177 | ( 178 | _async_shard_start.s(shard_id, from_pk, to_pk, self.__class__.__name__, self.kwargs).set(queue=self.get_subtask_queue()) 179 | for shard_id, from_pk, to_pk in shards 180 | ), 181 | _sprinkler_finished_wrap.s(sprinkler_name=self.__class__.__name__, kwargs=self.kwargs).set(queue=self.get_subtask_queue()) 182 | ) 183 | 184 | start_time = time() 185 | c.apply_async() 186 | end_time = time() 187 | 188 | duration = (end_time - start_time) * 1000 189 | shard_ids = [shard[0] for shard in shards] 190 | self.log(f"Started with {len(shards)} shards in {duration}ms.") 191 | self.log(f"Started with shards: {[str(shard_id) for shard_id in shard_ids]}") 192 | return shard_ids 193 | 194 | def shard_start(self, shard_id, from_pk=None, to_pk=None): 195 | pks = self.get_queryset_pks(from_pk, to_pk) 196 | 197 | c = chord( 198 | ( 199 | _async_subtask.s(pk, self.__class__.__name__, self.kwargs).set(queue=self.get_subtask_queue()) 200 | for pk in pks 201 | ), 202 | _sprinkler_shard_finished_wrap.s(sprinkler_name=self.__class__.__name__, shard_id=shard_id, kwargs=self.kwargs).set(queue=self.get_subtask_queue()) 203 | ) 204 | 205 | start_time = time() 206 | c.apply_async() 207 | end_time = time() 208 | 209 | duration = (end_time - start_time) * 1000 210 | self.log(f"Started shard {shard_id} in {duration}ms.") 211 | 212 | return shard_id 213 | 214 | def shard_finished(self, shard_id, results): 215 | pass 216 | 217 | def get_queryset_pks(self, from_pk=None, to_pk=None): 218 | queryset = self.get_queryset().only('pk').order_by('pk') 219 | 220 | if from_pk is not None: 221 | queryset = queryset.filter(pk__gt=from_pk) 222 | 223 | if to_pk is not None: 224 | queryset = queryset.filter(pk__lte=to_pk) 225 | 226 | # values_list in django 1.11 is broken and will run out of memory when iterating over a large queryset, even with .iterator() 227 | # the following code does basically the same thing as values_list, without running out of memory 228 | db = queryset.db 229 | compiler = queryset.query.get_compiler(db) 230 | results = compiler.execute_sql(chunked_fetch=True) 231 | 232 | for row in compiler.results_iter(results): 233 | yield row[0] 234 | 235 | def build_shards(self): 236 | last_pk = None 237 | next_pk = None 238 | 239 | for i, pk in enumerate(self.get_queryset_pks(), 1): 240 | if i % self.shard_size == 0: 241 | last_pk = next_pk 242 | next_pk = pk 243 | yield uuid.uuid4(), last_pk, next_pk 244 | 245 | yield uuid.uuid4(), next_pk, None 246 | -------------------------------------------------------------------------------- /sprinklers/registry.py: -------------------------------------------------------------------------------- 1 | class SprinklerRegistry(object): 2 | 3 | def __init__(self): 4 | self._registry = {} 5 | 6 | def register(self, sprinkler): 7 | self._registry[sprinkler.__name__] = sprinkler 8 | 9 | def __getitem__(self, key): 10 | return self._registry[key] 11 | 12 | sprinkler_registry = SprinklerRegistry() -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/groveco/django-sprinklers/8eb7af66c5ef52aa5a7fc3bf3d3f8e365e0ae886/tests/__init__.py -------------------------------------------------------------------------------- /tests/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | 4 | from django.db import models, migrations 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | dependencies = [ 10 | ] 11 | 12 | operations = [ 13 | migrations.CreateModel( 14 | name='DummyModel', 15 | fields=[ 16 | ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), 17 | ('name', models.CharField(default=b'Not sprinkled :(', max_length=128)), 18 | ], 19 | options={ 20 | }, 21 | bases=(models.Model,), 22 | ), 23 | ] 24 | -------------------------------------------------------------------------------- /tests/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/groveco/django-sprinklers/8eb7af66c5ef52aa5a7fc3bf3d3f8e365e0ae886/tests/migrations/__init__.py -------------------------------------------------------------------------------- /tests/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class DummyModel(models.Model): 5 | name = models.CharField(max_length=128, default="Not sprinkled :(") 6 | 7 | def __str__(self): 8 | return self.name -------------------------------------------------------------------------------- /tests/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for sprinkles project. 3 | 4 | For more information on this file, see 5 | https://docs.djangoproject.com/en/1.7/topics/settings/ 6 | 7 | For the full list of settings and their values, see 8 | https://docs.djangoproject.com/en/1.7/ref/settings/ 9 | """ 10 | 11 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 12 | import os, sys 13 | import djcelery 14 | djcelery.setup_loader() 15 | BASE_DIR = os.path.dirname(os.path.dirname(__file__)) 16 | 17 | 18 | # Quick-start development settings - unsuitable for production 19 | # See https://docs.djangoproject.com/en/1.7/howto/deployment/checklist/ 20 | 21 | # SECURITY WARNING: keep the secret key used in production secret! 22 | SECRET_KEY = 'u@-wftu5&@j$9yk5@fquus5%#1i!gp0smnoocx50+h!h67!jq1' 23 | 24 | # SECURITY WARNING: don't run with debug turned on in production! 25 | DEBUG = True 26 | 27 | TEMPLATE_DEBUG = True 28 | 29 | ALLOWED_HOSTS = [] 30 | 31 | 32 | # Application definition 33 | 34 | INSTALLED_APPS = ( 35 | 'django.contrib.admin', 36 | 'django.contrib.auth', 37 | 'django.contrib.contenttypes', 38 | 'django.contrib.sessions', 39 | 'django.contrib.messages', 40 | 'django.contrib.staticfiles', 41 | 'tests', 42 | 'djcelery', 43 | ) 44 | 45 | MIDDLEWARE_CLASSES = ( 46 | 'django.contrib.sessions.middleware.SessionMiddleware', 47 | 'django.middleware.common.CommonMiddleware', 48 | 'django.middleware.csrf.CsrfViewMiddleware', 49 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 50 | 'django.contrib.auth.middleware.SessionAuthenticationMiddleware', 51 | 'django.contrib.messages.middleware.MessageMiddleware', 52 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 53 | ) 54 | 55 | 56 | # Database 57 | # https://docs.djangoproject.com/en/1.7/ref/settings/#databases 58 | # 59 | # DATABASES = { 60 | # 'default': { 61 | # 'ENGINE': 'django.db.backends.sqlite3', 62 | # 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), 63 | # 'TEST_NAME': os.path.join(BASE_DIR, 'db.sqlite3'), 64 | # } 65 | # } 66 | 67 | 68 | DATABASES = { 69 | 'default': { 70 | 'ENGINE': 'django.db.backends.postgresql_psycopg2', 71 | 'NAME': 'sprinklers', 72 | 'TEST': { 73 | 'NAME': 'sprinklers', 74 | }, 75 | 'USER': 'postgres', 76 | 'HOST': 'localhost', 77 | 'PORT': '5432', 78 | }, 79 | } 80 | 81 | DISABLE_TRANSACTION_MANAGEMENT = True 82 | 83 | # Internationalization 84 | # https://docs.djangoproject.com/en/1.7/topics/i18n/ 85 | 86 | LANGUAGE_CODE = 'en-us' 87 | 88 | TIME_ZONE = 'UTC' 89 | 90 | USE_I18N = True 91 | 92 | USE_L10N = True 93 | 94 | USE_TZ = True 95 | 96 | 97 | # Static files (CSS, JavaScript, Images) 98 | # https://docs.djangoproject.com/en/1.7/howto/static-files/ 99 | 100 | STATIC_URL = '/static/' 101 | 102 | TEMPLATE_DIRS = ( 103 | os.path.join(BASE_DIR, 'templates'), 104 | ) 105 | 106 | # Change to True to run in a single thread for easier debugging 107 | CELERY_ALWAYS_EAGER = False 108 | CELERYD_HIJACK_ROOT_LOGGER = False 109 | BROKER_URL = 'redis://localhost:6379/0' 110 | 111 | from celery.signals import setup_logging 112 | @setup_logging.connect 113 | def configure_logging(sender=None, **kwargs): 114 | import logging 115 | import logging.config 116 | logging.config.dictConfig(LOGGING) 117 | 118 | LOGGING = { 119 | 'version': 1, 120 | 'disable_existing_loggers': True, 121 | 'formatters': { 122 | 'verbose': { 123 | 'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s' 124 | }, 125 | }, 126 | 'handlers': { 127 | 'console': { 128 | 'level': 'INFO', 129 | 'class': 'logging.StreamHandler', 130 | 'formatter': 'verbose' 131 | }, 132 | }, 133 | 'loggers': { 134 | '': { 135 | 'handlers': ['console'], 136 | 'level': 'INFO', 137 | 'propagate': True 138 | }, 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /tests/tasks.py: -------------------------------------------------------------------------------- 1 | from sprinklers.base import SprinklerBase, ShardedSprinkler, registry, SubtaskValidationException 2 | from tests.models import DummyModel 3 | from celery import task 4 | from traceback import format_exc 5 | 6 | 7 | @task 8 | def run_sample_sprinkler(**kwargs): 9 | SampleSprinkler(**kwargs).start() 10 | 11 | @task 12 | def run_sharded_sprinkler(**kwargs): 13 | ShardedSampleSprinkler(**kwargs).start() 14 | 15 | class SampleSprinkler(SprinklerBase): 16 | 17 | def get_queryset(self): 18 | """ 19 | This IF normally wouldn't be necessary because you either would or wouldn't be passing kwargs 20 | into the sprinkler. But I have tests that sometimes do and sometimes don't use kwargs, hence the 21 | check. 22 | """ 23 | if self.kwargs.get('name', None): 24 | return DummyModel.objects.filter(name=self.kwargs['name']).all() 25 | if self.kwargs.get('values', None): 26 | return DummyModel.objects.all().values('id') 27 | return DummyModel.objects.all() 28 | 29 | def subtask(self, obj): 30 | if self.kwargs.get('raise_error') and obj.name == 'fail': 31 | raise AttributeError("Oh noes!") 32 | obj.name = "Sprinkled!" 33 | obj.save() 34 | if self.kwargs.get('special_return'): 35 | return True 36 | 37 | def validate(self, obj): 38 | if self.kwargs.get('fail'): 39 | raise SubtaskValidationException 40 | 41 | def on_validation_exception(self, obj, e): 42 | return "v_fail" 43 | 44 | def finished(self, results): 45 | # Persist results to an external source (the database) so I can unit test this. 46 | # Note that it writes the entire result obj as the name 47 | if self.kwargs.get('persist_results'): 48 | DummyModel(name="%s" % results).save() 49 | 50 | def on_error(self, obj, e): 51 | print("Here's the error: " + format_exc()) 52 | return False 53 | 54 | registry.register(SampleSprinkler) 55 | 56 | class ShardedSampleSprinkler(ShardedSprinkler): 57 | shard_size = 2 58 | 59 | def get_queryset(self): 60 | if self.kwargs.get('name', None): 61 | return DummyModel.objects.filter(name=self.kwargs['name']) 62 | 63 | return DummyModel.objects.all() 64 | 65 | def subtask(self, obj): 66 | if self.kwargs.get('raise_error') and obj.name == 'fail': 67 | raise AttributeError("Oh noes!") 68 | 69 | obj.name = "Sharded!" 70 | obj.save() 71 | 72 | if self.kwargs.get('special_return'): 73 | return True 74 | 75 | def validate(self, obj): 76 | if self.kwargs.get('fail'): 77 | raise SubtaskValidationException 78 | 79 | def on_validation_exception(self, obj, e): 80 | return "v_fail" 81 | 82 | def finished(self, results): 83 | # Persist results to an external source (the database) so I can unit test this. 84 | # Note that it writes the entire result obj as the name 85 | 86 | if self.kwargs.get('persist_results'): 87 | DummyModel(name="%s" % results).save() 88 | 89 | def on_error(self, obj, e): 90 | print("Here's the error: " + format_exc()) 91 | return False 92 | 93 | registry.register(ShardedSampleSprinkler) 94 | -------------------------------------------------------------------------------- /tests/test_sprinklers.py: -------------------------------------------------------------------------------- 1 | from django.test import TransactionTestCase 2 | from tests.models import DummyModel 3 | from tests.tasks import run_sample_sprinkler, run_sharded_sprinkler, SampleSprinkler 4 | from django.conf import settings 5 | import time 6 | 7 | 8 | class SprinklerTest(TransactionTestCase): 9 | 10 | @classmethod 11 | def tearDown(self): 12 | if not settings.CELERY_ALWAYS_EAGER: 13 | time.sleep(2) 14 | 15 | def _run(self, **kwargs): 16 | r = run_sample_sprinkler.delay(**kwargs) 17 | if not settings.CELERY_ALWAYS_EAGER: 18 | time.sleep(2) 19 | 20 | def _run_sharded(self, **kwargs): 21 | r = run_sharded_sprinkler.delay(**kwargs) 22 | if not settings.CELERY_ALWAYS_EAGER: 23 | time.sleep(2) 24 | return r.get() 25 | 26 | def test_objects_get_sprinkled(self): 27 | DummyModel(name="foo").save() 28 | DummyModel(name="foo").save() 29 | self._run() 30 | for d in DummyModel.objects.all(): 31 | self.assertEqual(d.name, "Sprinkled!") 32 | 33 | def test_works_with_values_queryset(self): 34 | DummyModel(name="foo").save() 35 | DummyModel(name="foo").save() 36 | self._run(values=True) 37 | for d in DummyModel.objects.all(): 38 | self.assertEqual(d.name, "Sprinkled!") 39 | 40 | def test_queryset_refreshes_on_each_sprinkling(self): 41 | 42 | DummyModel(name="foo").save() 43 | self._run() 44 | 45 | # Make sure we don't incorrectly pass this test through sheer luck by generating the number 46 | # of models that happens to match the results cache of the SampleSprinkler queryset. 47 | # This was a bigger issue in an earlier version of sprinklers, but it still makes me feel good 48 | # knowing that this tests pass and sprinklers will always refresh their querset when they run. 49 | cur_len = len(SampleSprinkler().get_queryset()) 50 | for i in range(cur_len + 5): 51 | DummyModel(name="foo").save() 52 | 53 | self._run() 54 | 55 | for d in DummyModel.objects.all(): 56 | self.assertEqual(d.name, "Sprinkled!") 57 | 58 | def test_parameters_in_qs(self): 59 | 60 | DummyModel(name="qux").save() 61 | DummyModel(name="mux").save() 62 | 63 | self._run(name="qux") 64 | self.assertFalse(DummyModel.objects.filter(name="qux").exists()) 65 | self.assertTrue(DummyModel.objects.filter(name="mux").exists()) 66 | 67 | def test_sprinkler_finished(self): 68 | DummyModel(name="qux").save() 69 | DummyModel(name="mux").save() 70 | self._run(persist_results=True, special_return=True) 71 | self.assertEqual(DummyModel.objects.filter(name=str([True, True])).count(), 1) 72 | 73 | def test_validation_exception(self): 74 | DummyModel(name="foo").save() 75 | self._run(fail=True, persist_results=True) 76 | self.assertTrue(DummyModel.objects.filter(name="foo").exists()) 77 | self.assertEqual(DummyModel.objects.filter(name=str(['v_fail'])).count(), 1) 78 | 79 | def test_default_return_value_for_subtask(self): 80 | d1 = DummyModel(name="qux") 81 | d1.save() 82 | d2 = DummyModel(name="mux") 83 | d2.save() 84 | self._run(persist_results=True) 85 | self.assertEqual(DummyModel.objects.filter(name=str([d1.id, d2.id])).count(), 1) 86 | 87 | def test_error_on_subtask_calls_on_error(self): 88 | DummyModel(name="fail").save() 89 | DummyModel(name="succeed").save() 90 | self._run(raise_error=True, persist_results=True, special_return=True) 91 | self.assertEqual(DummyModel.objects.filter(name=str([False, True])).count(), 1) 92 | 93 | def test_sharded_sprinkler(self): 94 | for i in range(10): 95 | DummyModel(name="sharded").save() 96 | self._run_sharded(name="sharded") 97 | self.assertEqual(DummyModel.objects.filter(name="sharded").count(), 0) 98 | --------------------------------------------------------------------------------