├── .gitignore ├── .style.yapf ├── README.md ├── distml ├── __init__.py ├── operator │ ├── __init__.py │ ├── base_operator.py │ ├── jax_operator.py │ └── torch_operator.py ├── strategy │ ├── __init__.py │ ├── allreduce_strategy.py │ └── base_strategy.py └── util.py ├── examples ├── jax │ ├── jax_util │ │ ├── __init__.py │ │ ├── datasets.py │ │ └── resnet.py │ └── mnist_jax_example.py └── torch │ ├── __init__.py │ ├── cifar_pytorch_example.py │ └── resnet.py ├── format.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # pycharm related 132 | .idea/ 133 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style=pep8 3 | allow_split_before_dict_value=False 4 | join_multiple_lines=False 5 | allow_multiline_lambdas=True 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | *DistML* is a [Ray](https://github.com/ray-project/ray) extension library to support large-scale distributed ML training 4 | on heterogeneous multi-node multi-GPU clusters. This library is under active development and we are adding more advanced 5 | training strategies and auto-parallelization features. 6 | 7 | DistML currently supports: 8 | * Distributed training strategies 9 | * Data parallelism 10 | * AllReduce strategy 11 | * Sharded parameter server strategy 12 | * BytePS strategy 13 | Pipeline parallleism 14 | * Micro-batch pipeline parallelism 15 | 16 | * DL Frameworks: 17 | * PyTorch 18 | * JAX 19 | 20 | # Installation 21 | 22 | ### Install Dependencies 23 | Depending on your CUDA version, install cupy following https://docs.cupy.dev/en/stable/install.html. 24 | 25 | ### Install from source for dev 26 | ```python 27 | pip install -e . 28 | ``` -------------------------------------------------------------------------------- /distml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/distml/b2d4766664166a0163956c00d8472a03274d4d51/distml/__init__.py -------------------------------------------------------------------------------- /distml/operator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/distml/b2d4766664166a0163956c00d8472a03274d4d51/distml/operator/__init__.py -------------------------------------------------------------------------------- /distml/operator/base_operator.py: -------------------------------------------------------------------------------- 1 | """Abstract class for framework-specific training operators.""" 2 | from abc import ABCMeta 3 | from abc import abstractmethod 4 | 5 | 6 | class TrainingOperator(metaclass=ABCMeta): 7 | """Abstract class to define the training loop of a model. 8 | 9 | This class should be subclassed by the framework-specific 10 | operator implementations. For training, this class exposes 11 | two interfaces: 12 | - `derive_updates()` 13 | - `apply_updates()` 14 | in order for Ray collective backend to take over. 15 | 16 | For validation, this class exposes a single `validate_batch()` 17 | interface. The specific training and validation logic related 18 | with frameworks (JAX, PyTorch) is implemented in its subclasses 19 | 20 | Args: 21 | operator_config (dict): operator config specified by users. 22 | """ 23 | 24 | def __init__(self, *args, operator_config=None, **kwargs): 25 | self._operator_config = operator_config 26 | 27 | @abstractmethod 28 | def register(self, *, model, optimizer, criterion, **kwargs): 29 | """Register the model, optimizer, and criterion with the training operator. 30 | 31 | The function is instantiated in the framework-specific subclass. It 32 | is expected to be called by the user in self.setup(). 33 | """ 34 | raise NotImplementedError() 35 | 36 | @abstractmethod 37 | def register_data(self, *, train_loader=None, validation_loader=None): 38 | """Register batch-based data loaders.""" 39 | raise NotImplementedError() 40 | 41 | def setup(self, operator_config): 42 | """Method to be override by users. 43 | 44 | In this method, the user should register the model, optimizer, 45 | criterion, and data loaders to the operator class via the 46 | `register()` method. 47 | """ 48 | raise NotImplementedError() 49 | 50 | @abstractmethod 51 | def derive_updates(self, *args, **kwargs): 52 | """The sub-step that derives the gradient updates. 53 | 54 | This method should be instantiated by subclass operators. 55 | 56 | Returns: 57 | Tuple(loss, grads): A tuple that contains the loss value and 58 | the gradient updates. 59 | """ 60 | raise NotImplementedError() 61 | 62 | @abstractmethod 63 | def apply_updates(self, updates): 64 | """The sub-step that applies the updates. 65 | 66 | This method should be instantiated by subclass operators. 67 | 68 | Returns: 69 | None. 70 | """ 71 | raise NotImplementedError() 72 | 73 | @abstractmethod 74 | def validate_batch(self, *args, **kwargs): 75 | """Perform validation over a batch of validation data.""" 76 | raise NotImplementedError() 77 | 78 | def get_custom_states(self, *args, **kwargs): 79 | """Functions to be optionally override by users to represent any custom states. 80 | 81 | See ``save_parameters`` for more details. 82 | """ 83 | pass 84 | 85 | def load_custom_states(self, states, *args, **kwargs): 86 | """Functions to be optionally override by users to load any custom states. 87 | 88 | See ``load_parameters`` for more details. 89 | """ 90 | pass 91 | 92 | @abstractmethod 93 | def save_states(self, checkpoint): 94 | """Save the states to a file path. 95 | 96 | This function shall be instantiated in framework-specific operator 97 | implementations. 98 | """ 99 | raise NotImplementedError() 100 | 101 | @abstractmethod 102 | def get_states(self): 103 | """Return the states for the operator as a dict.""" 104 | raise NotImplementedError() 105 | 106 | @abstractmethod 107 | def load_states(self, checkpoint): 108 | """Load the states from a file path. 109 | 110 | This functions shall be instantiated in framework-specific operators 111 | implementations. 112 | """ 113 | raise NotImplementedError() 114 | 115 | def _get_train_loader(self): 116 | if not self._train_loader: 117 | raise RuntimeError( 118 | "The operator does not have any registered train loader. " 119 | "Please register the train loader via " 120 | "`self.register_data()` inside the `self.setup()` function.") 121 | return self._train_loader 122 | 123 | def _get_validation_loader(self): 124 | if not self._validation_loader: 125 | raise RuntimeError( 126 | "The operator does not have any registered validation loader. " 127 | "Please register the validation loader via " 128 | "`self.register_data()` inside the `self.setup()` function.") 129 | return self._validation_loader 130 | 131 | def _get_optimizer(self): 132 | if not self._optimizer: 133 | raise RuntimeError( 134 | "The operator does not have any registered optimizer. " 135 | "Please register the optimizer via " 136 | "`self.register()` inside the `self.setup()` function.") 137 | return self._optimizer 138 | 139 | def _get_criterion(self): 140 | if not self._optimizer: 141 | raise RuntimeError( 142 | "The operator does not have any registered criterion. " 143 | "Please register the criterion via " 144 | "`self.register()` inside the `self.setup()` function.") 145 | return self._criterion 146 | -------------------------------------------------------------------------------- /distml/operator/jax_operator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cupy as cp 3 | import jax 4 | from jax import value_and_grad 5 | import jax.numpy as jnp 6 | from jax.lib import xla_client 7 | from jax.dlpack import from_dlpack 8 | from jax.tree_util import tree_flatten, tree_unflatten, tree_structure 9 | from jax._src.util import unzip2 10 | from jax.experimental.optimizers import OptimizerState 11 | 12 | from distml.operator.base_operator import TrainingOperator 13 | 14 | 15 | class JAXTrainingOperator(TrainingOperator): 16 | def __init__(self, operator_config): 17 | super(JAXTrainingOperator, self).__init__(operator_config) 18 | # Should be set by users in the `register` function. 19 | # model methods 20 | self.opt_state = None 21 | self.init_fun = None 22 | self.predict_fun = None 23 | # optimizer methods 24 | self.opt_init = None 25 | self.opt_update = None 26 | self.get_params = None 27 | 28 | self.criterion = None 29 | 30 | # Data loaders for training and validation, registered by users. 31 | self._train_loader = None 32 | self._validation_loader = None 33 | 34 | self.setup(operator_config) 35 | 36 | if hasattr(operator_config, "jit_mode"): 37 | if operator_config["jit_mode"]: 38 | raise NotImplementedError("Not support jit in jax operator.") 39 | 40 | self.train_step_num = 0 41 | 42 | def setup(self, *args, **kwargs): 43 | """Function that needs to be override by users. 44 | 45 | Example: 46 | # some code is the same for all users, 47 | # maybe we can put it in register. 48 | rng_key = random.PRNGKey(0) 49 | input_shape = (28, 28, 1, 64) 50 | lr=0.01 51 | init_fun, predict_fun = ResNet18(num_classes) 52 | _, init_params = init_fun(rng_key, input_shape) 53 | 54 | opt_init, opt_update, get_params = optimizers.adam(lr) 55 | opt_state = opt_init(init_params) 56 | 57 | criterion = lambda logits, targets:-jnp.sum(logits * targets) 58 | 59 | self.register(model=(opt_state, init_fun, predict_fun), 60 | optimizer=(opt_init, opt_update, get_params), 61 | criterion=criterion) 62 | """ 63 | raise NotImplementedError("Please override this function to register " 64 | "your model, optimizer, and criterion.") 65 | 66 | def register(self, *, model, optimizer, criterion, jit_mode=False): 67 | """Register a few critical information about the model to operator. 68 | 69 | Args: 70 | model (tuple/list): a tuple/list has three elements. The first 71 | element should be opt_states that return from opt_init.The 72 | second element should be init_fun that used to initialize 73 | model params. The third element should be predict_fun that 74 | feed params and inputs, return prediction. 75 | optimizer (tuple/list): a tuple/list has three elements. The 76 | first element should be opt_init that used to initialize 77 | optimizer state. The second element should be opt_update 78 | that use to update the optimizer state. The third element 79 | should be get_params that feed opt_states and return the 80 | params. 81 | criterion (function): a function use to calculate the loss value. 82 | jit_mode (bool): use the jit mode in jax. 83 | """ 84 | 85 | if not isinstance(model, (tuple, list)) and len(model) != 3: 86 | raise RuntimeError("`model` must be a tuple or list and contains" 87 | "'opt_states', 'init_fun', 'predict_fun'." 88 | "Got: {} {}".format(type(model), len(model))) 89 | 90 | if not isinstance(optimizer, (tuple, list)) and len(optimizer) != 3: 91 | raise RuntimeError( 92 | "`optimizer` must be a tuple or list and contains" 93 | "'opt_init', 'opt_update' and 'get_params'." 94 | "Got: {} {}".format(type(optimizer), len(optimizer))) 95 | 96 | if not hasattr(criterion, "__call__"): 97 | raise RuntimeError( 98 | "The `criterion` must be callable function that " 99 | "feed logits and target, return the loss value. " 100 | "Got: {}".format(type(criterion))) 101 | 102 | self.criterion = criterion 103 | self._register_model(model) 104 | self._register_optimizer(optimizer) 105 | 106 | def _register_model(self, model): 107 | """register model components.""" 108 | 109 | if not isinstance(model[0], 110 | jax.experimental.optimizers.OptimizerState): 111 | raise RuntimeError( 112 | "The first elemente of `model` must be the " 113 | "`opt_states` return from optimizer `opt_init`. " 114 | "Got: {}".format(type(model[0]))) 115 | 116 | if not hasattr(model[1], "__call__"): 117 | raise RuntimeError("The second elemente of `model` must be the " 118 | "`init_fun` return from model. " 119 | "Got: {}".format(type(model[1]))) 120 | 121 | if not hasattr(model[2], "__call__"): 122 | raise RuntimeError("The third elemente of `model` must be the " 123 | "`predict_fun` return from model. " 124 | "Got: {}".format(type(model[2]))) 125 | 126 | self.opt_state = model[0] 127 | self.init_fun = model[1] 128 | self.predict_fun = model[2] 129 | 130 | def _register_optimizer(self, optimizer): 131 | """register optimizer components.""" 132 | if not hasattr(optimizer[0], "__call__"): 133 | raise RuntimeError("The fist elemente of `optimizer` must be the " 134 | "`opt_init` return from optimizer. " 135 | "Got: {}".format(type(optimizer[1]))) 136 | 137 | if not hasattr(optimizer[1], "__call__"): 138 | raise RuntimeError( 139 | "The second elemente of `optimizer` must be the " 140 | "`opt_update` return from optimizer. " 141 | "Got: {}".format(type(optimizer[1]))) 142 | 143 | if not hasattr(optimizer[2], "__call__"): 144 | raise RuntimeError("The third elemente of `optimizer` must be the " 145 | "`get_params` return from optimizer. " 146 | "Got: {}".format(type(optimizer[2]))) 147 | 148 | self.opt_init = optimizer[0] 149 | self.opt_update = optimizer[1] 150 | self.get_params = optimizer[2] 151 | 152 | def register_data(self, *, train_loader=None, validation_loader=None): 153 | self._train_loader = train_loader 154 | self._validation_loader = validation_loader 155 | 156 | def _get_train_loader(self): 157 | return self._train_loader 158 | 159 | def _get_validation_loader(self): 160 | return self._validation_loader 161 | 162 | def loss_func(self, params, batch): 163 | """A function to calculate predictions and loss value. 164 | 165 | This function is going to be decorated by 166 | `grad` in Jax to calculate gradients. 167 | 168 | Args: 169 | params (list): The params return from get_params(opt_states). 170 | batch (tuple): a data batch containing a feature/target pair. 171 | """ 172 | inputs, targets = batch 173 | logits = self.predict_fun(params, inputs) 174 | return self.criterion(logits, targets) 175 | 176 | def derive_updates(self, batch): 177 | """Compute the parameter updates on a given batch of data. 178 | 179 | The `derive_updates` function should be called in conjunction with 180 | the next `apply_updates` function in order to finish one iteration 181 | of training. 182 | 183 | Args: 184 | batch (tuple): a data batch containing a feature/target pair. 185 | """ 186 | loss_val, gradient = self._calculate_gradient(self.opt_state, batch) 187 | 188 | gradient_dict, tree = tree_flatten(gradient) 189 | assert tree == self.opt_state[1] 190 | 191 | if hasattr(self, "preset_keys"): 192 | gradient_dict = { 193 | k: g 194 | for k, g in zip(self.preset_keys, gradient_dict) 195 | } 196 | else: 197 | gradient_dict = { 198 | f"{idx}": g 199 | for idx, g in enumerate(gradient_dict) 200 | } 201 | return loss_val.item(), gradient_dict 202 | 203 | def apply_updates(self, updates): 204 | """Set and apply the updates using the opt_update in Jax. 205 | 206 | Args: 207 | updates (dict): a dictionary of parameter name and updates. 208 | """ 209 | keys, updates = unzip2( 210 | sorted(updates.items(), key=lambda d: int(d[0]))) 211 | updates = tree_unflatten(self.opt_state[1], updates) 212 | self.opt_state = self.opt_update(self.train_step_num, updates, 213 | self.opt_state) 214 | self.train_step_num += 1 215 | 216 | def to_cupy(self, tensor): 217 | """Convert a jax GPU tensor to cupy tensor.""" 218 | if isinstance(tensor, list): 219 | return list(map(self.to_cupy, tensor)) 220 | ctensor = cp.fromDlpack(self.get_jax_dlpack(tensor)) 221 | return ctensor 222 | 223 | def to_operator_tensor(self, tensor): 224 | """Convert a cupy tensor to jax tensor. 225 | 226 | There comes a bug. The layouts of tensor explained by cupy 227 | and jax are different. But dlpack doesn't convert the layout. 228 | """ 229 | if isinstance(tensor, list): 230 | return list(map(self.to_operator_tensor, tensor)) 231 | return from_dlpack(tensor.toDlpack()) 232 | 233 | # TODO(HUI): support return logits by adding use_aux in `value_and_grad` 234 | def _calculate_gradient(self, opt_state, batch): 235 | params = self.get_params(opt_state) 236 | loss_val, gradient = value_and_grad(self.loss_func)(params, batch) 237 | return loss_val, gradient 238 | 239 | def get_jax_dlpack(self, tensor): 240 | """Get the dlpack of a jax tensor. 241 | 242 | Jax api might cause different pointer address after the conversion. 243 | We use the xla api to avoid this bug. 244 | """ 245 | return xla_client._xla.buffer_to_dlpack_managed_tensor( 246 | tensor.device_buffer, take_ownership=False) 247 | 248 | def validate_batch(self, batch): 249 | """Perform validation over a data batch. 250 | 251 | Args: 252 | batch (tuple): a data batch containing a feature/target pair. 253 | """ 254 | params = self.get_params(self.opt_state) 255 | criterion = self.criterion 256 | predict_fun = self.predict_fun 257 | 258 | # unpack features into list to support multiple inputs model 259 | features, targets = batch 260 | 261 | outputs = predict_fun(params, features) 262 | loss = criterion(outputs, targets) 263 | prediction_class = jnp.argmax(outputs, axis=1) 264 | targets_class = jnp.argmax(targets, axis=1) 265 | 266 | acc = jnp.mean(prediction_class == targets_class) 267 | samples_num = targets.shape[0] 268 | 269 | return { 270 | "val_loss": loss.item(), 271 | "val_accuracy": acc.item(), 272 | "samples_num": samples_num 273 | } 274 | 275 | def get_parameters(self, cpu): 276 | """get the flatten parameters.""" 277 | params = self.get_params(self.opt_state) 278 | flatten_params, tree = tree_flatten(params) 279 | if not hasattr(self, "tree"): 280 | self.tree = tree 281 | 282 | if cpu: 283 | flatten_params = list(map(np.asarray, flatten_params)) 284 | return flatten_params 285 | 286 | def get_named_parameters(self, cpu): 287 | """Get the named parameters. 288 | 289 | In jax, we need to construct a dict to contain the parameters. 290 | """ 291 | params = self.get_parameters(cpu) 292 | if hasattr(self, "preset_keys"): 293 | dict_params = { 294 | name: p 295 | for name, p in zip(self.preset_keys, params) 296 | } 297 | else: 298 | dict_params = {f"{idx}": p for idx, p in enumerate(params)} 299 | return dict_params 300 | 301 | # TODO(HUI): used in load states or load parameters 302 | def set_parameters(self, new_params): 303 | """Use new parameters to replace model parameters. 304 | 305 | In jax, we need to construct a dict to contain the parameters. 306 | 307 | Args: 308 | new_params (dict): New parameters to updates the current model. 309 | """ 310 | assert isinstance(new_params, dict) 311 | 312 | keys, new_params = unzip2( 313 | sorted(new_params.items(), key=lambda d: int(d[0]))) 314 | self.preset_keys = keys 315 | 316 | if not hasattr(self, "tree"): 317 | self.tree = tree_structure(self.get_params(self.opt_state)) 318 | 319 | states_flat, tree, subtrees = self.opt_state 320 | 321 | states = map(tree_unflatten, subtrees, states_flat) 322 | 323 | def update(param, state): 324 | new_state = param, *state[1:] 325 | return new_state 326 | 327 | new_states = map(update, new_params, states) 328 | 329 | new_states_flat, new_subtrees = unzip2(map(tree_flatten, new_states)) 330 | 331 | if not new_subtrees: 332 | raise RuntimeError("subtrees of new params is empty.") 333 | for idx, (subtree, new_subtree) in enumerate( 334 | zip(subtrees, new_subtrees)): 335 | if new_subtree != subtree: 336 | msg = ( 337 | "input structur did not match the save params struture. " 338 | "input {} and output {}.") 339 | raise TypeError(msg.format(subtree, new_subtree)) 340 | 341 | self.opt_state = OptimizerState(new_states_flat, tree, new_subtrees) 342 | 343 | def reset_optimizer_for_params(self, params): 344 | if not isinstance(params, dict): 345 | raise RuntimeError("The `params` should be dict. " 346 | "Got {}".format(type(params))) 347 | 348 | keys, params = unzip2(sorted(params.items(), key=lambda d: int(d[0]))) 349 | self.tree = tree_structure(params) 350 | self.opt_state = self.opt_init(params) 351 | 352 | def clean_redundancy(self): 353 | del self._train_loader 354 | del self._validation_loader 355 | 356 | # TODO(HUI): use pickle to serialize parameters or states and save it. 357 | def save_parameters(self, checkpoint): 358 | raise NotImplementedError( 359 | "save_parameters is not support in jax operator.") 360 | 361 | def load_parameters(self, checkpoint): 362 | raise NotImplementedError( 363 | "load_parameters is not support in jax operator.") 364 | 365 | def save_states(self, checkpoint): 366 | raise NotImplementedError( 367 | "save_states is not support in jax operator.") 368 | 369 | def get_states(self): 370 | raise NotImplementedError("get_states is not support in jax operator.") 371 | 372 | def load_states(self, checkpoint): 373 | raise NotImplementedError( 374 | "load_states is not support in jax operator.") 375 | -------------------------------------------------------------------------------- /distml/operator/torch_operator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from distml.operator.base_operator import TrainingOperator 4 | 5 | try: 6 | import torch 7 | from torch.nn.modules.loss import _Loss 8 | except ImportError: 9 | raise ImportError("Please install PyTorch following: " 10 | "https://pytorch.org/get-started/locally/.") 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class TorchTrainingOperator(TrainingOperator): 16 | """Class to define the training logic of a PyTorch Model. 17 | 18 | Args: 19 | operator_config (dict): operator config specified by users. 20 | """ 21 | 22 | def __init__(self, *, operator_config=None, **kwargs): 23 | super(TorchTrainingOperator, 24 | self).__init__(operator_config=operator_config) 25 | # Should be set by users in the `register` function. 26 | self.model = None 27 | self.optimizer = None 28 | self.criterion = None 29 | # Models, optimizers, and criterion registered by users. 30 | self._model = None 31 | self._optimizer = None 32 | self._criterion = None 33 | self._lr_scheduler = None 34 | 35 | # Data loaders for training and validation, registered by users. 36 | self._train_loader = None 37 | self._validation_loader = None 38 | 39 | # TODO(Hao): lift the use_gpu attributes below to operator arguments, 40 | # and support CPU training (with GLOO backend). 41 | self._use_gpu = torch.cuda.is_available() 42 | if not self._use_gpu: 43 | raise RuntimeError( 44 | "ray.util.distml now only supports GPU training.") 45 | self.setup(operator_config) 46 | 47 | def register(self, 48 | *, 49 | model, 50 | optimizer, 51 | criterion, 52 | lr_scheduler=None, 53 | **kwargs): 54 | # TODO(Hao): support custom training loop by allowing multiple model, 55 | # optimizer, e.g. for GAN training 56 | if not isinstance(model, torch.nn.Module): 57 | raise RuntimeError("`model` must be torch.nn.Modules. " 58 | "Got: {}".format(model)) 59 | self._model = model 60 | if self._use_gpu: 61 | self._model.cuda() 62 | if not isinstance(optimizer, torch.optim.Optimizer): 63 | raise RuntimeError("`optimizer` must be torch.optim.Optimizer. " 64 | "Got: {}".format(optimizer)) 65 | 66 | # Note(Hao): this is problematic -- model and criterion are moved 67 | # to gpu but optimizer is constructed before the movement. 68 | # See: https://github.com/ray-project/ray/issues/15258 69 | self._optimizer = optimizer 70 | if criterion: 71 | if not isinstance(criterion, _Loss): 72 | raise RuntimeError( 73 | "`criterion` must be torch.nn.module._Loss. " 74 | "Got: {}".format(self._criterion)) 75 | self._criterion = criterion 76 | if self._use_gpu: 77 | self._criterion.cuda() 78 | # TODO(Hao): support lr schedulers 79 | return self._model, self._optimizer, self._criterion 80 | 81 | def register_data(self, *, train_loader=None, validation_loader=None): 82 | self._train_loader = train_loader 83 | self._validation_loader = validation_loader 84 | # TODO(Hao): convert each data loader to be distributed 85 | 86 | def setup(self, *args, **kwargs): 87 | """Function that needs to be override by users.""" 88 | raise NotImplementedError("Please override this function to register " 89 | "your model, optimizer, and criterion.") 90 | 91 | def derive_updates(self, batch): 92 | """Compute the parameter updates on a given batch of data. 93 | 94 | The `derive_updates` function should be called in conjunction with 95 | the next `apply_updates` function in order to finish one iteration 96 | of training. 97 | 98 | Args: 99 | batch (tuple): a data batch containing a feature/target pair. 100 | """ 101 | # TODO(Hao): 1. Add metric meters 102 | # 2. add lr_scheduler later. 103 | if not self.model: 104 | raise RuntimeError("Please set self.model at setup or override " 105 | "this function for deriving gradient updates.") 106 | model = self.model 107 | if not self.optimizer: 108 | raise RuntimeError( 109 | "Please set self.optimizer at setup or override " 110 | "this function for deriving gradient updates.") 111 | optimizer = self.optimizer 112 | if not self.criterion: 113 | raise RuntimeError( 114 | "Please set self.criterion at setup or override " 115 | "this function for deriving gradient updates.") 116 | criterion = self.criterion 117 | *features, target = batch 118 | model.train() 119 | 120 | if self._use_gpu: 121 | features = [ 122 | feature.cuda(non_blocking=True) for feature in features 123 | ] 124 | target = target.cuda(non_blocking=True) 125 | 126 | # TODO(Hao): scope the code below using a timer? 127 | output = model(*features) 128 | loss = criterion(output, target) 129 | optimizer.zero_grad() 130 | loss.backward() 131 | grads = self._get_gradients(model) 132 | return loss.item(), grads 133 | 134 | def apply_updates(self, updates): 135 | """Set and apply the updates using the optimizer.step() in Torch. 136 | 137 | Args: 138 | updates (dict): a dictionary of parameter name and updates. 139 | """ 140 | self._set_gradients(self.model, updates) 141 | self.optimizer.step() 142 | 143 | def validate_batch(self, batch): 144 | """Perform validation over a data batch. 145 | 146 | Args: 147 | batch (tuple): a data batch containing a feature/target pair. 148 | """ 149 | if not self.model: 150 | raise RuntimeError("Please set self.model at setup or override " 151 | "this function for validation.") 152 | model = self.model 153 | if not self.criterion: 154 | raise RuntimeError( 155 | "Please set self.criterion at setup or override " 156 | "this function for validation.") 157 | criterion = self.criterion 158 | *features, target = batch 159 | model.eval() 160 | if self._use_gpu: 161 | features = [ 162 | feature.cuda(non_blocking=True) for feature in features 163 | ] 164 | target = target.cuda(non_blocking=True) 165 | 166 | with torch.no_grad(): 167 | output = model(*features) 168 | loss = criterion(output, target) 169 | 170 | # Todo(Hao): report accuracy instead loss here. 171 | batch_metric = {"val_loss": loss.item()} 172 | return batch_metric 173 | 174 | def get_states(self): 175 | """Return the states of this training operator.""" 176 | states = { 177 | "model": self._model.state_dict(), 178 | "optimizer": self._optimizer.state_dict(), 179 | "custom": self.get_custom_states() 180 | } 181 | if self._lr_scheduler: 182 | states.update({"lr_scheduler": self._lr_scheduler.state_dict()}) 183 | return states 184 | 185 | def load_states(self, states=None, checkpoint=None): 186 | """Load the states into the operator.""" 187 | if not states and not checkpoint: 188 | raise RuntimeError( 189 | "One of `states` and `checkpoint` should be provided. " 190 | "Got states: {}, checkpoint: {}.".format(states, checkpoint)) 191 | if not states and checkpoint: 192 | states = self._load_from_checkpoint(checkpoint) 193 | self.model.load_state_dict(states["model"]) 194 | self.optimizer.load_state_dict(states["optimizer"]) 195 | if self._lr_scheduler: 196 | self._lr_scheduler.load_state_dict(states["lr_scheduler"]) 197 | self.load_custom_states(states["custom"]) 198 | 199 | def save_states(self, checkpoint): 200 | """Save the states to a file path.""" 201 | states = self.get_states() 202 | # TODO(Hao): test this. 203 | torch.save(states, checkpoint) 204 | 205 | @staticmethod 206 | def _get_gradients(model): 207 | """Return the gradient updates of the model as a Python dict. 208 | 209 | Returns: 210 | grads (dict): a dictionary of parameter name and grad tensors. 211 | """ 212 | grads = {} 213 | for name, p in model.named_parameters(): 214 | # grad = None if p.grad is None else p.grad.data 215 | # grad = None if p.grad is None else p.grad 216 | grads[name] = p.grad.data 217 | logger.debug("grad name: {}, grad type: {}, grad value: {} " 218 | .format(name, type(grads[name]), grads[name])) 219 | return grads 220 | 221 | @staticmethod 222 | def to_cupy(torch_tensor): 223 | """Convert a torch GPU tensor to cupy tensor. 224 | 225 | Since now ray.util.collective natively support torch.Tensor, 226 | so we do nothing in this function. 227 | """ 228 | if not isinstance(torch_tensor, torch.Tensor): 229 | raise RuntimeError("Expected torch.Tensor, but got: {}. " 230 | .format(torch_tensor)) 231 | return torch_tensor 232 | 233 | @staticmethod 234 | def _set_gradients(model, grads): 235 | """Set the model gradients as grads.""" 236 | for name, p in model.named_parameters(): 237 | p.grad = grads[name] 238 | # if grads[name] is not None: 239 | # if p.grad is not None: 240 | # p.grad = torch.from_numpy(gradients[name]). 241 | # to(p.grad.device) 242 | # else: 243 | # p.grad = torch.from_numpy(gradients[name]) 244 | -------------------------------------------------------------------------------- /distml/strategy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/distml/b2d4766664166a0163956c00d8472a03274d4d51/distml/strategy/__init__.py -------------------------------------------------------------------------------- /distml/strategy/allreduce_strategy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import ray 4 | import ray.util.collective as col 5 | from distml.strategy.base_strategy import BaseStrategy 6 | from distml.util import ThroughputCollection 7 | 8 | import numpy as np 9 | 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.INFO) 12 | 13 | 14 | class AllReduceStrategy(BaseStrategy): 15 | """Strategy that trains a model via collective AllReduce. 16 | 17 | Args: 18 | training_operator_cls (TrainingOperator): 19 | Custom training operator class. 20 | operator_config (dict): operator config specified by users. 21 | initialization_hook (function): A function to call on all training 22 | workers when they are first initialized. This could be useful to 23 | set environment variables for all the worker processes. 24 | world_size (int): The number of parallel workers. 25 | num_cpus_per_worker (int): number of CPUs allocated per worker. 26 | num_gpus_per_worker (int): number of GPUs allocated per worker. 27 | """ 28 | 29 | def __init__(self, 30 | *, 31 | training_operator_cls, 32 | operator_config=None, 33 | initialization_hook=None, 34 | world_size=2, 35 | num_cpus_per_worker=1, 36 | num_gpus_per_worker=1, 37 | **kwargs): 38 | super(AllReduceStrategy, self). \ 39 | __init__(training_operator_cls=training_operator_cls, 40 | operator_config=operator_config, 41 | initialization_hook=initialization_hook, 42 | world_size=world_size, 43 | num_cpus_per_worker=num_cpus_per_worker, 44 | num_gpus_per_worker=num_gpus_per_worker, 45 | **kwargs) 46 | self._global_batch_size = None 47 | if operator_config and operator_config.get("batch_size"): 48 | self._global_batch_size = operator_config.get("batch_size") 49 | if self._global_batch_size: 50 | self._collector = ThroughputCollection( 51 | batch_size=self._global_batch_size) 52 | else: 53 | self._collector = ThroughputCollection() 54 | 55 | def train(self, num_steps=None): 56 | """Run the training on parallel workers. 57 | 58 | Args: 59 | num_steps (int): number of steps to train. If none, the 60 | function will simply train for one epoch. 61 | 62 | Returns: 63 | None 64 | """ 65 | # TODO (Hao): add fault tolerance using `max_retries`. 66 | steps = num_steps if num_steps \ 67 | else self.data_parallel_group.get_data_loader_len() 68 | 69 | # TODO(Hao): this call should be hidden inside Replica. 70 | self.data_parallel_group.make_iterator() 71 | for idx in range(steps): 72 | with self._collector.record("train"): 73 | metrics = self.data_parallel_group.train_batch() 74 | print("Step: {}/{}".format(idx, steps)) 75 | return metrics 76 | 77 | def validate(self, num_steps=None): 78 | """Evaluates the model on the validation data. 79 | 80 | Args: 81 | num_steps (int): number of batches to evaluate. If None, the 82 | function will simply validate across the entire validation 83 | dataset. 84 | """ 85 | steps = num_steps if num_steps \ 86 | else self.data_parallel_group.get_data_loader_len(training=False) 87 | self.data_parallel_group.make_iterator(training=False) 88 | for idx in range(steps): 89 | with self._collector.record("validate"): 90 | batch_metrics = self.data_parallel_group.validate_batch() 91 | self._collector.update( 92 | "validate", val_acc=batch_metrics[0]["val_loss"]) 93 | self._collector.save("validate") 94 | # TODO: validate result should be the same in all workers 95 | return batch_metrics 96 | 97 | def _start_workers(self): 98 | """Create distributed workers on the Ray cluster for distributed training. 99 | 100 | Specifically, this function will spawn the necessary actor processes 101 | depending on the strategy used, and arrange and pass their required 102 | arguments. 103 | """ 104 | # TODO (Hao): infer the per-replica batch size here... 105 | # so here we get multiple sets of params that will be passed around: 106 | # (1) Those for setting up replica 107 | operator_config = self._operator_config.copy() 108 | replica_params = dict( 109 | training_operator_cls=self.training_operator_cls, 110 | operator_config=operator_config) 111 | # (2) params for setting up collective group and strategy prep-ups. 112 | dist_params = dict( 113 | strategy="allreduce", 114 | backend="nccl", 115 | group_name="default", 116 | ) 117 | group_init_args = dict( 118 | replica_params=replica_params, 119 | dist_params=dist_params, 120 | initialization_hook=self.initialization_hook, 121 | num_cpus_per_worker=self.num_cpus_per_worker, 122 | num_gpus_per_worker=self.num_gpus_per_worker) 123 | self.data_parallel_group = DataParallelGroup(**group_init_args) 124 | # Once the group is created, we start it. 125 | self.data_parallel_group.start_replicas(self.world_size) 126 | 127 | def shutdown(self, force=False): 128 | self.data_parallel_group.shutdown(force=force) 129 | 130 | def save_parameters(self, checkpoint): 131 | self.data_parallel_group.save_parameters(checkpoint) 132 | 133 | def load_parameters(self, checkpoint): 134 | self.data_parallel_group.load_parameters(checkpoint) 135 | 136 | def _init_strategy(self): 137 | pass 138 | 139 | 140 | class Replica: 141 | """Express the training semantics of a data-parallel replica. 142 | 143 | This class includes some glue code between the user-provided operator 144 | and Ray collective group setup. 145 | """ 146 | 147 | def __init__(self, training_operator_cls, operator_config): 148 | self.training_operator_cls = training_operator_cls 149 | self.operator_config = operator_config 150 | # Training operator 151 | self.training_operator = None 152 | 153 | # collective-related information 154 | self._world_size = None 155 | self._rank = None 156 | self._group_name = None 157 | 158 | # Iterators 159 | self.train_iterator = None 160 | self.validation_iterator = None 161 | 162 | def setup_operator(self): 163 | """Instantiate the training operator.""" 164 | self.training_operator = self.training_operator_cls( 165 | operator_config=self.operator_config) 166 | 167 | def setup_collective_group(self, 168 | rank, 169 | world_size, 170 | backend, 171 | group_name="default"): 172 | self._rank = rank 173 | self._group_name = group_name 174 | self._world_size = world_size 175 | col.init_collective_group( 176 | world_size, rank, backend=backend, group_name=group_name) 177 | 178 | def make_iterator(self, training=True): 179 | """Convert loader to be an iterator at the start of an epoch.""" 180 | # TODO(Hao): need to check whether reaching the boundary of iterator 181 | # instead of making a new one every time. 182 | if training: 183 | self.train_iterator = iter(self.train_loader) 184 | else: 185 | self.validation_iterator = iter(self.validation_loader) 186 | 187 | def get_data_loader_len(self, training=True): 188 | """Return the number of batches in the data loader.""" 189 | loader = self.train_loader if training \ 190 | else self.validation_loader 191 | if hasattr(loader, "__len__"): 192 | return len(loader) 193 | else: 194 | raise RuntimeError( 195 | "Data loader has no attribute `__len__`. " 196 | "Please set `num_steps` in `train()` or `validate()`.") 197 | 198 | def train_batch(self): 199 | metrics = {} 200 | try: 201 | batch = next(self.train_iterator) 202 | except StopIteration and NameError: 203 | self.make_iterator() 204 | batch = next(self.train_iterator) 205 | loss_val, updates = self.derive_updates(batch) 206 | assert isinstance(updates, dict) 207 | 208 | metrics["train_loss"] = loss_val 209 | for _, g in updates.items(): 210 | cg = self.training_operator.to_cupy(g) 211 | col.allreduce(cg) 212 | # TODO(Hao): this is conflicting with Runhui's code though. 213 | cg = cg / float(self.world_size) 214 | self.apply_updates(updates) 215 | return metrics 216 | 217 | def derive_updates(self, batch): 218 | return self.training_operator.derive_updates(batch) 219 | 220 | def apply_updates(self, updates): 221 | # TODO(Hao): conflicting with Runhui's code on averaging grads 222 | self.training_operator.apply_updates(updates) 223 | 224 | def updates_transform(self, updates): 225 | return self.training_operator.updates_transform(updates) 226 | 227 | def validate_batch(self): 228 | try: 229 | batch = next(self.validation_iterator) 230 | except StopIteration and NameError: 231 | self.make_iterator(training=False) 232 | batch = next(self.validation_iterator) 233 | batch_metric = self.training_operator.validate_batch(batch) 234 | return batch_metric 235 | 236 | def shutdown(self): 237 | # destroy the collective group resources on this process 238 | col.destroy_collective_group(self.group_name) 239 | if self.training_operator: 240 | del self.training_operator 241 | return 1 242 | 243 | def save_parameters(self, checkpoint): 244 | self.training_operator.save_parameters(checkpoint) 245 | 246 | def load_parameters(self, checkpoint): 247 | self.training_operator.load_parameters(checkpoint) 248 | 249 | def apply(self, fn): 250 | """Apply a function in the replica process.""" 251 | return fn() 252 | 253 | @property 254 | def train_loader(self): 255 | return self.training_operator._get_train_loader() 256 | 257 | @property 258 | def validation_loader(self): 259 | return self.training_operator._get_validation_loader() 260 | 261 | @property 262 | def world_size(self): 263 | return self._world_size 264 | 265 | @property 266 | def rank(self): 267 | return self._rank 268 | 269 | @property 270 | def group_name(self): 271 | return self._group_name 272 | 273 | 274 | class DataParallelGroup: 275 | """Spawn a group a replicas for data-parallel training.""" 276 | 277 | def __init__(self, replica_params, dist_params, initialization_hook, 278 | num_cpus_per_worker, num_gpus_per_worker): 279 | self._replica_params = replica_params 280 | self._dist_params = dist_params 281 | 282 | # try to unroll the dist_params 283 | self._backend = self._dist_params["backend"] 284 | self._group_name = self._dist_params["group_name"] 285 | 286 | self._initialization_hook = initialization_hook 287 | self._num_cpus_per_worker = num_cpus_per_worker 288 | self._num_gpus_per_worker = num_gpus_per_worker 289 | self._replicas = None 290 | 291 | @property 292 | def replicas(self): 293 | return self._replicas 294 | 295 | @property 296 | def world_size(self): 297 | return len(self._replicas) 298 | 299 | @property 300 | def backend(self): 301 | return self._backend 302 | 303 | @property 304 | def group_name(self): 305 | return self._group_name 306 | 307 | def start_replicas(self, num_replicas): 308 | assert num_replicas > 1 309 | RemoteReplica = ray.remote( 310 | num_cpus=self._num_cpus_per_worker, 311 | num_gpus=self._num_gpus_per_worker)(Replica) 312 | self._replicas = [ 313 | RemoteReplica.remote(**self._replica_params) 314 | for _ in range(num_replicas) 315 | ] 316 | 317 | # apply init_hook 318 | if self._initialization_hook: 319 | self.apply_all_replicas(self._initialization_hook) 320 | 321 | # setup the rank and group in each replica 322 | group_setup_refs = self._setup_collective_group( 323 | self.world_size, self.backend, self.group_name) 324 | ray.get(group_setup_refs) 325 | 326 | # setup the model training operator 327 | operator_setups = self._setup_operator() 328 | ray.get(operator_setups) 329 | 330 | def _make_iterator(self, training): 331 | return [ 332 | replica.make_iterator.remote(training=training) 333 | for replica in self.replicas 334 | ] 335 | 336 | def make_iterator(self, training=True): 337 | ray.get(self._make_iterator(training=training)) 338 | 339 | def get_data_loader_len(self, training=True): 340 | """Return the number of batches in the data loader.""" 341 | lens = ray.get([ 342 | replica.get_data_loader_len.remote(training=training) 343 | for replica in self.replicas 344 | ]) 345 | if len(set(lens)) != 1: 346 | # TODO(Hao): is this correct after we add distributed data loader? 347 | raise RuntimeError( 348 | "All replica should have the same dataloader len.") 349 | return lens[0] 350 | 351 | def train_batch(self): 352 | metrics = {} 353 | loss_vals = ray.get( 354 | [replica.train_batch.remote() for replica in self.replicas]) 355 | train_loss_list = [d["train_loss"] for d in loss_vals] 356 | metrics["train_loss"] = np.mean(train_loss_list) 357 | return metrics 358 | 359 | def validate_batch(self): 360 | rets = [replica.validate_batch.remote() for replica in self.replicas] 361 | stats = ray.get(rets) 362 | return stats 363 | 364 | def shutdown(self, force=False): 365 | rets = [replica.shutdown.remote() for replica in self.replicas] 366 | stats = ray.get(rets) 367 | return stats 368 | 369 | def reset(self): 370 | pass 371 | 372 | def save_parameters(self, checkpoint): 373 | rets = [self.replicas[0].save_parameters.remote(checkpoint)] 374 | ray.get(rets) 375 | 376 | def load_parameters(self, checkpoint): 377 | rets = [ 378 | replica.load_parameters.remote(checkpoint) 379 | for _, replica in enumerate(self.replicas) 380 | ] 381 | ray.get(rets) 382 | 383 | def set_parameters(self, params): 384 | rets = [ 385 | replica.set_parameters.remote(params) 386 | for _, replica in enumerate(self.replicas) 387 | ] 388 | ray.get(rets) 389 | 390 | def get_parameters(self, cpu=False): 391 | ret = self.replicas[0].get_parameters.remote(cpu) 392 | return ray.get(ret)[0] 393 | 394 | def get_named_parameters(self, cpu=False): 395 | ret = self.replicas[0].get_named_parameters.remote(cpu) 396 | return ray.get([ret])[0] 397 | 398 | def apply_all_replicas(self, fn): 399 | """Apply fn in all replica processes and wait until completion.""" 400 | return ray.get(self._apply_all_replicas(fn)) 401 | 402 | def _apply_all_replicas(self, fn): 403 | """Apply a function fn in all replica processes.""" 404 | return [replica.apply.remote(fn) for replica in self.replicas] 405 | 406 | def _setup_collective_group(self, 407 | world_size, 408 | backend, 409 | group_name="default"): 410 | refs = [ 411 | replica.setup_collective_group.remote( 412 | rank=i, 413 | world_size=world_size, 414 | backend=backend, 415 | group_name=group_name) 416 | for i, replica in enumerate(self.replicas) 417 | ] 418 | return refs 419 | 420 | def _setup_operator(self): 421 | refs = [ 422 | replica.setup_operator.remote() 423 | for i, replica in enumerate(self.replicas) 424 | ] 425 | return refs 426 | -------------------------------------------------------------------------------- /distml/strategy/base_strategy.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from abc import abstractmethod 3 | import logging 4 | 5 | import ray 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class BaseStrategy(metaclass=ABCMeta): 11 | def __init__(self, 12 | *, 13 | training_operator_cls, 14 | operator_config=None, 15 | initialization_hook=None, 16 | world_size=2, 17 | num_cpus_per_worker=1, 18 | num_gpus_per_worker=1, 19 | **kwargs): 20 | self.training_operator_cls = training_operator_cls 21 | self.initialization_hook = initialization_hook 22 | if world_size < 2: 23 | raise RuntimeError( 24 | "ray.util.distml does not support single-process training " 25 | "at this moment.") 26 | self.world_size = world_size 27 | self.num_cpus_per_worker = num_cpus_per_worker 28 | self.num_gpus_per_worker = num_gpus_per_worker 29 | self._operator_config = {} if not operator_config \ 30 | else operator_config 31 | if not ray.is_initialized() and self.world_size > 1: 32 | logger.info("Automatically initializing single-node Ray. To use " 33 | "multi-node training, be sure to run `ray.init(" 34 | "address='auto')` before instantiating the Strategy.") 35 | ray.init() 36 | self._start_workers() 37 | 38 | @abstractmethod 39 | def train(self, *args, **kwargs): 40 | """Run the training on parallel workers.""" 41 | raise NotImplementedError() 42 | 43 | @abstractmethod 44 | def validate(self): 45 | """Call operator validate to evaluate val_dataloader. 46 | """ 47 | raise NotImplementedError() 48 | 49 | @abstractmethod 50 | def save_parameters(self, checkpoint): 51 | """Saves the Trainer state to the provided checkpoint path. 52 | 53 | Args: 54 | checkpoint (str): Path to target checkpoint file. 55 | """ 56 | raise NotImplementedError() 57 | 58 | @abstractmethod 59 | def load_parameters(self, checkpoint): 60 | raise NotImplementedError() 61 | 62 | @abstractmethod 63 | def _start_workers(self): 64 | """Start all the workers to be used for training.""" 65 | raise NotImplementedError() 66 | 67 | @abstractmethod 68 | def _init_strategy(self): 69 | """Strategy-specific prep-up.""" 70 | raise NotImplementedError() 71 | 72 | @abstractmethod 73 | def shutdown(self, force=False): 74 | """Kill all workers.""" 75 | raise NotImplementedError() 76 | -------------------------------------------------------------------------------- /distml/util.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import wraps 3 | from collections import defaultdict 4 | 5 | import pandas as pd 6 | import numpy as np 7 | try: 8 | import cupy as cp 9 | except ModuleNotFoundError: 10 | raise ModuleNotFoundError("Please install cupy following: " 11 | "https://docs.cupy.dev/en/stable/install.html.") 12 | 13 | from ray.util.sgd.utils import TimerCollection 14 | 15 | 16 | def override(interface_class): 17 | def overrider(method): 18 | assert (method.__name__ in dir(interface_class)) 19 | return method 20 | 21 | return overrider 22 | 23 | 24 | # some operation for this ml system. 25 | def ones(shape, cpu=True): 26 | if cpu: 27 | return np.ones(shape) 28 | else: 29 | return cp.ones(shape) 30 | 31 | 32 | def zeros_like(x, cpu=True): 33 | if cpu: 34 | return np.ones_like(x) 35 | else: 36 | return cp.ones_like(x) 37 | 38 | 39 | # some operation for this ml system. 40 | def zeros(shape, cpu=True): 41 | if cpu: 42 | return np.zeros(shape) 43 | else: 44 | return cp.zeros(shape) 45 | 46 | 47 | def numel(v): 48 | return np.size(v) 49 | 50 | 51 | class EmptyTimeState: 52 | def __enter__(self): 53 | pass 54 | 55 | def __exit__(self, type, value, tb): 56 | pass 57 | 58 | 59 | class ThroughputCollection(TimerCollection): 60 | def __init__( 61 | self, 62 | batch_size=32, 63 | # num_workers=1, 64 | save_freq=50, 65 | job_name="default"): 66 | self.batch_size = batch_size 67 | self.save_freq = save_freq 68 | self.job_name = job_name 69 | super(ThroughputCollection, self).__init__() 70 | 71 | def defaultdict_list(): 72 | return defaultdict(list) 73 | 74 | self.result_collection = defaultdict(defaultdict_list) 75 | self.key_steps = defaultdict(int) 76 | 77 | def record(self, key): 78 | if self._enabled: 79 | self.key_steps[key] += 1 80 | if self.key_steps[key] % self.save_freq == 0: 81 | self.save(key) 82 | return self._timers[key] 83 | else: 84 | return EmptyTimeState() 85 | 86 | def update(self, key, **kwargs): 87 | # call before this key record start. 88 | for k, v in kwargs.items(): 89 | self.result_collection[key][k].append(v) 90 | 91 | def report(self, key): 92 | aggregates = {} 93 | k, t = key, self._timers[key] 94 | aggregates[f"count_{k}"] = t.count + 1 95 | aggregates[f"mean_{k}_s"] = t.mean 96 | aggregates[f"last_{k}_s"] = t.last 97 | aggregates[f"total_{k}_s"] = t._total_time 98 | aggregates[ 99 | f"pass_data_{k}"] = aggregates[f"count_{k}"] * self.batch_size 100 | aggregates[ 101 | f"throughout_{k}_d"] = aggregates[f"pass_data_{k}"] / t._total_time 102 | for metric in aggregates.keys(): 103 | self.result_collection[k][metric].append(aggregates[metric]) 104 | return aggregates 105 | 106 | def save(self, key): 107 | self.report(key) 108 | df = pd.DataFrame.from_dict( 109 | self.result_collection[key], orient="columns") 110 | df.to_csv(f"{self.job_name}_{key}.csv", index=None) 111 | 112 | 113 | def func_timer(function): 114 | """A decorator to record time.""" 115 | 116 | @wraps(function) 117 | def function_timer(*args, **kwargs): 118 | t0 = time.time() 119 | result = function(*args, **kwargs) 120 | t1 = time.time() 121 | print("[Function: {name} finished, spent time: {time:.5f}s]".format( 122 | name=function.__name__, time=t1 - t0)) 123 | return result 124 | 125 | return function_timer 126 | -------------------------------------------------------------------------------- /examples/jax/jax_util/__init__.py: -------------------------------------------------------------------------------- 1 | from jax_util.datasets import mnist, Dataloader # noqa: F401 2 | from jax_util.resnet import ResNet18, ResNet50, ResNet101 # noqa: F401 3 | -------------------------------------------------------------------------------- /examples/jax/jax_util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Datasets used in examples.""" 15 | 16 | import array 17 | import gzip 18 | import os 19 | from os import path 20 | import struct 21 | import urllib.request 22 | import jax.numpy as jnp 23 | 24 | import numpy as np 25 | import numpy.random as npr 26 | import pickle 27 | 28 | _DATA = "/tmp/jax_example_data/" 29 | 30 | 31 | def _download(url, filename, dataset_name="mnist"): 32 | """Download a url to a file in the JAX data temp directory.""" 33 | root = os.path.join(_DATA, dataset_name) 34 | if not path.exists(root): 35 | os.makedirs(root) 36 | out_file = path.join(root, filename) 37 | if not path.isfile(out_file): 38 | urllib.request.urlretrieve(url, out_file) 39 | print("downloaded {} to {}".format(url, root)) 40 | 41 | 42 | def _partial_flatten(x): 43 | """Flatten all but the first dimension of an ndarray.""" 44 | return np.reshape(x, (x.shape[0], -1)) 45 | 46 | 47 | def _one_hot(x, k, dtype=np.float32): 48 | """Create a one-hot encoding of x of size k.""" 49 | return np.asarray(x[:, None] == np.arange(k), dtype) 50 | 51 | 52 | # @partial(jit, static_argnums=1) 53 | def _one_hot_jit(x, k, dtype=np.float32): 54 | """Create a one-hot encoding of x of size k.""" 55 | return jnp.asarray(x[:, None] == jnp.arange(0, k), dtype) 56 | 57 | 58 | def mnist_raw(): 59 | """Download and parse the raw MNIST dataset.""" 60 | # CVDF mirror of http://yann.lecun.com/exdb/mnist/ 61 | base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" 62 | 63 | def parse_labels(filename): 64 | with gzip.open(filename, "rb") as fh: 65 | _ = struct.unpack(">II", fh.read(8)) 66 | return np.array(array.array("B", fh.read()), dtype=np.uint8) 67 | 68 | def parse_images(filename): 69 | with gzip.open(filename, "rb") as fh: 70 | _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) 71 | return np.array( 72 | array.array("B", fh.read()), dtype=np.uint8).reshape( 73 | num_data, rows, cols) 74 | 75 | for filename in [ 76 | "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", 77 | "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz" 78 | ]: 79 | _download(base_url + filename, filename) 80 | 81 | train_images = parse_images( 82 | path.join(_DATA, "mnist", "train-images-idx3-ubyte.gz")) 83 | train_labels = parse_labels( 84 | path.join(_DATA, "mnist", "train-labels-idx1-ubyte.gz")) 85 | test_images = parse_images( 86 | path.join(_DATA, "mnist", "t10k-images-idx3-ubyte.gz")) 87 | test_labels = parse_labels( 88 | path.join(_DATA, "mnist", "t10k-labels-idx1-ubyte.gz")) 89 | 90 | return train_images, train_labels, test_images, test_labels 91 | 92 | 93 | def mnist(permute_train=False): 94 | """ 95 | Download, parse and process MNIST data 96 | to unit scale and one-hot labels. 97 | """ 98 | train_images, train_labels, test_images, test_labels = mnist_raw() 99 | 100 | train_images = _partial_flatten(train_images) / np.float32(255.) 101 | test_images = _partial_flatten(test_images) / np.float32(255.) 102 | train_labels = _one_hot(train_labels, 10) 103 | test_labels = _one_hot(test_labels, 10) 104 | 105 | if permute_train: 106 | perm = np.random.RandomState(0).permutation(train_images.shape[0]) 107 | train_images = train_images[perm] 108 | train_labels = train_labels[perm] 109 | 110 | return train_images, train_labels, test_images, test_labels 111 | 112 | 113 | def cifa100_raw(): 114 | """Download and parse the raw MNIST dataset.""" 115 | base_url = "http://www.cs.toronto.edu/~kriz/" 116 | 117 | def load_CIFAR_batch(root, mode="train"): 118 | """ load single batch of cifar """ 119 | if mode == "train": 120 | filename = path.join(root, "train") 121 | elif mode == "test": 122 | filename = path.join(root, "test") 123 | else: 124 | raise RuntimeError("Error: unrecognized mode", 125 | " Got {}".format(mode)) 126 | 127 | with open(filename, "rb") as f: 128 | datadict = pickle.load(f, encoding="bytes") 129 | X = datadict[b"data"] 130 | Y = datadict[b"fine_labels"] 131 | if mode == "train": 132 | X = X.reshape(50000, 3, 32, 32) 133 | else: 134 | X = X.reshape(10000, 3, 32, 32) 135 | return np.array(X), np.array(Y) 136 | 137 | for filename in ["cifar-100-python.tar.gz"]: 138 | _download(base_url + filename, filename, dataset_name="cifa100") 139 | 140 | root = path.join(_DATA, "cifa100") 141 | 142 | if not os.path.exists(path.join(root, "cifar-100-python.tar.gz")): 143 | os.system("tar xvf {} -C {}".format( 144 | path.join(root, "cifar-100-python.tar.gz"), root)) 145 | 146 | train_images, train_labels = load_CIFAR_batch( 147 | path.join(root, "cifar-100-python"), mode="train") 148 | test_images, test_labels = load_CIFAR_batch( 149 | path.join(root, "cifar-100-python"), mode="test") 150 | 151 | # b"fine_label_names" b"coarse_label_names" 152 | # meta_path = path.join(root, "cifar-100-python", "meta") 153 | return train_images, train_labels, test_images, test_labels 154 | 155 | 156 | def cifa100(permute_train=False): 157 | """ 158 | Download, parse and process cida100 data to unit scale and one-hot labels. 159 | """ 160 | train_images, train_labels, test_images, test_labels = cifa100_raw() 161 | 162 | train_images = _partial_flatten(train_images) / np.float32(255.) 163 | test_images = _partial_flatten(test_images) / np.float32(255.) 164 | train_labels = _one_hot(train_labels, 100) 165 | test_labels = _one_hot(test_labels, 100) 166 | 167 | if permute_train: 168 | perm = np.random.RandomState(0).permutation(train_images.shape[0]) 169 | train_images = train_images[perm] 170 | train_labels = train_labels[perm] 171 | 172 | return train_images, train_labels, test_images, test_labels 173 | 174 | 175 | def cifa10_raw(): 176 | """Download and parse the raw MNIST dataset.""" 177 | base_url = "http://www.cs.toronto.edu/~kriz/" 178 | 179 | def load_CIFAR_batch(root, mode="train"): 180 | """ load single batch of cifar """ 181 | if mode == "train": 182 | filenames = [] 183 | for i in range(1, 6): 184 | filenames.append(path.join(root, f"data_batch_{i}")) 185 | elif mode == "test": 186 | filenames = [path.join(root, "test_batch")] 187 | else: 188 | raise RuntimeError("Error: unrecognized mode", 189 | " Got {}".format(mode)) 190 | print(filenames) 191 | datas = [] 192 | labels = [] 193 | for filename in filenames: 194 | with open(filename, "rb") as f: 195 | datadict = pickle.load(f, encoding="bytes") 196 | X = datadict[b"data"] 197 | Y = datadict[b"labels"] 198 | X = X.reshape(10000, 3, 32, 32) 199 | datas.append(X) 200 | labels.append(Y) 201 | return np.concatenate(datas, axis=0), np.concatenate(labels) 202 | 203 | for filename in ["cifar-10-python.tar.gz"]: 204 | _download(base_url + filename, filename, dataset_name="cifa10") 205 | 206 | root = path.join(_DATA, "cifa10") 207 | 208 | if not os.path.exists(path.join(root, "cifar-10-batches-py")): 209 | os.system("tar xvf {} -C {}".format( 210 | path.join(root, "cifar-10-python.tar.gz"), root)) 211 | 212 | train_images, train_labels = load_CIFAR_batch( 213 | path.join(root, "cifar-10-batches-py"), mode="train") 214 | test_images, test_labels = load_CIFAR_batch( 215 | path.join(root, "cifar-10-batches-py"), mode="test") 216 | print(test_images.shape) 217 | 218 | # b"fine_label_names" b"coarse_label_names" 219 | # meta_path = path.join(root, "cifar-100-python", "meta") 220 | return train_images, train_labels, test_images, test_labels 221 | 222 | 223 | def cifa10(permute_train=False): 224 | """ 225 | Download, parse and process cida100 data 226 | to unit scale and one-hot labels. 227 | """ 228 | train_images, train_labels, test_images, test_labels = cifa10_raw() 229 | 230 | train_images = _partial_flatten(train_images) / np.float32(255.) 231 | test_images = _partial_flatten(test_images) / np.float32(255.) 232 | train_labels = _one_hot(train_labels, 10) 233 | test_labels = _one_hot(test_labels, 10) 234 | 235 | if permute_train: 236 | perm = np.random.RandomState(0).permutation(train_images.shape[0]) 237 | train_images = train_images[perm] 238 | train_labels = train_labels[perm] 239 | 240 | return train_images, train_labels, test_images, test_labels 241 | 242 | 243 | class Dataloader: 244 | def __init__(self, data, target, batch_size=128, shuffle=False): 245 | """Init the data loader. 246 | 247 | Args: 248 | data: shape(width, height, channel, num) 249 | target: shape(num, num_classes) 250 | """ 251 | self.data = data 252 | self.target = target 253 | self.batch_size = batch_size 254 | num_data = self.target.shape[0] 255 | num_complete_batches, leftover = divmod(num_data, batch_size) 256 | self.num_batches = num_complete_batches + bool(leftover) 257 | self.shuffle = shuffle 258 | 259 | def synth_batches(self): 260 | num_imgs = self.target.shape[0] 261 | rng = npr.RandomState(npr.randint(10)) 262 | perm = rng.permutation(num_imgs) if self.shuffle else np.arange( 263 | num_imgs) 264 | for i in range(self.num_batches): 265 | batch_idx = perm[i * self.batch_size:(i + 1) * self.batch_size] 266 | img_batch = self.data[:, :, :, batch_idx] 267 | label_batch = self.target[batch_idx] 268 | yield img_batch, label_batch 269 | 270 | def __iter__(self): 271 | return self.synth_batches() 272 | 273 | def __len__(self): 274 | return self.num_batches 275 | 276 | 277 | if __name__ == "__main__": 278 | train_images, train_labels, test_images, test_labels = cifa10() 279 | 280 | print(type(train_images), type(train_labels)) 281 | print(train_images.shape, train_labels.shape) 282 | print(type(test_images), type(test_labels)) 283 | print(test_images.shape, test_labels.shape) 284 | 285 | train_images, train_labels, test_images, test_labels = cifa100() 286 | 287 | print(type(train_images), type(train_labels)) 288 | print(train_images.shape, train_labels.shape) 289 | print(type(test_images), type(test_labels)) 290 | print(test_images.shape, test_labels.shape) 291 | -------------------------------------------------------------------------------- /examples/jax/jax_util/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A mock-up showing a ResNet50 network with training on synthetic data. 15 | 16 | This file uses the stax neural network definition library and the optimizers 17 | optimization library. 18 | """ 19 | 20 | import numpy.random as npr 21 | 22 | import jax.numpy as jnp 23 | from jax import jit, grad, random 24 | from jax.experimental import optimizers 25 | from jax.experimental import stax 26 | from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum, 27 | FanOut, Flatten, GeneralConv, Identity, 28 | MaxPool, Relu, LogSoftmax) 29 | 30 | # ResNet blocks compose other layers 31 | 32 | 33 | def ConvBlock(kernel_size, filters, strides=(2, 2)): 34 | ks = kernel_size 35 | filters1, filters2, filters3 = filters 36 | Main = stax.serial( 37 | Conv(filters1, (1, 1), strides), BatchNorm(), Relu, 38 | Conv(filters2, (ks, ks), padding="SAME"), BatchNorm(), Relu, 39 | Conv(filters3, (1, 1)), BatchNorm()) 40 | Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) 41 | return stax.serial( 42 | FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) 43 | 44 | 45 | def IdentityBlock(kernel_size, filters): 46 | ks = kernel_size 47 | filters1, filters2 = filters 48 | 49 | def make_main(input_shape): 50 | # the number of output channels depends on the number of input channels 51 | return stax.serial( 52 | Conv(filters1, (1, 1)), BatchNorm(), Relu, 53 | Conv(filters2, (ks, ks), padding="SAME"), BatchNorm(), Relu, 54 | Conv(input_shape[3], (1, 1)), BatchNorm()) 55 | 56 | Main = stax.shape_dependent(make_main) 57 | return stax.serial( 58 | FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu) 59 | 60 | 61 | def BasicBlock(kernel_size, filters, strides=(1, 1)): 62 | ks = kernel_size 63 | filters1, filters2 = filters 64 | Main = stax.serial( 65 | Conv(filters1, (ks, ks), strides, padding="SAME"), BatchNorm(), Relu, 66 | Conv(filters2, (ks, ks), strides, padding="SAME"), BatchNorm()) 67 | 68 | Shortcut = stax.serial(Conv(filters2, (1, 1), strides), BatchNorm()) 69 | return stax.serial( 70 | FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) 71 | 72 | 73 | def BasicBlock_withoutBN(kernel_size, filters, strides=(1, 1)): 74 | ks = kernel_size 75 | filters1, filters2 = filters 76 | Main = stax.serial( 77 | Conv(filters1, (ks, ks), strides, padding="SAME"), Relu, 78 | Conv(filters2, (ks, ks), strides, padding="SAME")) 79 | 80 | Shortcut = stax.serial(Conv(filters2, (1, 1), strides)) 81 | return stax.serial( 82 | FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu) 83 | 84 | 85 | def IdentityBlock_withoutBN(kernel_size, filters): 86 | ks = kernel_size 87 | filters1, filters2 = filters 88 | 89 | def make_main(input_shape): 90 | # the number of output channels depends on the number of input channels 91 | return stax.serial( 92 | Conv(filters1, (1, 1)), Relu, 93 | Conv(filters2, (ks, ks), padding="SAME"), Relu, 94 | Conv(input_shape[3], (1, 1))) 95 | 96 | Main = stax.shape_dependent(make_main) 97 | return stax.serial( 98 | FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu) 99 | 100 | 101 | # ResNet architectures compose layers and ResNet blocks 102 | 103 | 104 | def ResNet101(num_classes): 105 | return stax.serial( 106 | GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), 107 | BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), 108 | ConvBlock(3, [64, 64, 256], strides=(1, 109 | 1)), IdentityBlock(3, [64, 64]), 110 | IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), 111 | IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), 112 | IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), 113 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 114 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 115 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 116 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 117 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 118 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 119 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 120 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 121 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 122 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 123 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 124 | ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), 125 | IdentityBlock(3, [512, 512]), AvgPool((7, 7), padding="SAME"), Flatten, 126 | Dense(num_classes), LogSoftmax) 127 | 128 | 129 | def ResNet50(num_classes): 130 | return stax.serial( 131 | GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), 132 | BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), 133 | ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock( 134 | 3, [64, 64]), IdentityBlock(3, [64, 64]), 135 | ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), 136 | IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), 137 | ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), 138 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 139 | IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 140 | ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), 141 | IdentityBlock(3, [512, 512]), AvgPool((7, 7), padding="SAME"), Flatten, 142 | Dense(num_classes), LogSoftmax) 143 | 144 | 145 | def ResNet18(num_classes): 146 | return stax.serial( 147 | GeneralConv(("HWCN", "OIHW", "NHWC"), 1, (7, 7), (2, 2), "SAME"), 148 | BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), 149 | BasicBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), 150 | BasicBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), 151 | BasicBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), 152 | BasicBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), 153 | AvgPool((7, 7), padding="SAME"), Flatten, Dense(num_classes), 154 | LogSoftmax) 155 | 156 | 157 | def MLP(num_classes): 158 | return stax.serial(Flatten, Dense(32), BatchNorm(), Relu, Dense(128), 159 | BatchNorm(), Relu, Dense(num_classes), LogSoftmax) 160 | 161 | 162 | if __name__ == "__main__": 163 | rng_key = random.PRNGKey(0) 164 | 165 | batch_size = 8 166 | num_classes = 1001 167 | input_shape = (224, 224, 3, batch_size) 168 | step_size = 0.1 169 | num_steps = 10 170 | 171 | init_fun, predict_fun = ResNet50(num_classes) 172 | _, init_params = init_fun(rng_key, input_shape) 173 | 174 | def loss(params, batch): 175 | inputs, targets = batch 176 | logits = predict_fun(params, inputs) 177 | return -jnp.sum(logits * targets) 178 | 179 | def accuracy(params, batch): 180 | inputs, targets = batch 181 | target_class = jnp.argmax(targets, axis=-1) 182 | predicted_class = jnp.argmax(predict_fun(params, inputs), axis=-1) 183 | return jnp.mean(predicted_class == target_class) 184 | 185 | def synth_batches(): 186 | rng = npr.RandomState(0) 187 | while True: 188 | images = rng.rand(*input_shape).astype("float32") 189 | labels = rng.randint(num_classes, size=(batch_size, 1)) 190 | onehot_labels = labels == jnp.arange(num_classes) 191 | yield images, onehot_labels 192 | 193 | opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9) 194 | batches = synth_batches() 195 | 196 | @jit 197 | def update(i, opt_state, batch): 198 | params = get_params(opt_state) 199 | return opt_update(i, grad(loss)(params, batch), opt_state) 200 | 201 | opt_state = opt_init(init_params) 202 | for i in range(num_steps): 203 | opt_state = update(i, opt_state, next(batches)) 204 | trained_params = get_params(opt_state) 205 | -------------------------------------------------------------------------------- /examples/jax/mnist_jax_example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from filelock import FileLock 5 | 6 | import ray 7 | from distml.operator.jax_operator import JAXTrainingOperator 8 | from distml.strategy.allreduce_strategy import AllReduceStrategy 9 | 10 | from ray.util.sgd.utils import override 11 | 12 | from jax import random 13 | from jax.experimental import optimizers 14 | import jax.numpy as jnp 15 | from jax_util.resnet import ResNet18, ResNet50, ResNet101 16 | from jax_util.datasets import mnist, Dataloader 17 | 18 | 19 | def initialization_hook(): 20 | # Need this for avoiding a connection restart issue on AWS. 21 | os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" 22 | os.environ["NCCL_LL_THRESHOLD"] = "0" 23 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" 24 | # set the below if needed 25 | # print("NCCL DEBUG SET") 26 | # os.environ["NCCL_DEBUG"] = "INFO" 27 | 28 | 29 | class MnistTrainingOperator(JAXTrainingOperator): 30 | @override(JAXTrainingOperator) 31 | def setup(self, config): 32 | batch_size = config["batch_size"] 33 | rng_key = random.PRNGKey(0) 34 | input_shape = (28, 28, 1, batch_size) 35 | lr = config["lr"] 36 | model_name = config["model_name"] 37 | num_classes = config["num_classes"] 38 | 39 | if model_name == "resnet18": 40 | init_fun, predict_fun = ResNet18(num_classes) 41 | elif model_name == "resnet50": 42 | init_fun, predict_fun = ResNet50(num_classes) 43 | elif model_name == "resnet101": 44 | init_fun, predict_fun = ResNet101(num_classes) 45 | else: 46 | raise RuntimeError("Unrecognized model name") 47 | 48 | _, init_params = init_fun(rng_key, input_shape) 49 | 50 | opt_init, opt_update, get_params = optimizers.adam(lr) 51 | opt_state = opt_init(init_params) 52 | 53 | with FileLock(".ray.lock"): 54 | train_images, train_labels, test_images, test_labels = mnist() 55 | 56 | train_images = train_images.reshape(train_images.shape[0], 1, 28, 57 | 28).transpose(2, 3, 1, 0) 58 | 59 | test_images = test_images.reshape(test_images.shape[0], 1, 28, 60 | 28).transpose(2, 3, 1, 0) 61 | 62 | train_loader = Dataloader( 63 | train_images, train_labels, batch_size=batch_size, shuffle=True) 64 | test_loader = Dataloader( 65 | test_images, test_labels, batch_size=batch_size) 66 | 67 | def criterion(logits, targets): 68 | return -jnp.sum(logits * targets) 69 | 70 | self.register( 71 | model=[opt_state, init_fun, predict_fun], 72 | optimizer=[opt_init, opt_update, get_params], 73 | criterion=criterion) 74 | 75 | self.register_data( 76 | train_loader=train_loader, validation_loader=test_loader) 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument( 82 | "--address", 83 | required=False, 84 | type=str, 85 | help="the address to use for connecting to the Ray cluster") 86 | parser.add_argument( 87 | "--num-workers", 88 | "-n", 89 | type=int, 90 | default=2, 91 | help="Sets number of workers for training.") 92 | parser.add_argument( 93 | "--num-epochs", 94 | type=int, 95 | default=20, 96 | help="Number of epochs to train.") 97 | parser.add_argument( 98 | "--fp16", 99 | action="store_true", 100 | default=False, 101 | help="Enables FP16 training with apex. Requires `use-gpu`.") 102 | parser.add_argument( 103 | "--model-name", 104 | type=str, 105 | default="resnet18", 106 | help="model, Optional: resnet18, resnet50, resnet101.") 107 | 108 | args, _ = parser.parse_known_args() 109 | 110 | if args.address: 111 | ray.init(args.address) 112 | else: 113 | ray.init( 114 | num_gpus=args.num_workers, 115 | num_cpus=args.num_workers * 2, 116 | log_to_driver=True) 117 | 118 | strategy = AllReduceStrategy( 119 | training_operator_cls=MnistTrainingOperator, 120 | world_size=args.num_workers, 121 | operator_config={ 122 | "lr": 0.01, 123 | "batch_size": 128, 124 | "num_workers": args.num_workers, 125 | "num_classes": 10, 126 | "model_name": args.model_name 127 | }, 128 | initialization_hook=initialization_hook) 129 | 130 | for i in range(args.num_epochs): 131 | strategy.train() 132 | print(strategy.validate()) 133 | strategy.shutdown() 134 | print("success!") 135 | -------------------------------------------------------------------------------- /examples/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/distml/b2d4766664166a0163956c00d8472a03274d4d51/examples/torch/__init__.py -------------------------------------------------------------------------------- /examples/torch/cifar_pytorch_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from filelock import FileLock 4 | 5 | import ray 6 | import torch 7 | import torch.nn as nn 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import DataLoader, Subset 10 | from torchvision.datasets import CIFAR10 11 | 12 | from resnet import ResNet18 13 | from distml.util import override 14 | from distml.strategy.allreduce_strategy import AllReduceStrategy 15 | from distml.operator.torch_operator import TorchTrainingOperator 16 | 17 | 18 | def initialization_hook(): 19 | # Need this for avoiding a connection restart issue on AWS. 20 | os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" 21 | os.environ["NCCL_LL_THRESHOLD"] = "0" 22 | 23 | # set the below if needed 24 | # print("NCCL DEBUG SET") 25 | # os.environ["NCCL_DEBUG"] = "INFO" 26 | 27 | 28 | class CifarTrainingOperator(TorchTrainingOperator): 29 | @override(TorchTrainingOperator) 30 | def setup(self, config): 31 | # Create model. 32 | model = ResNet18(config) 33 | 34 | # Create optimizer. 35 | optimizer = torch.optim.SGD( 36 | model.parameters(), 37 | lr=config.get("lr", 0.1), 38 | momentum=config.get("momentum", 0.9)) 39 | 40 | # Load in training and validation data. 41 | transform_train = transforms.Compose([ 42 | transforms.RandomCrop(32, padding=4), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.4914, 0.4822, 0.4465), 46 | (0.2023, 0.1994, 0.2010)), 47 | ]) # meanstd transformation 48 | 49 | transform_test = transforms.Compose([ 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.4914, 0.4822, 0.4465), 52 | (0.2023, 0.1994, 0.2010)), 53 | ]) 54 | with FileLock(".ray.lock"): 55 | train_dataset = CIFAR10( 56 | root="~/data", 57 | train=True, 58 | download=True, 59 | transform=transform_train) 60 | validation_dataset = CIFAR10( 61 | root="~/data", 62 | train=False, 63 | download=False, 64 | transform=transform_test) 65 | 66 | if config["test_mode"]: 67 | train_dataset = Subset(train_dataset, list(range(64))) 68 | validation_dataset = Subset(validation_dataset, list(range(64))) 69 | 70 | train_loader = DataLoader( 71 | train_dataset, batch_size=config["batch_size"], num_workers=2) 72 | validation_loader = DataLoader( 73 | validation_dataset, batch_size=config["batch_size"], num_workers=2) 74 | 75 | # # Create scheduler. 76 | # scheduler = torch.optim.lr_scheduler.MultiStepLR( 77 | # optimizer, milestones=[150, 250, 350], gamma=0.1) 78 | 79 | # Create loss. 80 | criterion = nn.CrossEntropyLoss() 81 | print(criterion) 82 | # Register all components. 83 | # # self.model, self.optimizer, self.criterion, self.scheduler = \ 84 | # # self.register(models=model, optimizers=optimizer, 85 | # criterion=criterion, schedulers=scheduler) 86 | self.model, self.optimizer, self.criterion = self.register( 87 | model=model, optimizer=optimizer, criterion=criterion) 88 | self.register_data( 89 | train_loader=train_loader, validation_loader=validation_loader) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument( 95 | "--address", 96 | required=False, 97 | type=str, 98 | help="the address to use for connecting to the Ray cluster") 99 | parser.add_argument( 100 | "--num-workers", 101 | "-n", 102 | type=int, 103 | default=2, 104 | help="Sets number of workers for training.") 105 | parser.add_argument( 106 | "--num-epochs", type=int, default=1, help="Number of epochs to train.") 107 | parser.add_argument( 108 | "--use-gpu", 109 | action="store_true", 110 | default=False, 111 | help="Enables GPU training") 112 | parser.add_argument( 113 | "--fp16", 114 | action="store_true", 115 | default=False, 116 | help="Enables FP16 training with apex. Requires `use-gpu`.") 117 | parser.add_argument( 118 | "--smoke-test", 119 | action="store_true", 120 | default=False, 121 | help="Finish quickly for testing.") 122 | parser.add_argument( 123 | "--tune", action="store_true", default=False, help="Tune training") 124 | 125 | args, _ = parser.parse_known_args() 126 | num_cpus = 4 if args.smoke_test else None 127 | ray.init(address=args.address, num_cpus=num_cpus, log_to_driver=True) 128 | 129 | strategy = AllReduceStrategy( 130 | training_operator_cls=CifarTrainingOperator, 131 | initialization_hook=initialization_hook, 132 | world_size=args.num_workers, 133 | operator_config={ 134 | "lr": 0.1, 135 | "test_mode": args.smoke_test, # subset the data 136 | # this will be split across workers. 137 | "batch_size": 128 * args.num_workers 138 | }) 139 | # pbar = trange(args.num_epochs, unit="epoch") 140 | # for i in pbar: 141 | # info = {"num_steps": 1} if args.smoke_test else {} 142 | # info["epoch_idx"] = i 143 | # info["num_epochs"] = args.num_epochs 144 | # # Increase `max_retries` to turn on fault tolerance. 145 | # strategy.train(max_retries=1, info=info) 146 | # # val_stats = trainer1.validate() 147 | # # pbar.set_postfix(dict(acc=val_stats["val_accuracy"])) 148 | 149 | for i in range(args.num_epochs): 150 | strategy.train() 151 | print(strategy.validate()) 152 | strategy.shutdown() 153 | print("success!") 154 | -------------------------------------------------------------------------------- /examples/torch/resnet.py: -------------------------------------------------------------------------------- 1 | """ResNet torch implementation. 2 | 3 | Copied from https://github.com/ray-project/ray/ 4 | blob/master/python/ray/util/sgd/torch/resnet.py 5 | """ 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, in_planes, planes, stride=1): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = nn.Conv2d( 16 | in_planes, 17 | planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=1, 21 | bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d( 24 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion * planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d( 31 | in_planes, 32 | self.expansion * planes, 33 | kernel_size=1, 34 | stride=stride, 35 | bias=False), nn.BatchNorm2d(self.expansion * planes)) 36 | 37 | def forward(self, x): 38 | out = F.relu(self.bn1(self.conv1(x))) 39 | out = self.bn2(self.conv2(out)) 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, in_planes, planes, stride=1): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d( 53 | planes, 54 | planes, 55 | kernel_size=3, 56 | stride=stride, 57 | padding=1, 58 | bias=False) 59 | self.bn2 = nn.BatchNorm2d(planes) 60 | self.conv3 = nn.Conv2d( 61 | planes, self.expansion * planes, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 63 | 64 | self.shortcut = nn.Sequential() 65 | if stride != 1 or in_planes != self.expansion * planes: 66 | self.shortcut = nn.Sequential( 67 | nn.Conv2d( 68 | in_planes, 69 | self.expansion * planes, 70 | kernel_size=1, 71 | stride=stride, 72 | bias=False), nn.BatchNorm2d(self.expansion * planes)) 73 | 74 | def forward(self, x): 75 | out = F.relu(self.bn1(self.conv1(x))) 76 | out = F.relu(self.bn2(self.conv2(out))) 77 | out = self.bn3(self.conv3(out)) 78 | out += self.shortcut(x) 79 | out = F.relu(out) 80 | return out 81 | 82 | 83 | class ResNet(nn.Module): 84 | def __init__(self, block, num_blocks, num_classes=10): 85 | super(ResNet, self).__init__() 86 | self.in_planes = 64 87 | 88 | self.conv1 = nn.Conv2d( 89 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 92 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 93 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 94 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 95 | self.linear = nn.Linear(512 * block.expansion, num_classes) 96 | 97 | def _make_layer(self, block, planes, num_blocks, stride): 98 | strides = [stride] + [1] * (num_blocks - 1) 99 | layers = [] 100 | for stride in strides: 101 | layers.append(block(self.in_planes, planes, stride)) 102 | self.in_planes = planes * block.expansion 103 | return nn.Sequential(*layers) 104 | 105 | def forward(self, x): 106 | out = F.relu(self.bn1(self.conv1(x))) 107 | out = self.layer1(out) 108 | out = self.layer2(out) 109 | out = self.layer3(out) 110 | out = self.layer4(out) 111 | out = F.avg_pool2d(out, 4) 112 | out = out.view(out.size(0), -1) 113 | out = self.linear(out) 114 | return out 115 | 116 | 117 | def ResNet18(_): 118 | return ResNet(BasicBlock, [2, 2, 2, 2]) 119 | 120 | 121 | def ResNet34(_): 122 | return ResNet(BasicBlock, [3, 4, 6, 3]) 123 | 124 | 125 | def ResNet50(_): 126 | return ResNet(Bottleneck, [3, 4, 6, 3]) 127 | 128 | 129 | def ResNet101(_): 130 | return ResNet(Bottleneck, [3, 4, 23, 3]) 131 | 132 | 133 | def ResNet152(_): 134 | return ResNet(Bottleneck, [3, 8, 36, 3]) 135 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Hao: modified from https://github.com/ray-project/xgboost_ray/blob/master/format.sh 3 | 4 | # YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. 5 | # You are encouraged to run this locally before pushing changes for review. 6 | 7 | # Cause the script to exit if a single command fails 8 | set -eo pipefail 9 | 10 | FLAKE8_VERSION_REQUIRED="3.7.7" 11 | YAPF_VERSION_REQUIRED="0.23.0" 12 | 13 | check_command_exist() { 14 | VERSION="" 15 | case "$1" in 16 | yapf) 17 | VERSION=$YAPF_VERSION_REQUIRED 18 | ;; 19 | flake8) 20 | VERSION=$FLAKE8_VERSION_REQUIRED 21 | ;; 22 | *) 23 | echo "$1 is not a required dependency" 24 | exit 1 25 | esac 26 | if ! [ -x "$(command -v $1)" ]; then 27 | echo "$1 not installed. pip install $1==$VERSION" 28 | exit 1 29 | fi 30 | } 31 | 32 | check_command_exist yapf 33 | check_command_exist flake8 34 | 35 | ver=$(yapf --version) 36 | if ! echo $ver | grep -q 0.23.0; then 37 | echo "Wrong YAPF version installed: 0.23.0 is required, not $ver. $YAPF_DOWNLOAD_COMMAND_MSG" 38 | exit 1 39 | fi 40 | 41 | # this stops git rev-parse from failing if we run this from the .git directory 42 | builtin cd "$(dirname "${BASH_SOURCE:-$0}")" 43 | 44 | ROOT="$(git rev-parse --show-toplevel)" 45 | builtin cd "$ROOT" || exit 1 46 | 47 | # Add the upstream remote if it doesn't exist 48 | if ! git remote -v | grep -q upstream; then 49 | git remote add 'upstream' 'https://github.com/ray-project/distml.git' 50 | fi 51 | 52 | FLAKE8_VERSION=$(flake8 --version | awk '{print $1}') 53 | YAPF_VERSION=$(yapf --version | awk '{print $2}') 54 | 55 | # params: tool name, tool version, required version 56 | tool_version_check() { 57 | if [[ $2 != $3 ]]; then 58 | echo "WARNING: DistML uses $1 $3, You currently are using $2. This might generate different results." 59 | fi 60 | } 61 | 62 | tool_version_check "flake8" $FLAKE8_VERSION $FLAKE8_VERSION_REQUIRED 63 | tool_version_check "yapf" $YAPF_VERSION $YAPF_VERSION_REQUIRED 64 | 65 | if which clang-format >/dev/null; then 66 | CLANG_FORMAT_VERSION=$(clang-format --version | awk '{print $3}') 67 | tool_version_check "clang-format" $CLANG_FORMAT_VERSION "7.0.0" 68 | else 69 | echo "WARNING: clang-format is not installed!" 70 | fi 71 | 72 | # Only fetch master since that's the branch we're diffing against. 73 | git fetch upstream master || true 74 | 75 | YAPF_FLAGS=( 76 | '--style' "$ROOT/.style.yapf" 77 | '--recursive' 78 | '--parallel' 79 | ) 80 | 81 | YAPF_EXCLUDES=( 82 | # '--exclude' 'python/ray/cloudpickle/*' 83 | # '--exclude' 'python/build/*' 84 | # '--exclude' 'python/ray/core/src/ray/gcs/*' 85 | # '--exclude' 'python/ray/thirdparty_files/*' 86 | ) 87 | 88 | # Format specified files 89 | format() { 90 | yapf --in-place "${YAPF_FLAGS[@]}" -- "$@" 91 | } 92 | 93 | # Format files that differ from main branch. Ignores dirs that are not slated 94 | # for autoformat yet. 95 | format_changed() { 96 | # The `if` guard ensures that the list of filenames is not empty, which 97 | # could cause yapf to receive 0 positional arguments, making it hang 98 | # waiting for STDIN. 99 | # 100 | # `diff-filter=ACRM` and $MERGEBASE is to ensure we only format files that 101 | # exist on both branches. 102 | MERGEBASE="$(git merge-base upstream/master HEAD)" 103 | 104 | if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then 105 | git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ 106 | yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" 107 | if which flake8 >/dev/null; then 108 | git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ 109 | flake8 --inline-quotes '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 110 | fi 111 | fi 112 | 113 | if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then 114 | if which flake8 >/dev/null; then 115 | git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \ 116 | flake8 --inline-quotes '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 117 | fi 118 | fi 119 | } 120 | 121 | # Format all files, and print the diff to stdout for travis. 122 | format_all() { 123 | yapf --diff "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" distml 124 | flake8 --inline-quotes '"' --no-avoid-escape --ignore=N,I,C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 distml 125 | } 126 | 127 | # This flag formats individual files. --files *must* be the first command line 128 | # arg to use this option. 129 | if [[ "$1" == '--files' ]]; then 130 | format "${@:2}" 131 | # If `--all` is passed, then any further arguments are ignored and the 132 | # entire python directory is formatted. 133 | elif [[ "$1" == '--all' ]]; then 134 | format_all 135 | else 136 | # Format only the files that changed in last commit. 137 | format_changed 138 | fi 139 | 140 | if ! git diff --quiet &>/dev/null; then 141 | echo 'Reformatted changed files. Please review and stage the changes.' 142 | echo 'Files updated:' 143 | echo 144 | 145 | git --no-pager diff --name-only 146 | 147 | exit 1 148 | fi 149 | 150 | echo 'Linting check finished successfully.' 151 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | 4 | setup( 5 | name="distml", 6 | version=os.environ.get("VERSION"), 7 | author="The DistML Authors", 8 | author_email="", 9 | description="DistML is a runtime libraray for distributed " 10 | "deep learning training.", 11 | long_description="DistML is a Ray extension library to support " 12 | "large-scale distributed ML training on heterogeneous " 13 | "multi-node multi-GPU clusters.", 14 | url="https://github.com/ray-project/distml", 15 | classifiers=[ 16 | "Programming Language :: Python :: 3", 17 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 18 | ], 19 | install_requires=[ 20 | "ray", 21 | "pandas", 22 | "tabulate", 23 | ], 24 | extras_require={ 25 | "dev": [ 26 | "pytest", 27 | "pytest-cov", 28 | "pydocstyle", 29 | "prospector", 30 | ] 31 | }, 32 | packages=find_packages(), 33 | python_requires=">=3.6", 34 | ) 35 | --------------------------------------------------------------------------------