├── setup.py ├── docs ├── api.rst ├── index.rst ├── examples.rst ├── example_rpc.slurm ├── conf.py └── example_rpc.py ├── rpcdataloader ├── launch.py ├── __init__.py ├── utils.py ├── dataloader.py └── rpc.py ├── .github └── workflows │ └── tests.yml ├── tests ├── test_dataloader.py └── test_rpc.py ├── pyproject.toml ├── README.rst └── LICENSE.txt /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | API documentation 2 | ================= 3 | 4 | .. automodule:: rpcdataloader 5 | :members: -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. meta:: 2 | :google-site-verification: dYDeEB1C_v5wqiUX9ZdWTK7bCfGizdeCaYSqHAZkwQk 3 | 4 | :og:description: A variant of the PyTorch Dataloader using remote workers. 5 | 6 | .. include:: ../README.rst 7 | :start-line: 7 8 | :end-line: 60 9 | 10 | Further reading 11 | =============== 12 | 13 | - :doc:`api` 14 | - :ref:`ImageNet training example` 15 | - :ref:`Slurm integration example` 16 | 17 | .. toctree:: 18 | :hidden: 19 | :maxdepth: 2 20 | 21 | self 22 | api 23 | examples 24 | Github repository -------------------------------------------------------------------------------- /rpcdataloader/launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | 5 | from rpcdataloader import run_worker 6 | 7 | 8 | if __name__ == "__main__": 9 | argparser = argparse.ArgumentParser( 10 | description="run RPC worker, prints hostname:port when ready." 11 | ) 12 | argparser.add_argument("--host", help="binding address") 13 | argparser.add_argument("--port", type=int, help="binding port port") 14 | argparser.add_argument("--timeout", type=int, default=60) 15 | argparser.add_argument( 16 | "--parallel", 17 | type=int, 18 | default=1, 19 | help="maximum number of concurrently active tasks", 20 | ) 21 | 22 | args = argparser.parse_args() 23 | 24 | run_worker(args.host, args.port, args.timeout, args.parallel) 25 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Continuous tests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ['3.7', '3.10'] 12 | 13 | steps: 14 | - uses: actions/checkout@2541b1294d2704b0964813337f33b291d3f8596b 15 | 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@c4e89fac7e8767b327bbad6cb4d859eda999cf08 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | 21 | - name: Upgrade system packages 22 | run: | 23 | python -m pip install --upgrade pip setuptools wheel 24 | 25 | - name: Install dependencies 26 | run: python -m pip install .[test] --extra-index-url https://download.pytorch.org/whl/cpu 27 | 28 | - name: Test with pytest 29 | run: | 30 | python -m pytest -------------------------------------------------------------------------------- /tests/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import multiprocessing 3 | import pytest 4 | import torch 5 | 6 | from rpcdataloader import run_worker, RPCDataloader, RPCDataset 7 | 8 | 9 | @pytest.fixture(scope="function") 10 | def workers(): 11 | host = "127.0.0.1" 12 | port = random.randint(32000, 65536) 13 | 14 | workers = [(host, port + i) for i in range(2)] 15 | procs = [multiprocessing.Process(target=run_worker, args=(host, port, 10)) 16 | for host, port in workers] 17 | for p in procs: 18 | p.start() 19 | 20 | yield [f"{h}:{p}" for h, p in workers] 21 | 22 | for p in procs: 23 | p.terminate() 24 | 25 | 26 | def test_rpcdataloader(workers): 27 | dataset = RPCDataset( 28 | workers=workers, 29 | dataset=torch.rand, 30 | size=(1000, 128)) 31 | 32 | dataloader = RPCDataloader( 33 | dataset, 34 | batch_size=5, 35 | ) 36 | 37 | i = 0 38 | for d in dataloader: 39 | assert isinstance(d, torch.Tensor) and d.shape == (5, 128) 40 | i += 1 41 | 42 | assert i == 200 43 | -------------------------------------------------------------------------------- /rpcdataloader/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2022 CEA LIST. All Rights Reserved. 2 | # Contributor(s): Nicolas Granger 3 | # 4 | # This software is governed by the CeCILL-C license under French law and 5 | # abiding by the rules of distribution of free software. You can use, 6 | # modify and/ or redistribute the software under the terms of the CeCILL-C 7 | # license as circulated by CEA, CNRS and INRIA at the following URL 8 | # "http://www.cecill.info". 9 | # As a counterpart to the access to the source code and rights to copy, 10 | # modify and redistribute granted by the license, users are provided only 11 | # with a limited warranty and the software's author, the holder of the 12 | # economic rights, and the successive licensors have only limited 13 | # liability. 14 | # The fact that you are presently reading this means that you have had 15 | # knowledge of the CeCILL-C license and that you accept its terms. 16 | 17 | from .dataloader import RPCDataloader as RPCDataloader, RPCDataset as RPCDataset 18 | from .rpc import rpc_async as rpc_async, run_worker as run_worker 19 | from .utils import set_random_seeds as set_random_seeds 20 | 21 | __all__ = ["rpc_async", "run_worker", "RPCDataloader", "RPCDataset", "set_random_seeds"] 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "rpcdataloader" 7 | authors = [{ name = "Nicolas Granger", email = "nicolas.granger@cea.fr" }] 8 | description = "A Dataloader using rpc-based workers" 9 | readme = "README.rst" 10 | license = { text = "CECILL-C" } 11 | requires-python = ">=3.7" 12 | classifiers = [ 13 | "License :: CeCILL-C Free Software License Agreement (CECILL-C)", 14 | "Development Status :: 4 - Beta", 15 | "Topic :: Scientific/Engineering", 16 | "Topic :: Software Development :: Libraries :: Python Modules", 17 | ] 18 | dependencies = [ 19 | "tblib", 20 | "typing;python_version<'3.9'", 21 | "pickle5;python_version<'3.8'", 22 | "torch", 23 | "numpy" 24 | ] 25 | dynamic = ["version"] 26 | 27 | [project.urls] 28 | repository = "https://github.com/CEA-LIST/RPCDataloader" 29 | documentation = "https://cea-list.github.io/RPCDataloader" 30 | 31 | [project.optional-dependencies] 32 | test = ["pytest"] 33 | doc = [ 34 | "sphinx", 35 | "docutils>=0.17", 36 | "sphinx-rtd-theme>=1.0", 37 | "sphinxext-opengraph", 38 | "sphinx-copybutton", 39 | "sphinx-sitemap" 40 | ] 41 | 42 | [tool.setuptools_scm] -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | .. _ImageNet training example: 5 | 6 | ImageNet training example 7 | ------------------------- 8 | 9 | This example uses rpcdataloader for the training of a ResNet50 on ImageNet. 10 | It supports for distributed training and mixed-precision. 11 | Modifications to make use of rpcdataloader are highlighted and evaluation routines are ommitted for readability. 12 | 13 | Prior to running this script, you should :ref:`spawn workers `. 14 | 15 | .. literalinclude:: example_rpc.py 16 | :linenos: 17 | :emphasize-lines: 13,20,48-50,71-73,84 18 | 19 | 20 | .. _Slurm integration example: 21 | 22 | Slurm integration example 23 | ------------------------- 24 | 25 | To use rpcdataloader on a `Slurm `_ cluster, the `heterogeneous jobs `_ functionality will let you reserve two groups of resources: one for workers and one for training scripts. 26 | The sample script below demonstrates how to do this. 27 | 28 | Note that you might need to adjust port numbers to avoid collisions between jobs. 29 | You might also need to adjust resource specifications depending on the slurm configuration. 30 | 31 | .. literalinclude:: example_rpc.slurm 32 | :linenos: 33 | :language: shell 34 | -------------------------------------------------------------------------------- /tests/test_rpc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import multiprocessing 4 | import pytest 5 | import torch 6 | 7 | from rpcdataloader import run_worker, rpc_async 8 | 9 | 10 | def do(a): 11 | return a, torch.rand(1024) 12 | 13 | 14 | @pytest.fixture(scope='function') 15 | def worker(): 16 | host = '127.0.0.1' 17 | port = random.randint(32000, 65536) 18 | 19 | worker = multiprocessing.Process(target=run_worker, args=(host, port, 10)) 20 | worker.start() 21 | 22 | yield f"{host}:{port}" 23 | 24 | worker.terminate() 25 | 26 | 27 | def test_rpc_async(worker): 28 | a = bytes([random.randint(0, 255) for _ in range(0, 1000)]) 29 | f = rpc_async(worker, do, (a,)) 30 | b, c = f.wait() 31 | 32 | assert a == b 33 | assert len(c) == 1024 34 | 35 | 36 | def test_rref(worker): 37 | a = torch.rand([500]) 38 | b = torch.rand([500]) 39 | f = rpc_async(worker, torch.add, (a, b), rref=True).wait() 40 | actual = rpc_async(worker, torch.sum, (f,), rref=False).wait().item() 41 | expected = (a + b).sum().item() 42 | 43 | assert actual == pytest.approx(expected) 44 | 45 | 46 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda unvailable") 47 | def test_rpc_pin(worker): 48 | a = bytes([random.randint(0, 255) for _ in range(0, 1000)]) 49 | f = rpc_async(worker, do, (a,), pin_memory=True) 50 | _, c = f.wait() 51 | 52 | assert c.is_pinned 53 | -------------------------------------------------------------------------------- /docs/example_rpc.slurm: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | #SBATCH --time=0-01:00:00 3 | 4 | # Resource specfification for training scripts 5 | #SBATCH --partition=gpu 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks-per-node=2 8 | #SBATCH --cpus-per-task=2 9 | #SBATCH --mem=32G 10 | #SBATCH --gres=gpu:2 11 | 12 | #SBATCH hetjob 13 | 14 | # Resource specfification for workers 15 | #SBATCH --partition=cpu 16 | #SBATCH --nodes=2 17 | #SBATCH --ntasks-per-node=4 18 | #SBATCH --cpus-per-task=1 19 | #SBATCH --mem=16G 20 | 21 | source ~/miniconda3/etc/profile.d/conda.sh 22 | conda activate rpcdataloader 23 | 24 | export rpc_port_start=15000 25 | 26 | # identify workers 27 | export tmpfile="${TMPDIR:-/tmp}/rpcdataloader_workers.$SLURM_JOB_ID" 28 | srun --het-group=1 -I --exclusive --kill-on-bad-exit=1 sh -c ' 29 | echo $(hostname):$(( $rpc_port_start + $SLURM_LOCALID )) 30 | ' > "${tmpfile}" 31 | readarray -t workers < "${tmpfile}" 32 | rm $tmpfile 33 | 34 | # start workers in background 35 | srun --het-group=1 -I --exclusive --kill-on-bad-exit=1 sh -c ' 36 | python -u -m rpcdataloader.launch \ 37 | --host=0.0.0.0 \ 38 | --port=$(( $rpc_port_start + $SLURM_LOCALID )) 39 | ' & 40 | worker_task_pid=$! 41 | 42 | # run training script 43 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_0 | head -n 1) 44 | export MASTER_PORT=16000 45 | srun --het-group=0 -I --exclusive --kill-on-bad-exit=1 \ 46 | python -u example_rpc.py \ 47 | --workers ${workers[@]} \ 48 | --data-path=/media/ILSVRC \ 49 | --batch-size=128 50 | 51 | # stop workers 52 | kill $worker_task_pid 53 | -------------------------------------------------------------------------------- /rpcdataloader/utils.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2022-2023 CEA LIST. All Rights Reserved. 2 | # Contributor(s): Nicolas Granger 3 | # 4 | # This software is governed by the CeCILL-C license under French law and 5 | # abiding by the rules of distribution of free software. You can use, 6 | # modify and/ or redistribute the software under the terms of the CeCILL-C 7 | # license as circulated by CEA, CNRS and INRIA at the following URL 8 | # "http://www.cecill.info". 9 | # As a counterpart to the access to the source code and rights to copy, 10 | # modify and redistribute granted by the license, users are provided only 11 | # with a limited warranty and the software's author, the holder of the 12 | # economic rights, and the successive licensors have only limited 13 | # liability. 14 | # The fact that you are presently reading this means that you have had 15 | # knowledge of the CeCILL-C license and that you accept its terms. 16 | 17 | import random 18 | 19 | import numpy as np 20 | import torch 21 | 22 | 23 | def unpickle_tensor(buffer, dtype, shape): 24 | return torch.frombuffer(buffer, dtype=dtype).view(shape) 25 | 26 | 27 | def pickle_tensor(t): 28 | return unpickle_tensor, (t.ravel().numpy().view("b"), t.dtype, t.shape) 29 | 30 | 31 | pkl_dispatch_table = {torch.Tensor: pickle_tensor} 32 | 33 | 34 | def set_random_seeds(base_seed, worker_id): 35 | """Set the seed of default random generators from python, torch and numpy. 36 | 37 | This should be called once on each worker. 38 | Note that workers may run tasks out of order, so this does not ensure 39 | reproducibility, only non-redundancy between workers. 40 | 41 | Example: 42 | 43 | >>> base_seed = torch.randint(0, 2**32-1, [1]).item() 44 | >>> for i, (host, port) in enumerate(workers): 45 | ... rpc_async(host, port, set_random_seeds, args=[base_seed, i]) 46 | """ 47 | 48 | seed = base_seed + worker_id 49 | random.seed(seed) 50 | torch.manual_seed(seed) 51 | np_seed = torch.utils.data._utils.worker._generate_state(base_seed, worker_id) 52 | np.random.seed(np_seed) 53 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. image:: https://img.shields.io/badge/doc-latest-brightgreen 2 | :target: https://cea-list.github.io/RPCDataloader 3 | :alt: Documentation 4 | .. image:: https://github.com/CEA-LIST/RPCDataloader/actions/workflows/tests.yml/badge.svg 5 | :target: https://github.com/CEA-LIST/RPCDataloader/actions/workflows/tests.yml 6 | :alt: Continuous tests 7 | 8 | ============== 9 | RPC Dataloader 10 | ============== 11 | 12 | This library implements a variant of the PyTorch Dataloader using remote workers. 13 | It allows to distribute workers over remote servers rather than the one running the main script. 14 | 15 | To use it, start one or several worker daemons on remote computers. 16 | The machines running the data loaders will dispatch requests for items to the workers and await the returned values. 17 | 18 | Though similar to `torch.rpc `_, this library uses its own implementation of RPC (Remote Procedure Call) which is simpler (no initialization) and does not conflict with the one from pytorch. 19 | 20 | 21 | Installation 22 | ============ 23 | 24 | .. code:: shell 25 | 26 | pip install rpcdataloader 27 | 28 | 29 | .. _Usage: 30 | 31 | Usage 32 | ===== 33 | 34 | To use the RPC dataloader, start a few workers either from the command line: 35 | 36 | .. code:: shell 37 | 38 | python -m rpcdataloader.launch --host=0.0.0.0 --port=6543 39 | 40 | or by calling :code:`rpcdataloader.run_worker` directly from a python script. 41 | 42 | Then instantiate a remote dataset and dataloader: 43 | 44 | .. code:: python 45 | 46 | dataset = rpcdataloader.RPCDataset( 47 | workers=['node01:6543', 'node02:5432'], 48 | dataset=torchvision.datasets.ImageFolder, 49 | root=args.data_path + "/train", 50 | transform=train_transform, 51 | ) 52 | 53 | dataloader = rpcdataloader.RPCDataloader( 54 | dataset 55 | batch_size=2, 56 | shuffle=True, 57 | pin_memory=True) 58 | 59 | for minibatch in dataloader: 60 | ... 61 | 62 | 63 | Further reading 64 | =============== 65 | 66 | - `API documentation `_ 67 | - `ResNet50 training on ImageNet dataset `_ 68 | - `Slurm integration using heterogeneous jobs `_ 69 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | import inspect 8 | import os 9 | from importlib.metadata import version 10 | import importlib 11 | import importlib.util 12 | import sys 13 | 14 | # -- Path setup -------------------------------------------------------------- 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | # import os 21 | # import sys 22 | # sys.path.insert(0, os.path.abspath('.')) 23 | 24 | 25 | # -- Project information ----------------------------------------------------- 26 | 27 | project = 'RPCDdataloader' 28 | copyright = '2022, CEA LIST' 29 | author = 'Nicolas Granger' 30 | 31 | # The full version, including alpha/beta/rc tags. 32 | pkg_version = version("rpcdataloader") 33 | 34 | if len(pkg_version.split("+")) > 1: 35 | release = pkg_version.split("+")[0] 36 | commit = pkg_version.split("+")[1].split('.')[0][1:] 37 | version = f"latest ({release})" 38 | else: 39 | release = pkg_version.split("+")[0] 40 | commit = f"v{release}" 41 | version = f"stable ({release})" 42 | 43 | 44 | # -- General configuration --------------------------------------------------- 45 | 46 | # Add any Sphinx extension module names here, as strings. They can be 47 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 48 | # ones. 49 | extensions = [ 50 | 'sphinx.ext.autodoc', 51 | 'sphinx.ext.intersphinx', 52 | 'sphinx.ext.linkcode', 53 | 'sphinxext.opengraph', 54 | 'sphinx_copybutton', 55 | 'sphinx_sitemap' 56 | ] 57 | 58 | typehints_defaults = 'braces-after' 59 | intersphinx_mapping = { 60 | 'python': ('https://docs.python.org/3', None), 61 | 'torch': ('https://pytorch.org/docs/stable', None) 62 | } 63 | ogp_site_url = "https://cea-list.github.io/RPCDataloader/" 64 | sitemap_url_scheme = "{link}" 65 | 66 | # List of patterns, relative to source directory, that match files and 67 | # directories to ignore when looking for source files. 68 | # This pattern also affects html_static_path and html_extra_path. 69 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 70 | 71 | # -- Options for HTML output ------------------------------------------------- 72 | 73 | # The theme to use for HTML and HTML Help pages. See the documentation for 74 | # a list of builtin themes. 75 | # 76 | html_theme = 'sphinx_rtd_theme' 77 | 78 | html_baseurl = 'https://cea-list.github.io/RPCDataloader/' 79 | 80 | 81 | # -- Options for Linkcode extension ------------------------------------------- 82 | 83 | linkcode_url = "https://github.com/CEA-LIST/RPCDataloader/blob/" \ 84 | + commit + "/{filepath}#L{linestart}-L{linestop}" 85 | 86 | 87 | def linkcode_resolve(domain, info): 88 | if domain != 'py' or not info['module']: 89 | return None 90 | 91 | spec = importlib.util.find_spec(info['module']) 92 | if spec is None or not spec.has_location: 93 | return None 94 | 95 | module = importlib.util.module_from_spec(spec) 96 | spec.loader.exec_module(module) 97 | 98 | obj = module 99 | for part in info['fullname'].split('.'): 100 | try: 101 | obj = getattr(obj, part) 102 | except AttributeError: 103 | return None 104 | 105 | filepath = inspect.getfile(obj) 106 | for p in sys.path: 107 | if filepath.startswith(os.path.abspath(p)): 108 | filepath = os.path.relpath(filepath, os.path.abspath(p)) 109 | break 110 | 111 | try: 112 | source, lineno = inspect.getsourcelines(obj) 113 | except OSError: 114 | return None 115 | else: 116 | linestart, linestop = lineno, lineno + len(source) - 1 117 | 118 | return linkcode_url.format( 119 | filepath=filepath, 120 | linestart=linestart, 121 | linestop=linestop) 122 | -------------------------------------------------------------------------------- /docs/example_rpc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | from torch import nn 7 | from torch.optim.lr_scheduler import StepLR 8 | from torch.utils.data import DistributedSampler, RandomSampler 9 | from torchvision import transforms 10 | from torchvision.datasets import ImageFolder 11 | from torchvision.models import get_model 12 | 13 | from rpcdataloader import RPCDataloader, RPCDataset 14 | 15 | 16 | def main(): 17 | argparser = argparse.ArgumentParser() 18 | argparser.add_argument("--data-path") 19 | argparser.add_argument("--model", default="resnet50") 20 | argparser.add_argument("--workers", type=str, nargs="+") 21 | argparser.add_argument("--batch-size", default=2, type=int) 22 | argparser.add_argument("--lr", default=0.1, type=float) 23 | argparser.add_argument("--momentum", default=0.9, type=float) 24 | argparser.add_argument("--weight-decay", default=1e-4, type=float) 25 | argparser.add_argument("--epochs", default=100, type=int) 26 | argparser.add_argument("--amp", action="store_true") 27 | args = argparser.parse_args() 28 | 29 | # Distributed 30 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: # torchrun launch 31 | rank = int(os.environ["RANK"]) 32 | local_rank = int(os.environ["LOCAL_RANK"]) 33 | world_size = int(os.environ["WORLD_SIZE"]) 34 | elif int(os.environ.get("SLURM_NPROCS", 1)) > 1: # srun launch 35 | rank = int(os.environ["SLURM_PROCID"]) 36 | local_rank = int(os.environ["SLURM_LOCALID"]) 37 | world_size = int(os.environ["SLURM_NPROCS"]) 38 | else: # single gpu & process launch 39 | rank = 0 40 | local_rank = 0 41 | world_size = 0 42 | 43 | if world_size > 0: 44 | torch.distributed.init_process_group( 45 | backend="nccl", world_size=world_size, rank=rank 46 | ) 47 | 48 | # split workers between GPUs (optional but recommended) 49 | if len(args.workers) > 0: 50 | args.workers = args.workers[rank::world_size] 51 | 52 | print(args) 53 | 54 | # Device 55 | device = torch.device("cuda", index=local_rank) 56 | 57 | # Preprocessing 58 | normalize = transforms.Normalize( 59 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 60 | ) 61 | train_transform = transforms.Compose( 62 | [ 63 | transforms.RandomResizedCrop(224), 64 | transforms.RandAugment(), 65 | transforms.ToTensor(), 66 | normalize, 67 | ] 68 | ) 69 | 70 | # Datasets 71 | train_dataset = RPCDataset( 72 | args.workers, 73 | ImageFolder, 74 | root=args.data_path + "/train", 75 | transform=train_transform, 76 | ) 77 | 78 | # Data loading 79 | if torch.distributed.is_initialized(): 80 | train_sampler = DistributedSampler(train_dataset) 81 | else: 82 | train_sampler = RandomSampler(train_dataset) 83 | 84 | train_loader = RPCDataloader( 85 | train_dataset, 86 | batch_size=args.batch_size, 87 | sampler=train_sampler, 88 | pin_memory=True, 89 | ) 90 | 91 | # Model 92 | model = get_model(args.model, num_classes=1000) 93 | model.to(device) 94 | if torch.distributed.is_initialized(): 95 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 96 | model = torch.nn.parallel.DistributedDataParallel( 97 | model, device_ids=[local_rank] 98 | ) 99 | 100 | # Optimization 101 | optimizer = torch.optim.SGD( 102 | model.parameters(), 103 | args.lr, 104 | momentum=args.momentum, 105 | weight_decay=args.weight_decay, 106 | ) 107 | scaler = torch.cuda.amp.GradScaler(enabled=args.amp) 108 | scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 109 | loss_fn = nn.CrossEntropyLoss().to(device) 110 | 111 | # Training 112 | for epoch in range(args.epochs): 113 | if isinstance(train_sampler, DistributedSampler): 114 | train_sampler.set_epoch(epoch) 115 | 116 | for it, (images, targets) in enumerate(train_loader): 117 | t0 = time.monotonic() 118 | 119 | optimizer.zero_grad(set_to_none=True) 120 | 121 | images = images.to(device, non_blocking=True) 122 | targets = targets.to(device, non_blocking=True) 123 | 124 | with torch.cuda.amp.autocast(enabled=args.amp): 125 | predictions = model(images) 126 | loss = loss_fn(predictions, targets) 127 | 128 | scaler.scale(loss).backward() 129 | scaler.step(optimizer) 130 | scaler.update() 131 | 132 | if (it + 1) % 20 == 0 and rank == 0: 133 | t1 = time.monotonic() 134 | print( 135 | f"[epoch {epoch:<3d}" 136 | f" it {it:-5d}/{len(train_loader)}]" 137 | f" loss: {loss.item():2.3f}" 138 | f" time: {t1 - t0:.1f}" 139 | ) 140 | 141 | scheduler.step() 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /rpcdataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2022-2023 CEA LIST. All Rights Reserved. 2 | # Contributor(s): Nicolas Granger 3 | # 4 | # This software is governed by the CeCILL-C license under French law and 5 | # abiding by the rules of distribution of free software. You can use, 6 | # modify and/ or redistribute the software under the terms of the CeCILL-C 7 | # license as circulated by CEA, CNRS and INRIA at the following URL 8 | # "http://www.cecill.info". 9 | # As a counterpart to the access to the source code and rights to copy, 10 | # modify and redistribute granted by the license, users are provided only 11 | # with a limited warranty and the software's author, the holder of the 12 | # economic rights, and the successive licensors have only limited 13 | # liability. 14 | # The fact that you are presently reading this means that you have had 15 | # knowledge of the CeCILL-C license and that you accept its terms. 16 | 17 | import itertools 18 | from typing import Any, Callable, List 19 | 20 | from torch.utils.data import ( 21 | BatchSampler, 22 | Dataset, 23 | RandomSampler, 24 | SequentialSampler, 25 | default_collate, 26 | ) 27 | 28 | from .rpc import rpc_async 29 | 30 | 31 | def _get_item(dataset, item): 32 | return dataset[item] 33 | 34 | 35 | def _get_batch(dataset, items, collate_fn=None): 36 | values = [dataset[i] for i in items] 37 | return values if collate_fn is None else collate_fn(values) 38 | 39 | 40 | class RPCDataset(Dataset): 41 | """Handle to instanciate and manage datasets on remote workers. 42 | 43 | :param workers: 44 | a list of workers with the format `host:port` 45 | :param dataset: 46 | dataset class or equivalent callable that returns a dataset instance 47 | :param args: 48 | positional arguments for :attr:`dataset` 49 | :param kwargs: 50 | keyword arguments for :attr:`dataset` 51 | 52 | .. note:: 53 | In a distributed setup, you should probably split the workers between 54 | the trainers (ie: :code:`workers = workers[rank::world_size]`). 55 | """ 56 | 57 | def __init__( 58 | self, workers: List[str], dataset: Callable[[Any], Dataset], *args, **kwargs 59 | ): 60 | futures = [rpc_async(w, dataset, args, kwargs, rref=True) for w in workers] 61 | self.workers = workers 62 | self.rrefs = [f.wait() for f in futures] 63 | 64 | def __len__(self): 65 | return rpc_async(self.workers[0], len, [self.rrefs[0]]).wait() 66 | 67 | 68 | class RPCDataloader: 69 | """A dataloader using remote rpc-based workers. 70 | 71 | :param dataset: 72 | A remote dataset 73 | :param batch_size: 74 | how many samples per batch to load. 75 | :param shuffle: 76 | set to ``True`` to have the data reshuffled at every epoch. 77 | :param sampler: 78 | defines the strategy to draw samples from the dataset. Can be any 79 | ``Iterable`` with ``__len__`` implemented. If specified, 80 | :attr:`shuffle` must not be specified. 81 | :param batch_sampler: 82 | like :attr:`sampler`, but returns a batch of indices at a time. 83 | Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, 84 | :attr:`sampler`, and :attr:`drop_last`. 85 | :param collate_fn: 86 | merges a list of samples to form a mini-batch of Tensor(s). Used 87 | when using batched loading from a map-style dataset. 88 | :param pin_memory: 89 | If ``True``, the data loader will copy Tensors into CUDA pinned 90 | memory before returning them. If your data elements are a custom 91 | type, or your :attr:`collate_fn` returns a batch that is a custom 92 | type, see the example below. 93 | :param drop_last: set to ``True`` to drop the last incomplete batch, if 94 | the dataset size is not divisible by the batch size. If ``False`` 95 | and the size of dataset is not divisible by the batch size, then the 96 | last batch will be smaller. 97 | :param prefetch_factor: Number of samples loaded in advance by each worker. 98 | ``2`` means there will be a total of 2 * num_workers samples 99 | prefetched across all workers. (default: ``2``) 100 | 101 | Notable differences with pytorch dataloader: 102 | 103 | - :attr:`timeout` is the timeout on individual network operations. 104 | - :attr:`worker_init_fn` and :attr:`generator` are not supported. 105 | - Random seeds are not supported because workers may execute requests 106 | out of order anyway, thus breaking reproducibility. 107 | """ 108 | 109 | def __init__( 110 | self, 111 | dataset: RPCDataset, 112 | batch_size=1, 113 | shuffle=False, 114 | sampler=None, 115 | batch_sampler=None, 116 | collate_fn=None, 117 | pin_memory=False, 118 | drop_last=False, 119 | timeout=120, 120 | *, 121 | prefetch_factor: int = 2, 122 | ): 123 | # Samplers 124 | if sampler is None: 125 | if shuffle: 126 | sampler = RandomSampler(dataset) 127 | 128 | else: 129 | sampler = SequentialSampler(dataset) 130 | 131 | if batch_size is not None and batch_sampler is None: 132 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 133 | 134 | # Remaining attributes 135 | self.dataset = dataset 136 | self.batch_size = batch_size 137 | self.sampler = sampler 138 | self.batch_sampler = batch_sampler 139 | self.collate_fn = default_collate if collate_fn is None else collate_fn 140 | self.pin_memory = pin_memory 141 | self.timeout = timeout 142 | self.prefetch_factor = prefetch_factor 143 | 144 | def __len__(self): 145 | if self.batch_sampler is None: 146 | return len(self.sampler) # type: ignore 147 | else: 148 | return len(self.batch_sampler) # type: ignore 149 | 150 | def _iter_tasks(self): 151 | remotes = zip(self.dataset.workers, self.dataset.rrefs) 152 | remotes_it = itertools.cycle(remotes) 153 | 154 | if self.batch_sampler is None: 155 | for (worker, rref), i in zip(remotes_it, self.sampler): 156 | yield worker, _get_item, (rref, i) 157 | 158 | else: 159 | for (worker, rref), i in zip(remotes_it, self.batch_sampler): 160 | yield worker, _get_batch, (rref, i, self.collate_fn) 161 | 162 | def __iter__(self): 163 | task_it = iter(self._iter_tasks()) 164 | 165 | # RPC to create dataset 166 | queue = [] 167 | 168 | try: 169 | # preload jobs 170 | for _ in range(self.prefetch_factor * len(self.dataset.workers)): 171 | try: 172 | task = next(task_it) 173 | except StopIteration: 174 | break 175 | else: 176 | queue.append(rpc_async(*task, timeout=self.timeout)) 177 | 178 | while len(queue) > 0: 179 | result = queue.pop(0).wait() 180 | 181 | # queue another job 182 | try: 183 | task = next(task_it) 184 | except StopIteration: 185 | pass 186 | else: 187 | queue.append(rpc_async(*task, timeout=self.timeout)) 188 | 189 | # return value 190 | yield result 191 | 192 | finally: 193 | for f in queue: 194 | try: 195 | f.wait() 196 | except BaseException: 197 | pass 198 | -------------------------------------------------------------------------------- /rpcdataloader/rpc.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2022 CEA LIST. All Rights Reserved. 2 | # Contributor(s): Nicolas Granger 3 | # 4 | # This software is governed by the CeCILL-C license under French law and 5 | # abiding by the rules of distribution of free software. You can use, 6 | # modify and/ or redistribute the software under the terms of the CeCILL-C 7 | # license as circulated by CEA, CNRS and INRIA at the following URL 8 | # "http://www.cecill.info". 9 | # As a counterpart to the access to the source code and rights to copy, 10 | # modify and redistribute granted by the license, users are provided only 11 | # with a limited warranty and the software's author, the holder of the 12 | # economic rights, and the successive licensors have only limited 13 | # liability. 14 | # The fact that you are presently reading this means that you have had 15 | # knowledge of the CeCILL-C license and that you accept its terms. 16 | 17 | import io 18 | import select 19 | import socket 20 | import struct 21 | import sys 22 | import threading 23 | import time 24 | import weakref 25 | from typing import Any, Callable, Dict, TypeVar 26 | 27 | if sys.version_info < (3, 8): 28 | import pickle5 as pickle 29 | else: 30 | import pickle 31 | 32 | import torch 33 | from tblib import pickling_support 34 | 35 | # absolute import required for unpickling 36 | from rpcdataloader.utils import pkl_dispatch_table 37 | 38 | pickling_support.install() 39 | 40 | 41 | _T = TypeVar("_T") 42 | 43 | 44 | def _serialize(obj, buffer_cb=None): 45 | buffer = io.BytesIO() 46 | pickler = pickle.Pickler( 47 | buffer, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=buffer_cb 48 | ) 49 | pickler.dispatch_table = pkl_dispatch_table 50 | pickler.dump(obj) 51 | return buffer.getvalue() 52 | 53 | 54 | def _sock_read(sock, size, buffer=None): 55 | """Read size bytes from sock.""" 56 | buffer = bytearray(size) if buffer is None else buffer 57 | received = 0 58 | while received < size: 59 | nread = sock.recv_into(memoryview(buffer)[received:]) 60 | if not nread: 61 | raise RuntimeError("Unexpected socket shutdown.") 62 | received += nread 63 | 64 | return buffer 65 | 66 | 67 | def _create_connection(host, timeout, *kargs, **kwargs): 68 | host, port = host.split(":") 69 | port = int(port) 70 | 71 | for i in range(int(timeout)): 72 | try: 73 | return socket.create_connection( 74 | (host, port), *kargs, timeout=timeout, **kwargs) 75 | except OSError as e: 76 | if i + 1 == timeout: 77 | raise e from None 78 | else: 79 | time.sleep(1) 80 | 81 | 82 | tls = threading.local() 83 | 84 | 85 | def _rpc_send_command(host, fut, func, args, kwargs, pin_memory, rref, timeout): 86 | tls.host = host 87 | 88 | try: 89 | payload = _serialize((func, args, kwargs, rref)) 90 | 91 | # connect to server 92 | with _create_connection(host, timeout=timeout) as s: 93 | s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 94 | 95 | # send command 96 | s.sendall(struct.pack("L", len(payload))) 97 | s.sendall(payload) 98 | 99 | select.select([s], [], []) 100 | 101 | # receive buffers 102 | payload = _sock_read(s, struct.calcsize("L")) 103 | (nbuffers,) = struct.unpack("L", payload) 104 | buffers = [] 105 | for _ in range(nbuffers): 106 | payload = _sock_read(s, struct.calcsize("L")) 107 | (n,) = struct.unpack("L", payload) 108 | 109 | if pin_memory: 110 | b = torch.empty( 111 | n, dtype=torch.uint8, pin_memory=pin_memory 112 | ).numpy() 113 | else: 114 | b = bytearray(n) 115 | 116 | _sock_read(s, n, b) 117 | 118 | buffers.append(b) 119 | 120 | # receive object 121 | payload = _sock_read(s, struct.calcsize("L")) 122 | (n,) = struct.unpack("L", payload) 123 | payload = _sock_read(s, n) 124 | 125 | out, err = pickle.loads(payload, buffers=buffers) 126 | 127 | if err: 128 | fut.set_exception(err) 129 | else: 130 | fut.set_result(out) 131 | 132 | except Exception as e: 133 | fut.set_exception(e) 134 | 135 | 136 | def rpc_async( 137 | host: str, 138 | func: Callable[..., _T], 139 | args=None, 140 | kwargs=None, 141 | pin_memory=False, 142 | rref: bool = False, 143 | timeout=120.0, 144 | ) -> torch.futures.Future[_T]: 145 | """Execute function on remote worker and return the result as a future. 146 | 147 | :param host: 148 | rpc worker host 149 | :param func: 150 | function to execute 151 | :param args: 152 | positional arguments 153 | :param kwargs: 154 | keword arguments 155 | :param pin_memory: 156 | wether buffers (ie: tensors) should be allocated in pinned memory. 157 | :param rref: 158 | whether to return the output as a remote reference. 159 | :param timeout: 160 | timeout in seconds on network operations 161 | 162 | :return: 163 | A future that will contain the function return value. 164 | 165 | .. note:: 166 | :attr:`func` and its arguments must be serializable, which exludes 167 | the usage of lambdas or locally defined functions. 168 | """ 169 | fut = torch.futures.Future() 170 | t = threading.Thread( 171 | target=_rpc_send_command, 172 | args=(host, fut, func, args, kwargs, pin_memory, rref, timeout), 173 | ) 174 | t.start() 175 | 176 | return fut 177 | 178 | 179 | def _handle_client(sock, parallel_sem): 180 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1) 181 | 182 | payload = _sock_read(sock, struct.calcsize("L")) 183 | (n,) = struct.unpack("L", payload) 184 | payload = _sock_read(sock, n) 185 | cmd, args, kwargs, rref = pickle.loads(payload) 186 | 187 | if args is None: 188 | args = () 189 | if kwargs is None: 190 | kwargs = {} 191 | 192 | try: 193 | with parallel_sem: 194 | out = cmd(*args, **kwargs) 195 | if rref: 196 | out = RRef(obj=out) 197 | err = None 198 | except Exception as e: 199 | out = None 200 | err = e 201 | 202 | try: 203 | buffers = [] 204 | payload = _serialize((out, err), buffer_cb=buffers.append) 205 | except Exception as e: 206 | buffers = [] 207 | payload = _serialize((None, e)) 208 | 209 | buffers = [memoryview(b).tobytes() for b in buffers] 210 | 211 | sock.sendall(struct.pack("L", len(buffers))) 212 | for b in buffers: 213 | sock.sendall(struct.pack("L", len(b))) 214 | sock.sendall(b) 215 | 216 | sock.sendall(struct.pack("L", len(payload))) 217 | sock.sendall(payload) 218 | 219 | sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0) 220 | 221 | 222 | def _create_server(address, *, family=socket.AF_INET): 223 | sock = socket.socket(family, socket.SOCK_STREAM) 224 | try: 225 | sock.bind(address) 226 | sock.listen() 227 | return sock 228 | except BaseException as e: 229 | sock.close() 230 | raise e from None 231 | 232 | 233 | def run_worker(host: str, port: int, timeout: float = 120, parallel: int = 1): 234 | """Start listening and processing remote procedure calls. 235 | 236 | :param host: interface to bind to (set to '0.0.0.0' for all interfaces) 237 | :param port: port to bind to 238 | :param timeout: timeout on network transfers from/to client 239 | :param parallel: max number procedures executing concurrently 240 | 241 | .. warning:: 242 | The workers neither implement authentication nor encryption, any 243 | user on the network can send arbitrary commands or may listen to the 244 | traffic from/to the worker. 245 | 246 | .. note:: 247 | - each request is processed in a separate thread 248 | - network transfers may overlap regardless of :attr:`parallel` argument. 249 | """ 250 | torch.set_num_threads(1) # prevent thread competition 251 | 252 | parallel_sem = threading.Semaphore(parallel) 253 | with _create_server((host, port), family=socket.AF_INET) as sock: 254 | while True: 255 | client_sock, _ = sock.accept() 256 | client_sock.settimeout(timeout) 257 | t = threading.Thread( 258 | target=_handle_client, args=[client_sock, parallel_sem] 259 | ) 260 | t.start() 261 | 262 | 263 | _handles: Dict[int, Any] = {} 264 | 265 | 266 | class RRef: 267 | def __init__(self, obj=None, uid=None): 268 | if uid is None: 269 | self.obj = obj 270 | self.uid = None 271 | 272 | else: 273 | self.uid = uid 274 | self.host = tls.host 275 | 276 | weakref.finalize(self, rpc_async, self.host, _handles.pop, [uid]) 277 | 278 | @staticmethod 279 | def wrap(func, args, kwargs): 280 | return RRef(obj=func(*args, **kwargs)) 281 | 282 | @staticmethod 283 | def _rebuild_remote(uid): 284 | return RRef(uid=uid) 285 | 286 | @staticmethod 287 | def _rebuild_local(uid): 288 | return _handles[uid] 289 | 290 | def __reduce__(self): 291 | if self.uid is not None: 292 | return RRef._rebuild_local, (self.uid,) 293 | 294 | else: 295 | uid = id(self.obj) 296 | if uid in _handles: 297 | raise RuntimeError("Only one rref can exist for a given object") 298 | 299 | _handles[uid] = self.obj 300 | 301 | return RRef._rebuild_remote, (uid,) 302 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | CeCILL-C FREE SOFTWARE LICENSE AGREEMENT 3 | 4 | 5 | Notice 6 | 7 | This Agreement is a Free Software license agreement that is the result 8 | of discussions between its authors in order to ensure compliance with 9 | the two main principles guiding its drafting: 10 | 11 | * firstly, compliance with the principles governing the distribution 12 | of Free Software: access to source code, broad rights granted to 13 | users, 14 | * secondly, the election of a governing law, French law, with which 15 | it is conformant, both as regards the law of torts and 16 | intellectual property law, and the protection that it offers to 17 | both authors and holders of the economic rights over software. 18 | 19 | The authors of the CeCILL-C (for Ce[a] C[nrs] I[nria] L[ogiciel] L[ibre]) 20 | license are: 21 | 22 | Commissariat à l'Energie Atomique - CEA, a public scientific, technical 23 | and industrial research establishment, having its principal place of 24 | business at 25 rue Leblanc, immeuble Le Ponant D, 75015 Paris, France. 25 | 26 | Centre National de la Recherche Scientifique - CNRS, a public scientific 27 | and technological establishment, having its principal place of business 28 | at 3 rue Michel-Ange, 75794 Paris cedex 16, France. 29 | 30 | Institut National de Recherche en Informatique et en Automatique - 31 | INRIA, a public scientific and technological establishment, having its 32 | principal place of business at Domaine de Voluceau, Rocquencourt, BP 33 | 105, 78153 Le Chesnay cedex, France. 34 | 35 | 36 | Preamble 37 | 38 | The purpose of this Free Software license agreement is to grant users 39 | the right to modify and re-use the software governed by this license. 40 | 41 | The exercising of this right is conditional upon the obligation to make 42 | available to the community the modifications made to the source code of 43 | the software so as to contribute to its evolution. 44 | 45 | In consideration of access to the source code and the rights to copy, 46 | modify and redistribute granted by the license, users are provided only 47 | with a limited warranty and the software's author, the holder of the 48 | economic rights, and the successive licensors only have limited liability. 49 | 50 | In this respect, the risks associated with loading, using, modifying 51 | and/or developing or reproducing the software by the user are brought to 52 | the user's attention, given its Free Software status, which may make it 53 | complicated to use, with the result that its use is reserved for 54 | developers and experienced professionals having in-depth computer 55 | knowledge. Users are therefore encouraged to load and test the 56 | suitability of the software as regards their requirements in conditions 57 | enabling the security of their systems and/or data to be ensured and, 58 | more generally, to use and operate it in the same conditions of 59 | security. This Agreement may be freely reproduced and published, 60 | provided it is not altered, and that no provisions are either added or 61 | removed herefrom. 62 | 63 | This Agreement may apply to any or all software for which the holder of 64 | the economic rights decides to submit the use thereof to its provisions. 65 | 66 | 67 | Article 1 - DEFINITIONS 68 | 69 | For the purpose of this Agreement, when the following expressions 70 | commence with a capital letter, they shall have the following meaning: 71 | 72 | Agreement: means this license agreement, and its possible subsequent 73 | versions and annexes. 74 | 75 | Software: means the software in its Object Code and/or Source Code form 76 | and, where applicable, its documentation, "as is" when the Licensee 77 | accepts the Agreement. 78 | 79 | Initial Software: means the Software in its Source Code and possibly its 80 | Object Code form and, where applicable, its documentation, "as is" when 81 | it is first distributed under the terms and conditions of the Agreement. 82 | 83 | Modified Software: means the Software modified by at least one 84 | Integrated Contribution. 85 | 86 | Source Code: means all the Software's instructions and program lines to 87 | which access is required so as to modify the Software. 88 | 89 | Object Code: means the binary files originating from the compilation of 90 | the Source Code. 91 | 92 | Holder: means the holder(s) of the economic rights over the Initial 93 | Software. 94 | 95 | Licensee: means the Software user(s) having accepted the Agreement. 96 | 97 | Contributor: means a Licensee having made at least one Integrated 98 | Contribution. 99 | 100 | Licensor: means the Holder, or any other individual or legal entity, who 101 | distributes the Software under the Agreement. 102 | 103 | Integrated Contribution: means any or all modifications, corrections, 104 | translations, adaptations and/or new functions integrated into the 105 | Source Code by any or all Contributors. 106 | 107 | Related Module: means a set of sources files including their 108 | documentation that, without modification to the Source Code, enables 109 | supplementary functions or services in addition to those offered by the 110 | Software. 111 | 112 | Derivative Software: means any combination of the Software, modified or 113 | not, and of a Related Module. 114 | 115 | Parties: mean both the Licensee and the Licensor. 116 | 117 | These expressions may be used both in singular and plural form. 118 | 119 | 120 | Article 2 - PURPOSE 121 | 122 | The purpose of the Agreement is the grant by the Licensor to the 123 | Licensee of a non-exclusive, transferable and worldwide license for the 124 | Software as set forth in Article 5 hereinafter for the whole term of the 125 | protection granted by the rights over said Software. 126 | 127 | 128 | Article 3 - ACCEPTANCE 129 | 130 | 3.1 The Licensee shall be deemed as having accepted the terms and 131 | conditions of this Agreement upon the occurrence of the first of the 132 | following events: 133 | 134 | * (i) loading the Software by any or all means, notably, by 135 | downloading from a remote server, or by loading from a physical 136 | medium; 137 | * (ii) the first time the Licensee exercises any of the rights 138 | granted hereunder. 139 | 140 | 3.2 One copy of the Agreement, containing a notice relating to the 141 | characteristics of the Software, to the limited warranty, and to the 142 | fact that its use is restricted to experienced users has been provided 143 | to the Licensee prior to its acceptance as set forth in Article 3.1 144 | hereinabove, and the Licensee hereby acknowledges that it has read and 145 | understood it. 146 | 147 | 148 | Article 4 - EFFECTIVE DATE AND TERM 149 | 150 | 151 | 4.1 EFFECTIVE DATE 152 | 153 | The Agreement shall become effective on the date when it is accepted by 154 | the Licensee as set forth in Article 3.1. 155 | 156 | 157 | 4.2 TERM 158 | 159 | The Agreement shall remain in force for the entire legal term of 160 | protection of the economic rights over the Software. 161 | 162 | 163 | Article 5 - SCOPE OF RIGHTS GRANTED 164 | 165 | The Licensor hereby grants to the Licensee, who accepts, the following 166 | rights over the Software for any or all use, and for the term of the 167 | Agreement, on the basis of the terms and conditions set forth hereinafter. 168 | 169 | Besides, if the Licensor owns or comes to own one or more patents 170 | protecting all or part of the functions of the Software or of its 171 | components, the Licensor undertakes not to enforce the rights granted by 172 | these patents against successive Licensees using, exploiting or 173 | modifying the Software. If these patents are transferred, the Licensor 174 | undertakes to have the transferees subscribe to the obligations set 175 | forth in this paragraph. 176 | 177 | 178 | 5.1 RIGHT OF USE 179 | 180 | The Licensee is authorized to use the Software, without any limitation 181 | as to its fields of application, with it being hereinafter specified 182 | that this comprises: 183 | 184 | 1. permanent or temporary reproduction of all or part of the Software 185 | by any or all means and in any or all form. 186 | 187 | 2. loading, displaying, running, or storing the Software on any or 188 | all medium. 189 | 190 | 3. entitlement to observe, study or test its operation so as to 191 | determine the ideas and principles behind any or all constituent 192 | elements of said Software. This shall apply when the Licensee 193 | carries out any or all loading, displaying, running, transmission 194 | or storage operation as regards the Software, that it is entitled 195 | to carry out hereunder. 196 | 197 | 198 | 5.2 RIGHT OF MODIFICATION 199 | 200 | The right of modification includes the right to translate, adapt, 201 | arrange, or make any or all modifications to the Software, and the right 202 | to reproduce the resulting software. It includes, in particular, the 203 | right to create a Derivative Software. 204 | 205 | The Licensee is authorized to make any or all modification to the 206 | Software provided that it includes an explicit notice that it is the 207 | author of said modification and indicates the date of the creation thereof. 208 | 209 | 210 | 5.3 RIGHT OF DISTRIBUTION 211 | 212 | In particular, the right of distribution includes the right to publish, 213 | transmit and communicate the Software to the general public on any or 214 | all medium, and by any or all means, and the right to market, either in 215 | consideration of a fee, or free of charge, one or more copies of the 216 | Software by any means. 217 | 218 | The Licensee is further authorized to distribute copies of the modified 219 | or unmodified Software to third parties according to the terms and 220 | conditions set forth hereinafter. 221 | 222 | 223 | 5.3.1 DISTRIBUTION OF SOFTWARE WITHOUT MODIFICATION 224 | 225 | The Licensee is authorized to distribute true copies of the Software in 226 | Source Code or Object Code form, provided that said distribution 227 | complies with all the provisions of the Agreement and is accompanied by: 228 | 229 | 1. a copy of the Agreement, 230 | 231 | 2. a notice relating to the limitation of both the Licensor's 232 | warranty and liability as set forth in Articles 8 and 9, 233 | 234 | and that, in the event that only the Object Code of the Software is 235 | redistributed, the Licensee allows effective access to the full Source 236 | Code of the Software at a minimum during the entire period of its 237 | distribution of the Software, it being understood that the additional 238 | cost of acquiring the Source Code shall not exceed the cost of 239 | transferring the data. 240 | 241 | 242 | 5.3.2 DISTRIBUTION OF MODIFIED SOFTWARE 243 | 244 | When the Licensee makes an Integrated Contribution to the Software, the 245 | terms and conditions for the distribution of the resulting Modified 246 | Software become subject to all the provisions of this Agreement. 247 | 248 | The Licensee is authorized to distribute the Modified Software, in 249 | source code or object code form, provided that said distribution 250 | complies with all the provisions of the Agreement and is accompanied by: 251 | 252 | 1. a copy of the Agreement, 253 | 254 | 2. a notice relating to the limitation of both the Licensor's 255 | warranty and liability as set forth in Articles 8 and 9, 256 | 257 | and that, in the event that only the object code of the Modified 258 | Software is redistributed, the Licensee allows effective access to the 259 | full source code of the Modified Software at a minimum during the entire 260 | period of its distribution of the Modified Software, it being understood 261 | that the additional cost of acquiring the source code shall not exceed 262 | the cost of transferring the data. 263 | 264 | 265 | 5.3.3 DISTRIBUTION OF DERIVATIVE SOFTWARE 266 | 267 | When the Licensee creates Derivative Software, this Derivative Software 268 | may be distributed under a license agreement other than this Agreement, 269 | subject to compliance with the requirement to include a notice 270 | concerning the rights over the Software as defined in Article 6.4. 271 | In the event the creation of the Derivative Software required modification 272 | of the Source Code, the Licensee undertakes that: 273 | 274 | 1. the resulting Modified Software will be governed by this Agreement, 275 | 2. the Integrated Contributions in the resulting Modified Software 276 | will be clearly identified and documented, 277 | 3. the Licensee will allow effective access to the source code of the 278 | Modified Software, at a minimum during the entire period of 279 | distribution of the Derivative Software, such that such 280 | modifications may be carried over in a subsequent version of the 281 | Software; it being understood that the additional cost of 282 | purchasing the source code of the Modified Software shall not 283 | exceed the cost of transferring the data. 284 | 285 | 286 | 5.3.4 COMPATIBILITY WITH THE CeCILL LICENSE 287 | 288 | When a Modified Software contains an Integrated Contribution subject to 289 | the CeCILL license agreement, or when a Derivative Software contains a 290 | Related Module subject to the CeCILL license agreement, the provisions 291 | set forth in the third item of Article 6.4 are optional. 292 | 293 | 294 | Article 6 - INTELLECTUAL PROPERTY 295 | 296 | 297 | 6.1 OVER THE INITIAL SOFTWARE 298 | 299 | The Holder owns the economic rights over the Initial Software. Any or 300 | all use of the Initial Software is subject to compliance with the terms 301 | and conditions under which the Holder has elected to distribute its work 302 | and no one shall be entitled to modify the terms and conditions for the 303 | distribution of said Initial Software. 304 | 305 | The Holder undertakes that the Initial Software will remain ruled at 306 | least by this Agreement, for the duration set forth in Article 4.2. 307 | 308 | 309 | 6.2 OVER THE INTEGRATED CONTRIBUTIONS 310 | 311 | The Licensee who develops an Integrated Contribution is the owner of the 312 | intellectual property rights over this Contribution as defined by 313 | applicable law. 314 | 315 | 316 | 6.3 OVER THE RELATED MODULES 317 | 318 | The Licensee who develops a Related Module is the owner of the 319 | intellectual property rights over this Related Module as defined by 320 | applicable law and is free to choose the type of agreement that shall 321 | govern its distribution under the conditions defined in Article 5.3.3. 322 | 323 | 324 | 6.4 NOTICE OF RIGHTS 325 | 326 | The Licensee expressly undertakes: 327 | 328 | 1. not to remove, or modify, in any manner, the intellectual property 329 | notices attached to the Software; 330 | 331 | 2. to reproduce said notices, in an identical manner, in the copies 332 | of the Software modified or not; 333 | 334 | 3. to ensure that use of the Software, its intellectual property 335 | notices and the fact that it is governed by the Agreement is 336 | indicated in a text that is easily accessible, specifically from 337 | the interface of any Derivative Software. 338 | 339 | The Licensee undertakes not to directly or indirectly infringe the 340 | intellectual property rights of the Holder and/or Contributors on the 341 | Software and to take, where applicable, vis-à-vis its staff, any and all 342 | measures required to ensure respect of said intellectual property rights 343 | of the Holder and/or Contributors. 344 | 345 | 346 | Article 7 - RELATED SERVICES 347 | 348 | 7.1 Under no circumstances shall the Agreement oblige the Licensor to 349 | provide technical assistance or maintenance services for the Software. 350 | 351 | However, the Licensor is entitled to offer this type of services. The 352 | terms and conditions of such technical assistance, and/or such 353 | maintenance, shall be set forth in a separate instrument. Only the 354 | Licensor offering said maintenance and/or technical assistance services 355 | shall incur liability therefor. 356 | 357 | 7.2 Similarly, any Licensor is entitled to offer to its licensees, under 358 | its sole responsibility, a warranty, that shall only be binding upon 359 | itself, for the redistribution of the Software and/or the Modified 360 | Software, under terms and conditions that it is free to decide. Said 361 | warranty, and the financial terms and conditions of its application, 362 | shall be subject of a separate instrument executed between the Licensor 363 | and the Licensee. 364 | 365 | 366 | Article 8 - LIABILITY 367 | 368 | 8.1 Subject to the provisions of Article 8.2, the Licensee shall be 369 | entitled to claim compensation for any direct loss it may have suffered 370 | from the Software as a result of a fault on the part of the relevant 371 | Licensor, subject to providing evidence thereof. 372 | 373 | 8.2 The Licensor's liability is limited to the commitments made under 374 | this Agreement and shall not be incurred as a result of in particular: 375 | (i) loss due the Licensee's total or partial failure to fulfill its 376 | obligations, (ii) direct or consequential loss that is suffered by the 377 | Licensee due to the use or performance of the Software, and (iii) more 378 | generally, any consequential loss. In particular the Parties expressly 379 | agree that any or all pecuniary or business loss (i.e. loss of data, 380 | loss of profits, operating loss, loss of customers or orders, 381 | opportunity cost, any disturbance to business activities) or any or all 382 | legal proceedings instituted against the Licensee by a third party, 383 | shall constitute consequential loss and shall not provide entitlement to 384 | any or all compensation from the Licensor. 385 | 386 | 387 | Article 9 - WARRANTY 388 | 389 | 9.1 The Licensee acknowledges that the scientific and technical 390 | state-of-the-art when the Software was distributed did not enable all 391 | possible uses to be tested and verified, nor for the presence of 392 | possible defects to be detected. In this respect, the Licensee's 393 | attention has been drawn to the risks associated with loading, using, 394 | modifying and/or developing and reproducing the Software which are 395 | reserved for experienced users. 396 | 397 | The Licensee shall be responsible for verifying, by any or all means, 398 | the suitability of the product for its requirements, its good working 399 | order, and for ensuring that it shall not cause damage to either persons 400 | or properties. 401 | 402 | 9.2 The Licensor hereby represents, in good faith, that it is entitled 403 | to grant all the rights over the Software (including in particular the 404 | rights set forth in Article 5). 405 | 406 | 9.3 The Licensee acknowledges that the Software is supplied "as is" by 407 | the Licensor without any other express or tacit warranty, other than 408 | that provided for in Article 9.2 and, in particular, without any warranty 409 | as to its commercial value, its secured, safe, innovative or relevant 410 | nature. 411 | 412 | Specifically, the Licensor does not warrant that the Software is free 413 | from any error, that it will operate without interruption, that it will 414 | be compatible with the Licensee's own equipment and software 415 | configuration, nor that it will meet the Licensee's requirements. 416 | 417 | 9.4 The Licensor does not either expressly or tacitly warrant that the 418 | Software does not infringe any third party intellectual property right 419 | relating to a patent, software or any other property right. Therefore, 420 | the Licensor disclaims any and all liability towards the Licensee 421 | arising out of any or all proceedings for infringement that may be 422 | instituted in respect of the use, modification and redistribution of the 423 | Software. Nevertheless, should such proceedings be instituted against 424 | the Licensee, the Licensor shall provide it with technical and legal 425 | assistance for its defense. Such technical and legal assistance shall be 426 | decided on a case-by-case basis between the relevant Licensor and the 427 | Licensee pursuant to a memorandum of understanding. The Licensor 428 | disclaims any and all liability as regards the Licensee's use of the 429 | name of the Software. No warranty is given as regards the existence of 430 | prior rights over the name of the Software or as regards the existence 431 | of a trademark. 432 | 433 | 434 | Article 10 - TERMINATION 435 | 436 | 10.1 In the event of a breach by the Licensee of its obligations 437 | hereunder, the Licensor may automatically terminate this Agreement 438 | thirty (30) days after notice has been sent to the Licensee and has 439 | remained ineffective. 440 | 441 | 10.2 A Licensee whose Agreement is terminated shall no longer be 442 | authorized to use, modify or distribute the Software. However, any 443 | licenses that it may have granted prior to termination of the Agreement 444 | shall remain valid subject to their having been granted in compliance 445 | with the terms and conditions hereof. 446 | 447 | 448 | Article 11 - MISCELLANEOUS 449 | 450 | 451 | 11.1 EXCUSABLE EVENTS 452 | 453 | Neither Party shall be liable for any or all delay, or failure to 454 | perform the Agreement, that may be attributable to an event of force 455 | majeure, an act of God or an outside cause, such as defective 456 | functioning or interruptions of the electricity or telecommunications 457 | networks, network paralysis following a virus attack, intervention by 458 | government authorities, natural disasters, water damage, earthquakes, 459 | fire, explosions, strikes and labor unrest, war, etc. 460 | 461 | 11.2 Any failure by either Party, on one or more occasions, to invoke 462 | one or more of the provisions hereof, shall under no circumstances be 463 | interpreted as being a waiver by the interested Party of its right to 464 | invoke said provision(s) subsequently. 465 | 466 | 11.3 The Agreement cancels and replaces any or all previous agreements, 467 | whether written or oral, between the Parties and having the same 468 | purpose, and constitutes the entirety of the agreement between said 469 | Parties concerning said purpose. No supplement or modification to the 470 | terms and conditions hereof shall be effective as between the Parties 471 | unless it is made in writing and signed by their duly authorized 472 | representatives. 473 | 474 | 11.4 In the event that one or more of the provisions hereof were to 475 | conflict with a current or future applicable act or legislative text, 476 | said act or legislative text shall prevail, and the Parties shall make 477 | the necessary amendments so as to comply with said act or legislative 478 | text. All other provisions shall remain effective. Similarly, invalidity 479 | of a provision of the Agreement, for any reason whatsoever, shall not 480 | cause the Agreement as a whole to be invalid. 481 | 482 | 483 | 11.5 LANGUAGE 484 | 485 | The Agreement is drafted in both French and English and both versions 486 | are deemed authentic. 487 | 488 | 489 | Article 12 - NEW VERSIONS OF THE AGREEMENT 490 | 491 | 12.1 Any person is authorized to duplicate and distribute copies of this 492 | Agreement. 493 | 494 | 12.2 So as to ensure coherence, the wording of this Agreement is 495 | protected and may only be modified by the authors of the License, who 496 | reserve the right to periodically publish updates or new versions of the 497 | Agreement, each with a separate number. These subsequent versions may 498 | address new issues encountered by Free Software. 499 | 500 | 12.3 Any Software distributed under a given version of the Agreement may 501 | only be subsequently distributed under the same version of the Agreement 502 | or a subsequent version. 503 | 504 | 505 | Article 13 - GOVERNING LAW AND JURISDICTION 506 | 507 | 13.1 The Agreement is governed by French law. The Parties agree to 508 | endeavor to seek an amicable solution to any disagreements or disputes 509 | that may arise during the performance of the Agreement. 510 | 511 | 13.2 Failing an amicable solution within two (2) months as from their 512 | occurrence, and unless emergency proceedings are necessary, the 513 | disagreements or disputes shall be referred to the Paris Courts having 514 | jurisdiction, by the more diligent Party. 515 | 516 | 517 | Version 1.0 dated 2006-09-05. 518 | --------------------------------------------------------------------------------