├── spacy_ray ├── tests │ ├── __init__.py │ ├── mock_ray.py │ └── test_worker.py ├── __init__.py ├── util.py ├── loggers.py ├── train_cli.py ├── proxies.py └── worker.py ├── MANIFEST.in ├── requirements.txt ├── setup.py ├── bin ├── push-tag.sh └── get-data.sh ├── LICENSE ├── README.md ├── setup.cfg ├── .github └── workflows │ └── tests.yml └── .gitignore /spacy_ray/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | -------------------------------------------------------------------------------- /spacy_ray/__init__.py: -------------------------------------------------------------------------------- 1 | from . import train_cli # noqa 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ray>=0.8,<1.0.0 2 | spacy>=3.1.0,<3.2.0 3 | pytest 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | if __name__ == "__main__": 4 | from setuptools import setup, find_packages 5 | 6 | setup(name="spacy_ray", packages=find_packages()) 7 | -------------------------------------------------------------------------------- /spacy_ray/tests/mock_ray.py: -------------------------------------------------------------------------------- 1 | def get(*args, **kwargs): 2 | return None 3 | 4 | 5 | def init(*args, **kwargs): 6 | return None 7 | 8 | 9 | def remote(*args, **kwargs): 10 | return None 11 | -------------------------------------------------------------------------------- /bin/push-tag.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | # Insist repository is clean 6 | git diff-index --quiet HEAD 7 | git checkout master 8 | git pull origin master 9 | git push origin master 10 | 11 | version=$(grep "version = " setup.cfg) 12 | version=${version/version = } 13 | git tag "v$version" 14 | git push origin "v$version" 15 | -------------------------------------------------------------------------------- /bin/get-data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | mkdir -p tmp 5 | cd tmp 6 | wget https://raw.githubusercontent.com/explosion/projects/master/ner-fashion-brands/fashion_brands_training.jsonl 7 | wget https://raw.githubusercontent.com/explosion/projects/master/ner-fashion-brands/fashion_brands_eval.jsonl 8 | 9 | cd .. 10 | 11 | mkdir -p examples/fashion-ner2 12 | python -m spacy convert tmp/fashion_brands_training.jsonl examples/fashion-ner --lang en 13 | python -m spacy convert tmp/fashion_brands_eval.jsonl examples/fashion-ner --lang en 14 | -------------------------------------------------------------------------------- /spacy_ray/tests/test_worker.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | from spacy.language import Language 3 | 4 | from . import mock_ray 5 | from ..worker import Worker 6 | 7 | # I don't normally go for so much mocking; I usually think it's pointless to 8 | # test something so different. But this time it seems worth trying, just to 9 | # get some basic tests, which would otherwise be hard. 10 | 11 | # These tests currently really do nothing, I'm just laying out a skeleton while 12 | # I think of what could be meaningfully tested this way. Maybe I'll delete it 13 | # all in the end. 14 | 15 | 16 | class TestWorker(Worker): 17 | def _load_nlp_and_config(self, config): 18 | return None, None 19 | 20 | def _initialize_models(self, nlp, config): 21 | return None, None 22 | 23 | 24 | def test_worker_init(): 25 | # Get a blank valid config 26 | nlp = spacy.blank("en") 27 | nlp.config["paths"]["train"] = "" 28 | nlp.config["paths"]["dev"] = "" 29 | worker = Worker(nlp.config, rank=1, num_workers=2, use_gpu=-1, ray=mock_ray) 30 | assert isinstance(worker.nlp, Language) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ExplosionAI GmbH 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # spacy-ray: Parallel and distributed training with spaCy and Ray 4 | 5 | > ⚠️ This repo is still a work in progress and requires the new **spaCy v3.0**. 6 | 7 | [Ray](https://ray.io/) is a fast and simple framework for building and running 8 | **distributed applications**. This very lightweight extension package lets you 9 | use Ray for parallel and distributed training with [spaCy](https://spacy.io). If 10 | `spacy-ray` is installed in the same environment as spaCy, it will automatically 11 | add `spacy ray` commands to your spaCy CLI. 12 | 13 | The main command is `spacy ray train` for parallel and distributed training, but 14 | we expect to add `spacy ray pretrain` and `spacy ray parse` as well. 15 | 16 | [![tests](https://github.com/explosion/spacy-ray/actions/workflows/tests.yml/badge.svg)](https://github.com/explosion/spacy-ray/actions/workflows/tests.yml) 17 | [![Current Release Version](https://img.shields.io/github/v/release/explosion/spacy-ray.svg?include_prereleases&sort=semver&style=flat-square&logo=github)](https://github.com/explosion/spacy-ray/releases) 18 | [![PyPi Version](https://img.shields.io/pypi/v/spacy-ray.svg?include_prereleases&sort=semver&style=flat-square&logo=pypi&logoColor=white)](https://pypi.python.org/pypi/spacy-ray) 19 | 20 | ## 🚀 Quickstart 21 | 22 | You can install `spacy-ray` from pip: 23 | 24 | ```bash 25 | pip install spacy-ray 26 | ``` 27 | 28 | To check if the command has been registered successfully: 29 | 30 | ```bash 31 | python -m spacy ray --help 32 | ``` 33 | 34 | Train a model using the same API as `spacy train`: 35 | 36 | ```bash 37 | python -m spacy ray train config.cfg --n-workers 2 38 | ``` 39 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | version = 0.1.4 3 | description = Parallel and distributed training with spaCy and Ray 4 | url = https://spacy.io 5 | author = Explosion 6 | author_email = contact@explosion.ai 7 | license = MIT 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | classifiers = 11 | Development Status :: 4 - Beta 12 | Environment :: Console 13 | Intended Audience :: Developers 14 | Intended Audience :: Science/Research 15 | Topic :: Scientific/Engineering 16 | Topic :: Scientific/Engineering :: Artificial Intelligence 17 | License :: OSI Approved :: MIT License 18 | Operating System :: POSIX :: Linux 19 | Operating System :: MacOS :: MacOS X 20 | Operating System :: Microsoft :: Windows 21 | Programming Language :: Python :: 3 22 | Programming Language :: Python :: 3.6 23 | Programming Language :: Python :: 3.7 24 | Programming Language :: Python :: 3.8 25 | Programming Language :: Python :: 3.9 26 | 27 | [options] 28 | zip_safe = true 29 | include_package_data = true 30 | python_requires = >=3.6 31 | install_requires = 32 | ray>=0.8,<1.0.0 33 | spacy>=3.1.0,<3.2.0 34 | 35 | [options.entry_points] 36 | # This is a sneaky lie: we're only doing this to get spaCy to import the 37 | # package, so that it installs our CLI extension. 38 | spacy_cli = 39 | ray = spacy_ray.train_cli:ray_cli 40 | spacy_loggers = 41 | spacy-ray.ConsoleLogger.v1 = spacy_ray.loggers:ray_console_logger 42 | 43 | [bdist_wheel] 44 | universal = true 45 | 46 | [sdist] 47 | formats = gztar 48 | 49 | [flake8] 50 | ignore = E203, E266, E501, E731, W503 51 | max-line-length = 80 52 | select = B,C,E,F,W,T4,B9 53 | exclude = 54 | .env, 55 | .git, 56 | __pycache__, 57 | 58 | [mypy] 59 | ignore_missing_imports = True 60 | no_implicit_optional = True 61 | plugins = pydantic.mypy, thinc.mypy 62 | -------------------------------------------------------------------------------- /spacy_ray/util.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict 2 | import time 3 | from collections import defaultdict 4 | 5 | 6 | KeyT = Tuple[int, str] 7 | 8 | 9 | class Timer: 10 | state: str 11 | sum: int 12 | n: int 13 | 14 | def __init__(self, state: str): 15 | self.state = state 16 | self.sum = 0 17 | self.n = 0 18 | 19 | def __enter__(self): 20 | self.start = time.time() 21 | self.n += 1 22 | return self 23 | 24 | def __exit__(self, *args): 25 | interval = time.time() - self.start 26 | self.sum += interval 27 | 28 | 29 | class ManyTimer: 30 | timers: Dict[str, Timer] 31 | 32 | def __init__(self): 33 | self.timers = {} 34 | 35 | def __call__(self, key: str) -> Timer: 36 | if key not in self.timers: 37 | self.timers[key] = Timer(key) 38 | return self.timers[key] 39 | 40 | 41 | def set_params_proxy(model, proxy): 42 | """Set a 'proxy' on the internal ParamServer object for the model and 43 | its children. Experimental. 44 | """ 45 | for node in model.walk(): 46 | node._params.proxy = None 47 | for name in node.param_names: 48 | if node.has_param(name): 49 | proxy.set_param(node.id, name, node.get_param(name)) 50 | node._params.proxy = proxy 51 | 52 | 53 | def make_key(model_id: int, name: str) -> Tuple[int, str]: 54 | return (model_id, name) 55 | 56 | 57 | def divide_params(model, num_workers): 58 | keys_by_node = defaultdict(list) 59 | for node in model.walk(): 60 | keys = [make_key(node.id, name) for name in node.param_names] 61 | if keys: 62 | keys_by_node[node.id].extend(keys) 63 | key_groups = list(keys_by_node.values()) 64 | n = max(1, len(key_groups) // num_workers) 65 | worker_keys = [] 66 | start = 0 67 | for i in range(num_workers): 68 | worker_keys.append([]) 69 | for kg in key_groups[start : start + n]: 70 | worker_keys[-1].extend(kg) 71 | start += n 72 | for kg in key_groups[start:]: 73 | worker_keys[-1].extend(kg) 74 | assert len(worker_keys) == num_workers, (len(worker_keys), num_workers) 75 | return worker_keys 76 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | paths-ignore: 6 | - "*.md" 7 | pull_request: 8 | types: [opened, synchronize, reopened, edited] 9 | paths-ignore: 10 | - "*.md" 11 | 12 | env: 13 | MODULE_NAME: 'spacy_ray' 14 | RUN_MYPY: 'false' 15 | 16 | jobs: 17 | tests: 18 | name: Test 19 | if: github.repository_owner == 'explosion' 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | os: [ubuntu-latest, windows-latest, macos-latest] 24 | python_version: ["3.7", "3.8", "3.9"] 25 | include: 26 | - os: windows-2019 27 | python_version: "3.6" 28 | - os: ubuntu-20.04 29 | python_version: "3.6" 30 | runs-on: ${{ matrix.os }} 31 | 32 | steps: 33 | - name: Check out repo 34 | uses: actions/checkout@v3 35 | 36 | - name: Configure Python version 37 | uses: actions/setup-python@v4 38 | with: 39 | python-version: ${{ matrix.python_version }} 40 | architecture: x64 41 | 42 | - name: Build sdist 43 | run: | 44 | python -m pip install -U build pip setuptools 45 | python -m pip install -U -r requirements.txt 46 | python -m build --sdist 47 | 48 | - name: Run mypy 49 | shell: bash 50 | if: ${{ env.RUN_MYPY == 'true' }} 51 | run: | 52 | python -m mypy $MODULE_NAME 53 | 54 | - name: Delete source directory 55 | shell: bash 56 | run: | 57 | rm -rf $MODULE_NAME 58 | 59 | - name: Uninstall all packages 60 | run: | 61 | python -m pip freeze > installed.txt 62 | python -m pip uninstall -y -r installed.txt 63 | 64 | - name: Install from sdist 65 | shell: bash 66 | run: | 67 | SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1) 68 | python -m pip install dist/$SDIST 69 | 70 | - name: Test import 71 | shell: bash 72 | run: | 73 | python -c "import $MODULE_NAME" -Werror 74 | 75 | - name: Install test requirements 76 | run: | 77 | python -m pip install -U -r requirements.txt 78 | 79 | - name: Run tests 80 | shell: bash 81 | run: | 82 | python -m pytest --pyargs $MODULE_NAME -Werror 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | env3.8 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # vim 13 | .*.sw* 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # celery beat schedule file 100 | celerybeat-schedule 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # Pycharm project files 133 | *.idea 134 | -------------------------------------------------------------------------------- /spacy_ray/loggers.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Tuple, Callable 2 | from datetime import timedelta 3 | from spacy.util import registry 4 | from spacy.errors import Errors 5 | from wasabi import msg 6 | 7 | 8 | @registry.loggers("spacy-ray.ConsoleLogger.v1") 9 | def ray_console_logger(): 10 | def setup_printer( 11 | nlp: "Language", 12 | ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: 13 | score_cols = list(nlp.config["training"]["score_weights"]) 14 | score_widths = [max(len(col), 6) for col in score_cols] 15 | loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names] 16 | loss_widths = [max(len(col), 8) for col in loss_cols] 17 | table_header = ["T", "E", "#", "W"] + loss_cols + score_cols + ["Score"] 18 | table_header = [col.upper() for col in table_header] 19 | table_widths = [8, 3, 6, 6] + loss_widths + score_widths + [6] 20 | table_aligns = ["r" for _ in table_widths] 21 | msg.row(table_header, widths=table_widths) 22 | msg.row(["-" * width for width in table_widths]) 23 | 24 | def log_step(info: Dict[str, Any]): 25 | try: 26 | losses = [ 27 | "{0:.2f}".format(float(info["losses"][pipe_name])) 28 | for pipe_name in nlp.pipe_names 29 | ] 30 | except KeyError as e: 31 | raise KeyError( 32 | Errors.E983.format( 33 | dict="scores (losses)", 34 | key=str(e), 35 | keys=list(info["losses"].keys()), 36 | ) 37 | ) from None 38 | 39 | try: 40 | scores = [ 41 | "{0:.2f}".format(float(info["other_scores"].get(col, 0.0)) * 100) 42 | for col in score_cols 43 | ] 44 | except KeyError as e: 45 | raise KeyError( 46 | Errors.E983.format( 47 | dict="scores (other)", 48 | key=str(e), 49 | keys=list(info["other_scores"].keys()), 50 | ) 51 | ) from None 52 | time = timedelta(seconds=info["seconds"]) 53 | data = ( 54 | [str(time), info["epoch"], info["step"], info["words"]] 55 | + losses 56 | + scores 57 | + ["{0:.2f}".format(float(info["score"]))] 58 | ) 59 | msg.row(data, widths=table_widths, aligns=table_aligns) 60 | 61 | def finalize(): 62 | pass 63 | 64 | return log_step, finalize 65 | 66 | return setup_printer 67 | -------------------------------------------------------------------------------- /spacy_ray/train_cli.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import time 3 | import typer 4 | import logging 5 | from pathlib import Path 6 | from spacy.util import load_config, logger 7 | from spacy.cli._util import parse_config_overrides, Arg, Opt, app 8 | from spacy.cli._util import setup_gpu, show_validation_error 9 | from thinc.api import Config 10 | 11 | from .worker import Worker, Evaluator 12 | 13 | 14 | RAY_HELP = """CLI for parallel and distributed computing via 15 | Ray. See the Ray documentation for details: https://ray.io. 16 | """ 17 | 18 | # Create our subcommand, and install it within spaCy's CLI 19 | ray_cli = typer.Typer(name="ray", help=RAY_HELP, no_args_is_help=True) 20 | app.add_typer(ray_cli) 21 | 22 | 23 | @ray_cli.command( 24 | "train", context_settings={"allow_extra_args": True, "ignore_unknown_options": True} 25 | ) 26 | def ray_train_cli( 27 | # fmt: off 28 | ctx: typer.Context, # This is only used to read additional arguments 29 | config_path: Path = Arg(..., help="Path to config file", exists=True), 30 | code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"), 31 | output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory or remote storage URL for saving trained pipeline"), 32 | num_workers: int = Opt(1, "--n-workers", "-w", help="Number of workers"), 33 | ray_address: Optional[str] = Opt(None, "--address", "-a", help="Address of ray cluster"), 34 | use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"), 35 | verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"), 36 | # fmt: on 37 | ): 38 | """ 39 | Train a spaCy pipeline using Ray for parallel training. 40 | """ 41 | # TODO: wire up output path 42 | logger.setLevel(logging.DEBUG if verbose else logging.ERROR) 43 | setup_gpu(use_gpu) 44 | overrides = parse_config_overrides(ctx.args) 45 | with show_validation_error(config_path): 46 | config = load_config(config_path, overrides=overrides, interpolate=False) 47 | ray_train( 48 | config, 49 | ray_address=ray_address, 50 | num_workers=num_workers, 51 | use_gpu=use_gpu, 52 | code_path=code_path, 53 | ) 54 | 55 | 56 | def ray_train( 57 | config: Config, 58 | *, 59 | ray_address: Optional[str] = None, 60 | num_workers: int = 1, 61 | use_gpu: int = -1, 62 | code_path: Optional[Path] = None, 63 | ) -> None: 64 | # We're importing Ray here so it doesn't need to be imported when spaCy / 65 | # spaCy's CLI is imported (which would otherwise take too long) 66 | import ray 67 | 68 | if ray_address is not None: 69 | ray.init(address=ray_address) 70 | else: 71 | ray.init(ignore_reinit_error=True) 72 | RemoteWorker = ray.remote(Worker).options(num_gpus=int(use_gpu >= 0), num_cpus=2) 73 | workers = [ 74 | RemoteWorker.remote( 75 | config, 76 | rank=rank, 77 | num_workers=num_workers, 78 | use_gpu=use_gpu, 79 | code_path=code_path, 80 | ) 81 | for rank in range(num_workers) 82 | ] 83 | for worker in workers: 84 | ray.get(worker.set_proxy.remote(workers)) 85 | evaluator = ray.remote(Evaluator).remote() 86 | for worker in workers: 87 | ray.get(worker.train.remote(workers, evaluator)) 88 | todo = list(workers) 89 | while todo: 90 | time.sleep(1) 91 | todo = [w for w in workers if ray.get(w.is_running.remote())] 92 | -------------------------------------------------------------------------------- /spacy_ray/proxies.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Set, Iterable, Any, Optional, cast 2 | from collections import Counter 3 | from thinc.types import FloatsXd 4 | from thinc.api import Optimizer 5 | 6 | from .util import make_key, KeyT 7 | 8 | 9 | class RayPeerProxy: 10 | """Proxy for workers where each worker owns some of the parameters. For 11 | parameters they don't own, workers will pull parameters and push gradients. 12 | For parameters they do own, they pull gradients, make the update, and push 13 | parameters. 14 | """ 15 | 16 | ray: Any 17 | optimizer: Optimizer 18 | grads_per_update: int 19 | peers: Dict 20 | other_workers: Set 21 | _params: Dict[KeyT, FloatsXd] 22 | _grads: Dict[KeyT, Optional[FloatsXd]] 23 | _versions: Dict[KeyT, int] 24 | _owned_keys: Set[KeyT] 25 | _grad_counts: Dict[KeyT, int] 26 | 27 | def __init__( 28 | self, 29 | peers: Dict[KeyT, Any], 30 | optimizer, 31 | keys: Iterable[KeyT], 32 | *, 33 | grads_per_update: int = 2, 34 | ray=None 35 | ): 36 | if ray is None: 37 | import ray # type: ignore 38 | # Pass in 'ray' so that we can test with a mock object. 39 | self.ray = ray 40 | self.optimizer = optimizer 41 | self.grads_per_update = grads_per_update 42 | self.peers = dict(peers) 43 | self._owned_keys = set(keys) 44 | self.other_workers = set() 45 | for key, peer in self.peers.items(): 46 | if key not in self._owned_keys and peer not in self.other_workers: 47 | self.other_workers.add(peer) 48 | self._params = {} 49 | self._versions = Counter() 50 | self._next_params = {} 51 | self._grads = {} 52 | self._grad_counts = Counter() 53 | 54 | def check_version(self, key: KeyT, version: int) -> Optional[bool]: 55 | if key not in self._versions: 56 | return None 57 | elif self._versions[key] != version: 58 | return False 59 | else: 60 | return True 61 | 62 | def set_param(self, id, name, value: FloatsXd) -> None: 63 | """Set a parameter to the connection.""" 64 | key = make_key(id, name) 65 | if key in self._owned_keys or key not in self._params: 66 | self._params[key] = value 67 | self._versions[key] += 1 68 | self._grads[key] = None 69 | self._grad_counts[key] = 0 70 | 71 | def send_param(self, key): 72 | param = self._params[key] 73 | version = self._versions[key] 74 | for peer in self.other_workers: 75 | peer.set_param.remote(key, version, param) 76 | 77 | def receive_param(self, key, version, value: FloatsXd) -> None: 78 | """Let the connection push a parameter to us.""" 79 | # We have to store this in a separate place, to make sure we don't 80 | # fetch the wrong version when we submit the gradient. For instance, 81 | # imagine if we received the param in between the forward and backward 82 | # pass. If we set the version to this one, we'd calculate a gradient 83 | # on the basis of the old param, but think we had a new version. 84 | self._next_params[key] = (version, value) 85 | 86 | def get_param(self, id, name) -> FloatsXd: 87 | key = make_key(id, name) 88 | self._maybe_update_param(key) 89 | return self._params[key] 90 | 91 | def set_grad(self, id, name, value: FloatsXd) -> None: 92 | """Set a gradient to the connection.""" 93 | key = make_key(id, name) 94 | if key in self._owned_keys: 95 | self._grads[key] = value 96 | self._grad_counts[key] = 1 97 | 98 | def inc_grad(self, id, name, value: FloatsXd) -> None: 99 | """Increment a gradient to the connection.""" 100 | key = make_key(id, name) 101 | self._grad_counts[key] += 1 102 | if key not in self._owned_keys: 103 | peer = self.peers[key] 104 | peer.inc_grad.remote(key, self._versions[key], value) 105 | else: 106 | if self._grads.get(key) is None: 107 | self._grads[key] = value.copy() 108 | else: 109 | self._grads[key] += value 110 | 111 | def _maybe_update_param(self, key: KeyT) -> bool: 112 | if key in self._next_params: 113 | version, value = self._next_params.pop(key) 114 | self._params[key] = value 115 | self._versions[key] = version 116 | self._grad_counts[key] = 0 117 | self._grads[key] = None 118 | return True 119 | elif key not in self._owned_keys: 120 | return False 121 | elif self._grad_counts[key] < self.grads_per_update: 122 | return False 123 | elif self._grads.get(key) is None: 124 | return False 125 | else: 126 | grad = cast(FloatsXd, self._grads[key]) 127 | self._versions[key] += 1 128 | param, _ = self.optimizer(key, self._params[key], grad) 129 | self._params[key] = param 130 | self._grads[key] = None 131 | self._grad_counts[key] = 0 132 | self.send_param(key) 133 | return True 134 | -------------------------------------------------------------------------------- /spacy_ray/worker.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Union, Any, Optional 2 | import time 3 | import os 4 | import threading 5 | from pathlib import Path 6 | from thinc.config import Config 7 | from thinc.types import FloatsXd 8 | from spacy.cli._util import import_code 9 | from spacy.training.loop import train_while_improving, create_train_batches 10 | from spacy.training.loop import create_evaluation_callback 11 | from spacy.training.loop import create_before_to_disk_callback 12 | from spacy.training.loop import update_meta 13 | from spacy.training.initialize import init_nlp 14 | from spacy.language import Language 15 | from spacy.util import registry, logger, resolve_dot_names 16 | from spacy.schemas import ConfigSchemaTraining 17 | from thinc.api import require_gpu, set_gpu_allocator 18 | 19 | from .proxies import RayPeerProxy 20 | from .util import set_params_proxy, divide_params, KeyT 21 | 22 | 23 | class Worker: 24 | """Actor class for spaCy parallel training. 25 | 26 | Okay this is pretty mind-bending stuff...But the idea is that the remote 27 | workers need to communicate directly, to avoid extra copies. The mechanics 28 | of this are super twisted though, because it mixes all sorts of levels. But 29 | it has the be *exactly this* worker object that is passed through, because 30 | we need the remote actor. That's why it's so strange. 31 | 32 | The workers install a "proxy" object into the Thinc models. When this object 33 | is installed, Thinc will direct calls to get_param, inc_grad, set_param etc 34 | through the proxy. 35 | 36 | On each worker, a subset of the weights will be "local". The proxy will 37 | receive a list of the keys that are local, and a mapping of keys to workers. 38 | 39 | Workers optimize the parameters that are 'local' to them, and then push 40 | the updates to the other workers. For parameters that aren't local, they 41 | find the worker that owns that parameter, and publish the gradient to it. 42 | 43 | This strategy is non-blocking, because the gradients and the parameters 44 | are both pushed, not pulled. 45 | 46 | In order to make this work, we need some concurrency within the workers, 47 | because the workers need to be listening for updates while continuing to 48 | work. Currently this is implemented by putting the main training work 49 | on a thread, and letting the main thread continue to listen for connections. 50 | 51 | Finally, not that there's a pretty tangled circular reference here. I hate 52 | circular references, it makes the code hard to understand and makes 53 | Python use GC. But the circular reference here is necessary: 54 | 55 | * Workers hold a reference to the nlp object. Within the nlp object, models 56 | hold references to the "proxy" object. 57 | * The proxy object holds a reference to the peer mapping, whose values are 58 | the workers. 59 | """ 60 | 61 | rank: int 62 | num_workers: int 63 | gpu_id: int 64 | nlp: Language 65 | config: Union[Dict[str, Any], Config] 66 | proxy: Optional[RayPeerProxy] 67 | thread: Optional[threading.Thread] 68 | _results: List 69 | _evaluation_callback: Any 70 | 71 | def __init__( 72 | self, 73 | config: Config, 74 | *, 75 | rank: int = 0, 76 | num_workers: int = 1, 77 | use_gpu: int = 0, 78 | code_path: Optional[Path] = None, 79 | ray=None, 80 | ): 81 | if ray is None: 82 | # Avoid importing ray in the module. This allows a test-ray to 83 | # be passed in, and speeds up the CLI. 84 | import ray # type: ignore 85 | 86 | self.ray = ray 87 | import_code(code_path) 88 | self.rank = rank 89 | self.num_workers = num_workers 90 | self.gpu_id = self._resolve_gpu(use_gpu) 91 | self.nlp = init_nlp(Config(config), use_gpu=self.gpu_id) 92 | config = self.nlp.config.interpolate() 93 | self.T = registry.resolve(config["training"], schema=ConfigSchemaTraining) 94 | dot_names = [self.T["train_corpus"], self.T["dev_corpus"]] 95 | self.train_corpus, self.dev_corpus = resolve_dot_names(config, dot_names) 96 | self.before_to_disk = create_before_to_disk_callback(self.T["before_to_disk"]) 97 | allocator = self.T["gpu_allocator"] 98 | if use_gpu >= 0 and allocator: 99 | set_gpu_allocator(allocator) 100 | self._evaluation_callback = lambda: {} 101 | self._results = [] 102 | self._has_evaluation_callback = False 103 | self.thread = None 104 | self.proxy = None 105 | self.n_grads_used = 0 106 | self.n_grads_discarded = 0 107 | 108 | ######################################################################## 109 | # Inter-worker communication 110 | # 111 | # It'd be nice to have this stuff in a different object, but we need 112 | # to pass the actual 'actor' handle around, we can't use a shared reference. 113 | # And if we made another actor, it would run within a different process. 114 | # 115 | ######################################################################### 116 | 117 | def inc_grad(self, key: KeyT, version: int, value: FloatsXd) -> None: 118 | if self.proxy is None: 119 | raise ValueError("Proxy object not set") 120 | if self.proxy.check_version(key, version): 121 | self.proxy.inc_grad(key[0], key[1], value) 122 | 123 | def set_param(self, key: KeyT, version: int, value: FloatsXd) -> Optional[FloatsXd]: 124 | return self.proxy.receive_param(key, version, value) 125 | 126 | def get_param(self, key: KeyT, version: int) -> Optional[FloatsXd]: 127 | if self.proxy is None: 128 | raise ValueError("Proxy object not set") 129 | elif self.proxy.check_version(key, version): 130 | return self.proxy.get_param(key[0], key[1]) 131 | else: 132 | return None 133 | 134 | ######################################################################### 135 | # Process control. These are used by the script or function coordinating 136 | # the work. 137 | # 138 | ######################################################################## 139 | 140 | def sync_params(self): 141 | for key in self.proxy._owned_keys: 142 | self.proxy.send_param(key) 143 | 144 | def get_percent_grads_used(self): 145 | total = self.n_grads_used + self.n_grads_discarded 146 | if total == 0: 147 | return None 148 | else: 149 | return self.n_grads_used / total 150 | 151 | def get_quorum(self) -> int: 152 | # Default to setting the 'quorum' to be the number of workers multiplied 153 | # by the accumulate_gradient value. This is how many gradients for a 154 | # parameter we will accumulate before running the optimizer. 155 | return self.num_workers * self.T["accumulate_gradient"] 156 | 157 | def train(self, peers: List, evaluator: "Evaluator") -> None: 158 | def evaluate(): 159 | if self.rank == 0: 160 | scores = self.evaluate() 161 | self.ray.get(evaluator.set_scores.remote(scores)) 162 | return scores 163 | else: 164 | scores = None 165 | while scores is None: 166 | time.sleep(5) 167 | scores = self.ray.get(evaluator.get_scores.remote()) 168 | return scores 169 | 170 | train_batches = create_train_batches( 171 | self.nlp, 172 | self.train_corpus, 173 | self.T["batcher"], 174 | self.T["max_epochs"], 175 | ) 176 | training_step_iterator = train_while_improving( 177 | self.nlp, 178 | FakeOptimizer(), 179 | train_batches, 180 | evaluate=evaluate, 181 | dropout=self.T["dropout"], 182 | accumulate_gradient=1, 183 | patience=self.T["patience"], 184 | max_steps=self.T["max_steps"], 185 | eval_frequency=self.T["eval_frequency"], 186 | exclude=self.T["frozen_components"], 187 | annotating_components=self.T["annotating_components"], 188 | before_update=self.T["before_update"], 189 | ) 190 | if self.rank == 0: 191 | print_row, finalize_logger = self.T["logger"](self.nlp) 192 | else: 193 | print_row = lambda: None 194 | self.thread = threading.Thread( 195 | target=thread_training, 196 | args=( 197 | training_step_iterator, 198 | print_row, 199 | self.rank, 200 | self.num_workers, 201 | self.gpu_id, 202 | ), 203 | ) 204 | self.thread.start() 205 | 206 | def is_running(self): 207 | return self.thread.is_alive() 208 | 209 | def evaluate(self) -> Dict[str, Union[Dict[str, float], float]]: 210 | if not self._has_evaluation_callback: 211 | self._evaluation_callback = create_evaluation_callback( 212 | self.nlp, 213 | self.dev_corpus, 214 | self.T["score_weights"], 215 | ) 216 | self._has_evaluation_callback = True 217 | return self._evaluation_callback() 218 | 219 | def save_checkpoint(self, info: Dict, output_path: Path) -> None: 220 | with self.nlp.select_pipes(disable=self.T["frozen_components"]): 221 | update_meta(self.T, self.nlp, info) 222 | self.before_to_disk(self.nlp).to_disk(output_path) 223 | 224 | def get_owned_keys(self): 225 | owned_keys = [] 226 | for name, component in self.nlp.pipeline: 227 | if hasattr(component, "model"): 228 | worker_keys = divide_params(component.model, self.num_workers) 229 | owned_keys.extend(worker_keys[self.rank]) 230 | return owned_keys 231 | 232 | def get_peer_map(self, workers): 233 | peer_map = {} 234 | for name, component in self.nlp.pipeline: 235 | if hasattr(component, "model"): 236 | worker_keys = divide_params(component.model, self.num_workers) 237 | for worker, keys in zip(workers, worker_keys): 238 | for key in keys: 239 | peer_map[key] = worker 240 | return peer_map 241 | 242 | def set_proxy(self, peers) -> None: 243 | proxy = RayPeerProxy( 244 | self.get_peer_map(peers), 245 | self.T["optimizer"], 246 | self.get_owned_keys(), 247 | ray=self.ray, 248 | ) 249 | for name, component in self.nlp.pipeline: 250 | if hasattr(component, "model"): 251 | set_params_proxy(component.model, proxy) 252 | self.proxy = proxy 253 | 254 | def _resolve_gpu(self, use_gpu: int) -> int: 255 | if use_gpu >= 0: 256 | gpu_id = int(os.environ.get("CUDA_VISIBLE_DEVICES", -1)) 257 | logger.info(f"Using GPU (isolated): {gpu_id}") 258 | require_gpu(0) 259 | else: 260 | logger.info("Using CPU") 261 | gpu_id = -1 262 | return gpu_id 263 | 264 | 265 | class FakeOptimizer: 266 | def __init__(self): 267 | self.averages = {} 268 | 269 | def __call__(self, key, weights, gradient): 270 | # This shouldn't be called, because when we have the parameter proxy 271 | # installed, the gradients should never appear, and the `has_grad` 272 | # check in `model.finish_update` should return False. 273 | # However, it's difficult to guarantee that for all subclasses and shims 274 | # so it's safer to noop instead of raising. 275 | return weights, gradient 276 | 277 | def step_schedules(self): 278 | pass 279 | 280 | 281 | class Evaluator: 282 | """Share evaluation results between workers. 283 | 284 | One worker should publish evaluation results to the evaluator, 285 | while the other workers should retrieve them (using a wait-loop if 286 | necessary). 287 | """ 288 | 289 | def __init__(self): 290 | self.scores = [] 291 | 292 | def set_scores(self, scores): 293 | self.scores.append(scores) 294 | return scores 295 | 296 | def get_scores(self): 297 | if not self.scores: 298 | return None 299 | else: 300 | return self.scores[-1] 301 | 302 | 303 | def thread_training(training_step_iterator, print_row, rank, num_workers, gpu_id): 304 | if gpu_id >= 0: 305 | # I don't fully understand why we need to do this within the thread. 306 | # I think 0 is also correct here, because ray sets the available devices? 307 | require_gpu(0) 308 | for batch, info, is_best_checkpoint in training_step_iterator: 309 | if rank == 0 and is_best_checkpoint is not None: 310 | info["words"] *= num_workers 311 | print_row(info) 312 | --------------------------------------------------------------------------------