├── .bumpversion.cfg ├── .github ├── FUNDING.yml └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── celery_heimdall ├── __init__.py ├── config.py ├── contrib │ ├── README.md │ ├── __init__.py │ └── inspector │ │ ├── README.md │ │ ├── __init__.py │ │ ├── cli.py │ │ ├── models.py │ │ └── monitor.py ├── errors.py └── task.py ├── pyproject.toml └── tests ├── conftest.py ├── test_only_after.py ├── test_rate_limited.py └── test_unique.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.0.1 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:pyproject.toml] 7 | parse = "version = (?P\d+)\.(?P.*)" 8 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: TkTech 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: TkTech 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with a single custom sponsorship URL 13 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Create release 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | jobs: 9 | sdist: 10 | name: Creating source release 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v1 17 | with: 18 | python-version: "3.10" 19 | 20 | - name: Install Poetry 21 | uses: snok/install-poetry@v1 22 | 23 | - name: Publishing 24 | env: 25 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 26 | run: | 27 | poetry config pypi-token.pypi $PYPI_TOKEN 28 | poetry publish --build -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run tests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | fail-fast: false 10 | matrix: 11 | python-version: [3.8, 3.9, "3.10", "3.11"] 12 | 13 | services: 14 | redis: 15 | image: redis 16 | options: >- 17 | --health-cmd "redis-cli ping" 18 | --health-interval 10s 19 | --health-timeout 5s 20 | --health-retries 5 21 | ports: 22 | - 6379:6379 23 | 24 | steps: 25 | - uses: actions/checkout@v2 26 | 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v2 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Install Poetry 33 | uses: snok/install-poetry@v1 34 | 35 | - name: Installing 36 | run: poetry install --no-interaction 37 | 38 | - name: Running tests 39 | run: | 40 | poetry run pytest --cov=celery_heimdall --cov-report=xml 41 | 42 | - name: Upload coverage 43 | uses: codecov/codecov-action@v4 44 | with: 45 | file: ./coverage.xml 46 | fail_ci_if_error: true 47 | flags: unittests 48 | token: ${{ secrets.CODECOV_TOKEN }} 49 | slug: TkTech/celery-heimdall -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .pytest_cache 4 | poetry.lock 5 | *.pyc 6 | venv 7 | heimdall.db -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Tyler Kennedy 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # celery-heimdall 2 | 3 | [![codecov](https://codecov.io/gh/TkTech/celery-heimdall/branch/main/graph/badge.svg?token=1A2CVHQ25Q)](https://codecov.io/gh/TkTech/celery-heimdall) 4 | ![GitHub](https://img.shields.io/github/license/tktech/celery-heimdall) 5 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/celery-heimdall) 6 | 7 | Celery Heimdall is a set of common utilities useful for the Celery background 8 | worker framework, built on top of Redis. It's not trying to handle every use 9 | case, but to be an easy, modern, and maintainable drop-in solution for 90% of 10 | projects. 11 | 12 | ## Features 13 | 14 | - Globally unique tasks, allowing only 1 copy of a task to execute at a time, or 15 | within a time period (ex: "Don't allow queuing until an hour has passed") 16 | - Global rate limiting. Celery has built-in rate limiting, but it's a rate limit 17 | _per worker_, making it unsuitable for purposes such as limiting requests to 18 | an API. 19 | 20 | ## Installation 21 | 22 | `pip install celery-heimdall` 23 | 24 | ## Usage 25 | 26 | ### Unique Tasks 27 | 28 | Imagine you have a task that starts when a user presses a button. This task 29 | takes a long time and a lot of resources to generate a report. You don't want 30 | the user to press the button 10 times and start 10 tasks. In this case, you 31 | want what Heimdall calls a unique task: 32 | 33 | ```python 34 | from celery import shared_task 35 | from celery_heimdall import HeimdallTask 36 | 37 | @shared_task(base=HeimdallTask, heimdall={'unique': True}) 38 | def generate_report(customer_id): 39 | pass 40 | ``` 41 | 42 | All we've done here is change the base Task class that Celery will use to run 43 | the task, and passed in some options for Heimdall to use. This task is now 44 | unique - for the given arguments, only 1 will ever run at the same time. 45 | 46 | #### Expiry 47 | 48 | What happens if our task dies, or something goes wrong? We might end up in a 49 | situation where our lock never gets cleared, called [deadlock][]. To work around 50 | this, we add a maximum time before the task is allowed to be queued again: 51 | 52 | 53 | ```python 54 | from celery import shared_task 55 | from celery_heimdall import HeimdallTask 56 | 57 | @shared_task( 58 | base=HeimdallTask, 59 | heimdall={ 60 | 'unique': True, 61 | 'unique_timeout': 60 * 60 62 | } 63 | ) 64 | def generate_report(customer_id): 65 | pass 66 | ``` 67 | 68 | Now, `generate_report` will be allowed to run again in an hour even if the 69 | task got stuck, the worker ran out of memory, the machine burst into flames, 70 | etc... 71 | 72 | #### Custom Keys 73 | 74 | By default, a hash of the task name and its arguments is used as the lock key. 75 | But this often might not be what you want. What if you only want one report at 76 | a time, even for different customers? Ex: 77 | 78 | ```python 79 | from celery import shared_task 80 | from celery_heimdall import HeimdallTask 81 | 82 | @shared_task( 83 | base=HeimdallTask, 84 | heimdall={ 85 | 'unique': True, 86 | 'key': lambda args, kwargs: 'generate_report' 87 | } 88 | ) 89 | def generate_report(customer_id): 90 | pass 91 | ``` 92 | By specifying our own key function, we can completely customize how we determine 93 | if a task is unique. 94 | 95 | #### The Existing Task 96 | 97 | By default, if you try to queue up a unique task that is already running, 98 | Heimdall will return the existing task's `AsyncResult`. This lets you write 99 | simple code that doesn't need to care if a task is unique or not. Imagine a 100 | simple API endpoint that starts a report when it's hit, but we only want it 101 | to run one at a time. The below is all you need: 102 | 103 | ```python 104 | import time 105 | from celery import shared_task 106 | from celery_heimdall import HeimdallTask 107 | 108 | @shared_task(base=HeimdallTask, heimdall={'unique': True}) 109 | def generate_report(customer_id): 110 | time.sleep(10) 111 | 112 | def my_api_call(customer_id: int): 113 | return { 114 | 'status': 'RUNNING', 115 | 'task_id': generate_report.delay(customer_id).id 116 | } 117 | ``` 118 | 119 | Everytime `my_api_call` is called with the same `customer_id`, the same 120 | `task_id` will be returned by `generate_report.delay()` until the original task 121 | has completed. 122 | 123 | Sometimes you'll want to catch that the task was already running when you tried 124 | to queue it again. We can tell Heimdall to raise an exception in this case: 125 | 126 | ```python 127 | import time 128 | from celery import shared_task 129 | from celery_heimdall import HeimdallTask, AlreadyQueuedError 130 | 131 | 132 | @shared_task( 133 | base=HeimdallTask, 134 | heimdall={ 135 | 'unique': True, 136 | 'unique_raises': True 137 | } 138 | ) 139 | def generate_report(customer_id): 140 | time.sleep(10) 141 | 142 | 143 | def my_api_call(customer_id: int): 144 | try: 145 | task = generate_report.delay(customer_id) 146 | return {'status': 'STARTED', 'task_id': task.id} 147 | except AlreadyQueuedError as exc: 148 | return {'status': 'ALREADY_RUNNING', 'task_id': exc.likely_culprit} 149 | ``` 150 | 151 | By setting `unique_raises` to `True` when we define our task, an 152 | `AlreadyQueuedError` will be raised when you try to queue up a unique task 153 | twice. The `AlreadyQueuedError` has two properties: 154 | 155 | - `likely_culprit`, which contains the task ID of the already-running task, 156 | - `expires_in`, which is the time remaining (in seconds) before the 157 | already-running task is considered to be expired. 158 | 159 | #### Unique Interval Task 160 | 161 | What if we want the task to only run once in an hour, even if it's finished? 162 | In those cases, we want it to run, but not clear the lock when it's finished: 163 | 164 | ```python 165 | from celery import shared_task 166 | from celery_heimdall import HeimdallTask 167 | 168 | @shared_task( 169 | base=HeimdallTask, 170 | heimdall={ 171 | 'unique': True, 172 | 'unique_timeout': 60 * 60, 173 | 'unique_wait_for_expiry': True 174 | } 175 | ) 176 | def generate_report(customer_id): 177 | pass 178 | ``` 179 | 180 | By setting `unique_wait_for_expiry` to `True`, the task will finish, and won't 181 | allow another `generate_report()` to be queued until `unique_timeout` has 182 | passed. 183 | 184 | ### Rate Limiting 185 | 186 | Celery offers rate limiting out of the box. However, this rate limiting applies 187 | on a per-worker basis. There's no reliable way to rate limit a task across all 188 | your workers. Heimdall makes this easy: 189 | 190 | ```python 191 | from celery import shared_task 192 | from celery_heimdall import HeimdallTask, RateLimit 193 | 194 | @shared_task( 195 | base=HeimdallTask, 196 | heimdall={ 197 | 'rate_limit': RateLimit((2, 60)) 198 | } 199 | ) 200 | def download_report_from_amazon(customer_id): 201 | pass 202 | ``` 203 | 204 | This says "every 60 seconds, only allow this task to run 2 times". If a task 205 | can't be run because it would violate the rate limit, it'll be rescheduled. 206 | 207 | It's important to note this does not guarantee that your task will run _exactly_ 208 | twice a second, just that it won't run _more_ than twice a second. Tasks are 209 | rescheduled with a random jitter to prevent the [thundering herd][] problem. 210 | 211 | 212 | #### Dynamic Rate Limiting 213 | 214 | Just like you can dynamically provide a key for a task, you can also 215 | dynamically provide a rate limit based off that key. 216 | 217 | 218 | ```python 219 | from celery import shared_task 220 | from celery_heimdall import HeimdallTask, RateLimit 221 | 222 | 223 | @shared_task( 224 | base=HeimdallTask, 225 | heimdall={ 226 | # Provide a lower rate limit for the customer with the ID 10, for everyone 227 | # else provide a higher rate limit. 228 | 'rate_limit': RateLimit(lambda args: (1, 30) if args[0] == 10 else (2, 30)), 229 | 'key': lambda args, kwargs: f'customer_{args[0]}' 230 | } 231 | ) 232 | def download_report_from_amazon(customer_id): 233 | pass 234 | ``` 235 | 236 | 237 | ## Inspirations 238 | 239 | These are more mature projects which inspired this library, and which may 240 | support older versions of Celery & Python then this project. 241 | 242 | - [celery_once][], which is unfortunately abandoned and the reason this project 243 | exists. 244 | - [celery_singleton][] 245 | - [This snippet][snip] by Vigrond, and subsequent improvements by various 246 | contributors. 247 | 248 | 249 | [celery_once]: https://github.com/cameronmaske/celery-once 250 | [celery_singleton]: https://github.com/steinitzu/celery-singleton 251 | [deadlock]: https://en.wikipedia.org/wiki/Deadlock 252 | [thundering herd]: https://en.wikipedia.org/wiki/Thundering_herd_problem 253 | [snip]: https://gist.github.com/Vigrond/2bbea9be6413415e5479998e79a1b11a -------------------------------------------------------------------------------- /celery_heimdall/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ("HeimdallTask", "AlreadyQueuedError", "RateLimit", "Strategy") 2 | 3 | from celery_heimdall.task import HeimdallTask, RateLimit, Strategy 4 | from celery_heimdall.errors import AlreadyQueuedError 5 | -------------------------------------------------------------------------------- /celery_heimdall/config.py: -------------------------------------------------------------------------------- 1 | from celery import Celery 2 | from celery.app import app_or_default 3 | 4 | 5 | class Config: 6 | def __init__(self, app: Celery, *, task=None): 7 | self.app = app_or_default(app) 8 | self.task = task 9 | 10 | def _from_task_or_app(self, key, default): 11 | if self.task: 12 | v = getattr(self.task, "heimdall", {}).get(key) 13 | if v is not None: 14 | return v 15 | 16 | return self.app.conf.get(f"heimdall_{key}", default) 17 | 18 | @property 19 | def unique_lock_timeout(self): 20 | return self._from_task_or_app("unique_lock_timeout", 1) 21 | 22 | @property 23 | def unique_lock_blocking(self): 24 | return self._from_task_or_app("unique_lock_blocking", True) 25 | 26 | @property 27 | def unique_timeout(self): 28 | return self._from_task_or_app("unique_timeout", 60 * 60) 29 | 30 | @property 31 | def lock_prefix(self): 32 | return self._from_task_or_app("lock_prefix", "h-lock:") 33 | 34 | @property 35 | def rate_limit_prefix(self): 36 | return self._from_task_or_app("rate_limit_prefix", "h-rate:") 37 | 38 | @property 39 | def unique_raises(self): 40 | return self._from_task_or_app("unique_raises", False) 41 | -------------------------------------------------------------------------------- /celery_heimdall/contrib/README.md: -------------------------------------------------------------------------------- 1 | # Contrib 2 | 3 | This directory contains optional integrations and tools. -------------------------------------------------------------------------------- /celery_heimdall/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TkTech/celery-heimdall/d7e41a959bb9f82bdad2d329fd09b084a6df2996/celery_heimdall/contrib/__init__.py -------------------------------------------------------------------------------- /celery_heimdall/contrib/inspector/README.md: -------------------------------------------------------------------------------- 1 | # Inspector 2 | 3 | **Note:** This tool is in beta, and currently only tested against SQLite as a 4 | data store. 5 | 6 | The Inspector is a minimal debugging tool for working with Celery queues and 7 | tasks. It is an optional component of celery-heimdall and not installed by 8 | default. 9 | 10 | It runs a monitor, which populates any SQLAlchemy-compatible database with the 11 | state of your Celery cluster. 12 | 13 | ## Why? 14 | 15 | This tool is used to assist in debugging, generate graphs of queues for 16 | documentation, to verify the final state of Celery after tests, etc... 17 | 18 | Flower deprecated their graphs page, and now require you to use prometheus and 19 | grafana, which is overkill when you just want to see what's been running. 20 | 21 | ## Installation 22 | 23 | ``` 24 | pip install celery-heimdall[inspector] 25 | ``` -------------------------------------------------------------------------------- /celery_heimdall/contrib/inspector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TkTech/celery-heimdall/d7e41a959bb9f82bdad2d329fd09b084a6df2996/celery_heimdall/contrib/inspector/__init__.py -------------------------------------------------------------------------------- /celery_heimdall/contrib/inspector/cli.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import click 4 | from celery import Celery 5 | 6 | from celery_heimdall.contrib.inspector.monitor import monitor 7 | 8 | 9 | @click.group() 10 | def cli(): 11 | """ 12 | heimdall-inspector provides tools for introspecting a live Celery cluster. 13 | """ 14 | 15 | 16 | @cli.command("monitor") 17 | @click.argument("broker_url") 18 | @click.option( 19 | "--enable-events", 20 | default=False, 21 | is_flag=True, 22 | help=( 23 | "Sends a command-and-control message to all Celery workers to start" 24 | " emitting worker events before starting the server." 25 | ), 26 | ) 27 | @click.option( 28 | "--db", 29 | default="heimdall.db", 30 | type=click.Path(dir_okay=False, writable=True, path_type=Path), 31 | help=("Use the provided path to store our sqlite database."), 32 | ) 33 | def monitor_command(broker_url: str, enable_events: bool, db: Path): 34 | """ 35 | Starts a monitor to watch for Celery events and records them to an SQLite 36 | database. 37 | 38 | Optionally enables event monitoring on a live cluster if --enable-events is 39 | provided. Note that it will not stop events when finished. 40 | """ 41 | if enable_events: 42 | celery_app = Celery(broker=broker_url) 43 | celery_app.control.enable_events() 44 | 45 | monitor(broker=broker_url, db=db) 46 | 47 | 48 | if __name__ == "__main__": 49 | cli() 50 | -------------------------------------------------------------------------------- /celery_heimdall/contrib/inspector/models.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | from sqlalchemy import ( 4 | Column, 5 | TIMESTAMP, 6 | Integer, 7 | String, 8 | func, 9 | BigInteger, 10 | text, 11 | Enum, 12 | ) 13 | from sqlalchemy.orm import declarative_base, sessionmaker 14 | 15 | Base = declarative_base() 16 | Session = sessionmaker() 17 | 18 | 19 | class WorkerStatus(enum.Enum): 20 | #: Worker is answering heartbeats. 21 | ALIVE = 0 22 | #: We didn't get an offline event, but we're not getting heartbeats. 23 | LOST = 10 24 | #: We got a shutdown event for the worker. 25 | OFFLINE = 20 26 | 27 | 28 | class TaskStatus(enum.Enum): 29 | RECEIVED = 0 30 | STARTED = 10 31 | SUCCEEDED = 20 32 | FAILED = 30 33 | REJECTED = 40 34 | REVOKED = 50 35 | RETRIED = 60 36 | 37 | 38 | class TaskInstance(Base): 39 | __tablename__ = "task_instance" 40 | 41 | uuid = Column(String, primary_key=True) 42 | name = Column(String) 43 | status = Column(Enum(TaskStatus)) 44 | hostname = Column(String) 45 | args = Column(String) 46 | kwargs = Column(String) 47 | 48 | runtime = Column(Integer) 49 | 50 | received = Column(TIMESTAMP, server_default=func.now()) 51 | started = Column(TIMESTAMP, nullable=True) 52 | failed = Column(TIMESTAMP, nullable=True) 53 | rejected = Column(TIMESTAMP, nullable=True) 54 | succeeded = Column(TIMESTAMP, nullable=True) 55 | 56 | retries = Column(Integer, server_default=text("0")) 57 | last_seen = Column(TIMESTAMP, server_onupdate=func.now()) 58 | 59 | 60 | class Worker(Base): 61 | __tablename__ = "worker" 62 | 63 | #: The hostname of a worker is used as its ID. 64 | id = Column(String, primary_key=True) 65 | 66 | #: How often the worker is configured to send heartbeats. 67 | frequency = Column(Integer, server_default=text("0")) 68 | #: Name of the worker software 69 | sw_identity = Column(String, nullable=True) 70 | #: Version of the worker software. 71 | sw_version = Column(String, nullable=True) 72 | #: Host operating system of the worker. 73 | sw_system = Column(String, nullable=True) 74 | 75 | #: Number of currently executing tasks. 76 | active = Column(BigInteger, server_default=text("0")) 77 | #: Number of processed tasks. 78 | processed = Column(BigInteger, server_default=text("0")) 79 | 80 | #: Last known status of the worker. 81 | status = Column(Enum(WorkerStatus)) 82 | 83 | first_seen = Column(TIMESTAMP, server_default=func.now()) 84 | last_seen = Column(TIMESTAMP, server_onupdate=func.now()) 85 | -------------------------------------------------------------------------------- /celery_heimdall/contrib/inspector/monitor.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from pathlib import Path 3 | 4 | from celery import Celery 5 | from sqlalchemy import create_engine, insert, func, update 6 | from sqlalchemy.dialects.sqlite import insert 7 | 8 | from celery_heimdall.contrib.inspector import models 9 | 10 | 11 | def task_received(event): 12 | with models.Session() as session: 13 | session.execute( 14 | insert(models.TaskInstance.__table__).values( 15 | uuid=event["uuid"], 16 | name=event["name"], 17 | status=models.TaskStatus.RECEIVED, 18 | hostname=event["hostname"], 19 | args=event["args"], 20 | kwargs=event["kwargs"], 21 | received=datetime.datetime.fromtimestamp(event["timestamp"]), 22 | ) 23 | ) 24 | session.commit() 25 | 26 | 27 | def task_started(event): 28 | with models.Session() as session: 29 | session.execute( 30 | update(models.TaskInstance.__table__) 31 | .where(models.TaskInstance.uuid == event["uuid"]) 32 | .values( 33 | runtime=event.get("runtime", 0), 34 | status=models.TaskStatus.STARTED, 35 | started=datetime.datetime.fromtimestamp(event["timestamp"]), 36 | last_seen=func.now(), 37 | ) 38 | ) 39 | session.commit() 40 | 41 | 42 | def task_succeeded(event): 43 | with models.Session() as session: 44 | session.execute( 45 | update(models.TaskInstance.__table__) 46 | .where(models.TaskInstance.uuid == event["uuid"]) 47 | .values( 48 | runtime=event.get("runtime", 0), 49 | status=models.TaskStatus.SUCCEEDED, 50 | succeeded=datetime.datetime.fromtimestamp(event["timestamp"]), 51 | last_seen=func.now(), 52 | ) 53 | ) 54 | session.commit() 55 | 56 | 57 | def task_retried(event): 58 | with models.Session() as session: 59 | session.execute( 60 | update(models.TaskInstance.__table__) 61 | .where(models.TaskInstance.uuid == event["uuid"]) 62 | .values( 63 | runtime=event.get("runtime", 0), 64 | status=models.TaskStatus.RETRIED, 65 | retries=models.TaskInstance.retries + 1, 66 | last_seen=func.now(), 67 | ) 68 | ) 69 | session.commit() 70 | 71 | 72 | def task_failed(event): 73 | with models.Session() as session: 74 | session.execute( 75 | update(models.TaskInstance.__table__) 76 | .where(models.TaskInstance.uuid == event["uuid"]) 77 | .values( 78 | runtime=event.get("runtime", 0), 79 | status=models.TaskStatus.FAILED, 80 | failed=datetime.datetime.fromtimestamp(event["timestamp"]), 81 | last_seen=func.now(), 82 | ) 83 | ) 84 | session.commit() 85 | 86 | 87 | def task_rejected(event): 88 | with models.Session() as session: 89 | session.execute( 90 | update(models.TaskInstance.__table__) 91 | .where(models.TaskInstance.uuid == event["uuid"]) 92 | .values( 93 | runtime=event.get("runtime", 0), 94 | status=models.TaskStatus.REJECTED, 95 | rejected=datetime.datetime.fromtimestamp(event["timestamp"]), 96 | last_seen=func.now(), 97 | ) 98 | ) 99 | session.commit() 100 | 101 | 102 | def worker_event(event): 103 | field_mapping = { 104 | "freq": models.Worker.frequency, 105 | "sw_ident": models.Worker.sw_identity, 106 | "sw_ver": models.Worker.sw_version, 107 | "sw_sys": models.Worker.sw_system, 108 | "active": models.Worker.active, 109 | "processed": models.Worker.processed, 110 | } 111 | 112 | payload = { 113 | "last_seen": func.now(), 114 | "status": { 115 | "worker-heartbeat": models.WorkerStatus.ALIVE, 116 | "worker-online": models.WorkerStatus.ALIVE, 117 | "worker-offline": models.WorkerStatus.OFFLINE, 118 | }.get(event["type"], models.WorkerStatus.LOST), 119 | } 120 | for k, v in field_mapping.items(): 121 | if k in event: 122 | payload[v] = event[k] 123 | 124 | # FIXME: Support postgres / MySQL 125 | with models.Session() as session: 126 | session.execute( 127 | insert(models.Worker.__table__) 128 | .values({"id": event["hostname"], **payload}) 129 | .on_conflict_do_update(index_elements=["id"], set_=payload) 130 | ) 131 | session.commit() 132 | 133 | 134 | def monitor(*, broker: str, db: Path): 135 | """ 136 | A real-time Celery event monitor which captures events and populates a 137 | supported SQLAlchemy database. 138 | """ 139 | app = Celery(broker=broker) 140 | 141 | engine = create_engine(f"sqlite:///{db}") 142 | models.Session.configure(bind=engine) 143 | models.Base.metadata.create_all(engine) 144 | 145 | with app.connection() as connection: 146 | recv = app.events.Receiver( 147 | connection, 148 | handlers={ 149 | # '*': state.event, 150 | "task-started": task_started, 151 | "task-rejected": task_rejected, 152 | "task-failed": task_failed, 153 | "task-received": task_received, 154 | "task-succeeded": task_succeeded, 155 | "task-retried": task_retried, 156 | "worker-online": worker_event, 157 | "worker-heartbeat": worker_event, 158 | "worker-offline": worker_event, 159 | }, 160 | ) 161 | recv.capture(limit=None, timeout=None, wakeup=True) 162 | -------------------------------------------------------------------------------- /celery_heimdall/errors.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class AlreadyQueuedError(Exception): 5 | """ 6 | Raised when a task has already been enqueued and may not be added to the 7 | queue. 8 | 9 | The exception may have the property `likely_culprit` set. If it is, this 10 | is the Celery Task ID of the task _most likely_ holding onto the lock. 11 | 12 | `likely_culprit` is here to assist in debugging deadlocks. Retrieving this 13 | value is not atomic, and thus should not be relied upon. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | *, 19 | expires_in: Optional[int] = None, 20 | likely_culprit: Optional[str] = None, 21 | ): 22 | super().__init__() 23 | self.likely_culprit = likely_culprit 24 | self.expires_in = expires_in 25 | 26 | def __repr__(self): 27 | return ( 28 | "" 31 | ) 32 | -------------------------------------------------------------------------------- /celery_heimdall/task.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import enum 3 | import hashlib 4 | import datetime 5 | import inspect 6 | from abc import ABC 7 | from typing import Union, Tuple, Callable 8 | 9 | import redis 10 | import redis.lock 11 | import celery 12 | from kombu import serialization 13 | from kombu.utils import uuid 14 | 15 | from celery_heimdall.config import Config 16 | from celery_heimdall.errors import AlreadyQueuedError 17 | 18 | 19 | class Strategy(enum.Enum): 20 | DEFAULT = 10 21 | 22 | 23 | @dataclasses.dataclass 24 | class RateLimit: 25 | rate_limit: Union[Tuple, Callable] 26 | strategy: Strategy = Strategy.DEFAULT 27 | 28 | 29 | def acquire_lock(task: "HeimdallTask", key: str, timeout: int, *, task_id: str): 30 | acquired = redis.lock.Lock( 31 | task.heimdall_redis, 32 | key, 33 | timeout=timeout, 34 | blocking=task.heimdall_config.unique_lock_blocking, 35 | blocking_timeout=task.heimdall_config.unique_lock_timeout, 36 | ).acquire(token=task_id) 37 | 38 | if not acquired: 39 | pipe = task.heimdall_redis.pipeline() 40 | pipe.get(key) 41 | pipe.ttl(key) 42 | task_id, ttl = pipe.execute() 43 | 44 | raise AlreadyQueuedError( 45 | # TTL may be -1 or -2 if the key didn't exist, depending on the 46 | # version of Redis. 47 | expires_in=max(0, ttl), 48 | likely_culprit=task_id.decode("utf-8") if task_id else None, 49 | ) 50 | 51 | return acquired 52 | 53 | 54 | def release_lock(task: "HeimdallTask", key: str): 55 | task.heimdall_redis.delete(key) 56 | 57 | 58 | def unique_key_for_task( 59 | task: "HeimdallTask", args, kwargs, *, prefix="" 60 | ) -> str: 61 | """ 62 | Given a task and its arguments, generate a unique key which can be used 63 | to identify it. 64 | """ 65 | h = getattr(task, "heimdall", {}) 66 | 67 | # When Celery deserializes the arguments for a job, args and kwargs will 68 | # be `[]` or `{}`, even if they were `None` when serialized. Ensure we 69 | # do the same here or the hashes will never match when arguments are empty. 70 | args = args or [] 71 | kwargs = kwargs or {} 72 | 73 | # User specified an explicit key function. 74 | if "key" in h: 75 | if callable(h["key"]): 76 | return prefix + h["key"](args, kwargs) 77 | return prefix + h["key"] 78 | 79 | # Try to generate a unique key from the arguments given to the task. 80 | # Most of the cases where this will fail are also cases where Celery 81 | # will be unable to serialize the job, so we're not too concerned with 82 | # validation. 83 | _, _, data = serialization.dumps( 84 | (args, kwargs), 85 | # TODO: We should _probably_ use the same serializer as the task. 86 | "json", 87 | ) 88 | 89 | h = hashlib.md5() 90 | h.update(task.name.encode("utf-8")) 91 | h.update(data.encode("utf-8")) 92 | return f"{prefix}{h.hexdigest()}" 93 | 94 | 95 | def rate_limited_countdown(task: "HeimdallTask", key, args, kwargs): 96 | # Based on improvements to Vigrond's original implementation by mlissner 97 | # on stack overflow. 98 | h = getattr(task, "heimdall", {}) 99 | r = task.heimdall_redis 100 | 101 | if "rate_limit" in h: 102 | try: 103 | times, per = h["rate_limit"].rate_limit 104 | except TypeError as e: 105 | f = h["rate_limit"].rate_limit 106 | 107 | rate_limit_args = {} 108 | signature = inspect.signature(f) 109 | if "key" in signature.parameters: 110 | rate_limit_args["key"] = key 111 | if "task" in signature.parameters: 112 | rate_limit_args["task"] = task 113 | if "args" in signature.parameters: 114 | rate_limit_args["args"] = args 115 | if "kwargs" in signature.parameters: 116 | rate_limit_args["kwargs"] = kwargs 117 | 118 | times, per = h["rate_limit"].rate_limit(**rate_limit_args) 119 | else: 120 | times, per = h["times"], h["per"] 121 | 122 | number_of_running_tasks = r.get(key) 123 | if number_of_running_tasks is None: 124 | r.set(key, 1, ex=per) 125 | return 0 126 | 127 | if int(number_of_running_tasks) < times: 128 | if r.incr(key, 1) == 1: 129 | r.expire(key, per) 130 | return 0 131 | 132 | schedule_key = f"{key}.schedule" 133 | now = datetime.datetime.now(tz=datetime.timezone.utc) 134 | 135 | delay = r.get(schedule_key) 136 | if delay is None or int(delay) < now.timestamp(): 137 | # Either not scheduled, or scheduled in the past. 138 | ttl = r.ttl(key) 139 | if ttl < 0: 140 | return 0 141 | 142 | r.set( 143 | schedule_key, 144 | int((now + datetime.timedelta(seconds=ttl)).timestamp()), 145 | ex=ttl + 20, 146 | ) 147 | return ttl 148 | 149 | new_time = datetime.datetime.fromtimestamp( 150 | int(delay), tz=datetime.timezone.utc 151 | ) + datetime.timedelta(seconds=per // times) 152 | new_delay = int((new_time - now).total_seconds()) 153 | r.set(schedule_key, int(new_time.timestamp()), ex=new_delay + 20) 154 | return new_delay 155 | 156 | 157 | class HeimdallTask(celery.Task, ABC): 158 | """ 159 | An all-seeing base task for Celery, it provides useful global utilities 160 | for common Celery behaviors, such as global rate limiting and singleton 161 | (only one at a time) tasks. 162 | """ 163 | 164 | abstract = True 165 | 166 | def __init__(self): 167 | super().__init__() 168 | self._heimdall_config = None 169 | self._heimdall_redis = None 170 | 171 | @property 172 | def heimdall_config(self) -> Config: 173 | if not self._heimdall_config: 174 | self._heimdall_config = Config(self.app, task=self) 175 | return self._heimdall_config 176 | 177 | @property 178 | def heimdall_redis(self) -> redis.Redis: 179 | if not self._heimdall_redis: 180 | self._heimdall_redis = self.setup_redis() 181 | return self._heimdall_redis 182 | 183 | def setup_redis(self) -> redis.Redis: 184 | """ 185 | Sets up the Redis connection. By default, it'll use any Redis instance 186 | it can find (in order): 187 | 188 | - the Celery result backend 189 | - the Celery broker 190 | 191 | If nothing can be found, or if you want to explicitly specify a Redis 192 | connection you'll need to implement this method yourself, ex: 193 | 194 | .. code:: 195 | 196 | from redis import Redis 197 | from celery_heimdall import HeimdallTask 198 | 199 | class MyHeimdallTask(HeimdallTask): 200 | def setup_redis(self): 201 | return Redis.from_url('redis://') 202 | """ 203 | # Try to use the Celery result backend, if it's configured for redis. 204 | backend = self.app.conf.get("result_backend") or "" 205 | if backend.startswith("redis://"): 206 | return redis.Redis.from_url(backend) 207 | 208 | # If not the backend, try the broker.... 209 | broker = self.app.conf.get("broker_url") or "" 210 | if broker.startswith("redis://"): 211 | return redis.Redis.from_url(broker) 212 | 213 | # Nope, we can't find a usable redis, user will need to implement 214 | # setup_redis() themselves. 215 | raise NotImplementedError() 216 | 217 | def apply_async(self, args=None, kwargs=None, task_id=None, **options): 218 | h = getattr(self, "heimdall", {}) 219 | if h and "unique" in h: 220 | task_id = task_id or uuid() 221 | 222 | # Task has been configured to be globally unique, so we check for 223 | # the presence of a global lock before allowing it to be queued. 224 | try: 225 | acquire_lock( 226 | self, 227 | unique_key_for_task( 228 | self, 229 | args, 230 | kwargs, 231 | prefix=self.heimdall_config.lock_prefix, 232 | ), 233 | h.get( 234 | "unique_timeout", self.heimdall_config.unique_timeout 235 | ), 236 | task_id=task_id, 237 | ) 238 | except AlreadyQueuedError as exc: 239 | if not self.heimdall_config.unique_raises: 240 | # If we were unable to get the task ID for whatever reason, 241 | # we just fall through and raise anyway. 242 | if exc.likely_culprit is not None: 243 | return self.AsyncResult(exc.likely_culprit) 244 | 245 | raise 246 | 247 | # TODO: If we kept track of queued, but not running, tasks, we should 248 | # be able to estimate _when_ it would be okay to run a 249 | # rate-limited task, rather then just checking when it runs. 250 | 251 | return super().apply_async( 252 | args=args, kwargs=kwargs, task_id=task_id, **options 253 | ) 254 | 255 | def __call__(self, *args, **kwargs): 256 | h = getattr(self, "heimdall", {}) 257 | if h and ("per" in h and "times" in h) or "rate_limit" in h: 258 | delay = rate_limited_countdown( 259 | self, 260 | unique_key_for_task( 261 | self, 262 | args, 263 | kwargs, 264 | prefix=self.heimdall_config.rate_limit_prefix, 265 | ), 266 | args, 267 | kwargs, 268 | ) 269 | if delay > 0: 270 | # We don't want our rescheduling retry to count against 271 | # any normal retry limits the user might have set on the 272 | # task or globally. 273 | self.request.retries -= 1 274 | # Max retries needs to be set to None _before_ calling 275 | # retry(). This value will not propagate, allowing the user's 276 | # normal retry behaviour to apply on the next call. 277 | self.max_retries = None 278 | raise self.retry(countdown=delay) 279 | 280 | # Normally, we check for uniqueness before calling the task, but if 281 | # celery beat is being used, it appears to bypass the apply_async 282 | # method, so we need to check again at run time. 283 | if h and "unique" in h: 284 | task_id = self.request.id 285 | try: 286 | acquire_lock( 287 | self, 288 | unique_key_for_task( 289 | self, 290 | args, 291 | kwargs, 292 | prefix=self.heimdall_config.lock_prefix, 293 | ), 294 | h.get( 295 | "unique_timeout", self.heimdall_config.unique_timeout 296 | ), 297 | task_id=task_id, 298 | ) 299 | except AlreadyQueuedError as exc: 300 | # If this task is the one holding the lock, we can just 301 | # continue on and run it. 302 | if exc.likely_culprit != task_id: 303 | # We can't raise an exception here because it breaks 304 | # celery's funky custom tracing if an exception occurs 305 | # outside of self.run(). 306 | return 307 | 308 | return self.run(*args, **kwargs) 309 | 310 | def after_return(self, status, retval, task_id, args, kwargs, einfo): 311 | # Handles post-task cleanup, when a task exits cleanly. This will be 312 | # called if a task raises an exception (stored in `einfo`), but not 313 | # if a worker straight up dies (say, because of running out of memory) 314 | h = getattr(self, "heimdall", {}) 315 | 316 | # Cleanup the unique task lock when the task finishes, unless the user 317 | # told us to wait for the remaining interval. 318 | if h and "unique" in h and not h.get("unique_wait_for_expiry"): 319 | release_lock( 320 | self, 321 | unique_key_for_task( 322 | self, args, kwargs, prefix=self.heimdall_config.lock_prefix 323 | ), 324 | ) 325 | 326 | super().after_return(status, retval, task_id, args, kwargs, einfo) 327 | 328 | def only_after(self, key: str, seconds: int) -> bool: 329 | """ 330 | A utility for writing sub-blocks in tasks that only execute if 331 | `seconds` has passed since the last time it was run. 332 | 333 | Imagine you have a task that runs every 5 minutes, but there's one line 334 | in that task you only want to run after at least an hour. You'd use 335 | `only_after` to accomplish that. 336 | """ 337 | task_id = getattr(self.request, "id", uuid()) 338 | return bool( 339 | redis.lock.Lock( 340 | self.heimdall_redis, 341 | key, 342 | timeout=seconds, 343 | blocking=self.heimdall_config.unique_lock_blocking, 344 | blocking_timeout=self.heimdall_config.unique_lock_timeout, 345 | ).acquire(token=task_id) 346 | ) 347 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "celery-heimdall" 3 | version = "1.0.1" 4 | description = "Helpful celery extensions." 5 | authors = ["Tyler Kennedy "] 6 | license = "MIT" 7 | readme = "README.md" 8 | homepage = "https://github.com/tktech/celery-heimdall" 9 | repository = "https://github.com/tktech/celery-heimdall" 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.8" 13 | celery = "*" 14 | redis = "*" 15 | click = {version = "*", optional = true} 16 | SQLAlchemy = {version = "*", optional = true} 17 | 18 | [tool.poetry.dev-dependencies] 19 | pytest = "^7.1.2" 20 | bumpversion = "^0.6.0" 21 | coverage = "^6.4.4" 22 | pytest-cov = "^3.0.0" 23 | black = "^23.9.1" 24 | 25 | [tool.poetry.extras] 26 | inspector = ["click", "sqlalchemy"] 27 | 28 | [tool.poetry.scripts] 29 | heimdall-inspector = "celery_heimdall.contrib.inspector.cli:cli" 30 | 31 | [build-system] 32 | requires = ["poetry-core>=1.0.0"] 33 | build-backend = "poetry.core.masonry.api" 34 | 35 | [tool.black] 36 | line-length = 80 37 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytest_plugins = ("celery.contrib.pytest",) 4 | 5 | 6 | @pytest.fixture(scope="session") 7 | def celery_config(): 8 | return { 9 | "broker_url": "redis://", 10 | "result_backend": "redis://", 11 | "worker_send_task_events": True, 12 | } 13 | 14 | 15 | @pytest.fixture(scope="session") 16 | def celery_worker_parameters(): 17 | return {"without_heartbeat": False} 18 | -------------------------------------------------------------------------------- /tests/test_only_after.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from celery import shared_task 4 | from celery.result import AsyncResult 5 | 6 | from celery_heimdall import HeimdallTask 7 | 8 | 9 | @shared_task(base=HeimdallTask, bind=True) 10 | def task_with_block(self: HeimdallTask): 11 | if self.only_after("only_after", 5): 12 | return True 13 | return False 14 | 15 | 16 | def test_only_after(celery_session_worker): 17 | """ 18 | Ensure that the blocks protected by `only_after()` only run after X 19 | seconds. 20 | """ 21 | task1: AsyncResult = task_with_block.apply_async() 22 | assert task1.get() is True 23 | task2: AsyncResult = task_with_block.apply_async() 24 | assert task2.get() is False 25 | time.sleep(10) 26 | task3: AsyncResult = task_with_block.apply_async() 27 | assert task3.get() is True 28 | -------------------------------------------------------------------------------- /tests/test_rate_limited.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | from celery import shared_task 5 | 6 | from celery_heimdall import HeimdallTask, RateLimit 7 | 8 | 9 | @shared_task(base=HeimdallTask, heimdall={"times": 2, "per": 10}) 10 | def default_rate_limit_task(): 11 | pass 12 | 13 | 14 | @shared_task(base=HeimdallTask, heimdall={"rate_limit": RateLimit((2, 10))}) 15 | def tuple_rate_limit_task(): 16 | pass 17 | 18 | 19 | @shared_task( 20 | base=HeimdallTask, heimdall={"rate_limit": RateLimit(lambda key: (2, 10))} 21 | ) 22 | def callable_rate_limit_task(): 23 | pass 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "func", 28 | [default_rate_limit_task, tuple_rate_limit_task, callable_rate_limit_task], 29 | ) 30 | def test_default_rate_limit(celery_session_worker, func): 31 | """ 32 | Ensure a unique task with no other configuration "just works". 33 | """ 34 | start = time.time() 35 | # Immediate 36 | task1 = func.apply_async() 37 | # Immediate 38 | task2 = func.apply_async() 39 | # After at least 10 seconds 40 | task3 = func.apply_async() 41 | # After at least 10 seconds 42 | task4 = func.apply_async() 43 | # After at least 20 seconds 44 | task5 = func.apply_async() 45 | # After at least 20 seconds 46 | task6 = func.apply_async() 47 | 48 | task1.get() 49 | task2.get() 50 | 51 | elapsed = time.time() - start 52 | assert elapsed < 2 53 | 54 | task3.get() 55 | task4.get() 56 | 57 | elapsed = time.time() - start 58 | assert 10 < elapsed < 20 59 | 60 | task5.get() 61 | task6.get() 62 | 63 | elapsed = time.time() - start 64 | assert 20 < elapsed < 30 65 | -------------------------------------------------------------------------------- /tests/test_unique.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for unique tasks. 3 | """ 4 | import celery 5 | import pytest 6 | from celery.result import AsyncResult 7 | 8 | from celery_heimdall import HeimdallTask, AlreadyQueuedError 9 | from celery_heimdall.task import release_lock, unique_key_for_task 10 | 11 | 12 | @celery.shared_task( 13 | base=HeimdallTask, 14 | heimdall={"unique": True}, 15 | ) 16 | def default_unique_task(dummy_arg=None): 17 | return 18 | 19 | 20 | @celery.shared_task( 21 | base=HeimdallTask, 22 | heimdall={ 23 | "unique": True, 24 | "unique_wait_for_expiry": True, 25 | }, 26 | name="wait_for_task", 27 | bind=True, 28 | ) 29 | def wait_for_task(self, dummy_arg=None): 30 | return self.request.id 31 | 32 | 33 | @celery.shared_task( 34 | base=HeimdallTask, heimdall={"unique": True, "unique_raises": True} 35 | ) 36 | def unique_raises_task(): 37 | return 38 | 39 | 40 | @celery.shared_task( 41 | base=HeimdallTask, 42 | heimdall={"unique": True, "key": lambda _, __: "MyTaskKey"}, 43 | ) 44 | def explicit_key_task(): 45 | return 46 | 47 | 48 | @celery.shared_task( 49 | base=HeimdallTask, 50 | heimdall={"unique": True, "key": "MyTaskKeyStr"}, 51 | ) 52 | def explicit_key_task_str(): 53 | return 54 | 55 | 56 | @celery.shared_task( 57 | base=HeimdallTask, 58 | bind=True, 59 | heimdall={"unique": True, "lock_prefix": "new-prefix:"}, 60 | ) 61 | def task_with_override_config(task: HeimdallTask): 62 | return task.heimdall_config.lock_prefix 63 | 64 | 65 | def test_default_unique(celery_session_worker): 66 | """ 67 | Ensure a unique task with no other configuration "just works". 68 | """ 69 | task1: AsyncResult = default_unique_task.apply_async() 70 | result: AsyncResult = default_unique_task.apply_async() 71 | assert result.id == task1.id 72 | 73 | # Ensure the key gets erased after the task finishes, and we can queue 74 | # again. 75 | task1.get() 76 | default_unique_task.apply_async() 77 | 78 | 79 | def test_raises_unique(celery_session_worker): 80 | """ 81 | Ensure a unique task raises an exception on conflicts. 82 | """ 83 | task1: AsyncResult = unique_raises_task.apply_async() 84 | with pytest.raises(AlreadyQueuedError) as exc_info: 85 | result: AsyncResult = unique_raises_task.apply_async() 86 | 87 | # Ensure we populate the ID of the task most likely holding onto the lock 88 | # preventing us from running. 89 | assert exc_info.value.likely_culprit == task1.id 90 | # 60 * 60 is the default Heimdall task timeout. 91 | assert 0 < exc_info.value.expires_in <= 60 * 60 92 | assert task1.id in repr(exc_info.value) 93 | 94 | 95 | def test_unique_explicit_key(celery_session_worker): 96 | """ 97 | Ensure a unique task with an explicitly provided key works. 98 | """ 99 | task1: AsyncResult = explicit_key_task.apply_async() 100 | result: AsyncResult = explicit_key_task.apply_async() 101 | assert task1.id == result.id 102 | 103 | # Ensure the key gets erased after the task finishes, and we can queue 104 | # again. 105 | task1.get() 106 | explicit_key_task.apply_async() 107 | 108 | # Ensure we can use a simple string instead of a lambda. 109 | task1: AsyncResult = explicit_key_task_str.apply_async() 110 | result: AsyncResult = explicit_key_task_str.apply_async() 111 | assert task1.id == result.id 112 | 113 | 114 | def test_different_keys(celery_session_worker): 115 | """ 116 | Ensure tasks enqueued with different args (and thus different auto keys) 117 | works as expected. 118 | """ 119 | default_unique_task.delay("Task1") 120 | default_unique_task.delay("Task2") 121 | 122 | 123 | def test_task_with_override_config(celery_session_worker): 124 | """ 125 | Ensure we can override Config values from the `heimdall` task argument. 126 | """ 127 | task1: AsyncResult = task_with_override_config.apply_async() 128 | result: AsyncResult = task_with_override_config.apply_async() 129 | 130 | assert task1.id == result.id 131 | assert task1.get() == "new-prefix:" 132 | 133 | 134 | def test_send_task(celery_session_app, celery_session_worker): 135 | """ 136 | Ensure that tasks triggered with send_task (like celery beat) will also 137 | be unique. 138 | """ 139 | # This celery pytest plugin doesn't appear to run more than 1 worker at 140 | # a time, even when configured for higher concurrency and a prefork model, 141 | # so we use a unique task that doesn't clear its lock until the timeout to 142 | # test the 2nd task. 143 | 144 | # First we clear any locks that might be hanging around from a previous 145 | # test run. 146 | task = celery_session_app.tasks["wait_for_task"] 147 | release_lock( 148 | task, 149 | unique_key_for_task( 150 | task, (), {}, prefix=task.heimdall_config.lock_prefix 151 | ), 152 | ) 153 | 154 | # Then we queue up the task, which will complete almost immediately but 155 | # leave a lock behind because of unique_wait_for_expiry. 156 | task1: AsyncResult = celery_session_app.send_task("wait_for_task") 157 | # Then we queue up a second task, bypassing `apply_async()`, which will 158 | # check the lock at runtime. 159 | task2: AsyncResult = celery_session_app.send_task("wait_for_task") 160 | 161 | assert task1.get() == task1.id 162 | assert task2.get() is None 163 | --------------------------------------------------------------------------------