├── .env.default ├── .github └── workflows │ ├── dockerhub.yml │ ├── mkdocs.yml │ ├── publish.yml │ └── test.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── README.md ├── distributask ├── __init__.py ├── distributask.py ├── example │ ├── __init__.py │ ├── distributed.py │ ├── local.py │ ├── shared.py │ └── worker.py └── tests │ ├── __init__.py │ ├── tests.py │ └── worker.py ├── docs ├── assets │ ├── DeepAI.png │ ├── banner.png │ ├── diagram.png │ ├── favicon.ico │ └── logo.png ├── distributask.md ├── getting_started.md ├── index.md └── more_info.md ├── mkdocs.yml ├── requirements.txt ├── scripts └── kill_redis_connections.sh └── setup.py /.env.default: -------------------------------------------------------------------------------- 1 | REDIS_HOST=localhost 2 | REDIS_PORT=6379 3 | REDIS_USER=default 4 | REDIS_PASSWORD= 5 | VAST_API_KEY= 6 | HF_TOKEN=hf_*** 7 | HF_REPO_ID=RaccoonResearch/test_dataset 8 | -------------------------------------------------------------------------------- /.github/workflows/dockerhub.yml: -------------------------------------------------------------------------------- 1 | name: Publish Docker image 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | env: 8 | REGISTRY: index.docker.io 9 | IMAGE_NAME: antbaez/distributask-test-worker 10 | 11 | jobs: 12 | push_to_registry: 13 | name: Push Docker image to Docker Hub 14 | runs-on: ubuntu-latest 15 | permissions: 16 | packages: write 17 | contents: read 18 | attestations: write 19 | id-token: write 20 | steps: 21 | - name: Check out the repo 22 | uses: actions/checkout@v4 23 | 24 | - name: Log in to Docker Hub 25 | uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a 26 | with: 27 | username: ${{ secrets.DOCKER_USERNAME }} 28 | password: ${{ secrets.DOCKER_PASSWORD }} 29 | 30 | - name: Extract metadata (tags, labels) for Docker 31 | id: meta 32 | uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 33 | with: 34 | images: ${{ env.IMAGE_NAME }} 35 | 36 | - name: Build and push Docker image 37 | id: push 38 | uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671 39 | with: 40 | context: . 41 | file: ./Dockerfile 42 | push: true 43 | tags: ${{ steps.meta.outputs.tags }} 44 | labels: ${{ steps.meta.outputs.labels }} 45 | 46 | 47 | - name: Generate artifact attestation 48 | uses: actions/attest-build-provenance@v1 49 | with: 50 | subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}} 51 | subject-digest: ${{ steps.push.outputs.digest }} 52 | push-to-registry: true 53 | 54 | -------------------------------------------------------------------------------- /.github/workflows/mkdocs.yml: -------------------------------------------------------------------------------- 1 | name: mkdocs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | permissions: 7 | contents: write 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Configure Git Credentials 14 | run: | 15 | git config user.name github-actions[bot] 16 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: 3.x 20 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV # (3)! 21 | - uses: actions/cache@v4 22 | with: 23 | key: mkdocs-material-${{ env.cache_id }} 24 | path: .cache 25 | restore-keys: | 26 | mkdocs-material- 27 | - run: pip install mkdocs-material mkdocstrings mkdocstrings-python 28 | - run: mkdocs gh-deploy --force 29 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install build 25 | - name: Extract package version 26 | id: extract_version 27 | run: echo "package_version=$(echo $GITHUB_REF | cut -d / -f 3)" >> $GITHUB_ENV 28 | - name: Write package version to file 29 | run: echo "${{ env.package_version }}" > version.txt 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: ${{ secrets.PYPI_USERNAME }} 36 | password: ${{ secrets.PYPI_PASSWORD }} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Lint and Test 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.11"] 11 | env: 12 | REDIS_HOST: ${{ secrets.REDIS_HOST }} 13 | REDIS_PORT: ${{ secrets.REDIS_PORT }} 14 | REDIS_USER: ${{ secrets.REDIS_USER }} 15 | REDIS_PASSWORD: ${{ secrets.REDIS_PASSWORD }} 16 | VAST_API_KEY: ${{ secrets.VAST_API_KEY }} 17 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 18 | HF_REPO_ID: ${{ secrets.HF_REPO_ID }} 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v3 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install pytest 29 | pip install -r requirements.txt 30 | - name: Write package version 31 | run: echo ::set-output name=package_version::$(echo $GITHUB_REF | cut -d / -f 3) > version.txt 32 | - name: Running tests 33 | run: | 34 | pytest distributask/tests/tests.py 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .DS_Store 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | .vscode/ 164 | .chroma 165 | memory 166 | test 167 | version.txt 168 | config.json -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=linux/x86_64 ubuntu:24.04 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y \ 5 | wget \ 6 | xz-utils \ 7 | bzip2 \ 8 | git \ 9 | git-lfs \ 10 | python3-pip \ 11 | python3 \ 12 | && apt-get install -y software-properties-common \ 13 | && apt-get clean \ 14 | && rm -rf /var/lib/apt/lists/* 15 | 16 | COPY requirements.txt . 17 | 18 | RUN pip install -r requirements.txt --break-system-packages 19 | 20 | COPY distributask/ ./distributask/ 21 | 22 | CMD ["celery", "-A", "distributask.example.worker", "worker", "--loglevel=info", "--concurrency=1"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 M̵̞̗̝̼̅̏̎͝Ȯ̴̝̻̊̃̋̀Õ̷̼͋N̸̩̿͜ ̶̜̠̹̼̩͒ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include version.txt 2 | include README.md 3 | include requirements.txt 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributask 2 | 3 | 4 | A simple way to distribute rendering tasks across multiple machines. 5 | 6 | [![Lint and Test](https://github.com/DeepAI-Research/Distributask/actions/workflows/test.yml/badge.svg)](https://github.com/DeepAI-Research/Distributask/actions/workflows/test.yml) 7 | [![PyPI version](https://badge.fury.io/py/distributask.svg)](https://badge.fury.io/py/distributask) 8 | [![License](https://img.shields.io/badge/License-MIT-blue)](https://github.com/DeepAI-Research/Distributask/blob/main/LICENSE) 9 | 10 | 11 | # Description 12 | 13 | Distributask is a package that automatically queues, executes, and uploads the result of any task you want using Vast.ai, a decentralized network of GPUs. It works by first creating a Celery queue of the tasks, which contain the code that you want to be ran on a GPU. The tasks are then passed to the Vast.ai GPU workers using Redis as a message broker. Once a worker has completed a task, the result is uploaded to Hugging Face. 14 | 15 | # Installation 16 | 17 | ```bash 18 | pip install distributask 19 | ``` 20 | 21 | # Development 22 | 23 | ### Setup 24 | 25 | Clone the repository and navigate to the project directory: 26 | 27 | ```bash 28 | git clone https://github.com/DeepAI-Research/Distributask.git 29 | cd Distributask 30 | ``` 31 | 32 | Install the required packages: 33 | 34 | ```bash 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | Or install Distributask as a package: 39 | 40 | ```bash 41 | pip install distributask 42 | ``` 43 | 44 | ### Configuration 45 | 46 | Create a `.env` file in the root directory of your project or set environment variables to create your desired setup: 47 | 48 | ```plaintext 49 | REDIS_HOST="name of your redis server" 50 | REDIS_PORT="port of your redis server 51 | REDIS_USER="username to login to redis server" 52 | REDIS_PASSWORD="password to login to redis server" 53 | VAST_API_KEY="your Vast.ai API key" 54 | HF_TOKEN="your Hugging Face token" 55 | HF_REPO_ID="name of your Hugging Face repository" 56 | BROKER_POOL_LIMIT="your broker pool limit setting" 57 | ``` 58 | 59 | ## Getting Started 60 | 61 | ### Running an Example Task 62 | 63 | To run an example task and see Distributask in action, you can execute the example script provided in the project: 64 | 65 | ```bash 66 | # Run the example task locally using either a Docker container or a Celery worker: 67 | python -m distributask.example.local 68 | 69 | # Run the example task on Vast.ai ("kitchen sink" example): 70 | python -m distributask.example.distributed 71 | 72 | ``` 73 | 74 | This script configures the environment, registers a sample function, creates a queue of tasks, and monitors its execution on some workers. 75 | 76 | ### Command Options 77 | 78 | - `--max_price` is the max price (in $/hour) a node can be be rented for. 79 | - `--max_nodes` is the max number of vast.ai nodes that can be rented. 80 | - `--docker_image` is the name of the docker image to load to the vast.ai node. 81 | - `--module_name` is the name of the Celery worker. 82 | - `--number_of_tasks` is the number of example tasks that will be added to the queue and done by the workers. 83 | 84 | ## Documentation 85 | 86 | For more info checkout our in-depth [documentation](https://deepai-research.github.io/Distributask)! 87 | 88 | ## Contributing 89 | 90 | Contributions are welcome! For any changes you would like to see, please open an issue to discuss what you would like to see changed or to change yourself. 91 | 92 | ## License 93 | 94 | This project is licensed under the MIT License - see the `LICENSE` file for details. 95 | 96 | ## Citation 97 | 98 | ```bibtex 99 | @misc{Distributask, 100 | author = {DeepAIResearch}, 101 | title = {Distributask: a simple way to distribute rendering tasks across mulitiple machines}, 102 | year = {2024}, 103 | publisher = {GitHub}, 104 | howpublished = {\url{https://github.com/DeepAI-Research/Distributask}} 105 | } 106 | ``` 107 | 108 | ## Contributors 109 | 110 | 111 | 112 | 113 | 115 | 116 | 117 |
M̵̞̗̝̼̅̏̎͝Ȯ̴̝̻̊̃̋̀Õ̷̼͋N̸̩̿͜ ̶̜̠̹̼̩͒
M̵̞̗̝̼̅̏̎͝Ȯ̴̝̻̊̃̋̀Õ̷̼͋N̸̩̿͜ ̶̜̠̹̼̩͒

