├── .env.default
├── .github
└── workflows
│ ├── dockerhub.yml
│ ├── mkdocs.yml
│ ├── publish.yml
│ └── test.yml
├── .gitignore
├── Dockerfile
├── LICENSE
├── MANIFEST.in
├── README.md
├── distributask
├── __init__.py
├── distributask.py
├── example
│ ├── __init__.py
│ ├── distributed.py
│ ├── local.py
│ ├── shared.py
│ └── worker.py
└── tests
│ ├── __init__.py
│ ├── tests.py
│ └── worker.py
├── docs
├── assets
│ ├── DeepAI.png
│ ├── banner.png
│ ├── diagram.png
│ ├── favicon.ico
│ └── logo.png
├── distributask.md
├── getting_started.md
├── index.md
└── more_info.md
├── mkdocs.yml
├── requirements.txt
├── scripts
└── kill_redis_connections.sh
└── setup.py
/.env.default:
--------------------------------------------------------------------------------
1 | REDIS_HOST=localhost
2 | REDIS_PORT=6379
3 | REDIS_USER=default
4 | REDIS_PASSWORD=
5 | VAST_API_KEY=
6 | HF_TOKEN=hf_***
7 | HF_REPO_ID=RaccoonResearch/test_dataset
8 |
--------------------------------------------------------------------------------
/.github/workflows/dockerhub.yml:
--------------------------------------------------------------------------------
1 | name: Publish Docker image
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | env:
8 | REGISTRY: index.docker.io
9 | IMAGE_NAME: antbaez/distributask-test-worker
10 |
11 | jobs:
12 | push_to_registry:
13 | name: Push Docker image to Docker Hub
14 | runs-on: ubuntu-latest
15 | permissions:
16 | packages: write
17 | contents: read
18 | attestations: write
19 | id-token: write
20 | steps:
21 | - name: Check out the repo
22 | uses: actions/checkout@v4
23 |
24 | - name: Log in to Docker Hub
25 | uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a
26 | with:
27 | username: ${{ secrets.DOCKER_USERNAME }}
28 | password: ${{ secrets.DOCKER_PASSWORD }}
29 |
30 | - name: Extract metadata (tags, labels) for Docker
31 | id: meta
32 | uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
33 | with:
34 | images: ${{ env.IMAGE_NAME }}
35 |
36 | - name: Build and push Docker image
37 | id: push
38 | uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671
39 | with:
40 | context: .
41 | file: ./Dockerfile
42 | push: true
43 | tags: ${{ steps.meta.outputs.tags }}
44 | labels: ${{ steps.meta.outputs.labels }}
45 |
46 |
47 | - name: Generate artifact attestation
48 | uses: actions/attest-build-provenance@v1
49 | with:
50 | subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
51 | subject-digest: ${{ steps.push.outputs.digest }}
52 | push-to-registry: true
53 |
54 |
--------------------------------------------------------------------------------
/.github/workflows/mkdocs.yml:
--------------------------------------------------------------------------------
1 | name: mkdocs
2 | on:
3 | push:
4 | branches:
5 | - main
6 | permissions:
7 | contents: write
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - name: Configure Git Credentials
14 | run: |
15 | git config user.name github-actions[bot]
16 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com
17 | - uses: actions/setup-python@v5
18 | with:
19 | python-version: 3.x
20 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV # (3)!
21 | - uses: actions/cache@v4
22 | with:
23 | key: mkdocs-material-${{ env.cache_id }}
24 | path: .cache
25 | restore-keys: |
26 | mkdocs-material-
27 | - run: pip install mkdocs-material mkdocstrings mkdocstrings-python
28 | - run: mkdocs gh-deploy --force
29 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | permissions:
8 | contents: read
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v3
17 | - name: Set up Python
18 | uses: actions/setup-python@v3
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install build
25 | - name: Extract package version
26 | id: extract_version
27 | run: echo "package_version=$(echo $GITHUB_REF | cut -d / -f 3)" >> $GITHUB_ENV
28 | - name: Write package version to file
29 | run: echo "${{ env.package_version }}" > version.txt
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: ${{ secrets.PYPI_USERNAME }}
36 | password: ${{ secrets.PYPI_PASSWORD }}
37 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Lint and Test
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | matrix:
10 | python-version: ["3.11"]
11 | env:
12 | REDIS_HOST: ${{ secrets.REDIS_HOST }}
13 | REDIS_PORT: ${{ secrets.REDIS_PORT }}
14 | REDIS_USER: ${{ secrets.REDIS_USER }}
15 | REDIS_PASSWORD: ${{ secrets.REDIS_PASSWORD }}
16 | VAST_API_KEY: ${{ secrets.VAST_API_KEY }}
17 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
18 | HF_REPO_ID: ${{ secrets.HF_REPO_ID }}
19 | steps:
20 | - uses: actions/checkout@v3
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v3
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | pip install pytest
29 | pip install -r requirements.txt
30 | - name: Write package version
31 | run: echo ::set-output name=package_version::$(echo $GITHUB_REF | cut -d / -f 3) > version.txt
32 | - name: Running tests
33 | run: |
34 | pytest distributask/tests/tests.py
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .DS_Store
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/#use-with-ide
111 | .pdm.toml
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 |
163 | .vscode/
164 | .chroma
165 | memory
166 | test
167 | version.txt
168 | config.json
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM --platform=linux/x86_64 ubuntu:24.04
2 |
3 | RUN apt-get update && \
4 | apt-get install -y \
5 | wget \
6 | xz-utils \
7 | bzip2 \
8 | git \
9 | git-lfs \
10 | python3-pip \
11 | python3 \
12 | && apt-get install -y software-properties-common \
13 | && apt-get clean \
14 | && rm -rf /var/lib/apt/lists/*
15 |
16 | COPY requirements.txt .
17 |
18 | RUN pip install -r requirements.txt --break-system-packages
19 |
20 | COPY distributask/ ./distributask/
21 |
22 | CMD ["celery", "-A", "distributask.example.worker", "worker", "--loglevel=info", "--concurrency=1"]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 M̵̞̗̝̼̅̏̎͝Ȯ̴̝̻̊̃̋̀Õ̷̼͋N̸̩̿͜ ̶̜̠̹̼̩͒
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include version.txt
2 | include README.md
3 | include requirements.txt
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Distributask
2 |
3 |
4 | A simple way to distribute rendering tasks across multiple machines.
5 |
6 | [](https://github.com/DeepAI-Research/Distributask/actions/workflows/test.yml)
7 | [](https://badge.fury.io/py/distributask)
8 | [](https://github.com/DeepAI-Research/Distributask/blob/main/LICENSE)
9 |
10 |
11 | # Description
12 |
13 | Distributask is a package that automatically queues, executes, and uploads the result of any task you want using Vast.ai, a decentralized network of GPUs. It works by first creating a Celery queue of the tasks, which contain the code that you want to be ran on a GPU. The tasks are then passed to the Vast.ai GPU workers using Redis as a message broker. Once a worker has completed a task, the result is uploaded to Hugging Face.
14 |
15 | # Installation
16 |
17 | ```bash
18 | pip install distributask
19 | ```
20 |
21 | # Development
22 |
23 | ### Setup
24 |
25 | Clone the repository and navigate to the project directory:
26 |
27 | ```bash
28 | git clone https://github.com/DeepAI-Research/Distributask.git
29 | cd Distributask
30 | ```
31 |
32 | Install the required packages:
33 |
34 | ```bash
35 | pip install -r requirements.txt
36 | ```
37 |
38 | Or install Distributask as a package:
39 |
40 | ```bash
41 | pip install distributask
42 | ```
43 |
44 | ### Configuration
45 |
46 | Create a `.env` file in the root directory of your project or set environment variables to create your desired setup:
47 |
48 | ```plaintext
49 | REDIS_HOST="name of your redis server"
50 | REDIS_PORT="port of your redis server
51 | REDIS_USER="username to login to redis server"
52 | REDIS_PASSWORD="password to login to redis server"
53 | VAST_API_KEY="your Vast.ai API key"
54 | HF_TOKEN="your Hugging Face token"
55 | HF_REPO_ID="name of your Hugging Face repository"
56 | BROKER_POOL_LIMIT="your broker pool limit setting"
57 | ```
58 |
59 | ## Getting Started
60 |
61 | ### Running an Example Task
62 |
63 | To run an example task and see Distributask in action, you can execute the example script provided in the project:
64 |
65 | ```bash
66 | # Run the example task locally using either a Docker container or a Celery worker:
67 | python -m distributask.example.local
68 |
69 | # Run the example task on Vast.ai ("kitchen sink" example):
70 | python -m distributask.example.distributed
71 |
72 | ```
73 |
74 | This script configures the environment, registers a sample function, creates a queue of tasks, and monitors its execution on some workers.
75 |
76 | ### Command Options
77 |
78 | - `--max_price` is the max price (in $/hour) a node can be be rented for.
79 | - `--max_nodes` is the max number of vast.ai nodes that can be rented.
80 | - `--docker_image` is the name of the docker image to load to the vast.ai node.
81 | - `--module_name` is the name of the Celery worker.
82 | - `--number_of_tasks` is the number of example tasks that will be added to the queue and done by the workers.
83 |
84 | ## Documentation
85 |
86 | For more info checkout our in-depth [documentation](https://deepai-research.github.io/Distributask)!
87 |
88 | ## Contributing
89 |
90 | Contributions are welcome! For any changes you would like to see, please open an issue to discuss what you would like to see changed or to change yourself.
91 |
92 | ## License
93 |
94 | This project is licensed under the MIT License - see the `LICENSE` file for details.
95 |
96 | ## Citation
97 |
98 | ```bibtex
99 | @misc{Distributask,
100 | author = {DeepAIResearch},
101 | title = {Distributask: a simple way to distribute rendering tasks across mulitiple machines},
102 | year = {2024},
103 | publisher = {GitHub},
104 | howpublished = {\url{https://github.com/DeepAI-Research/Distributask}}
105 | }
106 | ```
107 |
108 | ## Contributors
109 |
110 |
--------------------------------------------------------------------------------
/distributask/__init__.py:
--------------------------------------------------------------------------------
1 | from .distributask import *
2 |
--------------------------------------------------------------------------------
/distributask/distributask.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import time
4 | import requests
5 | from tqdm import tqdm
6 | from typing import Dict, List
7 | import atexit
8 | import tempfile
9 |
10 | from celery import Celery
11 | from redis import ConnectionPool, Redis
12 | from omegaconf import OmegaConf
13 | from dotenv import load_dotenv
14 | from huggingface_hub import HfApi, Repository
15 | from requests.exceptions import HTTPError
16 | from celery.utils.log import get_task_logger
17 |
18 |
19 | class Distributask:
20 | """
21 | The Distributask class contains the core features of distributask, including creating and
22 | executing the task queue, managing workers using the Vast.ai API, and uploading files and directories
23 | using the Hugging Face API.
24 | """
25 |
26 | app: Celery = None
27 | redis_client: Redis = None
28 | registered_functions: dict = {}
29 | pool: ConnectionPool = None
30 |
31 | def __init__(
32 | self,
33 | hf_repo_id=os.getenv("HF_REPO_ID"),
34 | hf_token=os.getenv("HF_TOKEN"),
35 | vast_api_key=os.getenv("VAST_API_KEY"),
36 | redis_host=os.getenv("REDIS_HOST", "localhost"),
37 | redis_password=os.getenv("REDIS_PASSWORD", ""),
38 | redis_port=os.getenv("REDIS_PORT", 6379),
39 | redis_username=os.getenv("REDIS_USER", "default"),
40 | broker_pool_limit=os.getenv("BROKER_POOL_LIMIT", 1),
41 | ) -> None:
42 | """
43 | Initialize the Distributask object with the provided configuration parameters. Also sets some
44 | default settings in Celery and handles cleanup of Celery queue and Redis server on exit.
45 |
46 | Args:
47 | hf_repo_id (str): Hugging Face repository ID.
48 | hf_token (str): Hugging Face API token.
49 | vast_api_key (str): Vast.ai API key.
50 | redis_host (str): Redis host. Defaults to "localhost".
51 | redis_password (str): Redis password. Defaults to an empty string.
52 | redis_port (int): Redis port. Defaults to 6379.
53 | redis_username (str): Redis username. Defaults to "default".
54 | broker_pool_limit (int): Celery broker pool limit. Defaults to 1.
55 |
56 | Raises:
57 | ValueError: If any of the required parameters (hf_repo_id, hf_token, vast_api_key) are not provided.
58 | """
59 | if hf_repo_id is None:
60 | raise ValueError(
61 | "HF_REPO_ID is not provided to the Distributask constructor"
62 | )
63 |
64 | if hf_token is None:
65 | raise ValueError("HF_TOKEN is not provided to the Distributask constructor")
66 |
67 | if vast_api_key is None:
68 | raise ValueError(
69 | "VAST_API_KEY is not provided to the Distributask constructor"
70 | )
71 |
72 | if redis_host == "localhost":
73 | print(
74 | "WARNING: Using default Redis host 'localhost'. This is not recommended for production use and won't work for distributed rendering."
75 | )
76 |
77 | self.settings = {
78 | "HF_REPO_ID": hf_repo_id,
79 | "HF_TOKEN": hf_token,
80 | "VAST_API_KEY": vast_api_key,
81 | "REDIS_HOST": redis_host,
82 | "REDIS_PASSWORD": redis_password,
83 | "REDIS_PORT": redis_port,
84 | "REDIS_USER": redis_username,
85 | "BROKER_POOL_LIMIT": broker_pool_limit,
86 | }
87 |
88 | redis_url = self.get_redis_url()
89 | # start Celery app instance
90 | self.app = Celery("distributask", broker=redis_url, backend=redis_url)
91 | self.app.conf.broker_pool_limit = self.settings["BROKER_POOL_LIMIT"]
92 |
93 | def cleanup_redis():
94 | """
95 | Deletes keys in redis related to Celery tasks and closes the Redis connection on exit
96 | """
97 | patterns = ["celery-task*", "task_status*"]
98 | redis_connection = self.get_redis_connection()
99 | for pattern in patterns:
100 | for key in redis_connection.scan_iter(match=pattern):
101 | redis_connection.delete(key)
102 | print("Redis server cleared")
103 |
104 | def cleanup_celery():
105 | """
106 | Clears Celery task queue on exit
107 | """
108 | self.app.control.purge()
109 | print("Celery queue cleared")
110 |
111 | # At exit, close Celery instance, delete all previous task info from queue and Redis, and close Redis
112 | atexit.register(self.app.close)
113 | atexit.register(cleanup_redis)
114 | atexit.register(cleanup_celery)
115 |
116 | self.redis_client = self.get_redis_connection()
117 |
118 | # Tasks are acknowledged after they have been executed
119 | self.app.conf.task_acks_late = True
120 | self.call_function_task = self.app.task(
121 | bind=True, name="call_function_task", max_retries=3, default_retry_delay=30
122 | )(self.call_function_task)
123 |
124 | def __del__(self):
125 | """Destructor to clean up resources."""
126 | if self.pool is not None:
127 | self.pool.disconnect()
128 | if self.redis_client is not None:
129 | self.redis_client.close()
130 | if self.app is not None:
131 | self.app.close()
132 |
133 | def log(self, message: str, level: str = "info") -> None:
134 | """
135 | Log a message with the specified level.
136 |
137 | Args:
138 | message (str): The message to log.
139 | level (str): The logging level. Defaults to "info".
140 | """
141 | logger = get_task_logger(__name__)
142 | getattr(logger, level)(message)
143 |
144 | def get_settings(self) -> str:
145 | """
146 | Return settings of distributask instance.
147 | """
148 | return self.settings
149 |
150 | def get_redis_url(self) -> str:
151 | """
152 | Construct a Redis URL from the configuration settings.
153 |
154 | Returns:
155 | str: A Redis URL string.
156 |
157 | Raises:
158 | ValueError: If any required Redis connection parameter is missing.
159 | """
160 | host = self.settings["REDIS_HOST"]
161 | password = self.settings["REDIS_PASSWORD"]
162 | port = self.settings["REDIS_PORT"]
163 | username = self.settings["REDIS_USER"]
164 |
165 | if None in [host, password, port, username]:
166 | raise ValueError("Missing required Redis configuration values")
167 |
168 | redis_url = f"redis://{username}:{password}@{host}:{port}"
169 | return redis_url
170 |
171 | def get_redis_connection(self, force_new: bool = False) -> Redis:
172 | """
173 | Returns Redis connection. If it already exists, returns current connection.
174 | If it does not exist, its create a new Redis connection using a connection pool.
175 |
176 | Args:
177 | force_new (bool): Force the creation of a new connection if set to True. Defaults to False.
178 |
179 | Returns:
180 | Redis: A Redis connection object.
181 | """
182 | if self.redis_client is not None and not force_new:
183 | return self.redis_client
184 | else:
185 | self.pool = ConnectionPool(host=self.settings["REDIS_HOST"],
186 | port=self.settings["REDIS_PORT"],
187 | password=self.settings["REDIS_PASSWORD"],
188 | max_connections=1)
189 | self.redis_client = Redis(connection_pool=self.pool)
190 | atexit.register(self.pool.disconnect)
191 |
192 | return self.redis_client
193 |
194 | def get_env(self, key: str, default: any = None) -> any:
195 | """
196 | Retrieve a value from the configuration or .env file, with an optional default if the key is not found.
197 |
198 | Args:
199 | key (str): The key to look for in the settings.
200 | default (any): The default value to return if the key is not found. Defaults to None.
201 |
202 | Returns:
203 | any: The value from the settings if the key exists, otherwise the default value.
204 | """
205 | return self.settings.get(key, default)
206 |
207 | def call_function_task(self, func_name: str, args_json: str) -> any:
208 | """
209 | Creates Celery task that executes a registered function with provided JSON arguments.
210 |
211 | Args:
212 | func_name (str): The name of the registered function to execute.
213 | args_json (str): JSON string representation of the arguments for the function.
214 |
215 | Returns:
216 | any: Celery.app.task object, represents result of the registered function
217 |
218 | Raises:
219 | ValueError: If the function name is not registered.
220 | Exception: If an error occurs during the execution of the function. The task will retry in this case.
221 | """
222 | try:
223 | if func_name not in self.registered_functions:
224 | raise ValueError(f"Function '{func_name}' is not registered.")
225 |
226 | func = self.registered_functions[func_name]
227 | args = json.loads(args_json)
228 | result = func(**args)
229 | # self.update_function_status(self.call_function_task.request.id, "success")
230 |
231 | return result
232 | except Exception as e:
233 | self.log(f"Error in call_function_task: {str(e)}", "error")
234 | # self.call_function_task.retry(exc=e)
235 |
236 |
237 | def register_function(self, func: callable) -> callable:
238 | """
239 | Decorator to register a function so that it can be invoked as a Celery task.
240 |
241 | Args:
242 | func (callable): The function to register.
243 |
244 | Returns:
245 | callable: The original function, now registered as a callable task.
246 | """
247 | self.registered_functions[func.__name__] = func
248 | return func
249 |
250 | def execute_function(self, func_name: str, args: dict) -> Celery.AsyncResult:
251 | """
252 | Execute a registered function as a Celery task with provided arguments.
253 |
254 | Args:
255 | func_name (str): The name of the function to execute.
256 | args (dict): Arguments to pass to the function.
257 |
258 | Returns:
259 | celery.result.AsyncResult: An object representing the asynchronous result of the task.
260 | """
261 | args_json = json.dumps(args)
262 | async_result = self.call_function_task.delay(func_name, args_json)
263 | return async_result
264 |
265 | def update_function_status(self, task_id: str, status: str) -> None:
266 | """
267 | Update the status of a function task as a new Redis key.
268 |
269 | Args:
270 | task_id (str): The ID of the task.
271 | status (str): The new status to set.
272 | """
273 | redis_client = self.get_redis_connection()
274 | redis_client.set(f"task_status:{task_id}", status)
275 |
276 | def initialize_dataset(self, **kwargs) -> None:
277 | """
278 | Initialize a Hugging Face repository if it doesn't exist. Reads Hugging Face info from config or .env
279 |
280 | Args:
281 | kwargs: kwargs that can be passed into the HfApi.create_repo function.
282 |
283 | Raises:
284 | HTTPError: If repo cannot be created due to connection error other than repo not existing
285 | """
286 | repo_id = self.settings.get("HF_REPO_ID")
287 | hf_token = self.settings.get("HF_TOKEN")
288 | api = HfApi(token=hf_token)
289 |
290 | # creates new repo if desired repo is not found
291 | try:
292 | repo_info = api.repo_info(repo_id=repo_id, repo_type="dataset", timeout=30)
293 | except HTTPError as e:
294 | if e.response.status_code == 404:
295 | self.log(
296 | f"Repository {repo_id} does not exist. Creating a new repository.",
297 | "warn",
298 | )
299 | api.create_repo(
300 | repo_id=repo_id, token=hf_token, repo_type="dataset", **kwargs
301 | )
302 | else:
303 | raise e
304 |
305 | # Create config.json file
306 | config = {
307 | "data_loader_name": "custom",
308 | "data_loader_kwargs": {
309 | "path": repo_id,
310 | "format": "files",
311 | "fields": ["file_path", "text"],
312 | },
313 | }
314 |
315 | # apply config.json to created repo
316 | with tempfile.TemporaryDirectory() as temp_dir:
317 | with Repository(
318 | local_dir=temp_dir,
319 | clone_from=repo_id,
320 | repo_type="dataset",
321 | use_auth_token=hf_token,
322 | ).commit(commit_message="Add config.json"):
323 | with open(os.path.join(temp_dir, "config.json"), "w") as f:
324 | json.dump(config, f, indent=2)
325 |
326 | self.log(f"Initialized repository {repo_id}.")
327 |
328 | # upload a single file to the Hugging Face repository
329 | def upload_file(self, file_path: str) -> None:
330 | """
331 | Upload a file to a Hugging Face repository.
332 |
333 | Args:
334 | file_path (str): The path of the file to upload.
335 |
336 | Raises:
337 | Exception: If an error occurs during the upload process.
338 |
339 | """
340 | hf_token = self.settings.get("HF_TOKEN")
341 | repo_id = self.settings.get("HF_REPO_ID")
342 |
343 | api = HfApi(token=hf_token)
344 |
345 | try:
346 | self.log(f"Uploading {file_path} to Hugging Face repo {repo_id}")
347 | api.upload_file(
348 | path_or_fileobj=file_path,
349 | path_in_repo=os.path.basename(file_path),
350 | repo_id=repo_id,
351 | token=hf_token,
352 | repo_type="dataset",
353 | )
354 | self.log(f"Uploaded {file_path} to Hugging Face repo {repo_id}")
355 | except Exception as e:
356 | self.log(
357 | f"Failed to upload {file_path} to Hugging Face repo {repo_id}: {e}",
358 | "error",
359 | )
360 |
361 | def upload_directory(self, dir_path: str) -> None:
362 | """
363 | Upload a directory to a Hugging Face repository. Can be used to reduce frequency of Hugging Face API
364 | calls if you are rate limited while using the upload_file function.
365 |
366 | Args:
367 | dir_path (str): The path of the directory to upload.
368 |
369 | Raises:
370 | Exception: If an error occurs during the upload process.
371 |
372 | """
373 | hf_token = self.settings.get("HF_TOKEN")
374 | repo_id = self.settings.get("HF_REPO_ID")
375 |
376 | try:
377 | self.log(f"Uploading {dir_path} to Hugging Face repo {repo_id}")
378 |
379 | api = HfApi(token=hf_token)
380 | api.upload_folder(
381 | folder_path=dir_path,
382 | repo_id=repo_id,
383 | repo_type="dataset",
384 | )
385 | self.log(f"Uploaded {dir_path} to Hugging Face repo {repo_id}")
386 | except Exception as e:
387 | self.log(
388 | f"Failed to upload {dir_path} to Hugging Face repo {repo_id}: {e}",
389 | "error",
390 | )
391 |
392 | def delete_file(self, repo_id: str, path_in_repo: str) -> None:
393 | """
394 | Delete a file from a Hugging Face repository.
395 |
396 | Args:
397 | repo_id (str): The ID of the repository.
398 | path_in_repo (str): The path of the file to delete within the repository.
399 |
400 | Raises:
401 | Exception: If an error occurs during the deletion process.
402 |
403 | """
404 | hf_token = self.settings.get("HF_TOKEN")
405 | api = HfApi(token=hf_token)
406 |
407 | try:
408 | api.delete_file(
409 | repo_id=repo_id,
410 | path_in_repo=path_in_repo,
411 | repo_type="dataset",
412 | token=hf_token,
413 | )
414 | self.log(f"Deleted {path_in_repo} from Hugging Face repo {repo_id}")
415 | except Exception as e:
416 | self.log(
417 | f"Failed to delete {path_in_repo} from Hugging Face repo {repo_id}: {e}",
418 | "error",
419 | )
420 |
421 | def file_exists(self, repo_id: str, path_in_repo: str) -> bool:
422 | """
423 | Check if a file exists in a Hugging Face repository.
424 |
425 | Args:
426 | repo_id (str): The ID of the repository.
427 | path_in_repo (str): The path of the file to check within the repository.
428 |
429 | Returns:
430 | bool: True if the file exists in the repository, False otherwise.
431 |
432 | Raises:
433 | Exception: If an error occurs while checking the existence of the file.
434 | """
435 | hf_token = self.settings.get("HF_TOKEN")
436 | api = HfApi(token=hf_token)
437 |
438 | try:
439 | repo_files = api.list_repo_files(
440 | repo_id=repo_id, repo_type="dataset", token=hf_token
441 | )
442 | return path_in_repo in repo_files
443 | except Exception as e:
444 | self.log(
445 | f"Failed to check if {path_in_repo} exists in Hugging Face repo {repo_id}: {e}",
446 | "error",
447 | )
448 | return False
449 |
450 | def list_files(self, repo_id: str) -> list:
451 | """
452 | Get a list of files from a Hugging Face repository.
453 |
454 | Args:
455 | repo_id (str): The ID of the repository.
456 |
457 | Returns:
458 | list: A list of file paths in the repository.
459 |
460 | Raises:
461 | Exception: If an error occurs while retrieving the list of files.
462 | """
463 | hf_token = self.settings.get("HF_TOKEN")
464 | api = HfApi(token=hf_token)
465 |
466 | try:
467 | repo_files = api.list_repo_files(
468 | repo_id=repo_id, repo_type="dataset", token=hf_token
469 | )
470 | return repo_files
471 | except Exception as e:
472 | self.log(
473 | f"Failed to get the list of files from Hugging Face repo {repo_id}: {e}",
474 | "error",
475 | )
476 | return []
477 |
478 | def search_offers(self, max_price: float) -> List[Dict]:
479 | """
480 | Search for available offers to rent a node as an instance on the Vast.ai platform.
481 |
482 | Args:
483 | max_price (float): The maximum price per hour for the instance.
484 |
485 | Returns:
486 | List[Dict]: A list of dictionaries representing the available offers.
487 |
488 | Raises:
489 | requests.exceptions.RequestException: If there is an error while making the API request.
490 | """
491 | api_key = self.get_env("VAST_API_KEY")
492 | base_url = "https://console.vast.ai/api/v0/bundles/"
493 | headers = {
494 | "Accept": "application/json",
495 | "Content-Type": "application/json",
496 | "Authorization": f"Bearer {api_key}",
497 | }
498 | url = (
499 | base_url
500 | + '?q={"gpu_ram":">=4","rentable":{"eq":true},"dph_total":{"lte":'
501 | + str(max_price)
502 | + '},"sort_option":{"0":["dph_total","asc"],"1":["total_flops","asc"]}}'
503 | )
504 |
505 | try:
506 | response = requests.get(url, headers=headers)
507 | response.raise_for_status()
508 | json_response = response.json()
509 | return json_response["offers"]
510 |
511 | except requests.exceptions.RequestException as e:
512 | self.log(
513 | f"Error: {e}\nResponse: {response.text if response else 'No response'}"
514 | )
515 | raise
516 |
517 | def create_instance(
518 | self, offer_id: str, image: str, module_name: str, env_settings: Dict, command: str
519 | ) -> Dict:
520 | """
521 | Create an instance on the Vast.ai platform. Passes in some useful Celery settings by default.
522 |
523 | Args:
524 | offer_id (str): The ID of the offer to create the instance from.
525 | image (str): The image to use for the instance. (example: RaccoonResearch/distributask-test-worker)
526 | module_name (str): The name of the module to run on the instance, configured to be a docker file (example: distributask.example.worker)
527 | command (str): Command that initializes celery worker. Has default command with specific settings if not passed in. These settings have
528 | been found to be beneficial to the stability and simplicity of a Distributask run.
529 | env_settings (Dict): Used to pass in environment variables to the Vast.ai instance. This is a dictionary with keys of the
530 | environment variable name and values of the desired value of the environment variable.
531 |
532 | Returns:
533 | Dict: A dictionary representing the created instance.
534 |
535 | Raises:
536 | ValueError: If the Vast.ai API key is not set in the environment.
537 | Exception: If there is an error while creating the instance.
538 | """
539 | if self.get_env("VAST_API_KEY") is None:
540 | self.log("VAST_API_KEY is not set in the environment", "error")
541 | raise ValueError("VAST_API_KEY is not set in the environment")
542 |
543 | if command is None:
544 | command = f"celery -A {module_name} worker --loglevel=info --concurrency=1 --without-heartbeat --prefetch-multiplier=1"
545 |
546 | if env_settings is None:
547 | env_settings = self.settings
548 |
549 | json_blob = {
550 | "client_id": "me",
551 | "image": image,
552 | "env": env_settings,
553 | "disk": 32, # Set a non-zero value for disk
554 | "onstart": f"export PATH=$PATH:/ && cd ../ && {command}",
555 | "runtype": "ssh ssh_proxy",
556 | }
557 | url = f"https://console.vast.ai/api/v0/asks/{offer_id}/?api_key={self.get_env('VAST_API_KEY')}"
558 | headers = {"Authorization": f"Bearer {self.get_env('VAST_API_KEY')}"}
559 | response = requests.put(url, headers=headers, json=json_blob)
560 |
561 | if response.status_code != 200:
562 | self.log(f"Failed to create instance: {response.text}", "error")
563 | raise Exception(f"Failed to create instance: {response.text}")
564 |
565 | return response.json()
566 |
567 | def destroy_instance(self, instance_id: str) -> Dict:
568 | """
569 | Destroy an instance on the Vast.ai platform.
570 |
571 | Args:
572 | instance_id (str): The ID of the instance to destroy.
573 |
574 | Returns:
575 | Dict: A dictionary representing the result of the destroy operation.
576 | """
577 | api_key = self.get_env("VAST_API_KEY")
578 | headers = {"Authorization": f"Bearer {api_key}"}
579 | url = (
580 | f"https://console.vast.ai/api/v0/instances/{instance_id}/?api_key={api_key}"
581 | )
582 | response = requests.delete(url, headers=headers)
583 | return response
584 |
585 | def rent_nodes(
586 | self,
587 | max_price: float,
588 | max_nodes: int,
589 | image: str,
590 | module_name: str,
591 | env_settings: Dict = None,
592 | command: str = None,
593 | ) -> List[Dict]:
594 | """
595 | Rent nodes as an instance on the Vast.ai platform.
596 |
597 | Args:
598 | max_price (float): The maximum price per hour for the nodes.
599 | max_nodes (int): The maximum number of nodes to rent.
600 | image (str): The image to use for the nodes.
601 | module_name (str): The name of the module to run on the nodes.
602 |
603 | Returns:
604 | List[Dict]: A list of dictionaries representing the rented nodes. If error is encountered
605 | trying to rent, it will retry every 5 seconds.
606 | """
607 | rented_nodes: List[Dict] = []
608 | while len(rented_nodes) < max_nodes:
609 | search_retries = 10
610 | while search_retries > 0:
611 | try:
612 | offers = self.search_offers(max_price)
613 | break
614 | except Exception as e:
615 | self.log(
616 | f"Error searching for offers: {str(e)} - retrying in 5 seconds...",
617 | "error",
618 | )
619 | search_retries -= 1
620 | # sleep for 10 seconds before retrying
621 | time.sleep(10)
622 | continue
623 |
624 | offers = sorted(
625 | offers, key=lambda offer: offer["dph_total"]
626 | ) # Sort offers by price, lowest to highest
627 | for offer in offers:
628 | time.sleep(5)
629 | if len(rented_nodes) >= max_nodes:
630 | break
631 | try:
632 | instance = self.create_instance(
633 | offer["id"], image, module_name, env_settings=env_settings, command=command
634 | )
635 | rented_nodes.append(
636 | {
637 | "offer_id": offer["id"],
638 | "instance_id": instance["new_contract"],
639 | }
640 | )
641 | except Exception as e:
642 | self.log(
643 | f"Error renting node: {str(e)} - searching for new offers",
644 | "error",
645 | )
646 | break # Break out of the current offer iteration
647 | else:
648 | # If the loop completes without breaking, all offers have been tried
649 | self.log("No more offers available - stopping node rental", "warning")
650 | break
651 |
652 | atexit.register(self.terminate_nodes, rented_nodes)
653 | return rented_nodes
654 |
655 | def get_node_log(self, node: Dict, wait_time: int = 2):
656 | """
657 | Get the log of the Vast.ai instance that is passed in. Makes an api call to tell the instance to send the log,
658 | and another one to actually retrive the log
659 | Args:
660 | node (Dict): the node that corresponds to the Vast.ai instance you want the log from
661 | wait_time (int): how long to wait in between the two api calls described above
662 |
663 | Returns:
664 | str: the log of the instance requested. If anything else other than a code 200 is received, return None
665 | """
666 | node_id = node["instance_id"]
667 | url = f"https://console.vast.ai/api/v0/instances/request_logs/{node_id}/"
668 |
669 | payload = {"tail": "1000"}
670 | headers = {
671 | "Accept": "application/json",
672 | "Authorization": f"Bearer {self.settings['VAST_API_KEY']}",
673 | }
674 |
675 | response = requests.request(
676 | "PUT", url, headers=headers, json=payload, timeout=5
677 | )
678 |
679 | if response.status_code == 200:
680 | log_url = response.json()["result_url"]
681 | time.sleep(wait_time)
682 | log_response = requests.get(log_url, timeout=5)
683 | if log_response.status_code == 200:
684 | return log_response
685 | else:
686 | return None
687 | else:
688 | return None
689 |
690 | def terminate_nodes(self, nodes: List[Dict]) -> None:
691 | """
692 | Terminate the instances of rented nodes on Vast.ai.
693 |
694 | Args:
695 | nodes (List[Dict]): A list of dictionaries representing the rented nodes.
696 |
697 | Raises:
698 | Exception: If error in destroying instances.
699 | """
700 | print("Terminating nodes...")
701 | for node in nodes:
702 | time.sleep(1)
703 | try:
704 | response = self.destroy_instance(node["instance_id"])
705 | if response.status_code != 200:
706 | time.sleep(5)
707 | self.destroy_instance(node["instance_id"])
708 | except Exception as e:
709 | self.log(
710 | f"Error terminating node: {node['instance_id']}, {str(e)}", "error"
711 | )
712 |
713 | def monitor_tasks(
714 | self, tasks, update_interval=1, show_time_left=True, print_statements=True
715 | ):
716 | """
717 | Monitor the status of the tasks on the Vast.ai nodes.
718 |
719 | Args:
720 | tasks (List): A list of the tasks to monitor. Should be a list of the results of execute_function.
721 | update_interval (bool): Number of seconds the status of tasks are updated.
722 | show_time_left (bool): Show the estimated time left to complete tasks using the tqdm progress bar
723 | print_statments (bool): Allow printing of status of task queue
724 |
725 | Raises:
726 | Exception: If error in the process of executing the tasks
727 | """
728 |
729 | try:
730 | # Wait for the tasks to complete
731 | if print_statements:
732 | print("Tasks submitted to queue. Starting queue...")
733 | print("Elapsed time Distributask:
750 | """
751 | Create Distributask object using settings that merge config.json and .env files present in distributask directory.
752 | If there are conflicting values, the .env takes priority.
753 |
754 | Args:
755 | config_path (str): path to config.json file
756 | env_path (str): path to .env file
757 |
758 | Returns:
759 | Distributask object initialized with settings from config or .env file
760 | """
761 | print("**** CREATE_FROM_CONFIG ****")
762 | global distributask
763 | if distributask is not None:
764 | return distributask
765 | # Load environment variables from .env file
766 | try:
767 | load_dotenv(env_path)
768 | except:
769 | print("No .env file found. Using system environment variables only.")
770 |
771 | # Load configuration from JSON file
772 | try:
773 | settings = OmegaConf.load(config_path)
774 | if not all(settings.values()):
775 | print(f"Configuration file is missing necessary values.")
776 | except:
777 | print(
778 | "Configuration file not found. Falling back to system environment variables."
779 | )
780 | settings = {}
781 |
782 | env_dict = {key: value for key, value in os.environ.items()}
783 | settings = OmegaConf.merge(settings, OmegaConf.create(env_dict))
784 |
785 | distributask = Distributask(
786 | hf_repo_id=settings.get("HF_REPO_ID"),
787 | hf_token=settings.get("HF_TOKEN"),
788 | vast_api_key=settings.get("VAST_API_KEY"),
789 | redis_host=settings.get("REDIS_HOST"),
790 | redis_password=settings.get("REDIS_PASSWORD"),
791 | redis_port=settings.get("REDIS_PORT"),
792 | redis_username=settings.get("REDIS_USER"),
793 | broker_pool_limit=int(settings.get("BROKER_POOL_LIMIT", 1)),
794 | )
795 |
796 | return distributask
797 |
--------------------------------------------------------------------------------
/distributask/example/__init__.py:
--------------------------------------------------------------------------------
1 | from .distributed import *
2 | from .local import *
3 | from .worker import *
4 | from .shared import *
5 |
--------------------------------------------------------------------------------
/distributask/example/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import atexit
5 |
6 | from .shared import distributask, example_function
7 |
8 | if __name__ == "__main__":
9 | # Create an ArgumentParser object
10 | parser = argparse.ArgumentParser(description="Distributask example script")
11 |
12 | # Add arguments with default values
13 | parser.add_argument(
14 | "--max_price",
15 | type=float,
16 | default=0.20,
17 | help="Max price per node, in dollars (default: 0.20)",
18 | )
19 | parser.add_argument(
20 | "--max_nodes",
21 | type=int,
22 | default=1,
23 | help="Max number of nodes to rent (default: 1)",
24 | )
25 | parser.add_argument(
26 | "--docker_image",
27 | type=str,
28 | default="antbaez/distributask-test-worker",
29 | help="Docker image to use for the worker (default: antbaez/distributask-test-worker)",
30 | )
31 | parser.add_argument(
32 | "--module_name",
33 | type=str,
34 | default="distributask.example.worker",
35 | help="Module name (default: distributask.example.worker)",
36 | )
37 | parser.add_argument(
38 | "--number_of_tasks", type=int, default=10, help="Number of tasks (default: 10)"
39 | )
40 |
41 | args = parser.parse_args()
42 |
43 | completed = False
44 |
45 | # Register function to distributask object
46 | distributask.register_function(example_function)
47 |
48 | # Initialize the dataset on Hugging Face
49 | distributask.initialize_dataset()
50 |
51 | # Create a file with the current date and time and save it as "datetime.txt"
52 | with open("datetime.txt", "w") as f:
53 | f.write(time.strftime("%Y-%m-%d %H:%M:%S"))
54 |
55 | # Upload file to the repository
56 | distributask.upload_file("datetime.txt")
57 |
58 | # Remove the example file from local
59 | os.remove("datetime.txt")
60 |
61 | vast_api_key = distributask.get_env("VAST_API_KEY")
62 | if not vast_api_key:
63 | raise ValueError("Vast API key not found in configuration.")
64 |
65 | job_configs = []
66 |
67 | # Compile parameters for tasks
68 | for i in range(args.number_of_tasks):
69 | job_configs.append(
70 | {
71 | "outputs": [f"result_{i}.txt"],
72 | "task_params": {"index": i, "arg1": 1, "arg2": 2},
73 | }
74 | )
75 |
76 | # Rent Vast.ai nodes and get list of node ids
77 | print("Renting nodes...")
78 | rented_nodes = distributask.rent_nodes(
79 | args.max_price, args.max_nodes, args.docker_image, args.module_name
80 | )
81 |
82 | print("Total rented nodes: ", len(rented_nodes))
83 |
84 | tasks = []
85 |
86 | # Submit the tasks to the queue for the Vast.ai worker nodes to execute
87 | for i in range(args.number_of_tasks):
88 | job_config = job_configs[i]
89 | print(f"Task {i}")
90 | print(job_config)
91 | print("Task params: ", job_config["task_params"])
92 |
93 | params = job_config["task_params"]
94 |
95 | # Each task executes the function "example_function", defined in shared.py
96 | task = distributask.execute_function(example_function.__name__, params)
97 |
98 | # Add the task to the list of tasks
99 | tasks.append(task)
100 |
101 | def terminate_workers():
102 | distributask.terminate_nodes(rented_nodes)
103 | print("Workers terminated.")
104 |
105 | # Terminate Vast.ai nodes on exit of script
106 | atexit.register(terminate_workers)
107 |
108 | # Monitor the status of the tasks with tqdm
109 | distributask.monitor_tasks(tasks)
110 |
--------------------------------------------------------------------------------
/distributask/example/local.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import os
3 | import subprocess
4 | import time
5 |
6 | from .shared import distributask, example_function
7 |
8 |
9 | if __name__ == "__main__":
10 | completed = False
11 |
12 | # Register function to distributask object
13 | distributask.register_function(example_function)
14 |
15 | # First, initialize the dataset on Hugging Face
16 | distributask.initialize_dataset()
17 |
18 | # Create a file with the current date and time and save it as "datetime.txt"
19 | with open("datetime.txt", "w") as f:
20 | f.write(time.strftime("%Y-%m-%d %H:%M:%S"))
21 |
22 | # Upload this to the repository
23 | distributask.upload_file("datetime.txt")
24 |
25 | # Remove the example file from local
26 | os.remove("datetime.txt")
27 |
28 | vast_api_key = distributask.get_env("VAST_API_KEY")
29 | if not vast_api_key:
30 | raise ValueError("Vast API key not found in configuration.")
31 |
32 | job_configs = []
33 | number_of_tasks = 3
34 |
35 | # Compile parameters for tasks
36 | for i in range(number_of_tasks):
37 | job_configs.append(
38 | {
39 | "outputs": [f"result_{i}.txt"],
40 | "task_params": {"index": i, "arg1": 1, "arg2": 2},
41 | }
42 | )
43 |
44 | tasks = []
45 |
46 | repo_id = distributask.get_env("HF_REPO_ID")
47 |
48 | # Submit the tasks to the queue for the Vast.ai worker nodes to execute
49 | for i in range(number_of_tasks):
50 | job_config = job_configs[i]
51 | print(f"Task {i}")
52 | print(job_config)
53 | print("Task params: ", job_config["task_params"])
54 |
55 | params = job_config["task_params"]
56 |
57 | # Each task executes the function "example_function", defined in shared.py
58 | task = distributask.execute_function(example_function.__name__, params)
59 |
60 | # Add the task to the list of tasks
61 | tasks.append(task)
62 |
63 | # Start the local worker
64 | docker_installed = False
65 | # Check if docker is installed
66 | try:
67 | subprocess.run(["docker", "version"], check=True)
68 | docker_installed = True
69 | except Exception as e:
70 | print("Docker is not installed. Starting worker locally.")
71 | print(e)
72 |
73 | docker_process = None
74 | # If docker is installed, start local Docker worker
75 | # If docker is not installed, start local Celery worker
76 | if docker_installed is False:
77 | print("Docker is not installed. Starting worker locally.")
78 | celery_worker = subprocess.Popen(
79 | ["celery", "-A", "distributask.example.worker", "worker", "--loglevel=info"]
80 | )
81 |
82 | else:
83 | build_process = subprocess.Popen(
84 | [
85 | "docker",
86 | "build",
87 | "-t",
88 | "distributask-example-worker",
89 | ".",
90 | ]
91 | )
92 | build_process.wait()
93 |
94 | docker_process = subprocess.Popen(
95 | [
96 | "docker",
97 | "run",
98 | "-e",
99 | f"VAST_API_KEY={vast_api_key}",
100 | "-e",
101 | f"REDIS_HOST={distributask.get_env('REDIS_HOST')}",
102 | "-e",
103 | f"REDIS_PORT={distributask.get_env('REDIS_PORT')}",
104 | "-e",
105 | f"REDIS_PASSWORD={distributask.get_env('REDIS_PASSWORD')}",
106 | "-e",
107 | f"REDIS_USER={distributask.get_env('REDIS_USER')}",
108 | "-e",
109 | f"HF_TOKEN={distributask.get_env('HF_TOKEN')}",
110 | "-e",
111 | f"HF_REPO_ID={repo_id}",
112 | "distributask-example-worker",
113 | ]
114 | )
115 |
116 | def kill_docker():
117 | print("Killing docker container")
118 | docker_process.terminate()
119 |
120 | # Terminate Docker worker on exit of script
121 | atexit.register(kill_docker)
122 |
123 | # Monitor the status of the tasks with tqdm
124 | distributask.monitor_tasks(tasks)
125 |
126 |
--------------------------------------------------------------------------------
/distributask/example/shared.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import time
4 |
5 | from ..distributask import create_from_config
6 |
7 | # Create distributask instance
8 | distributask = create_from_config()
9 |
10 | # This is the function that will be executed on the nodes
11 | # You can make your own function and pass in whatever arguments you want
12 | def example_function(index, arg1, arg2):
13 |
14 | result = arg1 + arg2
15 |
16 | time.sleep(random.randint(1, 6))
17 |
18 | # Save the result to a file
19 | with open(f"result_{index}.txt", "w") as f:
20 | f.write(f"{str(arg1)} plus {str(arg2)} is {str(result)}")
21 |
22 | # Write the file to huggingface
23 | distributask.upload_file(f"result_{index}.txt")
24 |
25 | # Delete local file
26 | os.remove(f"result_{index}.txt")
27 |
28 | # Return the result - you can get this value from the task object
29 | return f"Task {index} completed. Result ({str(arg1)} + {str(arg2)}): {str(result)}"
30 |
--------------------------------------------------------------------------------
/distributask/example/worker.py:
--------------------------------------------------------------------------------
1 | from .shared import distributask, example_function
2 |
3 | # Register function to worker using distributask instance
4 | distributask.register_function(example_function)
5 |
6 | # Create Celery worker
7 | celery = distributask.app
8 |
--------------------------------------------------------------------------------
/distributask/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from .tests import *
2 | from .worker import *
3 |
--------------------------------------------------------------------------------
/distributask/tests/tests.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pytest
3 | import time
4 | import os
5 | import tempfile
6 | from unittest.mock import MagicMock, patch
7 |
8 | from huggingface_hub import HfApi
9 |
10 | from ..distributask import create_from_config
11 | from .worker import example_test_function
12 |
13 |
14 | @pytest.fixture
15 | def mock_task_function():
16 | """
17 | Fixture that returns a mock task function.
18 | """
19 | return MagicMock()
20 |
21 |
22 | def test_register_function(mock_task_function):
23 | """
24 | Test the register_function function.
25 | """
26 | mock_task_function.__name__ = "mock_task" # Set the __name__ attribute
27 | distributask = create_from_config()
28 | decorated_task = distributask.register_function(mock_task_function)
29 |
30 | assert callable(decorated_task)
31 | assert mock_task_function.__name__ in distributask.registered_functions
32 | assert (
33 | distributask.registered_functions[mock_task_function.__name__]
34 | == mock_task_function
35 | )
36 | print("Test passed")
37 |
38 |
39 | @patch("distributask.distributask.call_function_task.delay")
40 | def test_execute_function(mock_delay, mock_task_function):
41 | """
42 | Test the execute_function function.
43 | """
44 | mock_task_function.__name__ = "mock_task" # Set the __name__ attribute
45 | distributask = create_from_config()
46 | distributask.register_function(mock_task_function)
47 |
48 | params = {"arg1": 1, "arg2": 2}
49 | distributask.execute_function(mock_task_function.__name__, params)
50 |
51 | mock_delay.assert_called_once_with(mock_task_function.__name__, json.dumps(params))
52 | print("Test passed")
53 |
54 |
55 | def test_register_function():
56 | distributask = create_from_config()
57 |
58 | distributask.register_function(example_test_function)
59 | assert "example_test_function" in distributask.registered_functions
60 | assert (
61 | distributask.registered_functions["example_test_function"]
62 | == example_test_function
63 | )
64 | print("Task registration test passed")
65 |
66 |
67 | def test_execute_function():
68 | distributask = create_from_config()
69 |
70 | distributask.register_function(example_test_function)
71 | task_params = {"arg1": 10, "arg2": 20}
72 | task = distributask.execute_function("example_test_function", task_params)
73 | assert task.id is not None
74 | print("Task execution test passed")
75 |
76 |
77 | # def test_worker_task_execution():
78 | # distributask = create_from_config()
79 |
80 | # distributask.register_function(example_test_function)
81 |
82 | # worker_cmd = [
83 | # "celery",
84 | # "-A",
85 | # "distributask.tests.worker",
86 | # "worker",
87 | # "--loglevel=info",
88 | # ]
89 | # print("worker_cmd")
90 | # print(worker_cmd)
91 | # worker_process = subprocess.Popen(worker_cmd)
92 |
93 | # time.sleep(2)
94 |
95 | # task_params = {"arg1": 10, "arg2": 20}
96 | # print("executing task")
97 | # task = distributask.execute_function("example_test_function", task_params)
98 | # result = task.get(timeout=30)
99 |
100 | # assert result == "+arg2=30"
101 |
102 | # worker_process.terminate()
103 | # worker_process.wait()
104 |
105 | # print("Worker task execution test passed")
106 |
107 |
108 | def test_task_status_update():
109 | distributask = create_from_config()
110 | redis_client = distributask.get_redis_connection()
111 |
112 | task_status_keys = redis_client.keys("task_status:*")
113 | if task_status_keys:
114 | redis_client.delete(*task_status_keys)
115 |
116 | task_id = "test_task_123"
117 | status = "COMPLETED"
118 |
119 | distributask.update_function_status(task_id, status)
120 |
121 | status_from_redis = redis_client.get(f"task_status:{task_id}").decode()
122 | assert status_from_redis == status
123 |
124 | redis_client.delete(f"task_status:{task_id}")
125 |
126 | print("Task status update test passed")
127 |
128 |
129 | def test_initialize_repo():
130 | distributask = create_from_config()
131 |
132 | # Initialize the repository
133 | distributask.initialize_dataset()
134 | hf_token = distributask.get_env("HF_TOKEN")
135 | repo_id = distributask.get_env("HF_REPO_ID")
136 |
137 | print("repo_id")
138 | print(repo_id)
139 |
140 | print("hf_token")
141 | print(hf_token)
142 |
143 | # Check if the repository exists
144 | api = HfApi(token=hf_token)
145 | repo_info = api.repo_info(repo_id=repo_id, repo_type="dataset", timeout=30)
146 | assert repo_info.id == repo_id
147 |
148 | # Check if the config.json file exists in the repository
149 | repo_files = api.list_repo_files(
150 | repo_id=repo_id, repo_type="dataset", token=hf_token
151 | )
152 | assert "config.json" in repo_files
153 |
154 | # Clean up the repository
155 | api.delete_repo(repo_id=repo_id, repo_type="dataset", token=hf_token)
156 |
157 |
158 | def test_upload_directory():
159 | distributask = create_from_config()
160 | distributask.initialize_dataset()
161 | # Create a temporary directory for testing
162 | with tempfile.TemporaryDirectory() as temp_dir:
163 | # Create test files
164 | test_files = ["test1.txt", "test2.txt"]
165 | for file in test_files:
166 | file_path = os.path.join(temp_dir, file)
167 | with open(file_path, "w") as f:
168 | f.write("Test content")
169 |
170 | hf_token = distributask.get_env("HF_TOKEN")
171 | repo_id = distributask.get_env("HF_REPO_ID")
172 | repo_path = distributask.get_env("HF_REPO_PATH", "data")
173 |
174 | # Upload the directory to the repository
175 | distributask.upload_directory(temp_dir)
176 |
177 | # Check if the files exist in the Hugging Face repository
178 | api = HfApi(token=hf_token)
179 | repo_files = api.list_repo_files(
180 | repo_id=repo_id, repo_type="dataset", token=hf_token
181 | )
182 | for file in test_files:
183 | print(file)
184 | assert file in repo_files
185 |
186 | for file in test_files:
187 | os.remove(os.path.join(temp_dir, file))
188 |
189 | # Clean up the repository
190 | api.delete_repo(repo_id=repo_id, repo_type="dataset", token=hf_token)
191 |
192 |
193 | def test_delete_file():
194 | distributask = create_from_config()
195 | distributask.initialize_dataset()
196 | hf_token = distributask.get_env("HF_TOKEN")
197 | repo_id = distributask.get_env("HF_REPO_ID")
198 |
199 | # Create a test file in the repository
200 | test_file = "test.txt"
201 | with open(test_file, "w") as f:
202 | f.write("Test content")
203 |
204 | api = HfApi(token=hf_token)
205 | api.upload_file(
206 | path_or_fileobj=test_file,
207 | path_in_repo=test_file,
208 | repo_id=repo_id,
209 | token=hf_token,
210 | repo_type="dataset",
211 | )
212 |
213 | # delete the file on disk
214 | os.remove(test_file)
215 |
216 | # Delete the file from the repository
217 | distributask.delete_file(repo_id, test_file)
218 |
219 | # Check if the file is deleted from the repository
220 | repo_files = api.list_repo_files(
221 | repo_id=repo_id, repo_type="dataset", token=hf_token
222 | )
223 | assert test_file not in repo_files
224 |
225 | # Clean up the repository
226 | api.delete_repo(repo_id=repo_id, repo_type="dataset", token=hf_token)
227 |
228 |
229 | def test_file_exists():
230 | distributask = create_from_config()
231 | distributask.initialize_dataset()
232 | hf_token = distributask.get_env("HF_TOKEN")
233 | repo_id = distributask.get_env("HF_REPO_ID")
234 |
235 | # Create a test file in the repository
236 | test_file = "test.txt"
237 | with open(test_file, "w") as f:
238 | f.write("Test content")
239 |
240 | api = HfApi(token=hf_token)
241 | api.upload_file(
242 | path_or_fileobj=test_file,
243 | path_in_repo=test_file,
244 | repo_id=repo_id,
245 | token=hf_token,
246 | repo_type="dataset",
247 | )
248 |
249 | # delete the file on disk
250 | os.remove(test_file)
251 |
252 | # Check if the file exists in the repository
253 | assert distributask.file_exists(repo_id, test_file)
254 |
255 | # Check if a non-existent file exists in the repository
256 | assert not distributask.file_exists(repo_id, "nonexistent.txt")
257 |
258 | # Clean up the repository
259 | api.delete_repo(repo_id=repo_id, repo_type="dataset", token=hf_token)
260 |
261 |
262 | def test_list_files():
263 | distributask = create_from_config()
264 | distributask.initialize_dataset()
265 | hf_token = distributask.get_env("HF_TOKEN")
266 | repo_id = distributask.get_env("HF_REPO_ID")
267 |
268 | # Create test files in the repository
269 | test_files = ["test1.txt", "test2.txt"]
270 | # for each test_file, write the file
271 | for file in test_files:
272 | with open(file, "w") as f:
273 | f.write("Test content")
274 | api = HfApi(token=hf_token)
275 | for file in test_files:
276 | api.upload_file(
277 | path_or_fileobj=file,
278 | path_in_repo=file,
279 | repo_id=repo_id,
280 | token=hf_token,
281 | repo_type="dataset",
282 | )
283 |
284 | # List the files in the repository
285 | repo_files = distributask.list_files(repo_id)
286 |
287 | for file in test_files:
288 | os.remove(file)
289 |
290 | # Check if the test files are present in the repository
291 | for file in test_files:
292 | assert file in repo_files
293 |
294 | # Clean up the repository
295 | api.delete_repo(repo_id=repo_id, repo_type="dataset", token=hf_token)
296 |
297 |
298 | @pytest.fixture(scope="module")
299 | def rented_nodes():
300 | distributask = create_from_config()
301 |
302 | max_price = 0.5
303 | max_nodes = 1
304 | image = "antbaez/distributask-worker:latest"
305 | module_name = "distributask.example.worker"
306 |
307 | nodes = distributask.rent_nodes(max_price, max_nodes, image, module_name)
308 | yield nodes
309 |
310 | distributask.terminate_nodes(nodes)
311 |
312 |
313 | def test_rent_run_terminate(rented_nodes):
314 | assert len(rented_nodes) == 1
315 | time.sleep(3) # sleep for 3 seconds to simulate runtime
316 |
317 |
318 | def test_get_redis_url():
319 | distributask = create_from_config()
320 | redis_url = distributask.get_redis_url()
321 |
322 | assert redis_url.startswith("redis://")
323 | assert distributask.settings["REDIS_USER"] in redis_url
324 | assert distributask.settings["REDIS_PASSWORD"] in redis_url
325 | assert distributask.settings["REDIS_HOST"] in redis_url
326 | assert str(distributask.settings["REDIS_PORT"]) in redis_url
327 |
328 |
329 | def test_get_redis_connection_force_new():
330 | distributask = create_from_config()
331 | redis_client1 = distributask.get_redis_connection()
332 | redis_client2 = distributask.get_redis_connection(force_new=True)
333 |
334 | assert redis_client1 is not redis_client2
335 |
336 |
337 | def test_get_redis_connection_force_new():
338 | distributask = create_from_config()
339 | redis_client1 = distributask.get_redis_connection()
340 | redis_client2 = distributask.get_redis_connection(force_new=True)
341 |
342 | assert redis_client1 is not redis_client2
343 |
344 |
345 | def test_get_env_with_default():
346 | distributask = create_from_config()
347 | default_value = "default"
348 | value = distributask.get_env("NON_EXISTENT_KEY", default_value)
349 |
350 | assert value == default_value
351 |
352 |
353 | @patch("requests.get")
354 | def test_search_offers(mock_get):
355 | distributask = create_from_config()
356 | max_price = 1.0
357 |
358 | mock_response = MagicMock()
359 | mock_response.json.return_value = {"offers": [{"id": "offer1"}, {"id": "offer2"}]}
360 | mock_get.return_value = mock_response
361 |
362 | offers = distributask.search_offers(max_price)
363 |
364 | assert len(offers) == 2
365 | assert offers[0]["id"] == "offer1"
366 | assert offers[1]["id"] == "offer2"
367 |
368 |
369 | @patch("requests.put")
370 | def test_create_instance(mock_put):
371 | distributask = create_from_config()
372 | offer_id = "offer1"
373 | image = "test_image"
374 | module_name = "distributask.example.worker"
375 | command = f"celery -A {module_name} worker --loglevel=info"
376 |
377 | mock_response = MagicMock()
378 | mock_response.status_code = 200
379 | mock_response.json.return_value = {"new_contract": "instance1"}
380 | mock_put.return_value = mock_response
381 |
382 | instance = distributask.create_instance(offer_id, image, module_name, distributask.settings, command)
383 |
384 | assert instance["new_contract"] == "instance1"
385 |
386 | # def test_get_node_log():
387 |
388 | # distributask = create_from_config()
389 |
390 | # max_price = 0.5
391 | # max_nodes = 1
392 | # image = "antbaez/distributask-test-worker"
393 | # module_name = "distributask.example.worker"
394 |
395 | # nodes = distributask.rent_nodes(max_price, max_nodes, image, module_name)
396 |
397 | # time.sleep(60)
398 |
399 | # response = distributask.get_node_log(nodes[0], wait_time=5)
400 |
401 | # assert response is not None
402 | # assert response.status_code == 200
403 |
404 |
405 | from io import StringIO
406 | import subprocess
407 |
408 | def test_local_example_run():
409 | # Capture the stdout and stderr during the execution
410 | with patch("sys.stdout", new=StringIO()) as fake_out, patch(
411 | "sys.stderr", new=StringIO()
412 | ) as fake_err:
413 |
414 | # Start a new process to run the local example
415 | process = subprocess.Popen(["python", "-m", "distributask.example.local"])
416 | # if process hasn't ended in 3min, test is failed
417 | process.wait(timeout=180)
418 |
419 | # Get the captured output from stdout
420 | # output = fake_out.getvalue()
421 | # print(output)
422 |
423 | # Assert that no errors are captured in stderr
424 | assert fake_err.getvalue() == ""
425 |
426 | try:
427 | stop_command = "docker stop $(docker ps -q)"
428 | subprocess.run(stop_command, shell=True, check=True)
429 | print("All containers stopped successfully")
430 | except:
431 | pass
432 |
433 |
434 | def test_distributed_example_run():
435 | # Capture the stdout and stderr during the execution
436 | with patch("sys.stdout", new=StringIO()) as fake_out, patch(
437 | "sys.stderr", new=StringIO()
438 | ) as fake_err:
439 |
440 | # Start a new process to run the local example
441 | process = subprocess.Popen(
442 | ["python", "-m", "distributask.example.distributed", "--number_of_tasks=3"]
443 | )
444 | # if process hasn't ended in 2min, test is failed
445 | process.wait(timeout=120)
446 |
447 | # Get the captured output from stdout
448 | # output = fake_out.getvalue()
449 | # print(output)
450 |
451 | # Assert that no errors are captured in stderr
452 | assert fake_err.getvalue() == ""
453 |
--------------------------------------------------------------------------------
/distributask/tests/worker.py:
--------------------------------------------------------------------------------
1 | from ..distributask import create_from_config
2 |
3 | distributaur = create_from_config()
4 |
5 |
6 | # Define and register the test_function
7 | def example_test_function(arg1, arg2):
8 | return f"Result: arg1+arg2={arg1+arg2}"
9 |
10 |
11 | celery = distributaur.app
12 |
13 |
14 | if __name__ == "__main__":
15 | distributaur.register_function(example_test_function)
16 |
--------------------------------------------------------------------------------
/docs/assets/DeepAI.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepAI-Research/Distributask/f7ba1b72720b48750d8909651d17c8a16a5e0b82/docs/assets/DeepAI.png
--------------------------------------------------------------------------------
/docs/assets/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepAI-Research/Distributask/f7ba1b72720b48750d8909651d17c8a16a5e0b82/docs/assets/banner.png
--------------------------------------------------------------------------------
/docs/assets/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepAI-Research/Distributask/f7ba1b72720b48750d8909651d17c8a16a5e0b82/docs/assets/diagram.png
--------------------------------------------------------------------------------
/docs/assets/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepAI-Research/Distributask/f7ba1b72720b48750d8909651d17c8a16a5e0b82/docs/assets/favicon.ico
--------------------------------------------------------------------------------
/docs/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepAI-Research/Distributask/f7ba1b72720b48750d8909651d17c8a16a5e0b82/docs/assets/logo.png
--------------------------------------------------------------------------------
/docs/distributask.md:
--------------------------------------------------------------------------------
1 | # Distributask Class
2 |
3 | ::: distributask.Distributask
4 | options:
5 | members: true
6 | show_root_heading: true
7 | show_source: true
--------------------------------------------------------------------------------
/docs/getting_started.md:
--------------------------------------------------------------------------------
1 | # Getting Started
2 |
3 | Below are instructions to get distributask running on your machine. Please read through the rest of the documentation for more detailed information.
4 |
5 | ## Installation
6 |
7 | ```bash
8 | pip install distributask
9 | ```
10 |
11 | ## Development
12 |
13 | ### Prerequisites
14 |
15 | - Python 3.8 or newer (tested on Python 3.11)
16 | - Redis server
17 | - Vast.ai API key
18 | - HuggingFace API key
19 |
20 |
21 | ### Setup
22 |
23 | Clone the repository and navigate to the project directory:
24 |
25 | ```bash
26 | git clone https://github.com/RaccoonResearch/Distributask.git
27 | cd Distributask
28 | ```
29 |
30 | Install the required packages:
31 |
32 | ```bash
33 | pip install -r requirements.txt
34 | ```
35 |
36 | Install the distributask package:
37 |
38 | ```bash
39 | python setup.py install
40 | ```
41 |
42 | ### Configuration
43 |
44 | Create a `.env` file in the root directory of your project or set environment variables to match your setup:
45 |
46 | ```plaintext
47 | REDIS_HOST=redis_host
48 | REDIS_PORT=redis_port
49 | REDIS_USER=redis_user
50 | REDIS_PASSWORD=redis_password
51 | VAST_API_KEY=your_vastai_api_key
52 | HF_TOKEN=your_huggingface_token
53 | HF_REPO_ID=your_huggingface_repo
54 | BROKER_POOL_LIMIT=broker_pool_limit
55 | ```
56 |
57 | ### Running an Example Task
58 |
59 | To run an example task and see distributask in action, you can execute the example script provided in the project:
60 |
61 | ```bash
62 | # Run an example task locally
63 | python -m distributask.example.local
64 |
65 | # Run an example task on Vast.ai ("kitchen sink" example)
66 | python -m distributask.example.distributed
67 | ```
68 |
69 | ### Command Options
70 |
71 | Below are options you can pass into your distributask example run.
72 |
73 | - `--max_price` is the max price (in $/hour) a node can be be rented for.
74 | - `--max_nodes` is the max number of vast.ai nodes that can be rented.
75 | - `--docker_image` is the name of the docker image to load to the vast.ai node.
76 | - `--module_name` is the name of the celery worker
77 | - `--number_of_tasks` is the number of example tasks that will be added to the queue and done by the workers.
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | Distributask is a simple way to distribute rendering tasks across multiple machines.
2 |
3 | This documentation is intended to help you understand the structure of the Distributask API and codebase and how to use it to distribute rendering tasks across multiple machines for your own projects.
4 |
5 | ## Core Use Cases
6 | Distributask can be used for any task that can is parallelizable. Some specific use cases include:
7 |
8 | - Rendering videos
9 | - Running simulations
10 | - Generating or processing large datasets
11 |
12 | ## Getting Started
13 |
14 | Visit the [Getting Started](getting_started.md) page to learn how to set up your environment and get started with distributing with Distributask.
15 |
16 | ## Overview
17 |
18 | Distributed rendering using Distributask can be broken into four steps:
19 |
20 | #### Creating the task queue
21 |
22 | Distributask uses Celery, an asyncronous distributed task processing package, to create the task queue on your local machine. Each task on the queue is a function that tells remote machines, or workers, what to do. For example, if we wanted to render videos, each task would be a function that contains the code to render a different video.
23 |
24 | #### Passing the tasks to workers
25 |
26 | Distributask uses Redis, a data structure that can be used as a database, as a message broker. This means that Redis is used to transfer tasks yet to be done from the task queue to the worker so that the job can be done.
27 |
28 | #### Executing the tasks
29 |
30 | Distributask uses Vast.ai, a decentralized GPU market, to create workers that execute the task. The task is given to the worker, executed, and the completed task status is passed back to the central machine via Redis.
31 |
32 | #### Storing results of the tasks
33 |
34 | Distributask uses Huggingface, a platform for sharing AI models and datasets, to store the results of the task. The results of the task are uploaded to Hugginface using API calls in Distributask. For example, our rendered videos would be uploaded as a dataset on Huggingface.
35 |
36 | ## Flowchart of Distributask process
37 |
38 |
39 |
--------------------------------------------------------------------------------
/docs/more_info.md:
--------------------------------------------------------------------------------
1 | # Summary of most relevant functions
2 |
3 | #### Settings, Environment, and Help
4 |
5 | - `create_from_config()` - creates Distribtask instance using environment variables
6 | - `get_env(key)` - gets value from .env
7 | - `get_settings(key)` - gets value from settings dictionary
8 |
9 | #### Celery tasks
10 |
11 | - `register_function(func)` - registers function to be task for worker
12 | - `execute_function(func_name, args)` - creates Celery task using registered function
13 |
14 | #### Redis server
15 |
16 | - `get_redis_url()` - gets Redis host url
17 | - `get_redis_connection()` - gets Redis connection instance
18 |
19 | #### Worker management via Vast.ai API
20 |
21 | - `search_offers(max_price)` - searches for available instances on Vast.ai
22 | - `rent_nodes(max_price, max_nodes, image, module_name, command)` - rents nodes using Vast.ai instance
23 | - `terminate_nodes(node_id_lists)` - terminates Vast.ai instance
24 |
25 |
26 | #### HuggingFace repositories and uploading
27 |
28 | - `initialize_dataset()` - intializes dataset repo on HuggingFace
29 | - `upload_file(path_to_file)` - uploads file to Huggingface
30 | - `upload_directory(path_to_directory)` - uploads folder to Huggingface repo
31 | - `delete_file(path_to_file)` - deletes file on HuggingFace repo
32 |
33 | #### Visit the [Distributask Class](distributask.md) page for full, detailed documentation of the distributask class.
34 |
35 | # Docker Setup
36 |
37 | Distributask uses a Docker image to transfer the environment and neccessary files to the Vast.ai nodes. In your implementation using Distributask, you can use the Docker file in the Distributask repository as a base for your own Docker file. If you do this, be sure to add Distributask to the list of packages to be installed on your Docker file.
38 |
39 | # Important Packages
40 |
41 | Visit the websites of these wonderful packages to learn more about how they work and how to use them.
42 |
43 | Celery: `https://docs.celeryq.dev/en/stable/`
44 | Redis: `https://redis.io/docs/latest/`
45 | Hugging Face: `https://huggingface.co/docs/huggingface_hub/en/guides/upload`
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Distributask
2 | theme:
3 | name: material
4 | logo: assets/logo.png
5 | favicon: assets/favicon.ico
6 | plugins:
7 | - search
8 | - autorefs
9 | - mkdocstrings:
10 | enabled: true
11 | default_handler: python
12 | nav:
13 | - Home: index.md
14 | - Getting Started: getting_started.md
15 | - More Information: more_info.md
16 | - Distributask Class: distributask.md
17 |
18 |
19 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | requests
2 | fsspec
3 | celery
4 | redis
5 | huggingface_hub
6 | python-dotenv
7 | omegaconf
8 | tqdm
--------------------------------------------------------------------------------
/scripts/kill_redis_connections.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | # Load environment variables from .env file
5 | source .env
6 |
7 | # Check if REDIS_PORT is set in the .env file
8 | if [ -z "$REDIS_PORT" ]; then
9 | echo "REDIS_PORT not found in .env file. Please set it and try again."
10 | exit 1
11 | fi
12 |
13 | # Use lsof to find all PIDs for the given port and store them in an array
14 | PIDS=($(lsof -i TCP:$REDIS_PORT -t))
15 |
16 | # Check if there are any PIDs to kill
17 | if [ ${#PIDS[@]} -eq 0 ]; then
18 | echo "No processes found using port $REDIS_PORT."
19 | exit 0
20 | fi
21 |
22 | # Loop through each PID and kill it
23 | for PID in "${PIDS[@]}"; do
24 | echo "Killing process $PID"
25 | sudo kill -9 $PID
26 | done
27 |
28 | echo "All processes using port $REDIS_PORT have been killed."
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | import os
3 |
4 | # get the cwd where the setup.py file is located
5 | file_path = os.path.dirname(os.path.realpath(__file__))
6 |
7 | long_description = ""
8 | with open(os.path.join(file_path, "README.md"), "r") as fh:
9 | long_description = fh.read()
10 | long_description = long_description.split("\n")
11 | long_description = [line for line in long_description if not "
")[0].split("<")[0] for line in install_requires
23 | ]
24 |
25 | setup(
26 | name="distributask",
27 | version=version,
28 | description="Simple task manager and job queue for distributed rendering. Built on celery and redis.",
29 | long_description=long_description,
30 | long_description_content_type="text/markdown",
31 | url="https://github.com/DeepAI-Research/Distributask",
32 | author="DeepAIResearch",
33 | author_email="team@deepai.org",
34 | license="MIT",
35 | packages=find_packages(),
36 | install_requires=install_requires,
37 | classifiers=[
38 | "Development Status :: 4 - Beta",
39 | "Intended Audience :: Science/Research",
40 | "License :: OSI Approved :: MIT License",
41 | "Operating System :: POSIX :: Linux",
42 | "Programming Language :: Python :: 3",
43 | "Operating System :: MacOS :: MacOS X",
44 | "Operating System :: Microsoft :: Windows",
45 | ],
46 | )
47 |
--------------------------------------------------------------------------------