├── .github └── workflows │ ├── main.yml │ └── publish-to-pypi.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── VERSION ├── dask_pytorch_ddp ├── __init__.py ├── data.py ├── dispatch.py └── results.py ├── setup.py └── tests ├── test_data.py ├── test_dispatch.py └── test_results.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: GitHub Actions 2 | 3 | # only run this workflow on new commits to main 4 | # or PRs into main 5 | on: 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | branches: 11 | - main 12 | schedule: 13 | # Run every Monday morning at 11:00a UTC, 6:00a CST 14 | - cron: '0 11 * * 1' 15 | 16 | jobs: 17 | test: 18 | name: ${{ matrix.task }} 19 | runs-on: ubuntu-latest 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | include: 24 | - task: linting 25 | - task: unit-tests 26 | - task: sdist 27 | - task: todo-checks 28 | steps: 29 | - name: Checkout repository 30 | uses: actions/checkout@v1 31 | - name: Set up Python 3.7 32 | if: matrix.task != 'todo-checks' 33 | uses: s-weigand/setup-conda@v1 34 | with: 35 | python-version: 3.7 36 | - name: linting 37 | if: matrix.task == 'linting' 38 | shell: bash 39 | run: | 40 | pip install --upgrade black flake8 mypy pylint 41 | make lint 42 | - name: unit-tests 43 | if: matrix.task == 'unit-tests' 44 | shell: bash 45 | run: | 46 | pip install --upgrade cloudpickle pytest pytest-cov responses 47 | make unit-tests 48 | - name: test source distribution 49 | if: matrix.task == 'sdist' 50 | shell: bash 51 | run: | 52 | python setup.py sdist 53 | pip install dist/dask-pytorch-ddp-$(cat VERSION).tar.gz 54 | - name: todo-checks 55 | if: matrix.task == 'todo-checks' 56 | shell: bash 57 | run: | 58 | num_todos=$(git grep -i -E "TODO|FIXME" | wc -l) 59 | echo "found ${num_todos} TODOs in code" 60 | num_allowed=10 61 | if [[ $num_todos -gt $num_allowed ]]; then 62 | exit ${num_todos} 63 | else 64 | exit 0 65 | fi 66 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish dask-pytorch-ddp to PyPI 2 | on: 3 | release: 4 | types: 5 | - published 6 | jobs: 7 | build-and-publish: 8 | name: Build and publish dask-pytorch-ddp to PyPI 9 | runs-on: ubuntu-18.04 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 3.7 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.7 16 | - name: Install build dependencies 17 | run: >- 18 | python -m 19 | pip install 20 | setuptools 21 | wheel 22 | --upgrade 23 | --user 24 | - name: Build a binary wheel and a source tarball 25 | run: >- 26 | python 27 | setup.py 28 | sdist 29 | bdist_wheel 30 | - name: Publish distribution to PyPI 31 | uses: pypa/gh-action-pypi-publish@v1.3.1 32 | with: 33 | # Password is set in GitHub UI to an API secret for pypi 34 | user: __token__ 35 | password: ${{ secrets.pypi_api_key }} 36 | 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # contributing to dask-pytorch-ddp 2 | 3 | This document contains details on contributing to `dask-pytorch-ddp`. 4 | 5 | * [Documentation](#documentation) 6 | * [Installation](#installation) 7 | * [Testing](#testing) 8 | * [Releasing](#releasing) 9 | 10 | ## Documentation 11 | 12 | This repository is the main source of documentation for `dask-pytorch-ddp`. If you would like to request a feature, report a bug, or ask a question of the maintainers, please [create an issue](https://github.com/saturncloud/dask-pytorch-ddp/issues). 13 | 14 | For general documentation on Saturn Cloud and its components, please visit https://docs.saturncloud.io/en/. 15 | 16 | ## Installation 17 | 18 | To develop `dask-pytorch-ddp`, install it locally with the following command 19 | 20 | ```shell 21 | python setup.py develop 22 | ``` 23 | 24 | NOTE: If you have previously `pip install`'d `dask-pytorch-ddp`, some steps in this project might not work for you. Run `pip uninstall -y dask-pytorch-ddp` to remove any `pip install`'d versions. 25 | 26 | ## Testing 27 | 28 | Every commit to this repository is tested automatically using continuous integration (CI). All CI checks must pass for a pull request to be accepted. 29 | 30 | To try running the tests locally, run the following: 31 | 32 | ```shell 33 | make test 34 | ``` 35 | 36 | ### Linting 37 | 38 | `dask-pytorch-ddp` uses the following static analysis tools: 39 | 40 | * `black` 41 | * `flake8` 42 | * `mypy` 43 | 44 | ```shell 45 | make format 46 | make lint 47 | ``` 48 | 49 | ### Unit tests 50 | 51 | Unit tests for the project use `pytest` and `pytest-cov`. All tests are stored in the `tests/` directory. 52 | 53 | ```shell 54 | make unit-tests 55 | ``` 56 | 57 | The `unit-tests` recipe in `Makefile` includes a minimum code coverage threshold. All pull requests must pass all tests with more than this level of code coverage. The current coverage is reported in the results of `make unit-tests`. 58 | 59 | ### Integration tests 60 | 61 | `dask-pytorch-ddp`'s unit tests mock out its interactions with the rest of Saturn Cloud. Integration tests that test those interactions contain some sensitive information, and are stored in a private repository. 62 | 63 | If you experience issues using `dask-pytorch-ddp` and Saturn Cloud, please see [the Saturn documentation](#documentation) or contact us at by following the `Contact Us` navigation at https://www.saturncloud.io/s. 64 | 65 | ## Releasing 66 | 67 | This section describes how to release a new version of `dask-pytorch-ddp` to PyPi. It is intended only for maintainers. 68 | 69 | 1. Open a new pull request which bumps the version in `VERSION`. Merge that PR. 70 | 2. [Create a new release](https://github.com/saturncloud/dask-pytorch-ddp/releases/new) 71 | - the tag should be a version number, like `v0.0.1` 72 | - choose the target from "recent commits", and select the most recent commit on `main` 73 | 3. Once this release is created, a GitHub Actions build will automatically start. That build publishes a release to PyPi. 74 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Saturn Cloud Developers 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include VERSION 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | .PHONY: clean 3 | clean: 4 | rm -rf ./build 5 | rm -rf ./dist 6 | rm -rf ./mypy_cache 7 | rm -rf ./pytest_cache 8 | 9 | .PHONY: format 10 | format: 11 | black --line-length 100 . 12 | 13 | .PHONY: lint 14 | lint: 15 | flake8 --count --max-line-length 100 . 16 | black --check --diff --line-length 100 . 17 | mypy --ignore-missing-imports . 18 | # pylint disables: 19 | # * C0301: line too long 20 | # * C0103: snake-case naming 21 | # * C0330: wrong hanging indent before block 22 | # * E0401: unable to import 23 | # * R0903: too few public methods 24 | # * W0212: access to protected member 25 | pylint --disable=C0103,C0301,C0330,E0401,R0903,W0212 dask_pytorch_ddp/ 26 | 27 | .PHONY: unit-tests 28 | unit-tests: 29 | pip uninstall -y dask-pytorch-ddp 30 | python setup.py develop 31 | pytest --cov=dask_pytorch_ddp tests/ 32 | 33 | .PHONY: test 34 | test: clean lint unit-tests 35 | 36 | .PHONY: format 37 | @echo -e '\n\nCheck formatting with Black...' 38 | black --line-length 100 --exclude '/(\.vscode|node_modules)/' . 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dask-pytorch-ddp 2 | 3 | 4 | 5 | `dask-pytorch-ddp` is a Python package that makes it easy to train PyTorch models on Dask clusters using distributed data parallel. The intended scope of the project is 6 | - bootstrapping PyTorch workers on top of a Dask cluster 7 | - Using distributed data stores (e.g., S3) as normal PyTorch datasets 8 | - mechanisms for tracking and logging intermediate results, training statistics, and checkpoints. 9 | 10 | At this point, this library and examples provided are tailored to computer vision tasks, but this library is intended to be useful for any sort of PyTorch tasks. The only thing really specific to image processing is the `S3ImageFolder` dataset class. Implementing a PyTorch dataset (assuming map style random access) outside of images currently requires implementing `__getitem__(self, idx: int):` and `__len__(self):` We plan to add more varied examples for other use cases in the future, and welcome PRs extending functionality. 11 | 12 | ## Typical non-dask workflow 13 | 14 | A typical example of non-dask PyTorch usage is as follows: 15 | 16 | ### Loading Data 17 | Create an dataset (`ImageFolder`), and wrap it in a `DataLoader` 18 | 19 | ```python 20 | transform = transforms.Compose([ 21 | transforms.Resize(256), 22 | transforms.CenterCrop(250), 23 | transforms.ToTensor() 24 | ]) 25 | 26 | whole_dataset = ImageFolder(path, transform=transform) 27 | 28 | batch_size = 100 29 | num_workers = 64 30 | indices = list(range(len(data))) 31 | np.random.shuffle(indices) 32 | train_idx = indices[:num] 33 | test_idx = indices[num:num+num] 34 | 35 | train_sampler = SubsetRandomSampler(train_idx) 36 | train_loader = DataLoader(data, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers) 37 | ``` 38 | 39 | ### Training a Model 40 | Loop over the dataset, and train the model by stepping the optimizer 41 | 42 | ```python 43 | device = torch.device(0) 44 | net = models.resnet18(pretrained=False) 45 | model = net.to(device) 46 | device_ids = [0] 47 | 48 | criterion = nn.CrossEntropyLoss().cuda() 49 | lr = 0.001 50 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) 51 | count = 0 52 | for epoch in range(n_epochs): 53 | model.train() # Set model to training mode 54 | for inputs, labels in train_loader: 55 | inputs = inputs.to(device) 56 | labels = labels.to(device) 57 | outputs = model(inputs) 58 | _, preds = torch.max(outputs, 1) 59 | loss = criterion(outputs, labels) 60 | 61 | # zero the parameter gradients 62 | optimizer.zero_grad() 63 | loss.backward() 64 | optimizer.step() 65 | count += 1 66 | ``` 67 | 68 | ## Now on Dask 69 | 70 | With dask_pytorch_ddp and PyTorch Distributed Data Parallel, we can train on multiple workers as follows: 71 | 72 | ### Loading Data 73 | Load the dataset from S3, and explicitly set the multiprocessing context (Dask defaults to spawn, but pytorch is generally configured to use fork) 74 | 75 | ```python 76 | from dask_pytorch_ddp.data import S3ImageFolder 77 | 78 | whole_dataset = S3ImageFolder(bucket, prefix, transform=transform) 79 | train_loader = torch.utils.data.DataLoader( 80 | whole_dataset, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers, multiprocessing_context=mp.get_context('fork') 81 | ) 82 | ``` 83 | 84 | ### Training in Parallel 85 | 86 | Wrap the training loop in a function (and add metrics logging. Not necessary, but very useful). Convert the model into a PyTorch Distributed Data Parallel (`DDP`) model which knows how to sync gradients together across workers. 87 | 88 | ```python 89 | import uuid 90 | import pickle 91 | import logging 92 | import json 93 | 94 | 95 | key = uuid.uuid4().hex 96 | rh = DaskResultsHandler(key) 97 | 98 | def run_transfer_learning(bucket, prefix, samplesize, n_epochs, batch_size, num_workers, train_sampler): 99 | worker_rank = int(dist.get_rank()) 100 | device = torch.device(0) 101 | net = models.resnet18(pretrained=False) 102 | model = net.to(device) 103 | model = DDP(model, device_ids=[0]) 104 | 105 | criterion = nn.CrossEntropyLoss().cuda() 106 | lr = 0.001 107 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) 108 | whole_dataset = S3ImageFolder(bucket, prefix, transform=transform) 109 | 110 | train_loader = torch.utils.data.DataLoader( 111 | whole_dataset, 112 | sampler=train_sampler, 113 | batch_size=batch_size, 114 | num_workers=num_workers, 115 | multiprocessing_context=mp.get_context('fork') 116 | ) 117 | 118 | count = 0 119 | for epoch in range(n_epochs): 120 | # Each epoch has a training and validation phase 121 | model.train() # Set model to training mode 122 | for inputs, labels in train_loader: 123 | dt = datetime.datetime.now().isoformat() 124 | inputs = inputs.to(device) 125 | labels = labels.to(device) 126 | outputs = model(inputs) 127 | _, preds = torch.max(outputs, 1) 128 | loss = criterion(outputs, labels) 129 | 130 | # zero the parameter gradients 131 | optimizer.zero_grad() 132 | loss.backward() 133 | optimizer.step() 134 | count += 1 135 | 136 | # statistics 137 | rh.submit_result( 138 | f"worker/{worker_rank}/data-{dt}.json", 139 | json.dumps({'loss': loss.item(), 'epoch': epoch, 'count': count, 'worker': worker_rank}) 140 | ) 141 | if (count % 100) == 0 and worker_rank == 0: 142 | rh.submit_result(f"checkpoint-{dt}.pkl", pickle.dumps(model.state_dict())) 143 | 144 | ``` 145 | 146 | ## How does it work? 147 | 148 | `dask-pytorch-ddp` is largely a wrapper around existing `pytorch` functionality. `pytorch.distributed` provides infrastructure for [Distributed Data Parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) (DDP). 149 | 150 | In DDP, you create N workers, and the 0th worker is the "master", and coordinates the synchronization of buffers and gradients. In SGD, gradients are normally averaged between all data points in a batch. By running batches on multiple workers, and averaging the gradients, DDP enables you to run SGD with a much bigger batch size `(N * batch_size)` 151 | 152 | `dask-pytorch-ddp` sets some environment variables to configure the "master" host and port, and then calls `init_process_group` before training, and calls `destroy_process_group` after training. This is the same process normally done manually by the data scientist. 153 | 154 | ### Multi GPU machines 155 | `dask_cuda_worker` automatically rotates `CUDA_VISIBLE_DEVICES` for each worker it creates (typically one per GPU). As a result, your PyTorch code should always start with the 0th GPU. 156 | 157 | For example, if I have an 8 GPU machine, the 3rd worker will have `CUDA_VISIBLE_DEVICES` set to `2,3,4,5,6,7,0,1`. On that worker, if I call `torch.device(0)`, I will get GPU 2. 158 | 159 | ## What else? 160 | 161 | `dask-pytorch-ddp` also implements an S3 based `ImageFolder`. More distributed friendly datasets are planned. `dask-pytorch-ddp` also implements a basic results aggregation framework so that it is easy to collect training metrics across different workers. Currently, only `DaskResultsHandler` which leverages [Dask pub-sub communication protocols][1] is implemented, but an S3 based result handler is planned. 162 | 163 | [1]:https://docs.dask.org/en/latest/futures.html#publish-subscribe 164 | 165 | ## Some Notes 166 | 167 | Dask generally spawns processes. PyTorch generally forks. When using a multiprocessing enabled data loader, it is a good idea to pass the `Fork` multiprocessing context to force the use of Forking in the data loader. 168 | 169 | Some Dask deployments do not permit spawning processes. To override this, you can change the [distributed.worker.daemon](https://docs.dask.org/en/latest/configuration-reference.html#distributed.worker.daemon) setting. 170 | 171 | Environment variables are a convenient way to do this: 172 | 173 | ``` 174 | DASK_DISTRIBUTED__WORKER__DAEMON=False 175 | ``` 176 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.2.2 2 | -------------------------------------------------------------------------------- /dask_pytorch_ddp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saturncloud/dask-pytorch-ddp/1dac8c60e3574e99d2b2c79d403f3e8d1f1984fc/dask_pytorch_ddp/__init__.py -------------------------------------------------------------------------------- /dask_pytorch_ddp/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for loading PyTorch data in distributed environments 3 | """ 4 | 5 | 6 | import tempfile 7 | from os.path import basename, dirname 8 | from typing import List, Callable, Optional 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | 12 | 13 | """ 14 | In the following, we are explicitly avoiding s3fs because it does not behave well with 15 | multiprocessing (which is commonly used in PyTorch dataloaders). 16 | 17 | https://github.com/dask/s3fs/issues/369 18 | """ # pylint: disable=pointless-string-statement 19 | 20 | 21 | def _list_all_files(bucket: str, prefix: str, s3_client=None, anon=False) -> List[str]: 22 | """ 23 | Get list of all files from an s3 bucket matching a certain prefix 24 | """ 25 | import boto3 # pylint: disable=import-outside-toplevel 26 | from botocore import UNSIGNED # pylint: disable=import-outside-toplevel 27 | from botocore.client import Config # pylint: disable=import-outside-toplevel 28 | 29 | if s3_client is None: 30 | if anon: 31 | s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED)) 32 | else: 33 | s3_client = boto3.client("s3") 34 | 35 | paginator = s3_client.get_paginator("list_objects") 36 | all_files = [] 37 | for page in paginator.paginate(Bucket=bucket, Prefix=prefix): 38 | files = [x["Key"] for x in page["Contents"]] 39 | all_files.extend(files) 40 | return all_files 41 | 42 | 43 | def _read_s3_fileobj(bucket, path, fileobj, anon=False): 44 | """ 45 | read an obj from s3 to a file like object 46 | """ 47 | import boto3 # pylint: disable=import-outside-toplevel 48 | from botocore import UNSIGNED # pylint: disable=import-outside-toplevel 49 | from botocore.client import Config # pylint: disable=import-outside-toplevel 50 | 51 | if anon: 52 | s3 = boto3.resource("s3", config=Config(signature_version=UNSIGNED)) 53 | else: 54 | s3 = boto3.resource("s3") 55 | 56 | bucket = s3.Bucket(bucket) 57 | bucket.download_fileobj(path, fileobj) 58 | fileobj.seek(0) 59 | return fileobj 60 | 61 | 62 | def _load_image_obj(fileobj): 63 | """ 64 | turn a file like object into an image 65 | """ 66 | return Image.open(fileobj).convert("RGB") 67 | 68 | 69 | class S3ImageFolder(Dataset): 70 | """ 71 | An image folder that lives in S3. Directories containing the image are classes. 72 | """ 73 | 74 | # pylint: disable=too-many-instance-attributes 75 | # pylint: disable=too-many-arguments 76 | 77 | def __init__( 78 | self, 79 | s3_bucket: str, 80 | s3_prefix: str, 81 | transform: Optional[Callable] = None, 82 | target_transform: Optional[Callable] = None, 83 | anon: Optional[bool] = False, 84 | ): 85 | self.s3_bucket = s3_bucket 86 | self.s3_prefix = s3_prefix 87 | self.anon = anon 88 | self.all_files = _list_all_files(s3_bucket, s3_prefix, anon=anon) 89 | self.classes = sorted({self._get_class(x) for x in self.all_files}) 90 | self.class_to_idx = {k: idx for idx, k in enumerate(self.classes)} 91 | self.transform = transform 92 | self.target_transform = target_transform 93 | 94 | @classmethod 95 | def _get_class(cls, path): 96 | """ 97 | parse the path to extract the class name 98 | """ 99 | return basename(dirname(path)) 100 | 101 | def __getitem__(self, idx): 102 | """ 103 | get the nth (idx) image and label 104 | """ 105 | path = self.all_files[idx] 106 | label = self.class_to_idx[self._get_class(path)] 107 | with tempfile.TemporaryFile() as f: 108 | f = _read_s3_fileobj(self.s3_bucket, path, f, self.anon) 109 | img = _load_image_obj(f) 110 | if self.transform is not None: 111 | img = self.transform(img) 112 | if self.target_transform is not None: 113 | label = self.target_transform(label) 114 | return img, label 115 | 116 | def __len__(self): 117 | """ 118 | total number of images 119 | """ 120 | return len(self.all_files) 121 | -------------------------------------------------------------------------------- /dask_pytorch_ddp/dispatch.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the user-facing API to submit PyTorch jobs to a Dask cluster 3 | """ 4 | 5 | import os 6 | from typing import List, Callable, Any, Dict 7 | from dask.distributed import Client 8 | import torch.distributed as dist 9 | 10 | 11 | def _get_worker_info(client: Client) -> List[Dict]: 12 | """ 13 | returns a list of workers (sorted), and the DNS name for the master host 14 | The master is the 0th worker's host 15 | """ 16 | workers = client.scheduler_info()["workers"] 17 | worker_keys = sorted(workers.keys()) 18 | workers_by_host: Dict[str, List[str]] = {} 19 | for key in worker_keys: 20 | worker = workers[key] 21 | host = worker["host"] 22 | workers_by_host.setdefault(host, []).append(key) 23 | host = workers[worker_keys[0]]["host"] 24 | all_workers = [] 25 | global_rank = 0 26 | for host in sorted(workers_by_host.keys()): 27 | local_rank = 0 28 | for worker in workers_by_host[host]: 29 | all_workers.append( 30 | dict( 31 | worker=worker, 32 | local_rank=local_rank, 33 | global_rank=global_rank, 34 | host=host, 35 | ) 36 | ) 37 | local_rank += 1 38 | global_rank += 1 39 | return all_workers 40 | 41 | 42 | def run( 43 | client: Client, 44 | pytorch_function: Callable, 45 | *args, 46 | backend: str = "nccl", 47 | pass_local_rank: bool = False, 48 | **kwargs 49 | ): 50 | """ 51 | Dispatch a pytorch function over a dask cluster, and returns a list of futures 52 | for the resulting tasks 53 | """ 54 | all_workers = _get_worker_info(client) 55 | world_size = len(all_workers) 56 | port = 23456 # pick a free port? 57 | host = all_workers[0]["host"] 58 | futures = [] 59 | for worker in all_workers: 60 | if pass_local_rank: 61 | fut = client.submit( 62 | dispatch_with_ddp, 63 | pytorch_function=pytorch_function, 64 | master_addr=host, 65 | master_port=port, 66 | rank=worker["global_rank"], 67 | world_size=world_size, 68 | *args, 69 | local_rank=worker["local_rank"], 70 | backend=backend, 71 | workers=[worker["worker"]], 72 | **kwargs 73 | ) 74 | else: 75 | fut = client.submit( 76 | dispatch_with_ddp, 77 | pytorch_function=pytorch_function, 78 | master_addr=host, 79 | master_port=port, 80 | rank=worker["global_rank"], 81 | world_size=world_size, 82 | *args, 83 | backend=backend, 84 | workers=[worker["worker"]], 85 | **kwargs 86 | ) 87 | futures.append(fut) 88 | return futures 89 | 90 | 91 | # pylint: disable=too-many-arguments 92 | def dispatch_with_ddp( 93 | pytorch_function: Callable, 94 | master_addr: Any, 95 | master_port: Any, 96 | rank: Any, 97 | world_size: Any, 98 | *args, 99 | backend: str = "nccl", 100 | **kwargs 101 | ) -> Any: 102 | """ 103 | runs a pytorch function, setting up torch.distributed before execution 104 | and tearing it down afterwards. 105 | """ 106 | # These are the parameters used to initialize the process group 107 | master_addr = str(master_addr) 108 | master_port = str(master_port) 109 | rank = str(rank) 110 | world_size = str(world_size) 111 | 112 | os.environ["MASTER_ADDR"] = master_addr 113 | os.environ["MASTER_PORT"] = master_port 114 | os.environ["RANK"] = rank 115 | os.environ["WORLD_SIZE"] = world_size 116 | 117 | try: 118 | dist.init_process_group(backend=backend) 119 | val = pytorch_function(*args, **kwargs) 120 | finally: 121 | dist.destroy_process_group() 122 | return val 123 | -------------------------------------------------------------------------------- /dask_pytorch_ddp/results.py: -------------------------------------------------------------------------------- 1 | """ 2 | Infrastructure for retrieving and logging intermediate results from pytorch training jobs. 3 | 4 | Currently using dask pub/sub, but will create an S3 version in the future. 5 | """ 6 | import uuid 7 | import logging 8 | import os 9 | from typing import List, Optional 10 | from os.path import join, exists, dirname 11 | 12 | from distributed.pubsub import Pub, Sub 13 | from distributed.utils import TimeoutError as DistributedTimeoutError 14 | from distributed.client import wait, FIRST_COMPLETED, Future 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class DaskResultsHandler: 21 | """ 22 | This class use Dask pubsub infra to pass intermediate results back from PyTorch 23 | jobs to the client. 24 | """ 25 | 26 | def __init__(self, pub_sub_key: Optional[str] = None): 27 | """ 28 | pub_sub_key is an arbitrary string (topic) for the pub sub channel. 29 | It's a good idea to change it. Sometimes old topics can get "clogged" 30 | """ 31 | if pub_sub_key is None: 32 | pub_sub_key = uuid.uuid4().hex 33 | self.pub_sub_key = pub_sub_key 34 | 35 | @classmethod 36 | def _get_all(cls, sub: Sub): 37 | while True: 38 | try: 39 | yield sub.get(timeout=1.0) 40 | except DistributedTimeoutError: 41 | break 42 | 43 | def _get_results(self, futures: List[Future], raise_errors: bool = True): 44 | sub = Sub(self.pub_sub_key) 45 | while True: 46 | for obj in self._get_all(sub): 47 | yield obj 48 | if not futures: 49 | break 50 | try: 51 | result = wait(futures, 0.1, FIRST_COMPLETED) 52 | except DistributedTimeoutError: 53 | continue 54 | 55 | for fut in result.done: 56 | try: 57 | fut.result() 58 | except Exception as e: # pylint: disable=broad-except 59 | logging.exception(e) 60 | if raise_errors: 61 | raise 62 | futures = result.not_done 63 | 64 | def process_results( 65 | self, prefix: str, futures: List[Future], raise_errors: bool = True 66 | ) -> None: 67 | """ 68 | Process the intermediate results: 69 | result objects will be dictionaries of the form {'path': path, 'data': data} 70 | As results come in, data will be written to f"prefix/{path}" 71 | 72 | prefix: directory where you want results to be written 73 | futures: list of futures for your jobs (output of dask_pytorch_ddp.dispatch.run) 74 | raise_errors: If any of the jobs fail, either raise an exception, or log it and continue. 75 | """ 76 | for result in self._get_results(futures, raise_errors=raise_errors): 77 | path = result["path"] 78 | data = result["data"] 79 | fpath = join(prefix, path) 80 | if not exists(dirname(fpath)): 81 | os.makedirs(dirname(fpath)) 82 | if isinstance(data, str): 83 | data = data.encode("utf-8") 84 | with open(fpath, "wb+") as f: 85 | f.write(data) 86 | 87 | def submit_result(self, path: str, data: str): 88 | """ 89 | To be used in jobs. Call this function with a path, and some data. 90 | Client will write {data} to a file at {path} 91 | """ 92 | pub = Pub(self.pub_sub_key) 93 | pub.put({"path": path, "data": data}) 94 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_packages 3 | 4 | with open("README.md", "r") as f: 5 | readme = f.read() 6 | 7 | with open("VERSION", "r") as f: 8 | version = f.read().strip() 9 | 10 | 11 | install_requires = ["dask", "distributed", "pillow", "torch"] 12 | testing_deps = ["black", "flake8", "mypy", "pytest", "pytest-cov"] 13 | 14 | setup( 15 | name="dask-pytorch-ddp", 16 | version=version, 17 | maintainer="Saturn Cloud Developers", 18 | maintainer_email="open-source@saturncloud.io", 19 | license="BSD 3-clause", 20 | classifiers=[ 21 | "Development Status :: 3 - Alpha", 22 | "License :: OSI Approved :: BSD License", 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Science/Research", 25 | "Topic :: Scientific/Engineering :: Image Processing", 26 | "Topic :: Scientific/Engineering :: Image Recognition", 27 | "Topic :: Scientific/Engineering", 28 | "Topic :: System :: Distributed Computing", 29 | "Programming Language :: Python", 30 | "Programming Language :: Python :: 3", 31 | "Programming Language :: Python :: 3.7", 32 | ], 33 | keywords="saturn cloud dask pytorch torch", 34 | description="library for setting up torch DDP on a dask cluster", 35 | long_description=readme, 36 | long_description_content_type="text/markdown", 37 | url="https://github.com/saturncloud/dask-pytorch-ddp", 38 | project_urls={ 39 | "Documentation": "https://github.com/saturncloud/dask-pytorch-ddp", 40 | "Source": "https://github.com/saturncloud/dask-pytorch-ddp", 41 | "Issue Tracker": "https://github.com/saturncloud/dask-pytorch-ddp/issues", 42 | }, 43 | packages=find_packages(), 44 | install_requires=install_requires, 45 | python_requires=">=3.7", 46 | extras_require={"dev": install_requires + testing_deps}, 47 | test_suite="tests", 48 | tests_require=testing_deps, 49 | zip_safe=False, 50 | ) 51 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch, ANY 2 | 3 | 4 | from dask_pytorch_ddp.data import S3ImageFolder 5 | 6 | 7 | def test_image_folder_constructor(): 8 | fake_file_list = ["d/a.jpg", "c/b.jpg"] 9 | with patch("dask_pytorch_ddp.data._list_all_files", return_value=fake_file_list): 10 | fake_transform = Mock() 11 | fake_target_transform = Mock() 12 | folder = S3ImageFolder( 13 | "fake-bucket", 14 | "fake-prefix/fake-prefix", 15 | fake_transform, 16 | fake_target_transform, 17 | ) 18 | assert folder.all_files == fake_file_list 19 | assert folder.classes == ["c", "d"] 20 | assert folder.class_to_idx == {"c": 0, "d": 1} 21 | assert folder.transform == fake_transform 22 | assert folder.target_transform == fake_target_transform 23 | 24 | 25 | def test_image_folder_len(): 26 | fake_file_list = ["d/a.jpg", "c/b.jpg"] 27 | with patch("dask_pytorch_ddp.data._list_all_files", return_value=fake_file_list): 28 | folder = S3ImageFolder("fake-bucket", "fake-prefix/fake-prefix") 29 | assert len(folder) == 2 30 | 31 | 32 | def test_image_folder_getitem(): 33 | fake_file_list = ["d/a.jpg", "c/b.jpg"] 34 | with patch("dask_pytorch_ddp.data._list_all_files", return_value=fake_file_list): 35 | folder = S3ImageFolder("fake-bucket", "fake-prefix/fake-prefix") 36 | with patch("dask_pytorch_ddp.data._read_s3_fileobj") as read_s3_fileobj, patch( 37 | "dask_pytorch_ddp.data._load_image_obj" 38 | ) as load_image_obj: 39 | 40 | read_s3_fileobj.return_value = Mock() 41 | load_image_obj.return_value = Mock() 42 | val, label = folder[0] 43 | read_s3_fileobj.assert_called_once_with("fake-bucket", fake_file_list[0], ANY, False) 44 | load_image_obj.assert_called_once_with(read_s3_fileobj()) 45 | assert val == load_image_obj.return_value 46 | assert label == 1 47 | -------------------------------------------------------------------------------- /tests/test_dispatch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from unittest.mock import Mock, patch 4 | 5 | from dask_pytorch_ddp.dispatch import run, dispatch_with_ddp 6 | 7 | 8 | workers = { 9 | "tcp://1.2.3.4:8786": {"host": "1.2.3.4"}, 10 | "tcp://2.2.3.4:8786": {"host": "2.2.3.4"}, 11 | "tcp://3.2.3.4:8786": {"host": "3.2.3.4"}, 12 | "tcp://4.2.3.4:8786": {"host": "4.2.3.4"}, 13 | } 14 | host_name = sorted(workers.keys())[0] 15 | host = workers[host_name]["host"] 16 | 17 | 18 | def test_run(): 19 | client = Mock() 20 | client.scheduler_info = Mock(return_value={"workers": workers}) 21 | 22 | fake_pytorch_func = Mock() 23 | 24 | fake_results = [] 25 | worker_keys = sorted(workers.keys()) 26 | for idx, worker in enumerate(worker_keys): 27 | r = Mock() 28 | r.result = Mock(return_value=idx) 29 | fake_results.append(r) 30 | 31 | client.submit = Mock(side_effect=fake_results) 32 | output = run(client, fake_pytorch_func) 33 | 34 | client.submit.assert_any_call( 35 | dispatch_with_ddp, 36 | pytorch_function=fake_pytorch_func, 37 | master_addr=host, 38 | master_port=23456, 39 | rank=0, 40 | world_size=len(workers), 41 | workers=[worker_keys[0]], 42 | backend="nccl", 43 | ) 44 | client.submit.assert_any_call( 45 | dispatch_with_ddp, 46 | pytorch_function=fake_pytorch_func, 47 | master_addr=host, 48 | master_port=23456, 49 | rank=1, 50 | workers=[worker_keys[1]], 51 | world_size=len(workers), 52 | backend="nccl", 53 | ) 54 | client.submit.assert_any_call( 55 | dispatch_with_ddp, 56 | pytorch_function=fake_pytorch_func, 57 | master_addr=host, 58 | master_port=23456, 59 | rank=2, 60 | workers=[worker_keys[2]], 61 | world_size=len(workers), 62 | backend="nccl", 63 | ) 64 | client.submit.assert_any_call( 65 | dispatch_with_ddp, 66 | pytorch_function=fake_pytorch_func, 67 | master_addr=host, 68 | master_port=23456, 69 | rank=3, 70 | workers=[worker_keys[3]], 71 | world_size=len(workers), 72 | backend="nccl", 73 | ) 74 | assert output == fake_results 75 | 76 | 77 | def test_run_with_local_rank_simple(): 78 | client = Mock() 79 | client.scheduler_info = Mock(return_value={"workers": workers}) 80 | 81 | fake_pytorch_func = Mock() 82 | 83 | fake_results = [] 84 | worker_keys = sorted(workers.keys()) 85 | for idx, worker in enumerate(worker_keys): 86 | r = Mock() 87 | r.result = Mock(return_value=idx) 88 | fake_results.append(r) 89 | 90 | client.submit = Mock(side_effect=fake_results) 91 | output = run(client, fake_pytorch_func, pass_local_rank=True) 92 | 93 | client.submit.assert_any_call( 94 | dispatch_with_ddp, 95 | pytorch_function=fake_pytorch_func, 96 | master_addr=host, 97 | master_port=23456, 98 | rank=0, 99 | local_rank=0, 100 | world_size=len(workers), 101 | workers=[worker_keys[0]], 102 | backend="nccl", 103 | ) 104 | client.submit.assert_any_call( 105 | dispatch_with_ddp, 106 | pytorch_function=fake_pytorch_func, 107 | master_addr=host, 108 | master_port=23456, 109 | rank=1, 110 | local_rank=0, 111 | workers=[worker_keys[1]], 112 | world_size=len(workers), 113 | backend="nccl", 114 | ) 115 | client.submit.assert_any_call( 116 | dispatch_with_ddp, 117 | pytorch_function=fake_pytorch_func, 118 | master_addr=host, 119 | master_port=23456, 120 | rank=2, 121 | local_rank=0, 122 | workers=[worker_keys[2]], 123 | world_size=len(workers), 124 | backend="nccl", 125 | ) 126 | client.submit.assert_any_call( 127 | dispatch_with_ddp, 128 | pytorch_function=fake_pytorch_func, 129 | master_addr=host, 130 | master_port=23456, 131 | rank=3, 132 | local_rank=0, 133 | workers=[worker_keys[3]], 134 | world_size=len(workers), 135 | backend="nccl", 136 | ) 137 | assert output == fake_results 138 | 139 | 140 | def test_run_with_local_rank_complex(): 141 | workers = { 142 | "tcp://1.2.3.4:8786": {"host": "1.2.3.4"}, 143 | "tcp://1.2.3.4:8787": {"host": "1.2.3.4"}, 144 | "tcp://3.2.3.4:8786": {"host": "3.2.3.4"}, 145 | "tcp://3.2.3.4:8787": {"host": "3.2.3.4"}, 146 | } 147 | host_name = sorted(workers.keys())[0] 148 | host = workers[host_name]["host"] 149 | client = Mock() 150 | client.scheduler_info = Mock(return_value={"workers": workers}) 151 | 152 | fake_pytorch_func = Mock() 153 | 154 | fake_results = [] 155 | worker_keys = sorted(workers.keys()) 156 | for idx, worker in enumerate(worker_keys): 157 | r = Mock() 158 | r.result = Mock(return_value=idx) 159 | fake_results.append(r) 160 | 161 | client.submit = Mock(side_effect=fake_results) 162 | output = run(client, fake_pytorch_func, pass_local_rank=True) 163 | 164 | client.submit.assert_any_call( 165 | dispatch_with_ddp, 166 | pytorch_function=fake_pytorch_func, 167 | master_addr=host, 168 | master_port=23456, 169 | rank=0, 170 | local_rank=0, 171 | world_size=len(workers), 172 | workers=[worker_keys[0]], 173 | backend="nccl", 174 | ) 175 | client.submit.assert_any_call( 176 | dispatch_with_ddp, 177 | pytorch_function=fake_pytorch_func, 178 | master_addr=host, 179 | master_port=23456, 180 | rank=1, 181 | local_rank=1, 182 | workers=[worker_keys[1]], 183 | world_size=len(workers), 184 | backend="nccl", 185 | ) 186 | client.submit.assert_any_call( 187 | dispatch_with_ddp, 188 | pytorch_function=fake_pytorch_func, 189 | master_addr=host, 190 | master_port=23456, 191 | rank=2, 192 | local_rank=0, 193 | workers=[worker_keys[2]], 194 | world_size=len(workers), 195 | backend="nccl", 196 | ) 197 | client.submit.assert_any_call( 198 | dispatch_with_ddp, 199 | pytorch_function=fake_pytorch_func, 200 | master_addr=host, 201 | master_port=23456, 202 | rank=3, 203 | local_rank=1, 204 | workers=[worker_keys[3]], 205 | world_size=len(workers), 206 | backend="nccl", 207 | ) 208 | assert output == fake_results 209 | 210 | 211 | def test_dispatch_with_ddp(): 212 | pytorch_func = Mock() 213 | 214 | with patch.object(os, "environ", {}) as environ, patch( 215 | "dask_pytorch_ddp.dispatch.dist", return_value=Mock() 216 | ) as dist: 217 | dispatch_with_ddp( 218 | pytorch_func, 219 | "master_addr", 220 | 2343, 221 | 1, 222 | 10, 223 | "a", 224 | "b", 225 | backend="nccl", 226 | foo="bar", 227 | ) 228 | assert environ["MASTER_ADDR"] == "master_addr" 229 | assert environ["MASTER_PORT"] == "2343" 230 | assert environ["RANK"] == "1" 231 | assert environ["WORLD_SIZE"] == "10" 232 | 233 | dist.init_process_group.assert_called() 234 | dist.destroy_process_group.assert_called() 235 | 236 | pytorch_func.assert_called_once_with("a", "b", foo="bar") 237 | -------------------------------------------------------------------------------- /tests/test_results.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | import pickle 3 | 4 | from pytest import raises 5 | 6 | from dask_pytorch_ddp.results import DaskResultsHandler 7 | from distributed.utils import TimeoutError # pylint: disable=redefined-builtin 8 | 9 | 10 | class FakeException(Exception): 11 | pass 12 | 13 | 14 | def test_dask_results_handler_constructor(): 15 | handler1 = DaskResultsHandler() 16 | handler2 = DaskResultsHandler() 17 | assert handler1.pub_sub_key != handler2.pub_sub_key 18 | 19 | 20 | def test_dask_results_handler_pickle(): 21 | handler = DaskResultsHandler() 22 | handler2 = pickle.loads(pickle.dumps(handler)) 23 | assert handler2.pub_sub_key == handler.pub_sub_key 24 | 25 | 26 | def test_get_all_futures(): 27 | sub = Mock() 28 | real_results = [{"path", "a", "data", "b"}, {"path", "b", "data", "c"}] 29 | sub.get = Mock(side_effect=real_results + [TimeoutError]) 30 | results = list(DaskResultsHandler._get_all(sub)) 31 | assert results == real_results 32 | 33 | 34 | def mock_waiting_result(done=None, not_done=None): 35 | if not done: 36 | done = [] 37 | if not not_done: 38 | not_done = [] 39 | result = Mock() 40 | result.done = done 41 | result.not_done = not_done 42 | return result 43 | 44 | 45 | def fake_future(result): 46 | future = Mock() 47 | future.result = Mock(return_value=result) 48 | return future 49 | 50 | 51 | def fake_error_future(error): 52 | future = Mock() 53 | future.result = Mock(side_effect=error) 54 | return future 55 | 56 | 57 | def test_get_results_retrieves_all_data(): 58 | with patch.object(DaskResultsHandler, "_get_all") as _get_all, patch( 59 | "dask_pytorch_ddp.results.wait" 60 | ) as wait, patch("dask_pytorch_ddp.results.Sub"): 61 | _get_all.side_effect = [["a", "b"], ["c", "d", "e"], ["f", "g"]] 62 | wait.side_effect = [TimeoutError, mock_waiting_result()] 63 | result = DaskResultsHandler(None) 64 | fake_futures = ["one", "two"] 65 | results = list(result._get_results(fake_futures)) 66 | assert results == ["a", "b", "c", "d", "e", "f", "g"] 67 | 68 | 69 | def test_get_results_throws_exceptions(): 70 | with patch.object(DaskResultsHandler, "_get_all") as _get_all, patch( 71 | "dask_pytorch_ddp.results.wait" 72 | ) as wait, patch("dask_pytorch_ddp.results.Sub"): 73 | _get_all.side_effect = [["a", "b"], ["c", "d", "e"], ["f", "g"]] 74 | wait.side_effect = [ 75 | mock_waiting_result( 76 | done=[fake_future(None), fake_error_future(FakeException("hello"))] 77 | ), 78 | ] 79 | result = DaskResultsHandler(None) 80 | fake_futures = ["one", "two"] 81 | with raises(FakeException): 82 | list(result._get_results(fake_futures)) 83 | 84 | 85 | def test_get_results_masks_exceptions(): 86 | with patch.object(DaskResultsHandler, "_get_all") as _get_all, patch( 87 | "dask_pytorch_ddp.results.wait" 88 | ) as wait, patch("dask_pytorch_ddp.results.Sub"): 89 | _get_all.side_effect = [["a", "b"], ["c", "d", "e"]] 90 | wait.side_effect = [ 91 | mock_waiting_result( 92 | done=[fake_future(None), fake_error_future(FakeException("hello"))] 93 | ), 94 | ] 95 | result = DaskResultsHandler(None) 96 | fake_futures = ["one", "two"] 97 | results = list(result._get_results(fake_futures, raise_errors=False)) 98 | assert results == ["a", "b", "c", "d", "e"] 99 | --------------------------------------------------------------------------------