114 |
Anthony
Anthony

-------------------------------------------------------------------------------- /distributask/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributask import * 2 | -------------------------------------------------------------------------------- /distributask/distributask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import requests 5 | from tqdm import tqdm 6 | from typing import Dict, List 7 | import atexit 8 | import tempfile 9 | 10 | from celery import Celery 11 | from redis import ConnectionPool, Redis 12 | from omegaconf import OmegaConf 13 | from dotenv import load_dotenv 14 | from huggingface_hub import HfApi, Repository 15 | from requests.exceptions import HTTPError 16 | from celery.utils.log import get_task_logger 17 | 18 | 19 | class Distributask: 20 | """ 21 | The Distributask class contains the core features of distributask, including creating and 22 | executing the task queue, managing workers using the Vast.ai API, and uploading files and directories 23 | using the Hugging Face API. 24 | """ 25 | 26 | app: Celery = None 27 | redis_client: Redis = None 28 | registered_functions: dict = {} 29 | pool: ConnectionPool = None 30 | 31 | def __init__( 32 | self, 33 | hf_repo_id=os.getenv("HF_REPO_ID"), 34 | hf_token=os.getenv("HF_TOKEN"), 35 | vast_api_key=os.getenv("VAST_API_KEY"), 36 | redis_host=os.getenv("REDIS_HOST", "localhost"), 37 | redis_password=os.getenv("REDIS_PASSWORD", ""), 38 | redis_port=os.getenv("REDIS_PORT", 6379), 39 | redis_username=os.getenv("REDIS_USER", "default"), 40 | broker_pool_limit=os.getenv("BROKER_POOL_LIMIT", 1), 41 | ) -> None: 42 | """ 43 | Initialize the Distributask object with the provided configuration parameters. Also sets some 44 | default settings in Celery and handles cleanup of Celery queue and Redis server on exit. 45 | 46 | Args: 47 | hf_repo_id (str): Hugging Face repository ID. 48 | hf_token (str): Hugging Face API token. 49 | vast_api_key (str): Vast.ai API key. 50 | redis_host (str): Redis host. Defaults to "localhost". 51 | redis_password (str): Redis password. Defaults to an empty string. 52 | redis_port (int): Redis port. Defaults to 6379. 53 | redis_username (str): Redis username. Defaults to "default". 54 | broker_pool_limit (int): Celery broker pool limit. Defaults to 1. 55 | 56 | Raises: 57 | ValueError: If any of the required parameters (hf_repo_id, hf_token, vast_api_key) are not provided. 58 | """ 59 | if hf_repo_id is None: 60 | raise ValueError( 61 | "HF_REPO_ID is not provided to the Distributask constructor" 62 | ) 63 | 64 | if hf_token is None: 65 | raise ValueError("HF_TOKEN is not provided to the Distributask constructor") 66 | 67 | if vast_api_key is None: 68 | raise ValueError( 69 | "VAST_API_KEY is not provided to the Distributask constructor" 70 | ) 71 | 72 | if redis_host == "localhost": 73 | print( 74 | "WARNING: Using default Redis host 'localhost'. This is not recommended for production use and won't work for distributed rendering." 75 | ) 76 | 77 | self.settings = { 78 | "HF_REPO_ID": hf_repo_id, 79 | "HF_TOKEN": hf_token, 80 | "VAST_API_KEY": vast_api_key, 81 | "REDIS_HOST": redis_host, 82 | "REDIS_PASSWORD": redis_password, 83 | "REDIS_PORT": redis_port, 84 | "REDIS_USER": redis_username, 85 | "BROKER_POOL_LIMIT": broker_pool_limit, 86 | } 87 | 88 | redis_url = self.get_redis_url() 89 | # start Celery app instance 90 | self.app = Celery("distributask", broker=redis_url, backend=redis_url) 91 | self.app.conf.broker_pool_limit = self.settings["BROKER_POOL_LIMIT"] 92 | 93 | def cleanup_redis(): 94 | """ 95 | Deletes keys in redis related to Celery tasks and closes the Redis connection on exit 96 | """ 97 | patterns = ["celery-task*", "task_status*"] 98 | redis_connection = self.get_redis_connection() 99 | for pattern in patterns: 100 | for key in redis_connection.scan_iter(match=pattern): 101 | redis_connection.delete(key) 102 | print("Redis server cleared") 103 | 104 | def cleanup_celery(): 105 | """ 106 | Clears Celery task queue on exit 107 | """ 108 | self.app.control.purge() 109 | print("Celery queue cleared") 110 | 111 | # At exit, close Celery instance, delete all previous task info from queue and Redis, and close Redis 112 | atexit.register(self.app.close) 113 | atexit.register(cleanup_redis) 114 | atexit.register(cleanup_celery) 115 | 116 | self.redis_client = self.get_redis_connection() 117 | 118 | # Tasks are acknowledged after they have been executed 119 | self.app.conf.task_acks_late = True 120 | self.call_function_task = self.app.task( 121 | bind=True, name="call_function_task", max_retries=3, default_retry_delay=30 122 | )(self.call_function_task) 123 | 124 | def __del__(self): 125 | """Destructor to clean up resources.""" 126 | if self.pool is not None: 127 | self.pool.disconnect() 128 | if self.redis_client is not None: 129 | self.redis_client.close() 130 | if self.app is not None: 131 | self.app.close() 132 | 133 | def log(self, message: str, level: str = "info") -> None: 134 | """ 135 | Log a message with the specified level. 136 | 137 | Args: 138 | message (str): The message to log. 139 | level (str): The logging level. Defaults to "info". 140 | """ 141 | logger = get_task_logger(__name__) 142 | getattr(logger, level)(message) 143 | 144 | def get_settings(self) -> str: 145 | """ 146 | Return settings of distributask instance. 147 | """ 148 | return self.settings 149 | 150 | def get_redis_url(self) -> str: 151 | """ 152 | Construct a Redis URL from the configuration settings. 153 | 154 | Returns: 155 | str: A Redis URL string. 156 | 157 | Raises: 158 | ValueError: If any required Redis connection parameter is missing. 159 | """ 160 | host = self.settings["REDIS_HOST"] 161 | password = self.settings["REDIS_PASSWORD"] 162 | port = self.settings["REDIS_PORT"] 163 | username = self.settings["REDIS_USER"] 164 | 165 | if None in [host, password, port, username]: 166 | raise ValueError("Missing required Redis configuration values") 167 | 168 | redis_url = f"redis://{username}:{password}@{host}:{port}" 169 | return redis_url 170 | 171 | def get_redis_connection(self, force_new: bool = False) -> Redis: 172 | """ 173 | Returns Redis connection. If it already exists, returns current connection. 174 | If it does not exist, its create a new Redis connection using a connection pool. 175 | 176 | Args: 177 | force_new (bool): Force the creation of a new connection if set to True. Defaults to False. 178 | 179 | Returns: 180 | Redis: A Redis connection object. 181 | """ 182 | if self.redis_client is not None and not force_new: 183 | return self.redis_client 184 | else: 185 | self.pool = ConnectionPool(host=self.settings["REDIS_HOST"], 186 | port=self.settings["REDIS_PORT"], 187 | password=self.settings["REDIS_PASSWORD"], 188 | max_connections=1) 189 | self.redis_client = Redis(connection_pool=self.pool) 190 | atexit.register(self.pool.disconnect) 191 | 192 | return self.redis_client 193 | 194 | def get_env(self, key: str, default: any = None) -> any: 195 | """ 196 | Retrieve a value from the configuration or .env file, with an optional default if the key is not found. 197 | 198 | Args: 199 | key (str): The key to look for in the settings. 200 | default (any): The default value to return if the key is not found. Defaults to None. 201 | 202 | Returns: 203 | any: The value from the settings if the key exists, otherwise the default value. 204 | """ 205 | return self.settings.get(key, default) 206 | 207 | def call_function_task(self, func_name: str, args_json: str) -> any: 208 | """ 209 | Creates Celery task that executes a registered function with provided JSON arguments. 210 | 211 | Args: 212 | func_name (str): The name of the registered function to execute. 213 | args_json (str): JSON string representation of the arguments for the function. 214 | 215 | Returns: 216 | any: Celery.app.task object, represents result of the registered function 217 | 218 | Raises: 219 | ValueError: If the function name is not registered. 220 | Exception: If an error occurs during the execution of the function. The task will retry in this case. 221 | """ 222 | try: 223 | if func_name not in self.registered_functions: 224 | raise ValueError(f"Function '{func_name}' is not registered.") 225 | 226 | func = self.registered_functions[func_name] 227 | args = json.loads(args_json) 228 | result = func(**args) 229 | # self.update_function_status(self.call_function_task.request.id, "success") 230 | 231 | return result 232 | except Exception as e: 233 | self.log(f"Error in call_function_task: {str(e)}", "error") 234 | # self.call_function_task.retry(exc=e) 235 | 236 | 237 | def register_function(self, func: callable) -> callable: 238 | """ 239 | Decorator to register a function so that it can be invoked as a Celery task. 240 | 241 | Args: 242 | func (callable): The function to register. 243 | 244 | Returns: 245 | callable: The original function, now registered as a callable task. 246 | """ 247 | self.registered_functions[func.__name__] = func 248 | return func 249 | 250 | def execute_function(self, func_name: str, args: dict) -> Celery.AsyncResult: 251 | """ 252 | Execute a registered function as a Celery task with provided arguments. 253 | 254 | Args: 255 | func_name (str): The name of the function to execute. 256 | args (dict): Arguments to pass to the function. 257 | 258 | Returns: 259 | celery.result.AsyncResult: An object representing the asynchronous result of the task. 260 | """ 261 | args_json = json.dumps(args) 262 | async_result = self.call_function_task.delay(func_name, args_json) 263 | return async_result 264 | 265 | def update_function_status(self, task_id: str, status: str) -> None: 266 | """ 267 | Update the status of a function task as a new Redis key. 268 | 269 | Args: 270 | task_id (str): The ID of the task. 271 | status (str): The new status to set. 272 | """ 273 | redis_client = self.get_redis_connection() 274 | redis_client.set(f"task_status:{task_id}", status) 275 | 276 | def initialize_dataset(self, **kwargs) -> None: 277 | """ 278 | Initialize a Hugging Face repository if it doesn't exist. Reads Hugging Face info from config or .env 279 | 280 | Args: 281 | kwargs: kwargs that can be passed into the HfApi.create_repo function. 282 | 283 | Raises: 284 | HTTPError: If repo cannot be created due to connection error other than repo not existing 285 | """ 286 | repo_id = self.settings.get("HF_REPO_ID") 287 | hf_token = self.settings.get("HF_TOKEN") 288 | api = HfApi(token=hf_token) 289 | 290 | # creates new repo if desired repo is not found 291 | try: 292 | repo_info = api.repo_info(repo_id=repo_id, repo_type="dataset", timeout=30) 293 | except HTTPError as e: 294 | if e.response.status_code == 404: 295 | self.log( 296 | f"Repository {repo_id} does not exist. Creating a new repository.", 297 | "warn", 298 | ) 299 | api.create_repo( 300 | repo_id=repo_id, token=hf_token, repo_type="dataset", **kwargs 301 | ) 302 | else: 303 | raise e 304 | 305 | # Create config.json file 306 | config = { 307 | "data_loader_name": "custom", 308 | "data_loader_kwargs": { 309 | "path": repo_id, 310 | "format": "files", 311 | "fields": ["file_path", "text"], 312 | }, 313 | } 314 | 315 | # apply config.json to created repo 316 | with tempfile.TemporaryDirectory() as temp_dir: 317 | with Repository( 318 | local_dir=temp_dir, 319 | clone_from=repo_id, 320 | repo_type="dataset", 321 | use_auth_token=hf_token, 322 | ).commit(commit_message="Add config.json"): 323 | with open(os.path.join(temp_dir, "config.json"), "w") as f: 324 | json.dump(config, f, indent=2) 325 | 326 | self.log(f"Initialized repository {repo_id}.") 327 | 328 | # upload a single file to the Hugging Face repository 329 | def upload_file(self, file_path: str) -> None: 330 | """ 331 | Upload a file to a Hugging Face repository. 332 | 333 | Args: 334 | file_path (str): The path of the file to upload. 335 | 336 | Raises: 337 | Exception: If an error occurs during the upload process. 338 | 339 | """ 340 | hf_token = self.settings.get("HF_TOKEN") 341 | repo_id = self.settings.get("HF_REPO_ID") 342 | 343 | api = HfApi(token=hf_token) 344 | 345 | try: 346 | self.log(f"Uploading {file_path} to Hugging Face repo {repo_id}") 347 | api.upload_file( 348 | path_or_fileobj=file_path, 349 | path_in_repo=os.path.basename(file_path), 350 | repo_id=repo_id, 351 | token=hf_token, 352 | repo_type="dataset", 353 | ) 354 | self.log(f"Uploaded {file_path} to Hugging Face repo {repo_id}") 355 | except Exception as e: 356 | self.log( 357 | f"Failed to upload {file_path} to Hugging Face repo {repo_id}: {e}", 358 | "error", 359 | ) 360 | 361 | def upload_directory(self, dir_path: str) -> None: 362 | """ 363 | Upload a directory to a Hugging Face repository. Can be used to reduce frequency of Hugging Face API 364 | calls if you are rate limited while using the upload_file function. 365 | 366 | Args: 367 | dir_path (str): The path of the directory to upload. 368 | 369 | Raises: 370 | Exception: If an error occurs during the upload process. 371 | 372 | """ 373 | hf_token = self.settings.get("HF_TOKEN") 374 | repo_id = self.settings.get("HF_REPO_ID") 375 | 376 | try: 377 | self.log(f"Uploading {dir_path} to Hugging Face repo {repo_id}") 378 | 379 | api = HfApi(token=hf_token) 380 | api.upload_folder( 381 | folder_path=dir_path, 382 | repo_id=repo_id, 383 | repo_type="dataset", 384 | ) 385 | self.log(f"Uploaded {dir_path} to Hugging Face repo {repo_id}") 386 | except Exception as e: 387 | self.log( 388 | f"Failed to upload {dir_path} to Hugging Face repo {repo_id}: {e}", 389 | "error", 390 | ) 391 | 392 | def delete_file(self, repo_id: str, path_in_repo: str) -> None: 393 | """ 394 | Delete a file from a Hugging Face repository. 395 | 396 | Args: 397 | repo_id (str): The ID of the repository. 398 | path_in_repo (str): The path of the file to delete within the repository. 399 | 400 | Raises: 401 | Exception: If an error occurs during the deletion process. 402 | 403 | """ 404 | hf_token = self.settings.get("HF_TOKEN") 405 | api = HfApi(token=hf_token) 406 | 407 | try: 408 | api.delete_file( 409 | repo_id=repo_id, 410 | path_in_repo=path_in_repo, 411 | repo_type="dataset", 412 | token=hf_token, 413 | ) 414 | self.log(f"Deleted {path_in_repo} from Hugging Face repo {repo_id}") 415 | except Exception as e: 416 | self.log( 417 | f"Failed to delete {path_in_repo} from Hugging Face repo {repo_id}: {e}", 418 | "error", 419 | ) 420 | 421 | def file_exists(self, repo_id: str, path_in_repo: str) -> bool: 422 | """ 423 | Check if a file exists in a Hugging Face repository. 424 | 425 | Args: 426 | repo_id (str): The ID of the repository. 427 | path_in_repo (str): The path of the file to check within the repository. 428 | 429 | Returns: 430 | bool: True if the file exists in the repository, False otherwise. 431 | 432 | Raises: 433 | Exception: If an error occurs while checking the existence of the file. 434 | """ 435 | hf_token = self.settings.get("HF_TOKEN") 436 | api = HfApi(token=hf_token) 437 | 438 | try: 439 | repo_files = api.list_repo_files( 440 | repo_id=repo_id, repo_type="dataset", token=hf_token 441 | ) 442 | return path_in_repo in repo_files 443 | except Exception as e: 444 | self.log( 445 | f"Failed to check if {path_in_repo} exists in Hugging Face repo {repo_id}: {e}", 446 | "error", 447 | ) 448 | return False 449 | 450 | def list_files(self, repo_id: str) -> list: 451 | """ 452 | Get a list of files from a Hugging Face repository. 453 | 454 | Args: 455 | repo_id (str): The ID of the repository. 456 | 457 | Returns: 458 | list: A list of file paths in the repository. 459 | 460 | Raises: 461 | Exception: If an error occurs while retrieving the list of files. 462 | """ 463 | hf_token = self.settings.get("HF_TOKEN") 464 | api = HfApi(token=hf_token) 465 | 466 | try: 467 | repo_files = api.list_repo_files( 468 | repo_id=repo_id, repo_type="dataset", token=hf_token 469 | ) 470 | return repo_files 471 | except Exception as e: 472 | self.log( 473 | f"Failed to get the list of files from Hugging Face repo {repo_id}: {e}", 474 | "error", 475 | ) 476 | return [] 477 | 478 | def search_offers(self, max_price: float) -> List[Dict]: 479 | """ 480 | Search for available offers to rent a node as an instance on the Vast.ai platform. 481 | 482 | Args: 483 | max_price (float): The maximum price per hour for the instance. 484 | 485 | Returns: 486 | List[Dict]: A list of dictionaries representing the available offers. 487 | 488 | Raises: 489 | requests.exceptions.RequestException: If there is an error while making the API request. 490 | """ 491 | api_key = self.get_env("VAST_API_KEY") 492 | base_url = "https://console.vast.ai/api/v0/bundles/" 493 | headers = { 494 | "Accept": "application/json", 495 | "Content-Type": "application/json", 496 | "Authorization": f"Bearer {api_key}", 497 | } 498 | url = ( 499 | base_url 500 | + '?q={"gpu_ram":">=4","rentable":{"eq":true},"dph_total":{"lte":' 501 | + str(max_price) 502 | + '},"sort_option":{"0":["dph_total","asc"],"1":["total_flops","asc"]}}' 503 | ) 504 | 505 | try: 506 | response = requests.get(url, headers=headers) 507 | response.raise_for_status() 508 | json_response = response.json() 509 | return json_response["offers"] 510 | 511 | except requests.exceptions.RequestException as e: 512 | self.log( 513 | f"Error: {e}\nResponse: {response.text if response else 'No response'}" 514 | ) 515 | raise 516 | 517 | def create_instance( 518 | self, offer_id: str, image: str, module_name: str, env_settings: Dict, command: str 519 | ) -> Dict: 520 | """ 521 | Create an instance on the Vast.ai platform. Passes in some useful Celery settings by default. 522 | 523 | Args: 524 | offer_id (str): The ID of the offer to create the instance from. 525 | image (str): The image to use for the instance. (example: RaccoonResearch/distributask-test-worker) 526 | module_name (str): The name of the module to run on the instance, configured to be a docker file (example: distributask.example.worker) 527 | command (str): Command that initializes celery worker. Has default command with specific settings if not passed in. These settings have 528 | been found to be beneficial to the stability and simplicity of a Distributask run. 529 | env_settings (Dict): Used to pass in environment variables to the Vast.ai instance. This is a dictionary with keys of the 530 | environment variable name and values of the desired value of the environment variable. 531 | 532 | Returns: 533 | Dict: A dictionary representing the created instance. 534 | 535 | Raises: 536 | ValueError: If the Vast.ai API key is not set in the environment. 537 | Exception: If there is an error while creating the instance. 538 | """ 539 | if self.get_env("VAST_API_KEY") is None: 540 | self.log("VAST_API_KEY is not set in the environment", "error") 541 | raise ValueError("VAST_API_KEY is not set in the environment") 542 | 543 | if command is None: 544 | command = f"celery -A {module_name} worker --loglevel=info --concurrency=1 --without-heartbeat --prefetch-multiplier=1" 545 | 546 | if env_settings is None: 547 | env_settings = self.settings 548 | 549 | json_blob = { 550 | "client_id": "me", 551 | "image": image, 552 | "env": env_settings, 553 | "disk": 32, # Set a non-zero value for disk 554 | "onstart": f"export PATH=$PATH:/ && cd ../ && {command}", 555 | "runtype": "ssh ssh_proxy", 556 | } 557 | url = f"https://console.vast.ai/api/v0/asks/{offer_id}/?api_key={self.get_env('VAST_API_KEY')}" 558 | headers = {"Authorization": f"Bearer {self.get_env('VAST_API_KEY')}"} 559 | response = requests.put(url, headers=headers, json=json_blob) 560 | 561 | if response.status_code != 200: 562 | self.log(f"Failed to create instance: {response.text}", "error") 563 | raise Exception(f"Failed to create instance: {response.text}") 564 | 565 | return response.json() 566 | 567 | def destroy_instance(self, instance_id: str) -> Dict: 568 | """ 569 | Destroy an instance on the Vast.ai platform. 570 | 571 | Args: 572 | instance_id (str): The ID of the instance to destroy. 573 | 574 | Returns: 575 | Dict: A dictionary representing the result of the destroy operation. 576 | """ 577 | api_key = self.get_env("VAST_API_KEY") 578 | headers = {"Authorization": f"Bearer {api_key}"} 579 | url = ( 580 | f"https://console.vast.ai/api/v0/instances/{instance_id}/?api_key={api_key}" 581 | ) 582 | response = requests.delete(url, headers=headers) 583 | return response 584 | 585 | def rent_nodes( 586 | self, 587 | max_price: float, 588 | max_nodes: int, 589 | image: str, 590 | module_name: str, 591 | env_settings: Dict = None, 592 | command: str = None, 593 | ) -> List[Dict]: 594 | """ 595 | Rent nodes as an instance on the Vast.ai platform. 596 | 597 | Args: 598 | max_price (float): The maximum price per hour for the nodes. 599 | max_nodes (int): The maximum number of nodes to rent. 600 | image (str): The image to use for the nodes. 601 | module_name (str): The name of the module to run on the nodes. 602 | 603 | Returns: 604 | List[Dict]: A list of dictionaries representing the rented nodes. If error is encountered 605 | trying to rent, it will retry every 5 seconds. 606 | """ 607 | rented_nodes: List[Dict] = [] 608 | while len(rented_nodes) < max_nodes: 609 | search_retries = 10 610 | while search_retries > 0: 611 | try: 612 | offers = self.search_offers(max_price) 613 | break 614 | except Exception as e: 615 | self.log( 616 | f"Error searching for offers: {str(e)} - retrying in 5 seconds...", 617 | "error", 618 | ) 619 | search_retries -= 1 620 | # sleep for 10 seconds before retrying 621 | time.sleep(10) 622 | continue 623 | 624 | offers = sorted( 625 | offers, key=lambda offer: offer["dph_total"] 626 | ) # Sort offers by price, lowest to highest 627 | for offer in offers: 628 | time.sleep(5) 629 | if len(rented_nodes) >= max_nodes: 630 | break 631 | try: 632 | instance = self.create_instance( 633 | offer["id"], image, module_name, env_settings=env_settings, command=command 634 | ) 635 | rented_nodes.append( 636 | { 637 | "offer_id": offer["id"], 638 | "instance_id": instance["new_contract"], 639 | } 640 | ) 641 | except Exception as e: 642 | self.log( 643 | f"Error renting node: {str(e)} - searching for new offers", 644 | "error", 645 | ) 646 | break # Break out of the current offer iteration 647 | else: 648 | # If the loop completes without breaking, all offers have been tried 649 | self.log("No more offers available - stopping node rental", "warning") 650 | break 651 | 652 | atexit.register(self.terminate_nodes, rented_nodes) 653 | return rented_nodes 654 | 655 | def get_node_log(self, node: Dict, wait_time: int = 2): 656 | """ 657 | Get the log of the Vast.ai instance that is passed in. Makes an api call to tell the instance to send the log, 658 | and another one to actually retrive the log 659 | Args: 660 | node (Dict): the node that corresponds to the Vast.ai instance you want the log from 661 | wait_time (int): how long to wait in between the two api calls described above 662 | 663 | Returns: 664 | str: the log of the instance requested. If anything else other than a code 200 is received, return None 665 | """ 666 | node_id = node["instance_id"] 667 | url = f"https://console.vast.ai/api/v0/instances/request_logs/{node_id}/" 668 | 669 | payload = {"tail": "1000"} 670 | headers = { 671 | "Accept": "application/json", 672 | "Authorization": f"Bearer {self.settings['VAST_API_KEY']}", 673 | } 674 | 675 | response = requests.request( 676 | "PUT", url, headers=headers, json=payload, timeout=5 677 | ) 678 | 679 | if response.status_code == 200: 680 | log_url = response.json()["result_url"] 681 | time.sleep(wait_time) 682 | log_response = requests.get(log_url, timeout=5) 683 | if log_response.status_code == 200: 684 | return log_response 685 | else: 686 | return None 687 | else: 688 | return None 689 | 690 | def terminate_nodes(self, nodes: List[Dict]) -> None: 691 | """ 692 | Terminate the instances of rented nodes on Vast.ai. 693 | 694 | Args: 695 | nodes (List[Dict]): A list of dictionaries representing the rented nodes. 696 | 697 | Raises: 698 | Exception: If error in destroying instances. 699 | """ 700 | print("Terminating nodes...") 701 | for node in nodes: 702 | time.sleep(1) 703 | try: 704 | response = self.destroy_instance(node["instance_id"]) 705 | if response.status_code != 200: 706 | time.sleep(5) 707 | self.destroy_instance(node["instance_id"]) 708 | except Exception as e: 709 | self.log( 710 | f"Error terminating node: {node['instance_id']}, {str(e)}", "error" 711 | ) 712 | 713 | def monitor_tasks( 714 | self, tasks, update_interval=1, show_time_left=True, print_statements=True 715 | ): 716 | """ 717 | Monitor the status of the tasks on the Vast.ai nodes. 718 | 719 | Args: 720 | tasks (List): A list of the tasks to monitor. Should be a list of the results of execute_function. 721 | update_interval (bool): Number of seconds the status of tasks are updated. 722 | show_time_left (bool): Show the estimated time left to complete tasks using the tqdm progress bar 723 | print_statments (bool): Allow printing of status of task queue 724 | 725 | Raises: 726 | Exception: If error in the process of executing the tasks 727 | """ 728 | 729 | try: 730 | # Wait for the tasks to complete 731 | if print_statements: 732 | print("Tasks submitted to queue. Starting queue...") 733 | print("Elapsed time Distributask: 750 | """ 751 | Create Distributask object using settings that merge config.json and .env files present in distributask directory. 752 | If there are conflicting values, the .env takes priority. 753 | 754 | Args: 755 | config_path (str): path to config.json file 756 | env_path (str): path to .env file 757 | 758 | Returns: 759 | Distributask object initialized with settings from config or .env file 760 | """ 761 | print("**** CREATE_FROM_CONFIG ****") 762 | global distributask 763 | if distributask is not None: 764 | return distributask 765 | # Load environment variables from .env file 766 | try: 767 | load_dotenv(env_path) 768 | except: 769 | print("No .env file found. Using system environment variables only.") 770 | 771 | # Load configuration from JSON file 772 | try: 773 | settings = OmegaConf.load(config_path) 774 | if not all(settings.values()): 775 | print(f"Configuration file is missing necessary values.") 776 | except: 777 | print( 778 | "Configuration file not found. Falling back to system environment variables." 779 | ) 780 | settings = {} 781 | 782 | env_dict = {key: value for key, value in os.environ.items()} 783 | settings = OmegaConf.merge(settings, OmegaConf.create(env_dict)) 784 | 785 | distributask = Distributask( 786 | hf_repo_id=settings.get("HF_REPO_ID"), 787 | hf_token=settings.get("HF_TOKEN"), 788 | vast_api_key=settings.get("VAST_API_KEY"), 789 | redis_host=settings.get("REDIS_HOST"), 790 | redis_password=settings.get("REDIS_PASSWORD"), 791 | redis_port=settings.get("REDIS_PORT"), 792 | redis_username=settings.get("REDIS_USER"), 793 | broker_pool_limit=int(settings.get("BROKER_POOL_LIMIT", 1)), 794 | ) 795 | 796 | return distributask 797 | -------------------------------------------------------------------------------- /distributask/example/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed import * 2 | from .local import * 3 | from .worker import * 4 | from .shared import * 5 | -------------------------------------------------------------------------------- /distributask/example/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import atexit 5 | 6 | from .shared import distributask, example_function 7 | 8 | if __name__ == "__main__": 9 | # Create an ArgumentParser object 10 | parser = argparse.ArgumentParser(description="Distributask example script") 11 | 12 | # Add arguments with default values 13 | parser.add_argument( 14 | "--max_price", 15 | type=float, 16 | default=0.20, 17 | help="Max price per node, in dollars (default: 0.20)", 18 | ) 19 | parser.add_argument( 20 | "--max_nodes", 21 | type=int, 22 | default=1, 23 | help="Max number of nodes to rent (default: 1)", 24 | ) 25 | parser.add_argument( 26 | "--docker_image", 27 | type=str, 28 | default="antbaez/distributask-test-worker", 29 | help="Docker image to use for the worker (default: antbaez/distributask-test-worker)", 30 | ) 31 | parser.add_argument( 32 | "--module_name", 33 | type=str, 34 | default="distributask.example.worker", 35 | help="Module name (default: distributask.example.worker)", 36 | ) 37 | parser.add_argument( 38 | "--number_of_tasks", type=int, default=10, help="Number of tasks (default: 10)" 39 | ) 40 | 41 | args = parser.parse_args() 42 | 43 | completed = False 44 | 45 | # Register function to distributask object 46 | distributask.register_function(example_function) 47 | 48 | # Initialize the dataset on Hugging Face 49 | distributask.initialize_dataset() 50 | 51 | # Create a file with the current date and time and save it as "datetime.txt" 52 | with open("datetime.txt", "w") as f: 53 | f.write(time.strftime("%Y-%m-%d %H:%M:%S")) 54 | 55 | # Upload file to the repository 56 | distributask.upload_file("datetime.txt") 57 | 58 | # Remove the example file from local 59 | os.remove("datetime.txt") 60 | 61 | vast_api_key = distributask.get_env("VAST_API_KEY") 62 | if not vast_api_key: 63 | raise ValueError("Vast API key not found in configuration.") 64 | 65 | job_configs = [] 66 | 67 | # Compile parameters for tasks 68 | for i in range(args.number_of_tasks): 69 | job_configs.append( 70 | { 71 | "outputs": [f"result_{i}.txt"], 72 | "task_params": {"index": i, "arg1": 1, "arg2": 2}, 73 | } 74 | ) 75 | 76 | # Rent Vast.ai nodes and get list of node ids 77 | print("Renting nodes...") 78 | rented_nodes = distributask.rent_nodes( 79 | args.max_price, args.max_nodes, args.docker_image, args.module_name 80 | ) 81 | 82 | print("Total rented nodes: ", len(rented_nodes)) 83 | 84 | tasks = [] 85 | 86 | # Submit the tasks to the queue for the Vast.ai worker nodes to execute 87 | for i in range(args.number_of_tasks): 88 | job_config = job_configs[i] 89 | print(f"Task {i}") 90 | print(job_config) 91 | print("Task params: ", job_config["task_params"]) 92 | 93 | params = job_config["task_params"] 94 | 95 | # Each task executes the function "example_function", defined in shared.py 96 | task = distributask.execute_function(example_function.__name__, params) 97 | 98 | # Add the task to the list of tasks 99 | tasks.append(task) 100 | 101 | def terminate_workers(): 102 | distributask.terminate_nodes(rented_nodes) 103 | print("Workers terminated.") 104 | 105 | # Terminate Vast.ai nodes on exit of script 106 | atexit.register(terminate_workers) 107 | 108 | # Monitor the status of the tasks with tqdm 109 | distributask.monitor_tasks(tasks) 110 | -------------------------------------------------------------------------------- /distributask/example/local.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import os 3 | import subprocess 4 | import time 5 | 6 | from .shared import distributask, example_function 7 | 8 | 9 | if __name__ == "__main__": 10 | completed = False 11 | 12 | # Register function to distributask object 13 | distributask.register_function(example_function) 14 | 15 | # First, initialize the dataset on Hugging Face 16 | distributask.initialize_dataset() 17 | 18 | # Create a file with the current date and time and save it as "datetime.txt" 19 | with open("datetime.txt", "w") as f: 20 | f.write(time.strftime("%Y-%m-%d %H:%M:%S")) 21 | 22 | # Upload this to the repository 23 | distributask.upload_file("datetime.txt") 24 | 25 | # Remove the example file from local 26 | os.remove("datetime.txt") 27 | 28 | vast_api_key = distributask.get_env("VAST_API_KEY") 29 | if not vast_api_key: 30 | raise ValueError("Vast API key not found in configuration.") 31 | 32 | job_configs = [] 33 | number_of_tasks = 3 34 | 35 | # Compile parameters for tasks 36 | for i in range(number_of_tasks): 37 | job_configs.append( 38 | { 39 | "outputs": [f"result_{i}.txt"], 40 | "task_params": {"index": i, "arg1": 1, "arg2": 2}, 41 | } 42 | ) 43 | 44 | tasks = [] 45 | 46 | repo_id = distributask.get_env("HF_REPO_ID") 47 | 48 | # Submit the tasks to the queue for the Vast.ai worker nodes to execute 49 | for i in range(number_of_tasks): 50 | job_config = job_configs[i] 51 | print(f"Task {i}") 52 | print(job_config) 53 | print("Task params: ", job_config["task_params"]) 54 | 55 | params = job_config["task_params"] 56 | 57 | # Each task executes the function "example_function", defined in shared.py 58 | task = distributask.execute_function(example_function.__name__, params) 59 | 60 | # Add the task to the list of tasks 61 | tasks.append(task) 62 | 63 | # Start the local worker 64 | docker_installed = False 65 | # Check if docker is installed 66 | try: 67 | subprocess.run(["docker", "version"], check=True) 68 | docker_installed = True 69 | except Exception as e: 70 | print("Docker is not installed. Starting worker locally.") 71 | print(e) 72 | 73 | docker_process = None 74 | # If docker is installed, start local Docker worker 75 | # If docker is not installed, start local Celery worker 76 | if docker_installed is False: 77 | print("Docker is not installed. Starting worker locally.") 78 | celery_worker = subprocess.Popen( 79 | ["celery", "-A", "distributask.example.worker", "worker", "--loglevel=info"] 80 | ) 81 | 82 | else: 83 | build_process = subprocess.Popen( 84 | [ 85 | "docker", 86 | "build", 87 | "-t", 88 | "distributask-example-worker", 89 | ".", 90 | ] 91 | ) 92 | build_process.wait() 93 | 94 | docker_process = subprocess.Popen( 95 | [ 96 | "docker", 97 | "run", 98 | "-e", 99 | f"VAST_API_KEY={vast_api_key}", 100 | "-e", 101 | f"REDIS_HOST={distributask.get_env('REDIS_HOST')}", 102 | "-e", 103 | f"REDIS_PORT={distributask.get_env('REDIS_PORT')}", 104 | "-e", 105 | f"REDIS_PASSWORD={distributask.get_env('REDIS_PASSWORD')}", 106 | "-e", 107 | f"REDIS_USER={distributask.get_env('REDIS_USER')}", 108 | "-e", 109 | f"HF_TOKEN={distributask.get_env('HF_TOKEN')}", 110 | "-e", 111 | f"HF_REPO_ID={repo_id}", 112 | "distributask-example-worker", 113 | ] 114 | ) 115 | 116 | def kill_docker(): 117 | print("Killing docker container") 118 | docker_process.terminate() 119 | 120 | # Terminate Docker worker on exit of script 121 | atexit.register(kill_docker) 122 | 123 | # Monitor the status of the tasks with tqdm 124 | distributask.monitor_tasks(tasks) 125 | 126 | -------------------------------------------------------------------------------- /distributask/example/shared.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | from ..distributask import create_from_config 6 | 7 | # Create distributask instance 8 | distributask = create_from_config() 9 | 10 | # This is the function that will be executed on the nodes 11 | # You can make your own function and pass in whatever arguments you want 12 | def example_function(index, arg1, arg2): 13 | 14 | result = arg1 + arg2 15 | 16 | time.sleep(random.randint(1, 6)) 17 | 18 | # Save the result to a file 19 | with open(f"result_{index}.txt", "w") as f: 20 | f.write(f"{str(arg1)} plus {str(arg2)} is {str(result)}") 21 | 22 | # Write the file to huggingface 23 | distributask.upload_file(f"result_{index}.txt") 24 | 25 | # Delete local file 26 | os.remove(f"result_{index}.txt") 27 | 28 | # Return the result - you can get this value from the task object 29 | return f"Task {index} completed. Result ({str(arg1)} + {str(arg2)}): {str(result)}" 30 | -------------------------------------------------------------------------------- /distributask/example/worker.py: -------------------------------------------------------------------------------- 1 | from .shared import distributask, example_function 2 | 3 | # Register function to worker using distributask instance 4 | distributask.register_function(example_function) 5 | 6 | # Create Celery worker 7 | celery = distributask.app 8 | -------------------------------------------------------------------------------- /distributask/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from .tests import * 2 | from .worker import * 3 | -------------------------------------------------------------------------------- /distributask/tests/tests.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pytest 3 | import time 4 | import os 5 | import tempfile 6 | from unittest.mock import MagicMock, patch 7 | 8 | from huggingface_hub import HfApi 9 | 10 | from ..distributask import create_from_config 11 | from .worker import example_test_function 12 | 13 | 14 | @pytest.fixture 15 | def mock_task_function(): 16 | """ 17 | Fixture that returns a mock task function. 18 | """ 19 | return MagicMock() 20 | 21 | 22 | def test_register_function(mock_task_function): 23 | """ 24 | Test the register_function function. 25 | """ 26 | mock_task_function.__name__ = "mock_task" # Set the __name__ attribute 27 | distributask = create_from_config() 28 | decorated_task = distributask.register_function(mock_task_function) 29 | 30 | assert callable(decorated_task) 31 | assert mock_task_function.__name__ in distributask.registered_functions 32 | assert ( 33 | distributask.registered_functions[mock_task_function.__name__] 34 | == mock_task_function 35 | ) 36 | print("Test passed") 37 | 38 | 39 | @patch("distributask.distributask.call_function_task.delay") 40 | def test_execute_function(mock_delay, mock_task_function): 41 | """ 42 | Test the execute_function function. 43 | """ 44 | mock_task_function.__name__ = "mock_task" # Set the __name__ attribute 45 | distributask = create_from_config() 46 | distributask.register_function(mock_task_function) 47 | 48 | params = {"arg1": 1, "arg2": 2} 49 | distributask.execute_function(mock_task_function.__name__, params) 50 | 51 | mock_delay.assert_called_once_with(mock_task_function.__name__, json.dumps(params)) 52 | print("Test passed") 53 | 54 | 55 | def test_register_function(): 56 | distributask = create_from_config() 57 | 58 | distributask.register_function(example_test_function) 59 | assert "example_test_function" in distributask.registered_functions 60 | assert ( 61 | distributask.registered_functions["example_test_function"] 62 | == example_test_function 63 | ) 64 | print("Task registration test passed") 65 | 66 | 67 | def test_execute_function(): 68 | distributask = create_from_config() 69 | 70 | distributask.register_function(example_test_function) 71 | task_params = {"arg1": 10, "arg2": 20} 72 | task = distributask.execute_function("example_test_function", task_params) 73 | assert task.id is not None 74 | print("Task execution test passed") 75 | 76 | 77 | # def test_worker_task_execution(): 78 | # distributask = create_from_config() 79 | 80 | # distributask.register_function(example_test_function) 81 | 82 | # worker_cmd = [ 83 | # "celery", 84 | # "-A", 85 | # "distributask.tests.worker", 86 | # "worker", 87 | # "--loglevel=info", 88 | # ] 89 | # print("worker_cmd") 90 | # print(worker_cmd) 91 | # worker_process = subprocess.Popen(worker_cmd) 92 | 93 | # time.sleep(2) 94 | 95 | # task_params = {"arg1": 10, "arg2": 20} 96 | # print("executing task") 97 | # task = distributask.execute_function("example_test_function", task_params) 98 | # result = task.get(timeout=30) 99 | 100 | # assert result == "+arg2=30" 101 | 102 | # worker_process.terminate() 103 | # worker_process.wait() 104 | 105 | # print("Worker task execution test passed") 106 | 107 | 108 | def test_task_status_update(): 109 | distributask = create_from_config() 110 | redis_client = distributask.get_redis_connection() 111 | 112 | task_status_keys = redis_client.keys("task_status:*") 113 | if task_status_keys: 114 | redis_client.delete(*task_status_keys) 115 | 116 | task_id = "test_task_123" 117 | status = "COMPLETED" 118 | 119 | distributask.update_function_status(task_id, status) 120 | 121 | status_from_redis = redis_client.get(f"task_status:{task_id}").decode() 122 | assert status_from_redis == status 123 | 124 | redis_client.delete(f"task_status:{task_id}") 125 | 126 | print("Task status update test passed") 127 | 128 | 129 | def test_initialize_repo(): 130 | distributask = create_from_config() 131 | 132 | # Initialize the repository 133 | distributask.initialize_dataset() 134 | hf_token = distributask.get_env("HF_TOKEN") 135 | repo_id = distributask.get_env("HF_REPO_ID") 136 | 137 | print("repo_id") 138 | print(repo_id) 139 | 140 | print("hf_token") 141 | print(hf_token) 142 | 143 | # Check if the repository exists 144 | api = HfApi(token=hf_token) 145 | repo_info = api.repo_info(repo_id=repo_id, repo_type="dataset", timeout=30) 146 | assert repo_info.id == repo_id 147 | 148 | # Check if the config.json file exists in the repository 149 | repo_files = api.list_repo_files( 150 | repo_id=repo_id, repo_type="dataset", token=hf_token 151 | ) 152 | assert "config.json" in repo_files 153 | 154 | # Clean up the repository 155 | api.delete_repo(repo_id=repo_id, repo_type="dataset", token=hf_token) 156 | 157 | 158 | def test_upload_directory(): 159 | distributask = create_from_config() 160 | distributask.initialize_dataset() 161 | # Create a temporary directory for testing 162 | with tempfile.TemporaryDirectory() as temp_dir: 163 | # Create test files 164 | test_files = ["test1.txt", "test2.txt"] 165 | for file in test_files: 166 | file_path = os.path.join(temp_dir, file) 167 | with open(file_path, "w") as f: 168 | f.write("Test content") 169 | 170 | hf_token = distributask.get_env("HF_TOKEN") 171 | repo_id = distributask.get_env("HF_REPO_ID") 172 | repo_path = distributask.get_env("HF_REPO_PATH", "data") 173 | 174 | # Upload the directory to the repository 175 | distributask.upload_directory(temp_dir) 176 | 177 | # Check if the files exist in the Hugging Face repository 178 | api = HfApi(token=hf_token) 179 | repo_files = api.list_repo_files( 180 | repo_id=repo_id, repo_type="dataset", token=hf_token 181 | ) 182 | for file in test_files: 183 | print(file) 184 | assert file in repo_files 185 | 186 | for file in test_files: 187 | os.remove(os.path.join(temp_dir, file)) 188 | 189 | # Clean up the repository 190 | api.delete_repo(repo_id=repo_id, repo_type="dataset", token=hf_token) 191 | 192 | 193 | def test_delete_file(): 194 | distributask = create_from_config() 195 | distributask.initialize_dataset() 196 | hf_token = distributask.get_env("HF_TOKEN") 197 | repo_id = distributask.get_env("HF_REPO_ID") 198 | 199 | # Create a test file in the repository 200 | test_file = "test.txt" 201 | with open(test_file, "w") as f: 202 | f.write("Test content") 203 | 204 | api = HfApi(token=hf_token) 205 | api.upload_file( 206 | path_or_fileobj=test_file, 207 | path_in_repo=test_file, 208 | repo_id=repo_id, 209 | token=hf_token, 210 | repo_type="dataset", 211 | ) 212 | 213 | # delete the file on disk 214 | os.remove(test_file) 215 | 216 | # Delete the file from the repository 217 | distributask.delete_file(repo_id, test_file) 218 | 219 | # Check if the file is deleted from the repository 220 | repo_files = api.list_repo_files( 221 | repo_id=repo_id, repo_type="dataset", token=hf_token 222 | ) 223 | assert test_file not in repo_files 224 | 225 | # Clean up the repository 226 | api.delete_repo(repo_id=repo_id, repo_type="dataset", token=hf_token) 227 | 228 | 229 | def test_file_exists(): 230 | distributask = create_from_config() 231 | distributask.initialize_dataset() 232 | hf_token = distributask.get_env("HF_TOKEN") 233 | repo_id = distributask.get_env("HF_REPO_ID") 234 | 235 | # Create a test file in the repository 236 | test_file = "test.txt" 237 | with open(test_file, "w") as f: 238 | f.write("Test content") 239 | 240 | api = HfApi(token=hf_token) 241 | api.upload_file( 242 | path_or_fileobj=test_file, 243 | path_in_repo=test_file, 244 | repo_id=repo_id, 245 | token=hf_token, 246 | repo_type="dataset", 247 | ) 248 | 249 | # delete the file on disk 250 | os.remove(test_file) 251 | 252 | # Check if the file exists in the repository 253 | assert distributask.file_exists(repo_id, test_file) 254 | 255 | # Check if a non-existent file exists in the repository 256 | assert not distributask.file_exists(repo_id, "nonexistent.txt") 257 | 258 | # Clean up the repository 259 | api.delete_repo(repo_id=repo_id, repo_type="dataset", token=hf_token) 260 | 261 | 262 | def test_list_files(): 263 | distributask = create_from_config() 264 | distributask.initialize_dataset() 265 | hf_token = distributask.get_env("HF_TOKEN") 266 | repo_id = distributask.get_env("HF_REPO_ID") 267 | 268 | # Create test files in the repository 269 | test_files = ["test1.txt", "test2.txt"] 270 | # for each test_file, write the file 271 | for file in test_files: 272 | with open(file, "w") as f: 273 | f.write("Test content") 274 | api = HfApi(token=hf_token) 275 | for file in test_files: 276 | api.upload_file( 277 | path_or_fileobj=file, 278 | path_in_repo=file, 279 | repo_id=repo_id, 280 | token=hf_token, 281 | repo_type="dataset", 282 | ) 283 | 284 | # List the files in the repository 285 | repo_files = distributask.list_files(repo_id) 286 | 287 | for file in test_files: 288 | os.remove(file) 289 | 290 | # Check if the test files are present in the repository 291 | for file in test_files: 292 | assert file in repo_files 293 | 294 | # Clean up the repository 295 | api.delete_repo(repo_id=repo_id, repo_type="dataset", token=hf_token) 296 | 297 | 298 | @pytest.fixture(scope="module") 299 | def rented_nodes(): 300 | distributask = create_from_config() 301 | 302 | max_price = 0.5 303 | max_nodes = 1 304 | image = "antbaez/distributask-worker:latest" 305 | module_name = "distributask.example.worker" 306 | 307 | nodes = distributask.rent_nodes(max_price, max_nodes, image, module_name) 308 | yield nodes 309 | 310 | distributask.terminate_nodes(nodes) 311 | 312 | 313 | def test_rent_run_terminate(rented_nodes): 314 | assert len(rented_nodes) == 1 315 | time.sleep(3) # sleep for 3 seconds to simulate runtime 316 | 317 | 318 | def test_get_redis_url(): 319 | distributask = create_from_config() 320 | redis_url = distributask.get_redis_url() 321 | 322 | assert redis_url.startswith("redis://") 323 | assert distributask.settings["REDIS_USER"] in redis_url 324 | assert distributask.settings["REDIS_PASSWORD"] in redis_url 325 | assert distributask.settings["REDIS_HOST"] in redis_url 326 | assert str(distributask.settings["REDIS_PORT"]) in redis_url 327 | 328 | 329 | def test_get_redis_connection_force_new(): 330 | distributask = create_from_config() 331 | redis_client1 = distributask.get_redis_connection() 332 | redis_client2 = distributask.get_redis_connection(force_new=True) 333 | 334 | assert redis_client1 is not redis_client2 335 | 336 | 337 | def test_get_redis_connection_force_new(): 338 | distributask = create_from_config() 339 | redis_client1 = distributask.get_redis_connection() 340 | redis_client2 = distributask.get_redis_connection(force_new=True) 341 | 342 | assert redis_client1 is not redis_client2 343 | 344 | 345 | def test_get_env_with_default(): 346 | distributask = create_from_config() 347 | default_value = "default" 348 | value = distributask.get_env("NON_EXISTENT_KEY", default_value) 349 | 350 | assert value == default_value 351 | 352 | 353 | @patch("requests.get") 354 | def test_search_offers(mock_get): 355 | distributask = create_from_config() 356 | max_price = 1.0 357 | 358 | mock_response = MagicMock() 359 | mock_response.json.return_value = {"offers": [{"id": "offer1"}, {"id": "offer2"}]} 360 | mock_get.return_value = mock_response 361 | 362 | offers = distributask.search_offers(max_price) 363 | 364 | assert len(offers) == 2 365 | assert offers[0]["id"] == "offer1" 366 | assert offers[1]["id"] == "offer2" 367 | 368 | 369 | @patch("requests.put") 370 | def test_create_instance(mock_put): 371 | distributask = create_from_config() 372 | offer_id = "offer1" 373 | image = "test_image" 374 | module_name = "distributask.example.worker" 375 | command = f"celery -A {module_name} worker --loglevel=info" 376 | 377 | mock_response = MagicMock() 378 | mock_response.status_code = 200 379 | mock_response.json.return_value = {"new_contract": "instance1"} 380 | mock_put.return_value = mock_response 381 | 382 | instance = distributask.create_instance(offer_id, image, module_name, distributask.settings, command) 383 | 384 | assert instance["new_contract"] == "instance1" 385 | 386 | # def test_get_node_log(): 387 | 388 | # distributask = create_from_config() 389 | 390 | # max_price = 0.5 391 | # max_nodes = 1 392 | # image = "antbaez/distributask-test-worker" 393 | # module_name = "distributask.example.worker" 394 | 395 | # nodes = distributask.rent_nodes(max_price, max_nodes, image, module_name) 396 | 397 | # time.sleep(60) 398 | 399 | # response = distributask.get_node_log(nodes[0], wait_time=5) 400 | 401 | # assert response is not None 402 | # assert response.status_code == 200 403 | 404 | 405 | from io import StringIO 406 | import subprocess 407 | 408 | def test_local_example_run(): 409 | # Capture the stdout and stderr during the execution 410 | with patch("sys.stdout", new=StringIO()) as fake_out, patch( 411 | "sys.stderr", new=StringIO() 412 | ) as fake_err: 413 | 414 | # Start a new process to run the local example 415 | process = subprocess.Popen(["python", "-m", "distributask.example.local"]) 416 | # if process hasn't ended in 3min, test is failed 417 | process.wait(timeout=180) 418 | 419 | # Get the captured output from stdout 420 | # output = fake_out.getvalue() 421 | # print(output) 422 | 423 | # Assert that no errors are captured in stderr 424 | assert fake_err.getvalue() == "" 425 | 426 | try: 427 | stop_command = "docker stop $(docker ps -q)" 428 | subprocess.run(stop_command, shell=True, check=True) 429 | print("All containers stopped successfully") 430 | except: 431 | pass 432 | 433 | 434 | def test_distributed_example_run(): 435 | # Capture the stdout and stderr during the execution 436 | with patch("sys.stdout", new=StringIO()) as fake_out, patch( 437 | "sys.stderr", new=StringIO() 438 | ) as fake_err: 439 | 440 | # Start a new process to run the local example 441 | process = subprocess.Popen( 442 | ["python", "-m", "distributask.example.distributed", "--number_of_tasks=3"] 443 | ) 444 | # if process hasn't ended in 2min, test is failed 445 | process.wait(timeout=120) 446 | 447 | # Get the captured output from stdout 448 | # output = fake_out.getvalue() 449 | # print(output) 450 | 451 | # Assert that no errors are captured in stderr 452 | assert fake_err.getvalue() == "" 453 | -------------------------------------------------------------------------------- /distributask/tests/worker.py: -------------------------------------------------------------------------------- 1 | from ..distributask import create_from_config 2 | 3 | distributaur = create_from_config() 4 | 5 | 6 | # Define and register the test_function 7 | def example_test_function(arg1, arg2): 8 | return f"Result: arg1+arg2={arg1+arg2}" 9 | 10 | 11 | celery = distributaur.app 12 | 13 | 14 | if __name__ == "__main__": 15 | distributaur.register_function(example_test_function) 16 | -------------------------------------------------------------------------------- /docs/assets/DeepAI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepAI-Research/Distributask/f7ba1b72720b48750d8909651d17c8a16a5e0b82/docs/assets/DeepAI.png -------------------------------------------------------------------------------- /docs/assets/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepAI-Research/Distributask/f7ba1b72720b48750d8909651d17c8a16a5e0b82/docs/assets/banner.png -------------------------------------------------------------------------------- /docs/assets/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepAI-Research/Distributask/f7ba1b72720b48750d8909651d17c8a16a5e0b82/docs/assets/diagram.png -------------------------------------------------------------------------------- /docs/assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepAI-Research/Distributask/f7ba1b72720b48750d8909651d17c8a16a5e0b82/docs/assets/favicon.ico -------------------------------------------------------------------------------- /docs/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepAI-Research/Distributask/f7ba1b72720b48750d8909651d17c8a16a5e0b82/docs/assets/logo.png -------------------------------------------------------------------------------- /docs/distributask.md: -------------------------------------------------------------------------------- 1 | # Distributask Class 2 | 3 | ::: distributask.Distributask 4 | options: 5 | members: true 6 | show_root_heading: true 7 | show_source: true -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | Below are instructions to get distributask running on your machine. Please read through the rest of the documentation for more detailed information. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | pip install distributask 9 | ``` 10 | 11 | ## Development 12 | 13 | ### Prerequisites 14 | 15 | - Python 3.8 or newer (tested on Python 3.11) 16 | - Redis server 17 | - Vast.ai API key 18 | - HuggingFace API key 19 | 20 | 21 | ### Setup 22 | 23 | Clone the repository and navigate to the project directory: 24 | 25 | ```bash 26 | git clone https://github.com/RaccoonResearch/Distributask.git 27 | cd Distributask 28 | ``` 29 | 30 | Install the required packages: 31 | 32 | ```bash 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | Install the distributask package: 37 | 38 | ```bash 39 | python setup.py install 40 | ``` 41 | 42 | ### Configuration 43 | 44 | Create a `.env` file in the root directory of your project or set environment variables to match your setup: 45 | 46 | ```plaintext 47 | REDIS_HOST=redis_host 48 | REDIS_PORT=redis_port 49 | REDIS_USER=redis_user 50 | REDIS_PASSWORD=redis_password 51 | VAST_API_KEY=your_vastai_api_key 52 | HF_TOKEN=your_huggingface_token 53 | HF_REPO_ID=your_huggingface_repo 54 | BROKER_POOL_LIMIT=broker_pool_limit 55 | ``` 56 | 57 | ### Running an Example Task 58 | 59 | To run an example task and see distributask in action, you can execute the example script provided in the project: 60 | 61 | ```bash 62 | # Run an example task locally 63 | python -m distributask.example.local 64 | 65 | # Run an example task on Vast.ai ("kitchen sink" example) 66 | python -m distributask.example.distributed 67 | ``` 68 | 69 | ### Command Options 70 | 71 | Below are options you can pass into your distributask example run. 72 | 73 | - `--max_price` is the max price (in $/hour) a node can be be rented for. 74 | - `--max_nodes` is the max number of vast.ai nodes that can be rented. 75 | - `--docker_image` is the name of the docker image to load to the vast.ai node. 76 | - `--module_name` is the name of the celery worker 77 | - `--number_of_tasks` is the number of example tasks that will be added to the queue and done by the workers. -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | Distributask is a simple way to distribute rendering tasks across multiple machines. 2 | 3 | This documentation is intended to help you understand the structure of the Distributask API and codebase and how to use it to distribute rendering tasks across multiple machines for your own projects. 4 | 5 | ## Core Use Cases 6 | Distributask can be used for any task that can is parallelizable. Some specific use cases include: 7 | 8 | - Rendering videos 9 | - Running simulations 10 | - Generating or processing large datasets 11 | 12 | ## Getting Started 13 | 14 | Visit the [Getting Started](getting_started.md) page to learn how to set up your environment and get started with distributing with Distributask. 15 | 16 | ## Overview 17 | 18 | Distributed rendering using Distributask can be broken into four steps: 19 | 20 | #### Creating the task queue 21 | 22 | Distributask uses Celery, an asyncronous distributed task processing package, to create the task queue on your local machine. Each task on the queue is a function that tells remote machines, or workers, what to do. For example, if we wanted to render videos, each task would be a function that contains the code to render a different video. 23 | 24 | #### Passing the tasks to workers 25 | 26 | Distributask uses Redis, a data structure that can be used as a database, as a message broker. This means that Redis is used to transfer tasks yet to be done from the task queue to the worker so that the job can be done. 27 | 28 | #### Executing the tasks 29 | 30 | Distributask uses Vast.ai, a decentralized GPU market, to create workers that execute the task. The task is given to the worker, executed, and the completed task status is passed back to the central machine via Redis. 31 | 32 | #### Storing results of the tasks 33 | 34 | Distributask uses Huggingface, a platform for sharing AI models and datasets, to store the results of the task. The results of the task are uploaded to Hugginface using API calls in Distributask. For example, our rendered videos would be uploaded as a dataset on Huggingface. 35 | 36 | ## Flowchart of Distributask process 37 | 38 | 39 | -------------------------------------------------------------------------------- /docs/more_info.md: -------------------------------------------------------------------------------- 1 | # Summary of most relevant functions 2 | 3 | #### Settings, Environment, and Help 4 | 5 | - `create_from_config()` - creates Distribtask instance using environment variables 6 | - `get_env(key)` - gets value from .env 7 | - `get_settings(key)` - gets value from settings dictionary 8 | 9 | #### Celery tasks 10 | 11 | - `register_function(func)` - registers function to be task for worker 12 | - `execute_function(func_name, args)` - creates Celery task using registered function 13 | 14 | #### Redis server 15 | 16 | - `get_redis_url()` - gets Redis host url 17 | - `get_redis_connection()` - gets Redis connection instance 18 | 19 | #### Worker management via Vast.ai API 20 | 21 | - `search_offers(max_price)` - searches for available instances on Vast.ai 22 | - `rent_nodes(max_price, max_nodes, image, module_name, command)` - rents nodes using Vast.ai instance 23 | - `terminate_nodes(node_id_lists)` - terminates Vast.ai instance 24 | 25 | 26 | #### HuggingFace repositories and uploading 27 | 28 | - `initialize_dataset()` - intializes dataset repo on HuggingFace 29 | - `upload_file(path_to_file)` - uploads file to Huggingface 30 | - `upload_directory(path_to_directory)` - uploads folder to Huggingface repo 31 | - `delete_file(path_to_file)` - deletes file on HuggingFace repo 32 | 33 | #### Visit the [Distributask Class](distributask.md) page for full, detailed documentation of the distributask class. 34 | 35 | # Docker Setup 36 | 37 | Distributask uses a Docker image to transfer the environment and neccessary files to the Vast.ai nodes. In your implementation using Distributask, you can use the Docker file in the Distributask repository as a base for your own Docker file. If you do this, be sure to add Distributask to the list of packages to be installed on your Docker file. 38 | 39 | # Important Packages 40 | 41 | Visit the websites of these wonderful packages to learn more about how they work and how to use them. 42 | 43 | Celery: `https://docs.celeryq.dev/en/stable/` 44 | Redis: `https://redis.io/docs/latest/` 45 | Hugging Face: `https://huggingface.co/docs/huggingface_hub/en/guides/upload` -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Distributask 2 | theme: 3 | name: material 4 | logo: assets/logo.png 5 | favicon: assets/favicon.ico 6 | plugins: 7 | - search 8 | - autorefs 9 | - mkdocstrings: 10 | enabled: true 11 | default_handler: python 12 | nav: 13 | - Home: index.md 14 | - Getting Started: getting_started.md 15 | - More Information: more_info.md 16 | - Distributask Class: distributask.md 17 | 18 | 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | fsspec 3 | celery 4 | redis 5 | huggingface_hub 6 | python-dotenv 7 | omegaconf 8 | tqdm -------------------------------------------------------------------------------- /scripts/kill_redis_connections.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Load environment variables from .env file 5 | source .env 6 | 7 | # Check if REDIS_PORT is set in the .env file 8 | if [ -z "$REDIS_PORT" ]; then 9 | echo "REDIS_PORT not found in .env file. Please set it and try again." 10 | exit 1 11 | fi 12 | 13 | # Use lsof to find all PIDs for the given port and store them in an array 14 | PIDS=($(lsof -i TCP:$REDIS_PORT -t)) 15 | 16 | # Check if there are any PIDs to kill 17 | if [ ${#PIDS[@]} -eq 0 ]; then 18 | echo "No processes found using port $REDIS_PORT." 19 | exit 0 20 | fi 21 | 22 | # Loop through each PID and kill it 23 | for PID in "${PIDS[@]}"; do 24 | echo "Killing process $PID" 25 | sudo kill -9 $PID 26 | done 27 | 28 | echo "All processes using port $REDIS_PORT have been killed." -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | # get the cwd where the setup.py file is located 5 | file_path = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | long_description = "" 8 | with open(os.path.join(file_path, "README.md"), "r") as fh: 9 | long_description = fh.read() 10 | long_description = long_description.split("\n") 11 | long_description = [line for line in long_description if not "")[0].split("<")[0] for line in install_requires 23 | ] 24 | 25 | setup( 26 | name="distributask", 27 | version=version, 28 | description="Simple task manager and job queue for distributed rendering. Built on celery and redis.", 29 | long_description=long_description, 30 | long_description_content_type="text/markdown", 31 | url="https://github.com/DeepAI-Research/Distributask", 32 | author="DeepAIResearch", 33 | author_email="team@deepai.org", 34 | license="MIT", 35 | packages=find_packages(), 36 | install_requires=install_requires, 37 | classifiers=[ 38 | "Development Status :: 4 - Beta", 39 | "Intended Audience :: Science/Research", 40 | "License :: OSI Approved :: MIT License", 41 | "Operating System :: POSIX :: Linux", 42 | "Programming Language :: Python :: 3", 43 | "Operating System :: MacOS :: MacOS X", 44 | "Operating System :: Microsoft :: Windows", 45 | ], 46 | ) 47 | --------------------------------------------------------------------------------